File size: 6,197 Bytes
e198913
 
 
 
97c8139
e198913
 
 
4670dfa
 
 
16bf2d1
 
e198913
16bf2d1
e198913
 
97c8139
e198913
16bf2d1
e198913
4670dfa
 
 
 
 
 
 
 
 
978a6b3
 
 
a4644a0
 
978a6b3
16bf2d1
 
fbe5121
 
97c8139
e198913
978a6b3
 
16bf2d1
 
978a6b3
 
16bf2d1
a4644a0
16bf2d1
 
e198913
978a6b3
16bf2d1
978a6b3
 
16bf2d1
97c8139
16bf2d1
e198913
 
16bf2d1
978a6b3
a4644a0
16bf2d1
 
e198913
 
97c8139
e198913
978a6b3
16bf2d1
 
 
978a6b3
16bf2d1
 
978a6b3
16bf2d1
 
 
 
 
 
4670dfa
16bf2d1
 
4670dfa
 
 
 
e198913
4670dfa
 
 
 
 
 
 
 
 
 
16bf2d1
978a6b3
 
16bf2d1
 
 
 
fbe5121
4670dfa
16bf2d1
 
 
fbe5121
 
4670dfa
16bf2d1
978a6b3
4670dfa
 
16bf2d1
 
 
fbe5121
4670dfa
 
 
 
 
 
 
 
16bf2d1
978a6b3
fbe5121
4670dfa
e198913
 
fbe5121
4670dfa
 
 
 
e198913
4670dfa
 
 
 
 
 
 
 
a4644a0
fbe5121
4670dfa
 
 
16bf2d1
 
fbe5121
 
e198913
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import sys
import os

# Add the cloned nanoVLM directory to Python's system path
NANOVLM_REPO_PATH = "/app/nanoVLM"
if NANOVLM_REPO_PATH not in sys.path:
    sys.path.insert(0, NANOVLM_REPO_PATH)

import gradio as gr
from PIL import Image
import torch
# Import specific processor components
from transformers import CLIPImageProcessor, GPT2TokenizerFast

# Import the custom VisionLanguageModel class
try:
    from models.vision_language_model import VisionLanguageModel
    print("Successfully imported VisionLanguageModel from nanoVLM clone.")
except ImportError as e:
    print(f"Error importing VisionLanguageModel from nanoVLM clone: {e}.")
    VisionLanguageModel = None

# Determine the device to use
device_choice = os.environ.get("DEVICE", "auto")
if device_choice == "auto":
    device = "cuda" if torch.cuda.is_available() else "cpu"
else:
    device = device_choice
print(f"Using device: {device}")

# --- Configuration for model components ---
model_id_for_weights = "lusxvr/nanoVLM-222M"
image_processor_id = "openai/clip-vit-base-patch32"
# Load the tokenizer from its original source to ensure all files are present
tokenizer_id = "gpt2" # Changed from "lusxvr/nanoVLM-222M"

image_processor = None
tokenizer = None
model = None

if VisionLanguageModel:
    try:
        print(f"Attempting to load CLIPImageProcessor from: {image_processor_id}")
        image_processor = CLIPImageProcessor.from_pretrained(image_processor_id, trust_remote_code=True)
        print("CLIPImageProcessor loaded.")
        
        print(f"Attempting to load GPT2TokenizerFast from: {tokenizer_id}")
        tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_id, trust_remote_code=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token # Important for GPT-2
            print("Set tokenizer pad_token to eos_token.")
        print("GPT2TokenizerFast loaded.")
        
        print(f"Attempting to load model weights from {model_id_for_weights} using VisionLanguageModel.from_pretrained")
        model = VisionLanguageModel.from_pretrained(
            model_id_for_weights,
            trust_remote_code=True
        ).to(device)
        print("Model loaded successfully.")
        model.eval()

    except Exception as e:
        print(f"Error loading model or processor components: {e}")
        import traceback
        traceback.print_exc()
        image_processor = None
        tokenizer = None
        model = None
else:
    print("Custom VisionLanguageModel class not imported, cannot load model.")

def prepare_inputs(text_list, image_input, image_processor_instance, tokenizer_instance, device_to_use):
    if image_processor_instance is None or tokenizer_instance is None:
        raise ValueError("Image processor or tokenizer not initialized.")
    
    processed_image = image_processor_instance(images=image_input, return_tensors="pt").pixel_values.to(device_to_use)
    
    processed_text = tokenizer_instance(
        text=text_list, return_tensors="pt", padding=True, truncation=True, max_length=tokenizer_instance.model_max_length
    )
    input_ids = processed_text.input_ids.to(device_to_use)
    attention_mask = processed_text.attention_mask.to(device_to_use)
    
    return {"pixel_values": processed_image, "input_ids": input_ids, "attention_mask": attention_mask}

def generate_text_for_image(image_input, prompt_input):
    if model is None or image_processor is None or tokenizer is None:
        return "Error: Model or processor components not loaded correctly. Check logs."

    if image_input is None:
        return "Please upload an image."
    if not prompt_input:
        return "Please provide a prompt."

    try:
        if not isinstance(image_input, Image.Image):
            pil_image = Image.fromarray(image_input)
        else:
            pil_image = image_input
        
        if pil_image.mode != "RGB":
            pil_image = pil_image.convert("RGB")

        inputs = prepare_inputs(
            text_list=[prompt_input],
            image_input=pil_image,
            image_processor_instance=image_processor,
            tokenizer_instance=tokenizer,
            device_to_use=device
        )
        
        generated_ids = model.generate(
            pixel_values=inputs['pixel_values'],
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_new_tokens=150,
            num_beams=3,
            no_repeat_ngram_size=2,
            early_stopping=True,
            pad_token_id=tokenizer.pad_token_id
        )
        
        generated_text_list = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        generated_text = generated_text_list[0] if generated_text_list else ""

        if prompt_input and generated_text.startswith(prompt_input):
             cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:")
        else:
            cleaned_text = generated_text

        return cleaned_text.strip()

    except Exception as e:
        print(f"Error during generation: {e}")
        import traceback
        traceback.print_exc()
        return f"An error occurred during text generation: {str(e)}"

description = "Interactive demo for lusxvr/nanoVLM-222M."
example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"

iface = gr.Interface(
    fn=generate_text_for_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(label="Your Prompt/Question")
    ],
    outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
    title="Interactive nanoVLM-222M Demo",
    description=description,
    examples=[
        [example_image_url, "a photo of a"],
        [example_image_url, "Describe the image in detail."],
    ],
    # cache_examples=True, # Temporarily commented out to ensure Gradio starts with minimal config
    allow_flagging="never"
)

if __name__ == "__main__":
    if model is None or image_processor is None or tokenizer is None:
        print("CRITICAL: Model or processor components failed to load.")
    else:
        print("Launching Gradio interface...")
    iface.launch(server_name="0.0.0.0", server_port=7860)