import React, { createContext, useState, useRef, useCallback } from "react"; import { AutoProcessor, AutoModelForImageTextToText, RawImage, TextStreamer } from "@huggingface/transformers"; import type { LlavaProcessor, PreTrainedModel, Tensor } from "@huggingface/transformers"; import type { VLMContextValue } from "../types/vlm"; const VLMContext = createContext(null); const MODEL_ID = "onnx-community/FastVLM-0.5B-ONNX"; const MAX_NEW_TOKENS = 512; export { VLMContext }; export const VLMProvider: React.FC = ({ children }) => { const [isLoaded, setIsLoaded] = useState(false); const [isLoading, setIsLoading] = useState(false); const [error, setError] = useState(null); const processorRef = useRef(null); const modelRef = useRef(null); const loadPromiseRef = useRef | null>(null); const inferenceLock = useRef(false); const canvasRef = useRef(null); const loadModel = useCallback( async (onProgress?: (msg: string) => void) => { if (isLoaded) { onProgress?.("Model already loaded!"); return; } if (loadPromiseRef.current) { return loadPromiseRef.current; } setIsLoading(true); setError(null); loadPromiseRef.current = (async () => { try { onProgress?.("Loading processor..."); processorRef.current = await AutoProcessor.from_pretrained(MODEL_ID); onProgress?.("Processor loaded. Loading model..."); modelRef.current = await AutoModelForImageTextToText.from_pretrained(MODEL_ID, { dtype: { embed_tokens: "fp16", vision_encoder: "q4", decoder_model_merged: "q4", }, device: "webgpu", }); onProgress?.("Model loaded successfully!"); setIsLoaded(true); } catch (e) { const errorMessage = e instanceof Error ? e.message : String(e); setError(errorMessage); console.error("Error loading model:", e); throw e; } finally { setIsLoading(false); loadPromiseRef.current = null; } })(); return loadPromiseRef.current; }, [isLoaded], ); const runInference = useCallback( async (media: HTMLVideoElement | HTMLImageElement, instruction: string, onTextUpdate?: (text: string) => void): Promise => { if (inferenceLock.current) { console.log("Inference already running, skipping frame"); return ""; // Return empty string to signal a skip } inferenceLock.current = true; if (!processorRef.current || !modelRef.current) { throw new Error("Model/processor not loaded"); } if (!canvasRef.current) { canvasRef.current = document.createElement("canvas"); } const canvas = canvasRef.current; // Support both video and image let width = 0; let height = 0; if (media instanceof HTMLVideoElement) { width = media.videoWidth; height = media.videoHeight; } else if (media instanceof HTMLImageElement) { width = media.naturalWidth; height = media.naturalHeight; } else { throw new Error("Unsupported media type"); } canvas.width = width; canvas.height = height; const ctx = canvas.getContext("2d", { willReadFrequently: true }); if (!ctx) throw new Error("Could not get canvas context"); ctx.drawImage(media, 0, 0, width, height); const frame = ctx.getImageData(0, 0, canvas.width, canvas.height); const rawImg = new RawImage(frame.data, frame.width, frame.height, 4); const messages = [ { role: "system", content: `You are a helpful visual AI assistant. Respond concisely and accurately to the user's query in one sentence.`, }, { role: "user", content: `${instruction}` }, ]; const prompt = processorRef.current.apply_chat_template(messages, { add_generation_prompt: true, }); const inputs = await processorRef.current(rawImg, prompt, { add_special_tokens: false, }); let streamed = ""; const streamer = new TextStreamer(processorRef.current.tokenizer!, { skip_prompt: true, skip_special_tokens: true, callback_function: (t: string) => { streamed += t; onTextUpdate?.(streamed.trim()); }, }); const outputs = (await modelRef.current.generate({ ...inputs, max_new_tokens: MAX_NEW_TOKENS, do_sample: false, streamer, repetition_penalty: 1.2, })) as Tensor; const decoded = processorRef.current.batch_decode(outputs.slice(null, [inputs.input_ids.dims.at(-1), null]), { skip_special_tokens: true, }); inferenceLock.current = false; return decoded[0].trim(); }, [], ); return ( {children} ); };