Spaces:
Runtime error
Runtime error
| import transformers | |
| from utils import printf | |
| import copy | |
| class prompt: | |
| def __init__(self, tokenizer, max_len, add_eos=True): | |
| self.tokenizer = tokenizer | |
| self.max_len = max_len | |
| self.add_eos=add_eos | |
| class instruct_prompt(prompt): | |
| prompt = ( | |
| "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" | |
| "### Instruction:\n{instruction}\n\n### Response:" | |
| ) | |
| prompt_input = ( | |
| "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" | |
| "### Instruction:{instruction}\n\n### Input:{input}\n\n### Response:" | |
| ) | |
| prompt_history = "User:{input}\n\nAssistant:{output}\n\n" | |
| prompt_post = "User:{input}\n\nAssistant:" | |
| def preprocess_gen(self, data_point): | |
| if 'history' not in data_point: | |
| # single instruction format {'instruction':..,'input':..} | |
| if 'input' in data_point: | |
| user_prompt = self.prompt_input.format_map(data_point) | |
| else: | |
| user_prompt = self.prompt.format_map(data_point) | |
| else: | |
| # multi turn format {'history':[..], 'input':[..]} | |
| user_prompt = "\n".join(["User:" + i['input']+"\n"+"Assistant:" + i['output'] for i in data_point['history']]) + "\nUser:" + data_point['input'] + "\nAssistant:" | |
| user_prompt = user_prompt[-self.max_len:] | |
| user_prompt=self.prompt.format_map({'instruction':user_prompt}) | |
| input_ids = self.tokenizer(user_prompt)["input_ids"] | |
| return input_ids | |
| def preprocess_train(self, data_point): | |
| # single instruction format {'instruction':..,'input':..,'output':..} | |
| if 'instruction' in data_point: | |
| if 'input' in data_point: | |
| user_prompt = self.prompt_input.format_map(data_point) | |
| else: | |
| user_prompt = self.prompt.format_map(data_point) | |
| output = data_point["output"] | |
| # multi turn format {'input':[..], 'output':[..]} | |
| else: | |
| user_prompt = '' | |
| lens = len(data_point['input']) | |
| for i in range(lens-1): | |
| user_prompt += self.prompt_history.format_map({'input':data_point['input'][i],'output':data_point['output'][i]}) | |
| user_prompt += self.prompt_post.format_map({'input':data_point['input'][-1]}) | |
| user_prompt = self.prompt.format_map({'instruction': user_prompt}) | |
| output = data_point['output'][-1] | |
| len_user_prompt_tokens = (len(self.tokenizer( | |
| user_prompt, | |
| truncation=True, | |
| max_length=self.max_len + 1, | |
| )["input_ids"])- 1) # no eos token | |
| full_tokens = self.tokenizer( | |
| user_prompt + output, | |
| truncation=True, | |
| max_length=self.max_len + 1, | |
| padding="max_length", | |
| )["input_ids"][:-1] | |
| return { | |
| "input_ids": full_tokens, | |
| "labels": [-100] * len_user_prompt_tokens | |
| + full_tokens[len_user_prompt_tokens:], | |
| "attention_mask": [1] * (len(full_tokens)), | |
| } | |
| def data_collator(self,): | |
| return transformers.DataCollatorForLanguageModeling(self.tokenizer, mlm=False) | |
| def postprocess(self, text, render=True): | |
| #import pdb;pdb.set_trace() | |
| printf(text) | |
| output = text.split("### Response:")[1].strip() | |
| output = output.replace("Belle", "Vicuna") | |
| printf(output) | |
| if '###' in output: | |
| output = output.split("###")[0] | |
| if 'User' in output: | |
| output = output.split("User")[0] | |
| output = output.replace('�','').replace('</s>', '') | |
| if render: | |
| # fix gradio chatbot markdown code render bug | |
| lines = output.split("\n") | |
| for i, line in enumerate(lines): | |
| if "```" in line: | |
| if line != "```": | |
| lines[i] = f'<pre><code class="language-{lines[i][3:]}">' | |
| else: | |
| lines[i] = '</code></pre>' | |
| else: | |
| if i > 0: | |
| lines[i] = "<br/>" + line.replace("<", "<").replace(">", ">").replace("__", '\_\_') | |
| output = "".join(lines) | |
| # output = output.replace('<br/><pre>','\n<pre>') work for html; but not for gradio | |
| return output | |
| class chat_prompt(prompt): | |
| prompt_pre = ( | |
| "The following is a conversation between an AI assistant called Assistant and a human user called User. " | |
| "The assistant is intelligent, knowledgeable and polite to answer questions of user.\n\n" | |
| ) | |
| prompt_history = "User:{input}\n\nAssistant:{output}\n\n" | |
| prompt_post = "User:{input}\n\nAssistant:" | |
| def preprocess_gen(self, data_point): | |
| user_prompt = self.prompt_pre | |
| len_avail = self.max_len - len(self.tokenizer(user_prompt, add_special_tokens=False)['input_ids']) | |
| input_prompt = self.prompt_post.format_map({'input':data_point['input']}) | |
| len_avail -= len(self.tokenizer(input_prompt, add_special_tokens=False)['input_ids']) | |
| lens = len(data_point['history']) | |
| tokenized_lens = [] | |
| for i in range(lens): | |
| tmp_prompt = self.prompt_history.format_map(data_point['history'][i]) | |
| tokenized_lens.append(len(self.tokenizer(tmp_prompt,add_special_tokens=False)["input_ids"])) | |
| # 启发式:/2 优先除前面的 | |
| i = 0 | |
| while sum(tokenized_lens) > len_avail and i < lens: | |
| history = data_point['history'][i] | |
| tmp_len1 = len(history['input']) | |
| tmp_len2 = len(history['output']) | |
| if tmp_len2 > tmp_len1: | |
| history['output'] = history['output'][:tmp_len2//2] | |
| else: | |
| history['input'] = history['input'][:tmp_len1//2] | |
| prompt = self.prompt_history.format_map(history) | |
| single_len =(len(self.tokenizer(prompt,add_special_tokens=False)["input_ids"])) | |
| tokenized_lens[i] = single_len | |
| i += 1 | |
| total_len = sum(tokenized_lens) | |
| # 还不够的话 直接截断 | |
| while total_len > len_avail and i < lens - 1 : | |
| total_len -= tokenized_lens[i] | |
| data_point['history'] = data_point['history'][1:] | |
| i += 1 | |
| # 最终合并 | |
| for i in range(lens): | |
| user_prompt += self.prompt_history.format_map(data_point['history'][i]) | |
| user_prompt += input_prompt | |
| printf({'real_input:':user_prompt}) | |
| inputs = self.tokenizer(user_prompt)["input_ids"] | |
| return inputs | |
| def preprocess_train(self, data_point): | |
| user_prompt = self.prompt_pre | |
| lens = len(data_point['input']) | |
| # print("Length of data_point['input']: ", len(data_point['input'])) | |
| # print("Length of data_point['output']: ", len(data_point['output'])) | |
| # for i in range(lens-1): | |
| # user_prompt += self.prompt_history.format_map({'input':data_point['input'][i].strip(),'output':data_point['output'][i].strip()}) | |
| user_prompt += self.prompt_post.format_map({'input':data_point['input'].strip()}) | |
| len_user_prompt_tokens = len(self.tokenizer( | |
| user_prompt, | |
| truncation=True, | |
| max_length=self.max_len, | |
| )["input_ids"]) - 1 # remove extra eos | |
| if self.add_eos: | |
| full_tokens = self.tokenizer( | |
| user_prompt + data_point["output"].strip(), | |
| truncation=True, | |
| padding=False, | |
| max_length=self.max_len, | |
| )["input_ids"] # need eos | |
| else: | |
| full_tokens = self.tokenizer( | |
| user_prompt + data_point["output"].strip(), | |
| truncation=True, | |
| padding=False, | |
| max_length=self.max_len+1, | |
| )["input_ids"][:-1] # delete eos | |
| return { | |
| "input_ids": full_tokens, | |
| "labels": [-100] * len_user_prompt_tokens + full_tokens[len_user_prompt_tokens:], | |
| "attention_mask": [1] * (len(full_tokens)), | |
| } | |
| def data_collator(self,): | |
| return transformers.DataCollatorForSeq2Seq(self.tokenizer) | |
| def postprocess(self, text, render=False): | |
| output = text.split("Assistant:")[-1].strip() | |
| if 'User:' in output: | |
| output = output.split("User:")[0] | |
| output = output.replace('�','') | |
| if render: | |
| # fix gradio chatbot markdown code render bug | |
| lines = output.split("\n") | |
| for i, line in enumerate(lines): | |
| if "```" in line: | |
| if line != "```": | |
| lines[i] = f'<pre><code class="language-{lines[i][3:]}">' | |
| else: | |
| lines[i] = '</code></pre>' | |
| else: | |
| if i > 0: | |
| lines[i] = "<br/>" + line.replace("<", "<").replace(">", ">").replace("__", '\_\_') | |
| output = "".join(lines) | |
| # output = output.replace('<br/><pre>','\n<pre>') work for html; but not for gradio | |
| return output | |
| def get_data_collator(): | |
| return transformers.DataCollatorForLanguageModeling | |