|
|
import os, re, logging, gradio as gr |
|
|
from openai import OpenAI |
|
|
from gateway import request_generation |
|
|
from utils import LATEX_DELIMS |
|
|
|
|
|
openai_api_key = os.getenv("API_KEY") |
|
|
openai_api_base = os.getenv("API_ENDPOINT") |
|
|
MODEL = os.getenv("MODEL_NAME", "") |
|
|
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base) |
|
|
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", 1024)) |
|
|
CONCURRENCY_LIMIT = int(os.getenv("CONCURRENCY_LIMIT", 20)) |
|
|
QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", CONCURRENCY_LIMIT * 4)) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
def format_analysis_response(text): |
|
|
m = re.search(r"analysis(.*?)assistantfinal", text, re.DOTALL) |
|
|
if m: |
|
|
reasoning = m.group(1).strip() |
|
|
response = text.split("assistantfinal", 1)[-1].strip() |
|
|
return ( |
|
|
f"**🤔 Analysis:**\n\n*{reasoning}*\n\n---\n\n" |
|
|
f"**💬 Response:**\n\n{response}" |
|
|
) |
|
|
return text.strip() |
|
|
|
|
|
def generate(message, history, |
|
|
system_prompt, temperature, |
|
|
frequency_penalty, presence_penalty, |
|
|
max_new_tokens): |
|
|
|
|
|
if not message.strip(): |
|
|
yield "Please enter a prompt." |
|
|
return |
|
|
|
|
|
msgs = [] |
|
|
for h in history: |
|
|
if isinstance(h, dict): |
|
|
msgs.append(h) |
|
|
elif isinstance(h, (list, tuple)) and len(h) == 2: |
|
|
u, a = h |
|
|
if u: msgs.append({"role": "user", "content": u}) |
|
|
if a: msgs.append({"role": "assistant", "content": a}) |
|
|
|
|
|
logging.info(f"[User] {message}") |
|
|
logging.info(f"[System] {system_prompt} | Temp={temperature}") |
|
|
|
|
|
collected, buffer = "", "" |
|
|
yielded_once = False |
|
|
|
|
|
try: |
|
|
for delta in request_generation( |
|
|
api_key=openai_api_key, api_base=openai_api_base, |
|
|
message=message, system_prompt=system_prompt, |
|
|
model_name=MODEL, chat_history=msgs, |
|
|
temperature=temperature, |
|
|
frequency_penalty=frequency_penalty, |
|
|
presence_penalty=presence_penalty, |
|
|
max_new_tokens=max_new_tokens, |
|
|
): |
|
|
if not delta: |
|
|
continue |
|
|
|
|
|
collected += delta |
|
|
buffer += delta |
|
|
|
|
|
if not yielded_once: |
|
|
yield delta |
|
|
buffer = "" |
|
|
yielded_once = True |
|
|
continue |
|
|
|
|
|
if "\n" in buffer or len(buffer) > 150: |
|
|
yield collected |
|
|
buffer = "" |
|
|
|
|
|
final = format_analysis_response(collected) |
|
|
if final.count("$") % 2: |
|
|
final += "$" |
|
|
yield final |
|
|
|
|
|
except Exception as e: |
|
|
logging.exception("Stream failed") |
|
|
yield f"❌ Error: {e}" |
|
|
|
|
|
chatbot_ui = gr.ChatInterface( |
|
|
fn=generate, |
|
|
type="messages", |
|
|
chatbot=gr.Chatbot( |
|
|
label="OSS vLLM Chatbot", |
|
|
type="messages", |
|
|
scale=2, |
|
|
height=600, |
|
|
latex_delimiters=LATEX_DELIMS, |
|
|
), |
|
|
stop_btn=True, |
|
|
additional_inputs=[ |
|
|
gr.Textbox(label="System prompt", value="You are a helpful assistant.", lines=2), |
|
|
gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.7), |
|
|
], |
|
|
examples=[ |
|
|
["Explain the difference between supervised and unsupervised learning."], |
|
|
["Summarize the plot of Inception in two sentences."], |
|
|
["Show me the LaTeX for the quadratic formula."], |
|
|
["What are advantages of AMD Instinct MI300X GPU?"], |
|
|
["Derive the gradient of softmax cross-entropy loss."], |
|
|
["Explain why ∂/∂x xⁿ = n·xⁿ⁻¹ holds."], |
|
|
], |
|
|
|
|
|
title=" XYZ on AMD MI300X", |
|
|
description="This Space is an Alpha release that demonstrates the X model running on AMD MI300 infrastructure. The space is built with XYX License. ", |
|
|
) |
|
|
if __name__ == "__main__": |
|
|
chatbot_ui.queue(max_size=QUEUE_SIZE, |
|
|
default_concurrency_limit=CONCURRENCY_LIMIT).launch() |
|
|
|
|
|
|