File size: 7,749 Bytes
e198913
 
f9326ef
 
e198913
 
97c8139
e198913
200357b
e198913
 
4670dfa
 
f9326ef
e198913
f9326ef
e198913
f9326ef
e198913
f9326ef
e198913
f9326ef
200357b
f9326ef
 
200357b
4670dfa
f9326ef
 
 
 
978a6b3
f9326ef
 
fbe5121
 
f9326ef
e198913
f9326ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16bf2d1
f9326ef
e198913
 
f9326ef
978a6b3
a4644a0
f9326ef
e198913
f9326ef
16bf2d1
200357b
 
 
 
f9326ef
 
 
 
4670dfa
 
f9326ef
fb82462
f9326ef
fb82462
 
f9326ef
 
 
 
 
 
 
 
 
 
fbe5121
f9326ef
 
 
 
 
 
 
 
 
 
 
200357b
f9326ef
 
 
 
 
 
 
 
4670dfa
f9326ef
 
 
 
 
200357b
 
 
 
f9326ef
200357b
fb82462
4670dfa
f9326ef
fb82462
f9326ef
 
0b8c303
f9326ef
200357b
f9326ef
 
 
200357b
 
f9326ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4670dfa
f9326ef
 
 
0b8c303
f9326ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import sys
import os
from typing import Optional
from PIL import Image as PILImage

# Add the cloned nanoVLM directory to Python's system path
NANOVLM_REPO_PATH = "/app/nanoVLM"
if NANOVLM_REPO_PATH not in sys.path:
    print(f"DEBUG: Adding {NANOVLM_REPO_PATH} to sys.path")
    sys.path.insert(0, NANOVLM_REPO_PATH)

import gradio as gr
import torch
from transformers import AutoProcessor # Using AutoProcessor as in generate.py

VisionLanguageModel = None
try:
    print("DEBUG: Attempting to import VisionLanguageModel")
    from models.vision_language_model import VisionLanguageModel
    print("DEBUG: Successfully imported VisionLanguageModel.")
except ImportError as e:
    print(f"CRITICAL ERROR: Importing VisionLanguageModel: {e}")

# --- Device Setup ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"DEBUG: Using device: {device}")

# --- Configuration ---
# This will be used for both model and processor, as in generate.py
model_repo_id = "lusxvr/nanoVLM-222M"
print(f"DEBUG: Model Repository ID for model and processor: {model_repo_id}")

# --- Initialize ---
processor = None
model = None

if VisionLanguageModel: # Only proceed if custom model class was imported
    try:
        # Load processor using AutoProcessor, like in generate.py
        print(f"DEBUG: Loading processor using AutoProcessor.from_pretrained('{model_repo_id}')")
        # Using trust_remote_code=True here as a precaution,
        # though ideally not needed if processor_config.json is complete.
        processor = AutoProcessor.from_pretrained(model_repo_id, trust_remote_code=True)
        print(f"DEBUG: AutoProcessor loaded: {type(processor)}")

        # Ensure tokenizer has pad_token set if it's GPT-2 based
        if hasattr(processor, 'tokenizer') and processor.tokenizer is not None:
            if getattr(processor.tokenizer, 'pad_token', None) is None: # Check if pad_token attribute exists and is None
                processor.tokenizer.pad_token = processor.tokenizer.eos_token
                print(f"DEBUG: Set processor.tokenizer.pad_token to eos_token (ID: {processor.tokenizer.eos_token_id})")
        else:
            print("DEBUG: Processor does not have a 'tokenizer' attribute or it is None.")


        # Load model, like in generate.py
        print(f"DEBUG: Loading model VisionLanguageModel.from_pretrained('{model_repo_id}')")
        model = VisionLanguageModel.from_pretrained(model_repo_id).to(device)
        print(f"DEBUG: VisionLanguageModel loaded: {type(model)}")
        model.eval()
        print("DEBUG: Model set to eval() mode.")

    except Exception as e:
        print(f"CRITICAL ERROR loading model or processor with AutoProcessor: {e}")
        import traceback
        traceback.print_exc()
        processor = None; model = None
else:
    print("CRITICAL ERROR: VisionLanguageModel class not imported. Cannot load model.")


