Spaces:
Running
Running
import React, { createContext, useState, useRef, useCallback } from "react"; | |
import { AutoProcessor, AutoModelForImageTextToText, RawImage, TextStreamer } from "@huggingface/transformers"; | |
import type { LlavaProcessor, PreTrainedModel, Tensor } from "@huggingface/transformers"; | |
import type { VLMContextValue } from "../types/vlm"; | |
const VLMContext = createContext<VLMContextValue | null>(null); | |
const MODEL_ID = "onnx-community/FastVLM-0.5B-ONNX"; | |
const MAX_NEW_TOKENS = 512; | |
export { VLMContext }; | |
export const VLMProvider: React.FC<React.PropsWithChildren> = ({ children }) => { | |
const [isLoaded, setIsLoaded] = useState(false); | |
const [isLoading, setIsLoading] = useState(false); | |
const [error, setError] = useState<string | null>(null); | |
const processorRef = useRef<LlavaProcessor | null>(null); | |
const modelRef = useRef<PreTrainedModel | null>(null); | |
const loadPromiseRef = useRef<Promise<void> | null>(null); | |
const inferenceLock = useRef(false); | |
const canvasRef = useRef<HTMLCanvasElement | null>(null); | |
const loadModel = useCallback( | |
async (onProgress?: (msg: string) => void) => { | |
if (isLoaded) { | |
onProgress?.("Model already loaded!"); | |
return; | |
} | |
if (loadPromiseRef.current) { | |
return loadPromiseRef.current; | |
} | |
setIsLoading(true); | |
setError(null); | |
loadPromiseRef.current = (async () => { | |
try { | |
onProgress?.("Loading processor..."); | |
processorRef.current = await AutoProcessor.from_pretrained(MODEL_ID); | |
onProgress?.("Processor loaded. Loading model..."); | |
modelRef.current = await AutoModelForImageTextToText.from_pretrained(MODEL_ID, { | |
dtype: { | |
embed_tokens: "fp16", | |
vision_encoder: "q4", | |
decoder_model_merged: "q4", | |
}, | |
device: "webgpu", | |
}); | |
onProgress?.("Model loaded successfully!"); | |
setIsLoaded(true); | |
} catch (e) { | |
const errorMessage = e instanceof Error ? e.message : String(e); | |
setError(errorMessage); | |
console.error("Error loading model:", e); | |
throw e; | |
} finally { | |
setIsLoading(false); | |
loadPromiseRef.current = null; | |
} | |
})(); | |
return loadPromiseRef.current; | |
}, | |
[isLoaded], | |
); | |
const runInference = useCallback( | |
async (media: HTMLVideoElement | HTMLImageElement, instruction: string, onTextUpdate?: (text: string) => void): Promise<string> => { | |
if (inferenceLock.current) { | |
console.log("Inference already running, skipping frame"); | |
return ""; // Return empty string to signal a skip | |
} | |
inferenceLock.current = true; | |
if (!processorRef.current || !modelRef.current) { | |
throw new Error("Model/processor not loaded"); | |
} | |
if (!canvasRef.current) { | |
canvasRef.current = document.createElement("canvas"); | |
} | |
const canvas = canvasRef.current; | |
// Support both video and image | |
let width = 0; | |
let height = 0; | |
if (media instanceof HTMLVideoElement) { | |
width = media.videoWidth; | |
height = media.videoHeight; | |
} else if (media instanceof HTMLImageElement) { | |
width = media.naturalWidth; | |
height = media.naturalHeight; | |
} else { | |
throw new Error("Unsupported media type"); | |
} | |
canvas.width = width; | |
canvas.height = height; | |
const ctx = canvas.getContext("2d", { willReadFrequently: true }); | |
if (!ctx) throw new Error("Could not get canvas context"); | |
ctx.drawImage(media, 0, 0, width, height); | |
const frame = ctx.getImageData(0, 0, canvas.width, canvas.height); | |
const rawImg = new RawImage(frame.data, frame.width, frame.height, 4); | |
const messages = [ | |
{ | |
role: "system", | |
content: `You are a helpful visual AI assistant. Respond concisely and accurately to the user's query in one sentence.`, | |
}, | |
{ role: "user", content: `<image>${instruction}` }, | |
]; | |
const prompt = processorRef.current.apply_chat_template(messages, { | |
add_generation_prompt: true, | |
}); | |
const inputs = await processorRef.current(rawImg, prompt, { | |
add_special_tokens: false, | |
}); | |
let streamed = ""; | |
const streamer = new TextStreamer(processorRef.current.tokenizer!, { | |
skip_prompt: true, | |
skip_special_tokens: true, | |
callback_function: (t: string) => { | |
streamed += t; | |
onTextUpdate?.(streamed.trim()); | |
}, | |
}); | |
const outputs = (await modelRef.current.generate({ | |
...inputs, | |
max_new_tokens: MAX_NEW_TOKENS, | |
do_sample: false, | |
streamer, | |
repetition_penalty: 1.2, | |
})) as Tensor; | |
const decoded = processorRef.current.batch_decode(outputs.slice(null, [inputs.input_ids.dims.at(-1), null]), { | |
skip_special_tokens: true, | |
}); | |
inferenceLock.current = false; | |
return decoded[0].trim(); | |
}, | |
[], | |
); | |
return ( | |
<VLMContext.Provider | |
value={{ | |
isLoaded, | |
isLoading, | |
error, | |
loadModel, | |
runInference, | |
}} | |
> | |
{children} | |
</VLMContext.Provider> | |
); | |
}; | |