context_window_POC / context_window_gradio.py
willsh1997's picture
TEST cast to bfloat16
fa8a6c8
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from transformers import pipeline
import pandas as pd
import gradio as gr
import os
import copy
import spaces
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer, TextIteratorStreamer
# quantization_config = BitsAndBytesConfig(load_in_4bit=True)
torch_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu")
torch_dtype = torch.bfloat16 if torch_device in ["cuda", "mps"] else torch.float32
llama_model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
# quantization_config=quantization_config,
torch_dtype=torch_dtype,
device_map=torch_device)
llama_tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
# streamer = TextStreamer(llama_tokenizer)
llama32_1b_pipe = pipeline(
"text-generation",
model=llama_model,
tokenizer=llama_tokenizer,
# streamer = streamer,
)
def context_window_limiting(history: list[dict], context_window: int):
'''
cull full messages until you have the desired context length
TO DO
'''
history_windowed = copy.deepcopy(history)
if len(history_windowed) > 0: #has to be a non-empty list
# print(history_windowed)
# print(llama32_1b_pipe.tokenizer.apply_chat_template(history_windowed))
# print(len(llama32_1b_pipe.tokenizer.apply_chat_template(history_windowed)))
while len(llama32_1b_pipe.tokenizer.apply_chat_template(history_windowed)) >= context_window:
if len(history_windowed) <= 0: #has to be a non-empty list
break
else:
del history_windowed[0] #delete first message
#DEBUG
print(f"number of messages in chat hist: {len(history_windowed)}")
return history_windowed
@spaces.GPU
def llama32_1b_chat(message, history, context_window) -> str:
"simplifies pipeline output to only return generated text"
input_history = copy.deepcopy(history)
input_history.append({"role": "user", "content": message})
input_history = context_window_limiting(input_history, context_window)
##add sth about context window here
outputs = llama32_1b_pipe(
input_history,
max_new_tokens=512
)
return outputs[-1]['generated_text'][-1]['content']
# Create the Gradio interface
def create_interface():
with gr.Blocks() as demo:
with gr.Row():
context_window = gr.Slider(64, 1024, value=256, label="size of context window", info="choose context window size")
with gr.Row():
gr.ChatInterface(fn=llama32_1b_chat, additional_inputs = [context_window], type="messages", title="context_window")
return demo
# Launch the app
demo = create_interface()
demo.launch()