File size: 7,401 Bytes
e198913
 
2af1927
 
 
e198913
 
97c8139
e198913
 
 
4670dfa
 
 
16bf2d1
e198913
 
 
97c8139
e198913
16bf2d1
e198913
4670dfa
 
 
 
 
 
 
 
978a6b3
 
3253deb
978a6b3
16bf2d1
 
fbe5121
 
97c8139
e198913
978a6b3
0b8c303
16bf2d1
 
978a6b3
0b8c303
16bf2d1
3253deb
16bf2d1
 
e198913
978a6b3
0b8c303
97c8139
16bf2d1
e198913
 
16bf2d1
978a6b3
a4644a0
0b8c303
e198913
97c8139
e198913
978a6b3
2af1927
16bf2d1
 
2af1927
978a6b3
2af1927
16bf2d1
0b8c303
16bf2d1
 
 
2af1927
16bf2d1
 
fb82462
16bf2d1
 
0b8c303
 
4670dfa
 
fb82462
 
 
 
 
4670dfa
16bf2d1
fb82462
0b8c303
16bf2d1
fbe5121
2af1927
fb82462
2af1927
 
4670dfa
2af1927
 
 
 
 
 
 
 
 
4670dfa
 
16bf2d1
 
 
fbe5121
4670dfa
 
 
fb82462
4670dfa
fb82462
4670dfa
 
fb82462
 
fbe5121
4670dfa
e198913
0b8c303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4670dfa
 
16bf2d1
0b8c303
 
 
fbe5121
0b8c303
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import sys
import os
from PIL import Image as PILImage # Add at the top of your app.py if not already there
from typing import Optional


# Add the cloned nanoVLM directory to Python's system path
NANOVLM_REPO_PATH = "/app/nanoVLM"
if NANOVLM_REPO_PATH not in sys.path:
    sys.path.insert(0, NANOVLM_REPO_PATH)

import gradio as gr
from PIL import Image
import torch
from transformers import CLIPImageProcessor, GPT2TokenizerFast

try:
    from models.vision_language_model import VisionLanguageModel
    print("Successfully imported VisionLanguageModel from nanoVLM clone.")
except ImportError as e:
    print(f"Error importing VisionLanguageModel from nanoVLM clone: {e}.")
    VisionLanguageModel = None

device_choice = os.environ.get("DEVICE", "auto")
if device_choice == "auto":
    device = "cuda" if torch.cuda.is_available() else "cpu"
else:
    device = device_choice
print(f"Using device: {device}")

model_id_for_weights = "lusxvr/nanoVLM-222M"
image_processor_id = "openai/clip-vit-base-patch32"
tokenizer_id = "gpt2"

image_processor = None
tokenizer = None
model = None

if VisionLanguageModel:
    try:
        print(f"Attempting to load CLIPImageProcessor from: {image_processor_id}")
        image_processor = CLIPImageProcessor.from_pretrained(image_processor_id) # Removed trust_remote_code if not strictly needed by processor
        print("CLIPImageProcessor loaded.")
        
        print(f"Attempting to load GPT2TokenizerFast from: {tokenizer_id}")
        tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_id) # Removed trust_remote_code if not strictly needed by tokenizer
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            print("Set tokenizer pad_token to eos_token.")
        print("GPT2TokenizerFast loaded.")
        
        print(f"Attempting to load model weights from {model_id_for_weights} using VisionLanguageModel.from_pretrained")
        model = VisionLanguageModel.from_pretrained(model_id_for_weights).to(device)
        print("Model loaded successfully.")
        model.eval()

    except Exception as e:
        print(f"Error loading model or processor components: {e}")
        import traceback
        traceback.print_exc()
        image_processor = None; tokenizer = None; model = None
else:
    print("Custom VisionLanguageModel class not imported, cannot load model.")

