File size: 12,827 Bytes
e198913
 
200357b
 
e198913
 
97c8139
e198913
200357b
e198913
200357b
 
e198913
4670dfa
 
16bf2d1
e198913
200357b
 
e198913
200357b
e198913
200357b
e198913
200357b
 
 
 
 
4670dfa
200357b
 
4670dfa
 
 
 
 
200357b
4670dfa
200357b
978a6b3
 
200357b
 
 
 
 
978a6b3
16bf2d1
 
fbe5121
 
200357b
 
e198913
200357b
 
 
16bf2d1
200357b
 
16bf2d1
3253deb
200357b
 
e198913
200357b
 
0b8c303
200357b
16bf2d1
200357b
 
 
 
 
e198913
 
200357b
978a6b3
a4644a0
200357b
 
 
 
e198913
200357b
 
e198913
200357b
978a6b3
200357b
16bf2d1
200357b
16bf2d1
2af1927
200357b
 
 
 
 
2af1927
200357b
 
 
 
 
 
 
16bf2d1
200357b
 
 
 
2af1927
200357b
16bf2d1
200357b
 
 
 
16bf2d1
200357b
 
 
 
 
 
 
 
 
4670dfa
 
200357b
 
fb82462
200357b
 
fb82462
200357b
fb82462
200357b
4670dfa
200357b
 
 
0b8c303
16bf2d1
fbe5121
200357b
 
2af1927
200357b
 
 
 
 
 
 
 
 
4670dfa
 
200357b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb82462
4670dfa
200357b
fb82462
200357b
 
4670dfa
0b8c303
200357b
 
 
 
 
 
 
 
 
 
 
0b8c303
 
 
 
200357b
 
0b8c303
 
200357b
 
 
0b8c303
 
 
200357b
 
0b8c303
200357b
0b8c303
200357b
 
 
0b8c303
4670dfa
200357b
4670dfa
200357b
 
 
 
 
0b8c303
 
200357b
0b8c303
200357b
 
0b8c303
200357b
 
 
0b8c303
200357b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import sys
import os
from typing import Optional # For type hinting
from PIL import Image as PILImage # Use an alias to avoid conflict with gr.Image

# Add the cloned nanoVLM directory to Python's system path
NANOVLM_REPO_PATH = "/app/nanoVLM"
if NANOVLM_REPO_PATH not in sys.path:
    print(f"DEBUG: Adding {NANOVLM_REPO_PATH} to sys.path")
    sys.path.insert(0, NANOVLM_REPO_PATH)
else:
    print(f"DEBUG: {NANOVLM_REPO_PATH} already in sys.path")

import gradio as gr
import torch
from transformers import CLIPImageProcessor, GPT2TokenizerFast

# Import the custom VisionLanguageModel class
VisionLanguageModel = None # Initialize to None
try:
    print("DEBUG: Attempting to import VisionLanguageModel from models.vision_language_model")
    from models.vision_language_model import VisionLanguageModel
    print("DEBUG: Successfully imported VisionLanguageModel from nanoVLM clone.")
except ImportError as e:
    print(f"CRITICAL ERROR: Error importing VisionLanguageModel from nanoVLM clone: {e}.")
    print("DEBUG: Please ensure /app/nanoVLM/models/vision_language_model.py exists and is correct.")
    # No need to exit here, the checks later will handle it.
except Exception as e:
    print(f"CRITICAL ERROR: An unexpected error occurred during VisionLanguageModel import: {e}")


# Determine the device to use
device_choice = os.environ.get("DEVICE", "auto")
if device_choice == "auto":
    device = "cuda" if torch.cuda.is_available() else "cpu"
else:
    device = device_choice
print(f"DEBUG: Using device: {device}")

# --- Configuration for model components ---
model_id_for_weights = "lusxvr/nanoVLM-222M"
image_processor_id = "openai/clip-vit-base-patch32"
tokenizer_id = "gpt2" # Using canonical gpt2 tokenizer

print(f"DEBUG: Configuration - model_id_for_weights: {model_id_for_weights}")
print(f"DEBUG: Configuration - image_processor_id: {image_processor_id}")
print(f"DEBUG: Configuration - tokenizer_id: {tokenizer_id}")

image_processor = None
tokenizer = None
model = None

