File size: 5,052 Bytes
e198913
 
 
 
97c8139
e198913
 
 
4670dfa
 
 
97c8139
e198913
97c8139
e198913
 
97c8139
e198913
97c8139
e198913
4670dfa
 
 
 
 
 
 
 
 
 
 
fbe5121
 
 
97c8139
e198913
 
97c8139
 
e198913
 
 
97c8139
 
 
 
 
 
e198913
 
 
97c8139
e198913
 
 
97c8139
e198913
4670dfa
 
97c8139
 
4670dfa
 
 
 
e198913
4670dfa
 
 
 
 
 
 
 
 
 
 
fbe5121
97c8139
 
 
4670dfa
97c8139
 
 
fbe5121
 
4670dfa
 
 
 
 
 
fbe5121
4670dfa
 
 
 
 
 
 
 
fbe5121
4670dfa
e198913
 
fbe5121
 
4670dfa
 
 
 
e198913
4670dfa
 
 
 
 
 
 
 
1792bb4
fbe5121
 
4670dfa
 
 
fbe5121
e198913
fbe5121
 
e198913
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import sys
import os

# Add the cloned nanoVLM directory to Python's system path
NANOVLM_REPO_PATH = "/app/nanoVLM"
if NANOVLM_REPO_PATH not in sys.path:
    sys.path.insert(0, NANOVLM_REPO_PATH)

import gradio as gr
from PIL import Image
import torch
from transformers import AutoProcessor # AutoProcessor should still be fine

# Import the custom VisionLanguageModel class from the cloned nanoVLM repository
try:
    from models.vision_language_model import VisionLanguageModel
    print("Successfully imported VisionLanguageModel from nanoVLM clone.")
except ImportError as e:
    print(f"Error importing VisionLanguageModel from nanoVLM clone: {e}. Check NANOVLM_REPO_PATH and ensure nanoVLM cloned correctly.")
    VisionLanguageModel = None

# Determine the device to use
device_choice = os.environ.get("DEVICE", "auto")
if device_choice == "auto":
    device = "cuda" if torch.cuda.is_available() else "cpu"
else:
    device = device_choice
print(f"Using device: {device}")

# Load the model and processor
model_id = "lusxvr/nanoVLM-222M"
processor = None
model = None

if VisionLanguageModel:
    try:
        print(f"Attempting to load processor for {model_id}")
        # trust_remote_code=True might be beneficial if the processor config itself refers to custom code,
        # though less likely for processors.
        processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
        print("Processor loaded.")
        
        print(f"Attempting to load model {model_id} using VisionLanguageModel.from_pretrained")
        # The VisionLanguageModel.from_pretrained method should handle its own configuration loading
        # from the model_id repository (which includes config.json).
        # trust_remote_code=True here allows the custom VisionLanguageModel code to run.
        model = VisionLanguageModel.from_pretrained(model_id, trust_remote_code=True).to(device)
        print("Model loaded successfully.")
        model.eval() # Set to evaluation mode

    except Exception as e:
        print(f"Error loading model or processor: {e}")
        processor = None
        model = None
else:
    print("Custom VisionLanguageModel class not imported, cannot load model.")


def generate_text_for_image(image_input, prompt_input):
    if model is None or processor is None:
        return "Error: Model or processor not loaded correctly. Check logs."

    if image_input is None:
        return "Please upload an image."
    if not prompt_input:
        return "Please provide a prompt."

    try:
        if not isinstance(image_input, Image.Image):
            pil_image = Image.fromarray(image_input)
        else:
            pil_image = image_input
        
        if pil_image.mode != "RGB":
            pil_image = pil_image.convert("RGB")

        inputs = processor(text=[prompt_input], images=[pil_image], return_tensors="pt").to(device)
        
        # Call the generate method of the VisionLanguageModel instance
        # Check the definition of generate in nanoVLM/models/vision_language_model.py for exact signature if issues persist
        # It likely expects pixel_values and input_ids directly or as part of a dictionary
        generated_ids = model.generate(
            pixel_values=inputs.get('pixel_values'),
            input_ids=inputs.get('input_ids'),
            attention_mask=inputs.get('attention_mask'),
            max_new_tokens=150,
            num_beams=3,
            no_repeat_ngram_size=2,
            early_stopping=True
        )
        
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        if prompt_input and generated_text.startswith(prompt_input):
             cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:")
        else:
            cleaned_text = generated_text

        return cleaned_text.strip()

    except Exception as e:
        print(f"Error during generation: {e}")
        return f"An error occurred during text generation: {str(e)}"

description = "Interactive demo for lusxvr/nanoVLM-222M."
example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
gradio_cache_dir = os.environ.get("GRADIO_TEMP_DIR", "/tmp/gradio_tmp")

iface = gr.Interface(
    fn=generate_text_for_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(label="Your Prompt/Question")
    ],
    outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
    title="Interactive nanoVLM-222M Demo",
    description=description,
    examples=[
        [example_image_url, "a photo of a"],
        [example_image_url, "Describe the image in detail."],
    ],
    cache_examples=True,
    examples_cache_folder=gradio_cache_dir,
    allow_flagging="never"
)

if __name__ == "__main__":
    if model is None or processor is None:
        print("CRITICAL: Model or processor failed to load. Gradio interface may not function correctly.")
    else:
        print("Launching Gradio interface...")
    iface.launch(server_name="0.0.0.0", server_port=7860)