-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Goal: Implement the capacity of computing
Description: It is possible to compute such probability without increasing the method's complexity exponentially. However, we do need to do extra work at each iteration for each sequence, in particular, instead of 1 model call per sequence, we need to execute |b| model calls (where |b| is the size of the prefix at the beginning of the conditioning).
Here is a candidate pseudo-algorithm. We will first assume a batch size with a single example:
Input: prefix P, History H, max_seq_len K, model M
1. If |P| > K, then return 1.
2. Sample next token:
2.1. Augment input w/ |P|-1 sequences: [H, P_0], [H, P_0, P_1], ...
2.2. Feed into model M and obtain conditional probabilities.
2.3. Compute P(P|H) = P(P_0 | H) P(P_1 | H, P_0) P(P_2 | H, P_0, P_1) ...
2.4. Before the sample, compute the normalized probability of P'(P_0|H) = P(P_0|H) * (1 - P(P_1 | H, P_0)) * (1-P(P_2 | H, P_0, P_1) * ...
3. Sample the next token based on these probabilities.
4. If the sampled token is not along the prefix, repeat step 1.
5. If the sampled token is along the prefix:
5.1. Do not need to make additional model calls. Can use previous model calls (indexed based on conditioning).
5.2. Repeat logic w/ renormalization from step 2.3. but using Prefix P_{i:}.We will use the following substructures to aid in the implementation of the algorithm:
cache: dictionary <tokens, logits_dist>prefix_index: a vector of integers stating which position of the prefix, we're currently on. For example, at iterationtthe value ofprefix_index[i]will have the index on the prefix array that the path is currently exploring. This structure allows us to end generations sooner.
TODO
- Elaborate on what exactly will be the final probabilities for each sample.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request