Source code for beta_rec.utils.alias_table

import numpy as np

[docs]class AliasTable: """AliasTable Class. A list of indices of tokens in the vocab following a power law distribution, used to draw negative samples. """ def __init__(self, obj_freq): """Initialize an AliasTable. Args: obj_freq: A list of indices of tokens in the vocab following a power law distribution, used to draw negative samples. Returns: None. Raises: ValueError: obj_freq is invalid type. """ vocab_size = len(obj_freq) self.vocab_size = vocab_size if type(obj_freq) == list: # obj_freq can be a list if len(np.array(obj_freq).shape) != 1: raise ValueError("Error: obj_freq is not 1-dim") total = np.sum(obj_freq) elif type(obj_freq) == dict: # obj_freq can be a dict total = np.sum(list(obj_freq.values())) else: raise ValueError("Error: obj_freq is invalid") probs = [] index2Label = [] # used to transform an index to label in dict if type(obj_freq) == list: for i in range(len(obj_freq)): probs.append(obj_freq[i] / total) index2Label.append(i) elif type(obj_freq) == dict: i = 0 for obj, freq in obj_freq.items(): i += 1 probs.append(freq / total) index2Label.append(obj) table_size = vocab_size prob_arr = np.zeros(table_size) # Probability Array alias_arr = np.zeros(table_size, # Alias Array print("Filling alias table") # Sort the data into the outcomes with probabilities # that are larger and smaller than 1/K. smaller = [] # save columns that are smaller than 1 larger = [] # save columns that are larger than 1 for index, prob in enumerate(probs): prob_arr[index] = table_size * prob # probability * vocab_size if prob_arr[index] < 1.0: smaller.append(index) else: larger.append(index) # Loop though and create little binary mixtures that # appropriately allocate the larger outcomes over the # overall uniform mixture. while len(smaller) > 0 and len(larger) > 0: small = smaller.pop() large = larger.pop() alias_arr[small] = large # Fill Alias with the large prob_arr[large] = prob_arr[large] - (1.0 - prob_arr[small]) if prob_arr[large] < 1.0: smaller.append(large) else: larger.append(large) self.prob_arr = prob_arr self.alias_arr = alias_arr self.index2Label = index2Label
[docs] def sample(self, count, obj_num=1, no_repeat=False): """Generate samples. Args: count: the number of tokens in a draw. obj_num: the number of draws. no_repeat: whether repeat tokens are allowed in a single draw. Returns: A list of tokens. Raises: ValueError: count is larger than vocab_size when no_repeat is True. """ nd_samples = [] for i in range(obj_num): indices = np.random.randint(low=0, high=len(self.prob_arr), size=count) samples = [] for i in indices: if np.random.uniform() < self.prob_arr[i]: samples.append(self.index2Label[i]) else: samples.append(self.index2Label[self.alias_arr[i]]) if no_repeat: if count > self.vocab_size: raise ValueError( "Error: count>vocab_size!! Skip no_repeat parameter" ) samples = set(samples) while len(samples) < count: index = np.random.randint(low=0, high=len(self.prob_arr)) if np.random.uniform() < self.prob_arr[index]: samples = samples | {self.index2Label[index]} else: samples = samples | {self.index2Label[self.alias_arr[index]]} samples = list(samples) if obj_num == 1: return samples nd_samples.append(samples) return nd_samples