from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch from transformers import pipeline import pandas as pd import gradio as gr import os import copy import spaces from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer, TextIteratorStreamer # quantization_config = BitsAndBytesConfig(load_in_4bit=True) torch_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu") torch_dtype = torch.bfloat16 if torch_device in ["cuda", "mps"] else torch.float32 llama_model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", # quantization_config=quantization_config, torch_dtype=torch_dtype, device_map=torch_device) llama_tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") # streamer = TextStreamer(llama_tokenizer) llama32_1b_pipe = pipeline( "text-generation", model=llama_model, tokenizer=llama_tokenizer, # streamer = streamer, ) def context_window_limiting(history: list[dict], context_window: int): ''' cull full messages until you have the desired context length TO DO ''' history_windowed = copy.deepcopy(history) if len(history_windowed) > 0: #has to be a non-empty list # print(history_windowed) # print(llama32_1b_pipe.tokenizer.apply_chat_template(history_windowed)) # print(len(llama32_1b_pipe.tokenizer.apply_chat_template(history_windowed))) while len(llama32_1b_pipe.tokenizer.apply_chat_template(history_windowed)) >= context_window: if len(history_windowed) <= 0: #has to be a non-empty list break else: del history_windowed[0] #delete first message #DEBUG print(f"number of messages in chat hist: {len(history_windowed)}") return history_windowed @spaces.GPU def llama32_1b_chat(message, history, context_window) -> str: "simplifies pipeline output to only return generated text" input_history = copy.deepcopy(history) input_history.append({"role": "user", "content": message}) input_history = context_window_limiting(input_history, context_window) ##add sth about context window here outputs = llama32_1b_pipe( input_history, max_new_tokens=512 ) return outputs[-1]['generated_text'][-1]['content'] # Create the Gradio interface def create_interface(): with gr.Blocks() as demo: with gr.Row(): context_window = gr.Slider(64, 1024, value=256, label="size of context window", info="choose context window size") with gr.Row(): gr.ChatInterface(fn=llama32_1b_chat, additional_inputs = [context_window], type="messages", title="context_window") return demo # Launch the app demo = create_interface() demo.launch()