File size: 4,359 Bytes
4670dfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
import os

# Determine the device to use
# Using os.environ.get to allow device override from Space hardware config if needed
# Defaults to CUDA if available, else CPU.
device_choice = os.environ.get("DEVICE", "auto")
if device_choice == "auto":
    device = "cuda" if torch.cuda.is_available() else "cpu"
else:
    device = device_choice

print(f"Using device: {device}")

# Load the model and processor
model_id = "lusxvr/nanoVLM-222M"
try:
    processor = AutoProcessor.from_pretrained(model_id)
    model = AutoModelForVision2Seq.from_pretrained(model_id).to(device)
    print("Model and processor loaded successfully.")
except Exception as e:
    print(f"Error loading model/processor: {e}")
    # If loading fails, we'll have the Gradio app display an error.
    # This helps in debugging if the Space doesn't start correctly.
    processor = None
    model = None

def generate_text_for_image(image_input, prompt_input):
    """
    Generates text based on an image and a text prompt.
    """
    if model is None or processor is None:
        return "Error: Model or processor not loaded. Check the Space logs for details."

    if image_input is None:
        return "Please upload an image."
    if not prompt_input:
        return "Please provide a prompt (e.g., 'Describe this image' or 'What color is the car?')."

    try:
        # Ensure the image is in PIL format and RGB
        if not isinstance(image_input, Image.Image):
            pil_image = Image.fromarray(image_input)
        else:
            pil_image = image_input
        
        if pil_image.mode != "RGB":
            pil_image = pil_image.convert("RGB")

        # Prepare inputs for the model
        # The prompt for nanoVLM is typically a question or an instruction.
        inputs = processor(text=[prompt_input], images=[pil_image], return_tensors="pt").to(device)

        # Generate text
        # You can adjust max_new_tokens, temperature, top_k, etc.
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=150,  # Increased for potentially longer descriptions
            num_beams=3,         # Example of adding beam search
            no_repeat_ngram_size=2,
            early_stopping=True
        )
        
        # Decode the generated tokens
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        # The output might sometimes include the prompt itself, depending on the model.
        # Simple heuristic to remove prompt if it appears at the beginning:
        if generated_text.startswith(prompt_input):
             cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:")
        else:
            cleaned_text = generated_text

        return cleaned_text.strip()

    except Exception as e:
        print(f"Error during generation: {e}")
        return f"An error occurred: {str(e)}"

# Create the Gradio interface
description = """
Upload an image and provide a text prompt (e.g., "What is in this image?", "Describe the animal in detail.").
The model will generate a textual response based on the visual content and your query.
This Space uses the `lusxvr/nanoVLM-222M` model.
"""

# Example image from COCO dataset
example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # A cat and a remote

iface = gr.Interface(
    fn=generate_text_for_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(label="Your Prompt/Question", info="e.g., 'What is this a picture of?', 'Describe the main subject.', 'How many animals are there?'")
    ],
    outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
    title="Interactive nanoVLM-222M Demo",
    description=description,
    examples=[
        [example_image_url, "a photo of a"],
        [example_image_url, "Describe the image in detail."],
        [example_image_url, "What objects are on the sofa?"],
    ],
    cache_examples=True # Cache results for examples to load faster
)

if __name__ == "__main__":
    # For Hugging Face Spaces, it's common to launch with server_name="0.0.0.0"
    # The Space infrastructure handles the public URL and port mapping.
    iface.launch(server_name="0.0.0.0", server_port=7860)