# --- Load Processor and Model ---
if VisionLanguageModel is not None: # Only proceed if custom model class was imported
    try:
        print(f"DEBUG: Attempting to load CLIPImageProcessor from: {image_processor_id}")
        image_processor = CLIPImageProcessor.from_pretrained(image_processor_id)
        print(f"DEBUG: CLIPImageProcessor loaded: {type(image_processor)}")
        
        print(f"DEBUG: Attempting to load GPT2TokenizerFast from: {tokenizer_id}")
        tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_id)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            print(f"DEBUG: Set tokenizer pad_token to eos_token (ID: {tokenizer.eos_token_id})")
        print(f"DEBUG: GPT2TokenizerFast loaded: {type(tokenizer)}, vocab_size: {tokenizer.vocab_size}")
        
        print(f"DEBUG: Attempting to load model weights from {model_id_for_weights} using VisionLanguageModel.from_pretrained")
        # Note: The custom VisionLanguageModel.from_pretrained in nanoVLM does not take trust_remote_code
        model = VisionLanguageModel.from_pretrained(model_id_for_weights).to(device)
        print(f"DEBUG: Model loaded successfully: {type(model)}")
        model.eval()
        print("DEBUG: Model set to evaluation mode (model.eval())")

        # Optional: Print model's state_dict keys (can be very long)
        # print("DEBUG: Model state_dict keys (first 10):", list(model.state_dict().keys())[:10])
        # print(f"DEBUG: Is model on device '{device}'? {next(model.parameters()).device}")

    except Exception as e:
        print(f"CRITICAL ERROR: Error loading model or processor components: {e}")
        import traceback
        traceback.print_exc()
        # Reset to ensure generate_text_for_image knows they failed
        image_processor = None
        tokenizer = None
        model = None
else:
    print("CRITICAL ERROR: Custom VisionLanguageModel class not imported. Cannot load model.")


# --- Input Preparation Function ---
def prepare_inputs(text_list, image_input, image_processor_instance, tokenizer_instance, device_to_use):
    print(f"DEBUG (prepare_inputs): Received text_list: {text_list}")
    if image_processor_instance is None or tokenizer_instance is None:
        print("ERROR (prepare_inputs): Image processor or tokenizer not initialized.")
        raise ValueError("Image processor or tokenizer not initialized.")
    
    # Process image
    print(f"DEBUG (prepare_inputs): Processing image with {type(image_processor_instance)}")
    processed_image_output = image_processor_instance(images=image_input, return_tensors="pt")
    pixel_values = processed_image_output.pixel_values.to(device_to_use)
    print(f"DEBUG (prepare_inputs): pixel_values shape: {pixel_values.shape}, dtype: {pixel_values.dtype}")
    
    # Process text
    print(f"DEBUG (prepare_inputs): Processing text with {type(tokenizer_instance)}")
    # Using model_max_length from tokenizer, with a fallback.
    max_len = getattr(tokenizer_instance, 'model_max_length', 512) 
    print(f"DEBUG (prepare_inputs): Tokenizer max_length: {max_len}")
    processed_text_output = tokenizer_instance(
        text=text_list, return_tensors="pt", padding=True, truncation=True, max_length=max_len
    )
    input_ids = processed_text_output.input_ids.to(device_to_use)
    attention_mask = processed_text_output.attention_mask.to(device_to_use)
    print(f"DEBUG (prepare_inputs): input_ids shape: {input_ids.shape}, dtype: {input_ids.dtype}, values: {input_ids}")
    print(f"DEBUG (prepare_inputs): attention_mask shape: {attention_mask.shape}, dtype: {attention_mask.dtype}, values: {attention_mask}")
    
    return {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask}


# --- Text Generation Function ---
def generate_text_for_image(image_input_pil: Optional[PILImage.Image], prompt_input_str: Optional[str]) -> str:
    print(f"DEBUG (generate_text_for_image): Received prompt: '{prompt_input_str}'")
    if model is None or image_processor is None or tokenizer is None:
        print("ERROR (generate_text_for_image): Model or processor components not loaded.")
        return "Error: Model or processor components not loaded correctly. Check application logs."

    if image_input_pil is None:
        print("WARN (generate_text_for_image): No image uploaded.")
        return "Please upload an image."
    if not prompt_input_str:
        print("WARN (generate_text_for_image): No prompt provided.")
        return "Please provide a prompt (e.g., 'a photo of a')."

    try:
        print("DEBUG (generate_text_for_image): Preparing image...")
        current_pil_image = image_input_pil # Gradio provides PIL if type="pil"
        if not isinstance(current_pil_image, PILImage.Image):
             print(f"WARN (generate_text_for_image): Input image not PIL, type: {type(current_pil_image)}. Converting.")
             current_pil_image = PILImage.fromarray(current_pil_image) # Fallback if not PIL
        if current_pil_image.mode != "RGB":
            print(f"DEBUG (generate_text_for_image): Converting image from mode {current_pil_image.mode} to RGB.")
            current_pil_image = current_pil_image.convert("RGB")
        print(f"DEBUG (generate_text_for_image): Image size: {current_pil_image.size}, mode: {current_pil_image.mode}")

        print("DEBUG (generate_text_for_image): Preparing inputs for the model...")
        inputs_dict = prepare_inputs(
            text_list=[prompt_input_str], image_input=current_pil_image,
            image_processor_instance=image_processor, tokenizer_instance=tokenizer, device_to_use=device
        )
        
        print(f"DEBUG (generate_text_for_image): Calling model.generate with input_ids (shape {inputs_dict['input_ids'].shape}), pixel_values (shape {inputs_dict['pixel_values'].shape}), attention_mask (shape {inputs_dict['attention_mask'].shape})")
        
        # Match the signature: def generate(self, input_ids, image, attention_mask=None, max_new_tokens=...)
        generated_ids_tensor = model.generate(
            inputs_dict['input_ids'],          # 1st argument: input_ids (text prompt)
            inputs_dict['pixel_values'],       # 2nd argument: image (pixel values)
            inputs_dict['attention_mask'],     # 3rd argument: attention_mask (for text)
            max_new_tokens=30,                 # Using a smaller value for quicker debugging
            temperature=0.8,                   # Slightly higher temperature to encourage diversity
            top_k=50,                          # As per nanoVLM signature default
            top_p=0.9,                         # As per nanoVLM signature default
            greedy=False                       # As per nanoVLM signature default
        )
        
        print(f"DEBUG (generate_text_for_image): Raw generated_ids tensor: {generated_ids_tensor}")
        
        # Decode the generated tokens
        print("DEBUG (generate_text_for_image): Decoding generated tokens...")
        generated_text_list_decoded = tokenizer.batch_decode(generated_ids_tensor, skip_special_tokens=True)
        print(f"DEBUG (generate_text_for_image): Decoded text list (before join/cleanup): {generated_text_list_decoded}")
        generated_text_str = generated_text_list_decoded[0] if generated_text_list_decoded else ""

        # Optional: Clean up prompt if it's echoed by the model
        cleaned_text_str = generated_text_str
        if prompt_input_str and generated_text_str.startswith(prompt_input_str):
             print("DEBUG (generate_text_for_image): Prompt found at the beginning of generation, removing it.")
             cleaned_text_str = generated_text_str[len(prompt_input_str):].lstrip(" ,.:")
        
        print(f"DEBUG (generate_text_for_image): Final cleaned text to be returned: '{cleaned_text_str}'")
        return cleaned_text_str.strip()

    except Exception as e:
        print(f"CRITICAL ERROR (generate_text_for_image): An error occurred during generation: {e}")
        import traceback
        traceback.print_exc() # Print full traceback to logs
        return f"An error occurred during text generation: {str(e)}. Check application logs."


# --- Gradio Interface Definition ---
description_md = """
## Interactive nanoVLM-222M Demo
Upload an image and provide a text prompt (e.g., "What is in this image?", "Describe the animal in detail.").
The model will attempt to generate a textual response based on the visual content and your query.
This Space uses the `lusxvr/nanoVLM-222M` model with code from the original `huggingface/nanoVLM` repository.
"""
# example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # Not used currently

print("DEBUG: Defining Gradio interface...")
iface = None
try:
    iface = gr.Interface(
        fn=generate_text_for_image,
        inputs=[
            gr.Image(type="pil", label="Upload Image"), # type="pil" ensures PIL.Image object
            gr.Textbox(label="Your Prompt / Question", info="e.g., 'a photo of a', 'Describe this scene.'")
        ],
        outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
        title="nanoVLM-222M Interactive Demo",
        description=description_md,
        # examples=[ # Examples commented out to simplify Gradio setup
        #     [example_image_url, "a photo of a"],
        #     [example_image_url, "Describe the image in detail."],
        # ],
        # cache_examples=False, # Explicitly False, or remove argument
        allow_flagging="never" # Keep flagging disabled
    )
    print("DEBUG: Gradio interface defined successfully.")
except Exception as e:
    print(f"CRITICAL ERROR: Error defining Gradio interface: {e}")
    import traceback
    traceback.print_exc()


# --- Launch Gradio App ---
if __name__ == "__main__":
    print("DEBUG: Entered __main__ block.")
    if VisionLanguageModel is None:
        print("CRITICAL ERROR: VisionLanguageModel class was not imported. Cannot proceed.")
    elif model is None or image_processor is None or tokenizer is None:
        print("CRITICAL ERROR: Model, image_processor, or tokenizer failed to load. Gradio app might not be fully functional.")
    
    if iface is not None:
        print("DEBUG: Attempting to launch Gradio interface...")
        try:
            iface.launch(server_name="0.0.0.0", server_port=7860) # Standard for Spaces
            print("DEBUG: Gradio launch command issued.") # This might not be reached if launch blocks or errors immediately
        except Exception as e:
            print(f"CRITICAL ERROR: Error launching Gradio interface: {e}")
            import traceback
            traceback.print_exc()
    else:
        print("CRITICAL ERROR: Gradio interface (iface) is None. Cannot launch.")