File size: 6,837 Bytes
e198913
 
 
 
97c8139
e198913
 
 
4670dfa
 
 
16bf2d1
 
e198913
16bf2d1
e198913
 
97c8139
e198913
16bf2d1
e198913
4670dfa
 
 
 
 
 
 
 
 
16bf2d1
4670dfa
16bf2d1
 
fbe5121
 
97c8139
e198913
16bf2d1
 
 
 
 
 
 
 
 
 
 
 
e198913
97c8139
16bf2d1
 
 
 
 
 
 
97c8139
16bf2d1
e198913
 
16bf2d1
 
 
e198913
 
97c8139
e198913
16bf2d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4670dfa
 
16bf2d1
 
4670dfa
 
 
 
e198913
4670dfa
 
 
 
 
 
 
 
 
 
16bf2d1
 
 
 
 
 
 
 
fbe5121
16bf2d1
4670dfa
16bf2d1
 
 
fbe5121
 
4670dfa
16bf2d1
 
4670dfa
 
16bf2d1
 
 
 
 
 
fbe5121
4670dfa
 
 
 
 
 
 
 
16bf2d1
 
fbe5121
4670dfa
e198913
 
fbe5121
 
4670dfa
 
 
 
e198913
4670dfa
 
 
 
 
 
 
 
1792bb4
fbe5121
 
4670dfa
 
 
16bf2d1
 
fbe5121
 
e198913
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import sys
import os

# Add the cloned nanoVLM directory to Python's system path
NANOVLM_REPO_PATH = "/app/nanoVLM"
if NANOVLM_REPO_PATH not in sys.path:
    sys.path.insert(0, NANOVLM_REPO_PATH)

import gradio as gr
from PIL import Image
import torch
# Import specific processor components
from transformers import CLIPImageProcessor, GPT2TokenizerFast

# Import the custom VisionLanguageModel class
try:
    from models.vision_language_model import VisionLanguageModel
    print("Successfully imported VisionLanguageModel from nanoVLM clone.")
except ImportError as e:
    print(f"Error importing VisionLanguageModel from nanoVLM clone: {e}.")
    VisionLanguageModel = None

# Determine the device to use
device_choice = os.environ.get("DEVICE", "auto")
if device_choice == "auto":
    device = "cuda" if torch.cuda.is_available() else "cpu"
else:
    device = device_choice
print(f"Using device: {device}")

# Load the model and processor components
model_id = "lusxvr/nanoVLM-222M"
image_processor = None
tokenizer = None
model = None

if VisionLanguageModel:
    try:
        print(f"Attempting to load specific processor components for {model_id}")
        # Load the image processor
        image_processor = CLIPImageProcessor.from_pretrained(model_id, trust_remote_code=True)
        print("CLIPImageProcessor loaded.")
        
        # Load the tokenizer
        tokenizer = GPT2TokenizerFast.from_pretrained(model_id, trust_remote_code=True)
        # Add a padding token if it's not already there (common for GPT2)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            print("Set tokenizer pad_token to eos_token.")
        print("GPT2TokenizerFast loaded.")
        
        print(f"Attempting to load model {model_id} using VisionLanguageModel.from_pretrained")
        model = VisionLanguageModel.from_pretrained(
            model_id,
            trust_remote_code=True # Allows custom model code to run
            # The VisionLanguageModel might need image_processor and tokenizer passed during init,
            # or it might retrieve them from its config. Check its __init__ if issues persist.
            # For now, assume it gets them from config or they are not strictly needed at init.
        ).to(device)
        print("Model loaded successfully.")
        model.eval()

    except Exception as e:
        print(f"Error loading model or processor components: {e}")
        image_processor = None
        tokenizer = None
        model = None
else:
    print("Custom VisionLanguageModel class not imported, cannot load model.")

# Define a simple processor-like function for preparing inputs
def prepare_inputs(text, image, image_processor_instance, tokenizer_instance, device_to_use):
    if image_processor_instance is None or tokenizer_instance is None:
        raise ValueError("Image processor or tokenizer not initialized.")
    
    # Process image
    processed_image = image_processor_instance(images=image, return_tensors="pt").pixel_values.to(device_to_use)
    
    # Process text
    # Ensure padding is handled correctly for batching (even if batch size is 1)
    processed_text = tokenizer_instance(
        text=text, return_tensors="pt", padding=True, truncation=True
    )
    input_ids = processed_text.input_ids.to(device_to_use)
    attention_mask = processed_text.attention_mask.to(device_to_use)
    
    return {"pixel_values": processed_image, "input_ids": input_ids, "attention_mask": attention_mask}


def generate_text_for_image(image_input, prompt_input):
    if model is None or image_processor is None or tokenizer is None:
        return "Error: Model or processor components not loaded correctly. Check logs."

    if image_input is None:
        return "Please upload an image."
    if not prompt_input:
        return "Please provide a prompt."

    try:
        if not isinstance(image_input, Image.Image):
            pil_image = Image.fromarray(image_input)
        else:
            pil_image = image_input
        
        if pil_image.mode != "RGB":
            pil_image = pil_image.convert("RGB")

        # Use our custom input preparation function
        inputs = prepare_inputs(
            text=[prompt_input],  # Expects a list of text prompts
            image=pil_image,      # Expects a single PIL image or list
            image_processor_instance=image_processor,
            tokenizer_instance=tokenizer,
            device_to_use=device
        )
        
        # Generate text using the model's generate method
        generated_ids = model.generate(
            pixel_values=inputs['pixel_values'],
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_new_tokens=150,
            num_beams=3,
            no_repeat_ngram_size=2,
            early_stopping=True,
            pad_token_id=tokenizer.pad_token_id # Important for generation
        )
        
        # Decode the generated tokens
        # skip_special_tokens=True removes special tokens like <|endoftext|>
        generated_text_list = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        generated_text = generated_text_list[0] if generated_text_list else ""

        # Basic cleaning of the prompt if the model includes it in the output
        if prompt_input and generated_text.startswith(prompt_input):
             cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:")
        else:
            cleaned_text = generated_text

        return cleaned_text.strip()

    except Exception as e:
        print(f"Error during generation: {e}")
        import traceback
        traceback.print_exc() # Print full traceback for debugging
        return f"An error occurred during text generation: {str(e)}"

description = "Interactive demo for lusxvr/nanoVLM-222M."
example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
gradio_cache_dir = os.environ.get("GRADIO_TEMP_DIR", "/tmp/gradio_tmp")

iface = gr.Interface(
    fn=generate_text_for_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(label="Your Prompt/Question")
    ],
    outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
    title="Interactive nanoVLM-222M Demo",
    description=description,
    examples=[
        [example_image_url, "a photo of a"],
        [example_image_url, "Describe the image in detail."],
    ],
    cache_examples=True,
    examples_cache_folder=gradio_cache_dir,
    allow_flagging="never"
)

if __name__ == "__main__":
    if model is None or image_processor is None or tokenizer is None:
        print("CRITICAL: Model or processor components failed to load.")
    else:
        print("Launching Gradio interface...")
    iface.launch(server_name="0.0.0.0", server_port=7860)