# --- Text Generation Function ---
def generate_text_for_image(image_input_pil: Optional[PILImage.Image], prompt_input_str: Optional[str]) -> str:
    print(f"DEBUG (generate_text_for_image): Received prompt: '{prompt_input_str}'")
    if model is None or processor is None:
        return "Error: Model or processor not loaded. Check logs."
    if image_input_pil is None: return "Please upload an image."
    if not prompt_input_str: return "Please provide a prompt."

    try:
        current_pil_image = image_input_pil
        if not isinstance(current_pil_image, PILImage.Image):
             current_pil_image = PILImage.fromarray(current_pil_image)
        if current_pil_image.mode != "RGB":
            current_pil_image = current_pil_image.convert("RGB")
        print(f"DEBUG: Image prepped - size: {current_pil_image.size}, mode: {current_pil_image.mode}")

        # Prepare inputs using the AutoProcessor, as in generate.py
        print("DEBUG: Processing inputs with AutoProcessor...")
        inputs = processor(
            text=[prompt_input_str], images=current_pil_image, return_tensors="pt"
        ).to(device)
        print(f"DEBUG: Inputs from AutoProcessor - keys: {inputs.keys()}")
        print(f"DEBUG:   input_ids shape: {inputs['input_ids'].shape}, values: {inputs['input_ids']}")
        print(f"DEBUG:   pixel_values shape: {inputs['pixel_values'].shape}")
        
        # Ensure attention_mask is present, default to ones if not (though AutoProcessor should provide it)
        attention_mask = inputs.get('attention_mask')
        if attention_mask is None:
            print("WARN: attention_mask not found in processor output, creating a default one of all 1s.")
            attention_mask = torch.ones_like(inputs['input_ids']).to(device)
        print(f"DEBUG:   attention_mask shape: {attention_mask.shape}")


        print("DEBUG: Calling model.generate (aligning with nanoVLM's generate.py)...")
        # Signature for nanoVLM's generate: (self, input_ids, image, attention_mask, max_new_tokens, ...)
        # `image` parameter in generate() corresponds to `pixel_values` from processor output
        generated_ids_tensor = model.generate(
            inputs['input_ids'],          # 1st argument to model.generate: input_ids (text prompt)
            inputs['pixel_values'],       # 2nd argument to model.generate: image (pixel values)
            attention_mask,               # 3rd argument to model.generate: attention_mask
            max_new_tokens=30,            # Corresponds to 4th argument in model.generate
            temperature=0.7,              # Match generate.py default or your choice
            top_k=50,                     # Match generate.py default or your choice
            greedy=False                  # Match generate.py default or your choice
            # top_p is also an option from generate.py's model.generate
        )
        print(f"DEBUG: Raw generated_ids: {generated_ids_tensor}")

        generated_text_list = processor.batch_decode(generated_ids_tensor, skip_special_tokens=True)
        print(f"DEBUG: Decoded text list: {generated_text_list}")
        generated_text_str = generated_text_list[0] if generated_text_list else ""

        cleaned_text_str = generated_text_str
        if prompt_input_str and generated_text_str.startswith(prompt_input_str):
             cleaned_text_str = generated_text_str[len(prompt_input_str):].lstrip(" ,.:")
        print(f"DEBUG: Final cleaned text: '{cleaned_text_str}'")
        return cleaned_text_str.strip()

    except Exception as e:
        print(f"CRITICAL ERROR during generation: {e}")
        import traceback
        traceback.print_exc()
        return f"Error during generation: {str(e)}"

# --- Gradio Interface ---
description_md = """
## Interactive nanoVLM-222M Demo (Mirroring generate.py)
Trying to replicate the working `generate.py` script from `huggingface/nanoVLM`.
Using AutoProcessor for inputs.
"""
iface = None
if processor and model:
    try:
        iface = gr.Interface(
            fn=generate_text_for_image,
            inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Your Prompt")],
            outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
            title="nanoVLM-222M Demo (generate.py Alignment)",
            description=description_md,
            allow_flagging="never"
        )
        print("DEBUG: Gradio interface defined.")
    except Exception as e:
        print(f"CRITICAL ERROR defining Gradio interface: {e}")
        import traceback; traceback.print_exc()

if __name__ == "__main__":
    if iface:
        print("DEBUG: Launching Gradio...")
        iface.launch(server_name="0.0.0.0", server_port=7860)
    else:
        print("CRITICAL ERROR: Gradio interface not defined or model/processor failed to load. Cannot launch.")