Created
May 9, 2024 17:36
-
-
Save gautierdag/f8c9b6a2157a06f3e95c0b25c8afa131 to your computer and use it in GitHub Desktop.
Example class for token healing - use get_start_decoding to find idx to start constrained decoding over matched tokens
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Any | |
import abc | |
from transformers import PreTrainedTokenizerFast | |
class BaseTokenizer(abc.ABC): | |
def __init__(self) -> None: | |
super().__init__() | |
self.bos_id = 1 | |
self.eos_id = 2 | |
self.pad_id = 3 | |
self.n_words = 3 | |
@abc.abstractmethod | |
def _encode(self, s: str) -> List[int]: | |
pass | |
def encode(self, s: str, bos: bool, eos: bool) -> List[int]: | |
assert isinstance(s, str) | |
t = self._encode(s) | |
if bos: | |
t.insert(0, self.bos_id) | |
if eos: | |
t.append(self.eos_id) | |
return t | |
def decode(self, tokens: List[int], cut_at_eos: bool = True) -> str: | |
if cut_at_eos: | |
for k, t in enumerate(tokens): | |
if t == self.eos_id: | |
tokens = tokens[: k + 1] | |
break | |
return self._decode(tokens) | |
@abc.abstractmethod | |
def _decode(self, tokens: List[int]) -> str: | |
pass | |
_token_prefix_map: Any = None | |
max_token_length: int = 16 | |
def _build_token_prefix_map(self): | |
""" | |
Build a map from token to index using a Trie datastructure | |
Taken from Microsoft's guidance library: | |
https://github.com/guidance-ai/guidance/blob/23d0ba12720d09bb87b520d6c84462857f5dfcec/guidance/llms/_transformers.py#L74 | |
""" | |
import pygtrie | |
token_map = pygtrie.CharTrie() | |
for i in range(self.n_words): | |
try: | |
s = self._decode([i]) | |
self.max_token_length = max(self.max_token_length, len(s)) | |
except: | |
print(f"token id {i} not found in tokenizer") | |
continue | |
if s in token_map: | |
token_map[s].append(i) # handle duplicate token encodings | |
else: | |
token_map[s] = [i] | |
return token_map | |
def prefix_matches(self, prefix: str) -> list[int]: | |
""" | |
Return the list of tokens ids that match the given prefix string. | |
Raises KeyError if the prefix is not found. | |
""" | |
if self._token_prefix_map is None: | |
self._token_prefix_map = self._build_token_prefix_map() | |
return [v for arr in self._token_prefix_map.values(prefix=prefix) for v in arr] | |
def get_start_decoding(self, prompt_tokens: list[int]) -> tuple[int, list[int]]: | |
""" | |
Given encoded tokens, return the index of the start of token healing | |
and the list of tokens that match the possible healing tokens. | |
This builds the possible healing tokens by taking the longest subsequence | |
that has matches, growing iteratively from the end of the prompt | |
up to the max token length. | |
""" | |
matches, subseq = ([], "") | |
i, out_index = len(prompt_tokens) - 1, len(prompt_tokens) - 1 | |
while len(subseq) < self.max_token_length and i >= 0: | |
subseq = self.decode(prompt_tokens[i:]) | |
try: | |
matches = self.prefix_matches(prefix=subseq) | |
out_index = i | |
except KeyError: | |
pass | |
i -= 1 | |
return out_index, matches | |
class HuggingFaceTokenizer(BaseTokenizer): | |
def __init__(self, model_path: str): | |
self.tokenizer_model = PreTrainedTokenizerFast( | |
tokenizer_file=model_path, | |
clean_up_tokenization_spaces=False, | |
) | |
self.n_words = len(self.tokenizer_model) | |
self.bos_id = self.tokenizer_model.bos_token_id | |
self.eos_id = self.tokenizer_model.eos_token_id | |
self.pad_id: int = -1 | |
def _encode(self, s: str) -> List[int]: | |
return self.tokenizer_model.encode(s) | |
def _decode(self, tokens: List[int]) -> str: | |
return self.tokenizer_model.decode(tokens, skip_special_tokens=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment