import { useState, useRef, useCallback, useEffect } from "react"; import { AutoModelForCausalLM, AutoTokenizer, TextStreamer, } from "@huggingface/transformers"; const MODEL_ID = "shreyask/Maincoder-1B-ONNX-web"; interface LLMState { isLoading: boolean; isReady: boolean; error: string | null; progress: number; } interface LLMInstance { model: any; tokenizer: any; } let cachedInstance: LLMInstance | null = null; let loadingPromise: Promise | null = null; export const useLLM = () => { const [state, setState] = useState({ isLoading: false, isReady: false, error: null, progress: 0, }); const instanceRef = useRef(null); const pastKeyValuesRef = useRef(null); const loadModel = useCallback(async () => { if (instanceRef.current || cachedInstance) { const instance = instanceRef.current || cachedInstance; instanceRef.current = instance; cachedInstance = instance; setState((prev) => ({ ...prev, isReady: true, isLoading: false })); return instance; } if (loadingPromise) { const instance = await loadingPromise; instanceRef.current = instance; cachedInstance = instance; setState((prev) => ({ ...prev, isReady: true, isLoading: false })); return instance; } setState((prev) => ({ ...prev, isLoading: true, error: null, progress: 0, })); loadingPromise = (async () => { try { const progress_callback = (progress: any) => { if ( progress.status === "progress" && (progress.file?.endsWith(".onnx") || progress.file?.endsWith(".onnx_data")) ) { const percentage = Math.round( (progress.loaded / progress.total) * 100, ); setState((prev) => ({ ...prev, progress: percentage })); } }; const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID, { progress_callback, }); const model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, { dtype: "q4", device: "webgpu", progress_callback, }); const instance = { model, tokenizer }; instanceRef.current = instance; cachedInstance = instance; loadingPromise = null; setState({ isLoading: false, isReady: true, error: null, progress: 100, }); return instance; } catch (error) { loadingPromise = null; const message = error instanceof Error ? error.message : "Failed to load model"; setState((prev) => ({ ...prev, isLoading: false, error: message, })); throw error; } })(); return loadingPromise; }, []); const generateResponse = useCallback( async ( messages: Array<{ role: string; content: string }>, onToken?: (token: string) => void, ): Promise => { const instance = instanceRef.current; if (!instance) { throw new Error("Model not loaded. Call loadModel() first."); } const { model, tokenizer } = instance; const input = tokenizer.apply_chat_template(messages, { add_generation_prompt: true, return_dict: true, }); const streamer = onToken ? new TextStreamer(tokenizer, { skip_prompt: true, skip_special_tokens: true, callback_function: onToken, }) : undefined; const { sequences, past_key_values } = await model.generate({ ...input, past_key_values: pastKeyValuesRef.current, max_new_tokens: 1024, do_sample: false, repetition_penalty: 1.2, eos_token_id: [151643, 151645], // <|endoftext|> and <|im_end|> streamer, return_dict_in_generate: true, }); pastKeyValuesRef.current = past_key_values; const response = tokenizer .batch_decode(sequences.slice(null, [input.input_ids.dims[1], null]), { skip_special_tokens: true, })[0]; return response; }, [], ); const clearHistory = useCallback(() => { pastKeyValuesRef.current = null; }, []); useEffect(() => { if (cachedInstance) { instanceRef.current = cachedInstance; setState((prev) => ({ ...prev, isReady: true })); } }, []); return { ...state, loadModel, generateResponse, clearHistory, }; };