Spaces:
Running
Running
import gradio as gr | |
from PIL import Image | |
import torch | |
from transformers import AutoProcessor, AutoModelForVision2Seq # Keep these for now | |
import os | |
# Determine the device to use | |
device_choice = os.environ.get("DEVICE", "auto") | |
if device_choice == "auto": | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
else: | |
device = device_choice | |
print(f"Using device: {device}") | |
# Load the model and processor | |
model_id = "lusxvr/nanoVLM-222M" | |
processor = None | |
model = None | |
try: | |
print(f"Attempting to load processor for {model_id} with trust_remote_code=True") | |
# For custom models like nanoVLM, trust_remote_code=True is often needed. | |
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
print(f"Processor loaded. Attempting to load model for {model_id} with trust_remote_code=True") | |
model = AutoModelForVision2Seq.from_pretrained(model_id, trust_remote_code=True).to(device) | |
print("Model and processor loaded successfully.") | |
except Exception as e: | |
print(f"Error loading model/processor: {e}") | |
# More detailed error logging or fallback could be added here. | |
def generate_text_for_image(image_input, prompt_input): | |
if model is None or processor is None: | |
return "Error: Model or processor not loaded. Check the Space logs. This might be due to missing 'trust_remote_code=True' or model compatibility issues." | |
if image_input is None: | |
return "Please upload an image." | |
if not prompt_input: | |
return "Please provide a prompt (e.g., 'Describe this image' or 'What color is the car?')." | |
try: | |
if not isinstance(image_input, Image.Image): | |
pil_image = Image.fromarray(image_input) | |
else: | |
pil_image = image_input | |
if pil_image.mode != "RGB": | |
pil_image = pil_image.convert("RGB") | |
inputs = processor(text=[prompt_input], images=[pil_image], return_tensors="pt").to(device) | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=150, | |
num_beams=3, | |
no_repeat_ngram_size=2, | |
early_stopping=True | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# Basic cleaning of the prompt if the model includes it in the output | |
if prompt_input and generated_text.startswith(prompt_input): | |
cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:") | |
else: | |
cleaned_text = generated_text | |
return cleaned_text.strip() | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
# Provide a more user-friendly error if possible | |
return f"An error occurred during text generation: {str(e)}" | |
description = """ | |
Upload an image and provide a text prompt (e.g., "What is in this image?", "Describe the animal in detail."). | |
The model will generate a textual response based on the visual content and your query. | |
This Space uses the `lusxvr/nanoVLM-222M` model. | |
""" | |
example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # A cat and a remote | |
# Get the pre-defined writable directory for Gradio's temporary files/cache | |
# This environment variable is set in your Dockerfile. | |
gradio_cache_dir = os.environ.get("GRADIO_TEMP_DIR", "/tmp/gradio_tmp") | |
iface = gr.Interface( | |
fn=generate_text_for_image, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Textbox(label="Your Prompt/Question", info="e.g., 'What is this a picture of?', 'Describe the main subject.', 'How many animals are there?'") | |
], | |
outputs=gr.Textbox(label="Generated Text", show_copy_button=True), | |
title="Interactive nanoVLM-222M Demo", | |
description=description, | |
examples=[ | |
[example_image_url, "a photo of a"], | |
[example_image_url, "Describe the image in detail."], | |
[example_image_url, "What objects are on the sofa?"], | |
], | |
cache_examples=True, | |
# Use the writable directory for caching examples | |
examples_cache_folder=gradio_cache_dir, | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
if model is None or processor is None: | |
print("CRITICAL: Model or processor failed to load. Gradio interface will not start.") | |
# You could raise an error here or sys.exit(1) to make the Space fail clearly if loading is essential. | |
else: | |
print("Launching Gradio interface...") | |
iface.launch(server_name="0.0.0.0", server_port=7860) |