Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import torch | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
import os | |
# Determine the device to use | |
# Using os.environ.get to allow device override from Space hardware config if needed | |
# Defaults to CUDA if available, else CPU. | |
device_choice = os.environ.get("DEVICE", "auto") | |
if device_choice == "auto": | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
else: | |
device = device_choice | |
print(f"Using device: {device}") | |
# Load the model and processor | |
model_id = "lusxvr/nanoVLM-222M" | |
try: | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = AutoModelForVision2Seq.from_pretrained(model_id).to(device) | |
print("Model and processor loaded successfully.") | |
except Exception as e: | |
print(f"Error loading model/processor: {e}") | |
# If loading fails, we'll have the Gradio app display an error. | |
# This helps in debugging if the Space doesn't start correctly. | |
processor = None | |
model = None | |
def generate_text_for_image(image_input, prompt_input): | |
""" | |
Generates text based on an image and a text prompt. | |
""" | |
if model is None or processor is None: | |
return "Error: Model or processor not loaded. Check the Space logs for details." | |
if image_input is None: | |
return "Please upload an image." | |
if not prompt_input: | |
return "Please provide a prompt (e.g., 'Describe this image' or 'What color is the car?')." | |
try: | |
# Ensure the image is in PIL format and RGB | |
if not isinstance(image_input, Image.Image): | |
pil_image = Image.fromarray(image_input) | |
else: | |
pil_image = image_input | |
if pil_image.mode != "RGB": | |
pil_image = pil_image.convert("RGB") | |
# Prepare inputs for the model | |
# The prompt for nanoVLM is typically a question or an instruction. | |
inputs = processor(text=[prompt_input], images=[pil_image], return_tensors="pt").to(device) | |
# Generate text | |
# You can adjust max_new_tokens, temperature, top_k, etc. | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=150, # Increased for potentially longer descriptions | |
num_beams=3, # Example of adding beam search | |
no_repeat_ngram_size=2, | |
early_stopping=True | |
) | |
# Decode the generated tokens | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# The output might sometimes include the prompt itself, depending on the model. | |
# Simple heuristic to remove prompt if it appears at the beginning: | |
if generated_text.startswith(prompt_input): | |
cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:") | |
else: | |
cleaned_text = generated_text | |
return cleaned_text.strip() | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
return f"An error occurred: {str(e)}" | |
# Create the Gradio interface | |
description = """ | |
Upload an image and provide a text prompt (e.g., "What is in this image?", "Describe the animal in detail."). | |
The model will generate a textual response based on the visual content and your query. | |
This Space uses the `lusxvr/nanoVLM-222M` model. | |
""" | |
# Example image from COCO dataset | |
example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # A cat and a remote | |
iface = gr.Interface( | |
fn=generate_text_for_image, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Textbox(label="Your Prompt/Question", info="e.g., 'What is this a picture of?', 'Describe the main subject.', 'How many animals are there?'") | |
], | |
outputs=gr.Textbox(label="Generated Text", show_copy_button=True), | |
title="Interactive nanoVLM-222M Demo", | |
description=description, | |
examples=[ | |
[example_image_url, "a photo of a"], | |
[example_image_url, "Describe the image in detail."], | |
[example_image_url, "What objects are on the sofa?"], | |
], | |
cache_examples=True # Cache results for examples to load faster | |
) | |
if __name__ == "__main__": | |
# For Hugging Face Spaces, it's common to launch with server_name="0.0.0.0" | |
# The Space infrastructure handles the public URL and port mapping. | |
iface.launch(server_name="0.0.0.0", server_port=7860) |