Source code for easygraph.utils.alias
__all__ = ["create_alias_table", "alias_sample", "alias_setup", "alias_draw"]
[docs]def create_alias_table(area_ratio):
"""
Parameters
---------
area_ratio :
sum(area_ratio)=1
Returns
----------
1. accept
2. alias
"""
import numpy as np
l = len(area_ratio)
accept, alias = [0] * l, [0] * l
small, large = [], []
area_ratio_ = np.array(area_ratio) * l
for i, prob in enumerate(area_ratio_):
if prob < 1.0:
small.append(i)
else:
large.append(i)
while small and large:
small_idx, large_idx = small.pop(), large.pop()
accept[small_idx] = area_ratio_[small_idx]
alias[small_idx] = large_idx
area_ratio_[large_idx] = area_ratio_[large_idx] - (1 - area_ratio_[small_idx])
if area_ratio_[large_idx] < 1.0:
small.append(large_idx)
else:
large.append(large_idx)
while large:
large_idx = large.pop()
accept[large_idx] = 1
while small:
small_idx = small.pop()
accept[small_idx] = 1
return accept, alias
[docs]def alias_sample(accept, alias):
"""
Parameters
----------
accept :
alias :
Returns
----------
sample index
"""
import numpy as np
N = len(accept)
i = int(np.random.random() * N)
r = np.random.random()
if r < accept[i]:
return i
else:
return alias[i]
[docs]def alias_draw(J, q):
import numpy as np
"""
Draw sample from a non-uniform discrete distribution using alias sampling.
"""
K = len(J)
kk = int(np.floor(np.random.rand() * K))
if np.random.rand() < q[kk]:
return kk
else:
return J[kk]
[docs]def alias_setup(probs):
import numpy as np
"""
Compute utility lists for non-uniform sampling from discrete distributions.
Refer to https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
for details
"""
K = len(probs)
q = np.zeros(K)
J = np.zeros(K, dtype=int)
smaller = []
larger = []
for kk, prob in enumerate(probs):
q[kk] = K * prob
if q[kk] < 1.0:
smaller.append(kk)
else:
larger.append(kk)
while len(smaller) > 0 and len(larger) > 0:
small = smaller.pop()
large = larger.pop()
J[small] = large
q[large] = q[large] + q[small] - 1.0
if q[large] < 1.0:
smaller.append(large)
else:
larger.append(large)
return J, q