Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import re | |
| from bytelatent.tokenizers.abstract_tokenizer import Tokenizer | |
| from bytelatent.tokenizers.constants import ( | |
| BOE_ID, | |
| BOS_ID, | |
| BPE_ID, | |
| BYTE_UNITS, | |
| EOS_ID, | |
| OFFSET, | |
| PAD_ID, | |
| ) | |
| from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer | |
| def convert_to_bytes(s): | |
| # check if the output is a bytes like object of the format <0x00> | |
| if re.match(r"<0x[0-9a-fA-F]+>", s): | |
| return bytes.fromhex(s[3:-1]) | |
| else: | |
| return bytes(s, "utf-8", errors="ignore") | |
| def text2bytes_bpe_delims( | |
| text: str, | |
| *, | |
| bpe_tokenizer, | |
| bpe_id: int, | |
| offsetting_special_char: int, | |
| add_bos: bool, | |
| add_eos: bool, | |
| ): | |
| cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos) | |
| # merge the leading space tokens | |
| leading_space_tokens = [] | |
| other_bpe_tokens = [] | |
| leading = True | |
| for token in cur_bpe: | |
| bpe_str = bpe_tokenizer.sp_model.id_to_piece(token) | |
| if leading and all(c == "▁" for c in bpe_str): | |
| leading_space_tokens.append(bpe_str) | |
| else: | |
| leading = False | |
| other_bpe_tokens.append(bpe_str) | |
| cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens | |
| # Remove the '▁' characters | |
| bpe_strs = [] | |
| for i, bpe_str in enumerate(cur_bpe_strs): | |
| if ( | |
| len(bpe_strs) <= 1 | |
| and all([c == " " for s in bpe_strs for c in s]) | |
| and not all(c == "▁" for c in bpe_str) | |
| ): | |
| # Remove leading space for first non space token. | |
| bpe_str = bpe_str.replace("▁", "") | |
| elif i == 0 and all(c == "▁" for c in bpe_str): | |
| bpe_str = " " * (len(text) - len(text.lstrip(" "))) | |
| else: | |
| bpe_str = bpe_str.replace("▁", " ") | |
| if len(bpe_str) > 0: | |
| bpe_strs.append(bpe_str) | |
| ex_seq = [] | |
| # Convert bpe tokens to bytes | |
| for s in bpe_strs: | |
| byte_chunk = convert_to_bytes(s) | |
| proc_chunk = [int(unit) for unit in byte_chunk] | |
| ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk) | |
| return ex_seq | |
| class BltTokenizer(Tokenizer): | |
| def __init__( | |
| self, | |
| *, | |
| vocab_size_unit_1: int = BYTE_UNITS, | |
| bpe_delim: bool = False, | |
| bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", | |
| add_bos: bool = True, | |
| add_eos: bool = True, | |
| ): | |
| self.add_bos = add_bos | |
| self.add_eos = add_eos | |
| self.vocab_size_unit_1 = vocab_size_unit_1 | |
| self.boe_id = BOE_ID | |
| self.bos_id = BOS_ID | |
| self.eos_id = EOS_ID | |
| self.pad_id = PAD_ID | |
| self.bpe_id = BPE_ID | |
| self.bpe_tokenizer_path = bpe_tokenizer_path | |
| if bpe_delim: | |
| self.bpe_tokenizer = SentencePieceTokenizer( | |
| model_path=self.bpe_tokenizer_path | |
| ) | |
| else: | |
| self.bpe_tokenizer = None | |
| self.bpe_delim = bpe_delim | |
| self.offsetting_special_char = OFFSET | |
| self.vocab_size_unit_1 = vocab_size_unit_1 | |
| self.n_words = vocab_size_unit_1 + self.offsetting_special_char | |
| def encode( | |
| self, text: str, add_bos: bool | None = None, add_eos: bool | None = None | |
| ): | |
| if add_bos is None: | |
| add_bos = self.add_bos | |
| if add_eos is None: | |
| add_eos = self.add_eos | |
| if self.bpe_delim: | |
| tokens = text2bytes_bpe_delims( | |
| text, | |
| bpe_tokenizer=self.bpe_tokenizer, | |
| bpe_id=self.bpe_id, | |
| offsetting_special_char=self.offsetting_special_char, | |
| add_bos=False, | |
| add_eos=False, | |
| ) | |
| else: | |
| tokens = bytes(text, encoding="utf-8", errors="ignore") | |
| # Offsetting | |
| tokens = [int(unit) + self.offsetting_special_char for unit in tokens] | |
| if add_bos: | |
| tokens.insert(0, self.bos_id) | |
| if add_eos: | |
| tokens.append(self.eos_id) | |
| return tokens | |
| def decode(self, tokens: list[int], cut_at_eos: bool = False): | |
| if cut_at_eos: | |
| for k, t in enumerate(tokens): | |
| if t == self.eos_id: | |
| tokens = tokens[: k + 1] | |
| break | |
| return bytes( | |
| [ | |
| tok - self.offsetting_special_char | |
| for tok in tokens | |
| if tok - self.offsetting_special_char >= 0 | |
| ] | |
| ).decode("utf-8", errors="ignore") | |
| def get_token_offsets(self, text: str, tokens: list[int] | None = None): | |
| # TODO: Figure out what this does | |
| raise NotImplementedError() | |