File size: 4,503 Bytes
4670dfa
 
 
fbe5121
4670dfa
 
 
 
 
 
 
 
 
 
 
 
fbe5121
 
 
4670dfa
fbe5121
 
 
 
 
4670dfa
 
 
fbe5121
4670dfa
 
 
fbe5121
4670dfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbe5121
4670dfa
 
fbe5121
 
4670dfa
 
 
 
 
 
fbe5121
 
4670dfa
 
 
 
 
 
 
 
fbe5121
 
4670dfa
 
 
 
 
 
 
 
fbe5121
 
 
 
1792bb4
4670dfa
 
 
 
 
 
 
 
 
 
 
 
 
 
1792bb4
fbe5121
 
 
4670dfa
 
 
fbe5121
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq # Keep these for now
import os

# Determine the device to use
device_choice = os.environ.get("DEVICE", "auto")
if device_choice == "auto":
    device = "cuda" if torch.cuda.is_available() else "cpu"
else:
    device = device_choice
print(f"Using device: {device}")

# Load the model and processor
model_id = "lusxvr/nanoVLM-222M"
processor = None
model = None

try:
    print(f"Attempting to load processor for {model_id} with trust_remote_code=True")
    # For custom models like nanoVLM, trust_remote_code=True is often needed.
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
    print(f"Processor loaded. Attempting to load model for {model_id} with trust_remote_code=True")
    model = AutoModelForVision2Seq.from_pretrained(model_id, trust_remote_code=True).to(device)
    print("Model and processor loaded successfully.")
except Exception as e:
    print(f"Error loading model/processor: {e}")
    # More detailed error logging or fallback could be added here.

def generate_text_for_image(image_input, prompt_input):
    if model is None or processor is None:
        return "Error: Model or processor not loaded. Check the Space logs. This might be due to missing 'trust_remote_code=True' or model compatibility issues."

    if image_input is None:
        return "Please upload an image."
    if not prompt_input:
        return "Please provide a prompt (e.g., 'Describe this image' or 'What color is the car?')."

    try:
        if not isinstance(image_input, Image.Image):
            pil_image = Image.fromarray(image_input)
        else:
            pil_image = image_input
        
        if pil_image.mode != "RGB":
            pil_image = pil_image.convert("RGB")

        inputs = processor(text=[prompt_input], images=[pil_image], return_tensors="pt").to(device)
        
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=150,
            num_beams=3,
            no_repeat_ngram_size=2,
            early_stopping=True
        )
        
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        # Basic cleaning of the prompt if the model includes it in the output
        if prompt_input and generated_text.startswith(prompt_input):
             cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:")
        else:
            cleaned_text = generated_text

        return cleaned_text.strip()

    except Exception as e:
        print(f"Error during generation: {e}")
        # Provide a more user-friendly error if possible
        return f"An error occurred during text generation: {str(e)}"

description = """
Upload an image and provide a text prompt (e.g., "What is in this image?", "Describe the animal in detail.").
The model will generate a textual response based on the visual content and your query.
This Space uses the `lusxvr/nanoVLM-222M` model.
"""
example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # A cat and a remote

# Get the pre-defined writable directory for Gradio's temporary files/cache
# This environment variable is set in your Dockerfile.
gradio_cache_dir = os.environ.get("GRADIO_TEMP_DIR", "/tmp/gradio_tmp")


iface = gr.Interface(
    fn=generate_text_for_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(label="Your Prompt/Question", info="e.g., 'What is this a picture of?', 'Describe the main subject.', 'How many animals are there?'")
    ],
    outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
    title="Interactive nanoVLM-222M Demo",
    description=description,
    examples=[
        [example_image_url, "a photo of a"],
        [example_image_url, "Describe the image in detail."],
        [example_image_url, "What objects are on the sofa?"],
    ],
    cache_examples=True,
    # Use the writable directory for caching examples
    examples_cache_folder=gradio_cache_dir,
    allow_flagging="never"
)

if __name__ == "__main__":
    if model is None or processor is None:
        print("CRITICAL: Model or processor failed to load. Gradio interface will not start.")
        # You could raise an error here or sys.exit(1) to make the Space fail clearly if loading is essential.
    else:
        print("Launching Gradio interface...")
        iface.launch(server_name="0.0.0.0", server_port=7860)