Spaces:
Running
on
Zero
Running
on
Zero
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import torch | |
| from transformers import pipeline | |
| import pandas as pd | |
| import gradio as gr | |
| import os | |
| import copy | |
| import spaces | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer, TextIteratorStreamer | |
| # quantization_config = BitsAndBytesConfig(load_in_4bit=True) | |
| torch_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu") | |
| torch_dtype = torch.bfloat16 if torch_device in ["cuda", "mps"] else torch.float32 | |
| llama_model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", | |
| # quantization_config=quantization_config, | |
| torch_dtype=torch_dtype, | |
| device_map=torch_device) | |
| llama_tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") | |
| # streamer = TextStreamer(llama_tokenizer) | |
| llama32_1b_pipe = pipeline( | |
| "text-generation", | |
| model=llama_model, | |
| tokenizer=llama_tokenizer, | |
| # streamer = streamer, | |
| ) | |
| def context_window_limiting(history: list[dict], context_window: int): | |
| ''' | |
| cull full messages until you have the desired context length | |
| TO DO | |
| ''' | |
| history_windowed = copy.deepcopy(history) | |
| if len(history_windowed) > 0: #has to be a non-empty list | |
| # print(history_windowed) | |
| # print(llama32_1b_pipe.tokenizer.apply_chat_template(history_windowed)) | |
| # print(len(llama32_1b_pipe.tokenizer.apply_chat_template(history_windowed))) | |
| while len(llama32_1b_pipe.tokenizer.apply_chat_template(history_windowed)) >= context_window: | |
| if len(history_windowed) <= 0: #has to be a non-empty list | |
| break | |
| else: | |
| del history_windowed[0] #delete first message | |
| #DEBUG | |
| print(f"number of messages in chat hist: {len(history_windowed)}") | |
| return history_windowed | |
| def llama32_1b_chat(message, history, context_window) -> str: | |
| "simplifies pipeline output to only return generated text" | |
| input_history = copy.deepcopy(history) | |
| input_history.append({"role": "user", "content": message}) | |
| input_history = context_window_limiting(input_history, context_window) | |
| ##add sth about context window here | |
| outputs = llama32_1b_pipe( | |
| input_history, | |
| max_new_tokens=512 | |
| ) | |
| return outputs[-1]['generated_text'][-1]['content'] | |
| # Create the Gradio interface | |
| def create_interface(): | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| context_window = gr.Slider(64, 1024, value=256, label="size of context window", info="choose context window size") | |
| with gr.Row(): | |
| gr.ChatInterface(fn=llama32_1b_chat, additional_inputs = [context_window], type="messages", title="context_window") | |
| return demo | |
| # Launch the app | |
| demo = create_interface() | |
| demo.launch() | |