def prepare_inputs(text_list, image_input, image_processor_instance, tokenizer_instance, device_to_use):
    # This function is fine
    if image_processor_instance is None or tokenizer_instance is None:
        raise ValueError("Image processor or tokenizer not initialized.")
    
    processed_image = image_processor_instance(images=image_input, return_tensors="pt").pixel_values.to(device_to_use)
    
    processed_text = tokenizer_instance(
        text=text_list, return_tensors="pt", padding=True, truncation=True, max_length=getattr(tokenizer_instance, 'model_max_length', 512)
    )
    input_ids = processed_text.input_ids.to(device_to_use)
    attention_mask = processed_text.attention_mask.to(device_to_use)
    
    return {"pixel_values": processed_image, "input_ids": input_ids, "attention_mask": attention_mask}

def generate_text_for_image(image_input: Optional[PILImage.Image], prompt_input: Optional[str]) -> str:
    if model is None or image_processor is None or tokenizer is None:
        return "Error: Model or processor components not loaded correctly. Check logs."
    if image_input is None: return "Please upload an image."
    if not prompt_input: return "Please provide a prompt."

    try:
        current_pil_image = image_input
        if not isinstance(current_pil_image, PILImage.Image):
             current_pil_image = PILImage.fromarray(current_pil_image)
        if current_pil_image.mode != "RGB":
            current_pil_image = current_pil_image.convert("RGB")

        inputs = prepare_inputs(
            text_list=[prompt_input], image_input=current_pil_image,
            image_processor_instance=image_processor, tokenizer_instance=tokenizer, device_to_use=device
        )
        
        print(f"Debug: Shapes before model.generate: pixel_values={inputs['pixel_values'].shape}, input_ids={inputs['input_ids'].shape}, attention_mask={inputs['attention_mask'].shape}")

        # --- CORRECTED model.generate CALL ---
        # Match the signature: def generate(self, input_ids, image, attention_mask=None, max_new_tokens=...)
        generated_ids = model.generate(
            inputs['input_ids'],          # 1st argument: input_ids (text prompt)
            inputs['pixel_values'],       # 2nd argument: image (pixel values)
            inputs['attention_mask'],     # 3rd argument: attention_mask (for text)
            max_new_tokens=150,           # Keyword argument for max_new_tokens
            # Other optional keyword arguments from the signature can be added here:
            # top_k=50,
            # top_p=0.9,
            # temperature=0.7, # Default is 0.5 in the provided signature
            # greedy=False
        )
        
        generated_text_list = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        generated_text = generated_text_list[0] if generated_text_list else ""

        if prompt_input and generated_text.startswith(prompt_input):
             cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:")
        else:
            cleaned_text = generated_text

        return cleaned_text.strip()

    except Exception as e:
        print(f"Error during generation: {e}")
        import traceback
        traceback.print_exc()
        return f"An error occurred during text generation: {str(e)}"

description = "Interactive demo for lusxvr/nanoVLM-222M."
# example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # Not used for now

print("Defining Gradio interface...")
try:
    iface = gr.Interface(
        fn=generate_text_for_image,
        inputs=[
            gr.Image(type="pil", label="Upload Image"),
            gr.Textbox(label="Your Prompt/Question")
        ],
        outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
        title="Interactive nanoVLM-222M Demo",
        description=description,
        # examples=[  # <<<< REMOVED EXAMPLES
        #     [example_image_url, "a photo of a"],
        #     [example_image_url, "Describe the image in detail."],
        # ],
        allow_flagging="never"
    )
    print("Gradio interface defined.")
except Exception as e:
    print(f"Error defining Gradio interface: {e}")
    import traceback; traceback.print_exc()
    iface = None


if __name__ == "__main__":
    if model is None or image_processor is None or tokenizer is None:
        print("CRITICAL: Model or processor components failed to load. Gradio might not work.")
    
    if iface is not None:
        print("Launching Gradio interface...")
        try:
            iface.launch(server_name="0.0.0.0", server_port=7860)
        except Exception as e:
            print(f"Error launching Gradio interface: {e}")
            import traceback; traceback.print_exc()
            # This is where the ValueError: When localhost is not accessible... usually comes from
            # if the underlying TypeError has already happened during iface setup.
    else:
        print("Gradio interface could not be defined due to earlier errors.")