Spaces:
Runtime error
Runtime error
| from unittest import TestCase | |
| from hypothesis import given | |
| from hypothesis import strategies as st | |
| from transformers import AutoTokenizer | |
| from trlx.pipeline.offline_pipeline import DialogMessage, tokenize_dialogue | |
| class TestTokenizeDialog(TestCase): | |
| def setUp(self): | |
| self.tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| def test_tokenize_dialogue_truncation_basic(self): | |
| dialogue = ["this will be truncated", "."] | |
| self.tokenizer.truncation_side = "left" | |
| dialog = tokenize_dialogue(dialogue, self.tokenizer, max_length=2) | |
| assert len(dialog) == 2 | |
| user_dm, bot_dm = dialog | |
| assert len(user_dm.tokens) == 1 | |
| assert len(bot_dm.tokens) == 1 | |
| assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,)) | |
| assert bot_dm == DialogMessage(is_output=True, tokens=(self.tokenizer.eos_token_id,)) | |
| def test_tokenize_dialogue_single_turn(self, response_words): | |
| response = " ".join(response_words) # space seperate to make it multiple tokens | |
| tokenized_response = tuple(self.tokenizer(response, add_special_tokens=False).input_ids) | |
| tokenized_response = tokenized_response + (self.tokenizer.eos_token_id,) | |
| dialog = tokenize_dialogue(response, self.tokenizer) | |
| assert len(dialog) == 2 | |
| user_dm, bot_dm = dialog | |
| assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,)) | |
| assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response) | |
| def test_tokenize_dialogue_single_turn_truncation_right(self, response_words, max_length): | |
| response = " ".join(response_words) # space seperate to make it multiple tokens | |
| self.tokenizer.truncation_side = "right" | |
| tokenized_response = tuple(self.tokenizer(response, add_special_tokens=False).input_ids) | |
| tokenized_response = tokenized_response + (self.tokenizer.eos_token_id,) | |
| dialog = tokenize_dialogue(response, self.tokenizer, max_length=max_length) | |
| assert len(dialog) == 2 | |
| user_dm, bot_dm = dialog | |
| assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,)) | |
| assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response[: max_length - 1]) | |
| all_tokens = sum((dm.tokens for dm in dialog), ()) | |
| assert len(all_tokens) <= max_length | |
| def test_tokenize_dialogue_single_turn_truncation_left(self, response_words, max_length): | |
| response = " ".join(response_words) # space seperate to make it multiple tokens | |
| self.tokenizer.truncation_side = "left" | |
| tokenized_response = tuple(self.tokenizer(response, add_special_tokens=False).input_ids) | |
| tokenized_response += (self.tokenizer.eos_token_id,) | |
| dialog = tokenize_dialogue(response, self.tokenizer, max_length=max_length) | |
| # whether or not truncation has happened, user BOS prompt should be present | |
| assert len(dialog) == 2 | |
| user_dm, bot_dm = dialog | |
| assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,)) | |
| if len(tokenized_response) < max_length: | |
| assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response) | |
| else: | |
| assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response[-max_length + 1 :]) | |
| all_tokens = sum((dm.tokens for dm in dialog), ()) | |
| assert len(all_tokens) <= max_length | |
| def test_tokenize_dialogue_multi_turn(self, user_response_pairs): | |
| convo = [[" ".join(user_words), " ".join(response_words)] for user_words, response_words in user_response_pairs] | |
| flat_convo = sum(convo, []) | |
| tokenized_flat_convo = tuple( | |
| tuple(self.tokenizer(turn, add_special_tokens=False).input_ids) for turn in flat_convo | |
| ) | |
| tokenized_flat_convo = (*tokenized_flat_convo[:-1], (*tokenized_flat_convo[-1], self.tokenizer.eos_token_id)) | |
| dialog = tokenize_dialogue(flat_convo, self.tokenizer) | |
| dm_convo = [DialogMessage(is_output=i % 2 == 1, tokens=tokens) for i, tokens in enumerate(tokenized_flat_convo)] | |
| nonempty_dm_convo = [dm for dm in dm_convo if dm.tokens] | |
| if nonempty_dm_convo[0].is_output: | |
| nonempty_dm_convo.insert(0, DialogMessage(is_output=False, tokens=(self.tokenizer.eos_token_id,))) | |
| assert dialog == nonempty_dm_convo | |
| def test_tokenize_dialogue_multi_turn_truncation_right(self, user_response_pairs, max_length): | |
| convo = [[" ".join(user_words), " ".join(response_words)] for user_words, response_words in user_response_pairs] | |
| flat_convo = sum(convo, []) | |
| self.tokenizer.truncation_side = "right" | |
| tokenized_flat_convo = tuple( | |
| tuple(self.tokenizer(turn, add_special_tokens=False).input_ids) for turn in flat_convo | |
| ) | |
| tokenized_flat_convo = (*tokenized_flat_convo[:-1], (*tokenized_flat_convo[-1], self.tokenizer.eos_token_id)) | |
| dialog = tokenize_dialogue(flat_convo, self.tokenizer, max_length=max_length) | |
| all_tokens = sum((dm.tokens for dm in dialog), ()) | |
| should_be_tokens = sum(tokenized_flat_convo, ())[:max_length] | |
| if dialog[0] == DialogMessage(is_output=False, tokens=(self.tokenizer.eos_token_id,)): | |
| should_be_tokens = (self.tokenizer.eos_token_id, *should_be_tokens[: max_length - 1]) | |
| assert all_tokens == should_be_tokens | |
| assert len(all_tokens) <= max_length | |
| def test_tokenize_dialogue_multi_turn_truncation_left(self, user_response_pairs, max_length): | |
| convo = [[" ".join(user_words), " ".join(response_words)] for user_words, response_words in user_response_pairs] | |
| flat_convo = sum(convo, []) | |
| self.tokenizer.truncation_side = "left" | |
| tokenized_flat_convo = tuple( | |
| tuple(self.tokenizer(turn, add_special_tokens=False).input_ids) for turn in flat_convo | |
| ) | |
| tokenized_flat_convo = (*tokenized_flat_convo[:-1], (*tokenized_flat_convo[-1], self.tokenizer.eos_token_id)) | |
| dialog = tokenize_dialogue(flat_convo, self.tokenizer, max_length=max_length) | |
| all_tokens = sum((dm.tokens for dm in dialog), ()) | |
| should_be_tokens = sum(tokenized_flat_convo, ())[-max_length:] | |
| if dialog[0] == DialogMessage(is_output=False, tokens=(self.tokenizer.eos_token_id,)): | |
| should_be_tokens = (self.tokenizer.eos_token_id, *should_be_tokens[-max_length + 1 :]) | |
| assert all_tokens == should_be_tokens | |
| assert len(all_tokens) <= max_length | |