Spaces:
Sleeping
Sleeping
import sys | |
import os | |
from typing import Optional # For type hinting | |
from PIL import Image as PILImage # Use an alias to avoid conflict with gr.Image | |
# Add the cloned nanoVLM directory to Python's system path | |
NANOVLM_REPO_PATH = "/app/nanoVLM" | |
if NANOVLM_REPO_PATH not in sys.path: | |
print(f"DEBUG: Adding {NANOVLM_REPO_PATH} to sys.path") | |
sys.path.insert(0, NANOVLM_REPO_PATH) | |
else: | |
print(f"DEBUG: {NANOVLM_REPO_PATH} already in sys.path") | |
import gradio as gr | |
import torch | |
from transformers import CLIPImageProcessor, GPT2TokenizerFast | |
# Import the custom VisionLanguageModel class | |
VisionLanguageModel = None # Initialize to None | |
try: | |
print("DEBUG: Attempting to import VisionLanguageModel from models.vision_language_model") | |
from models.vision_language_model import VisionLanguageModel | |
print("DEBUG: Successfully imported VisionLanguageModel from nanoVLM clone.") | |
except ImportError as e: | |
print(f"CRITICAL ERROR: Error importing VisionLanguageModel from nanoVLM clone: {e}.") | |
print("DEBUG: Please ensure /app/nanoVLM/models/vision_language_model.py exists and is correct.") | |
# No need to exit here, the checks later will handle it. | |
except Exception as e: | |
print(f"CRITICAL ERROR: An unexpected error occurred during VisionLanguageModel import: {e}") | |
# Determine the device to use | |
device_choice = os.environ.get("DEVICE", "auto") | |
if device_choice == "auto": | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
else: | |
device = device_choice | |
print(f"DEBUG: Using device: {device}") | |
# --- Configuration for model components --- | |
model_id_for_weights = "lusxvr/nanoVLM-222M" | |
image_processor_id = "openai/clip-vit-base-patch32" | |
tokenizer_id = "gpt2" # Using canonical gpt2 tokenizer | |
print(f"DEBUG: Configuration - model_id_for_weights: {model_id_for_weights}") | |
print(f"DEBUG: Configuration - image_processor_id: {image_processor_id}") | |
print(f"DEBUG: Configuration - tokenizer_id: {tokenizer_id}") | |
image_processor = None | |
tokenizer = None | |
model = None | |
# --- Load Processor and Model --- | |
if VisionLanguageModel is not None: # Only proceed if custom model class was imported | |
try: | |
print(f"DEBUG: Attempting to load CLIPImageProcessor from: {image_processor_id}") | |
image_processor = CLIPImageProcessor.from_pretrained(image_processor_id) | |
print(f"DEBUG: CLIPImageProcessor loaded: {type(image_processor)}") | |
print(f"DEBUG: Attempting to load GPT2TokenizerFast from: {tokenizer_id}") | |
tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_id) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
print(f"DEBUG: Set tokenizer pad_token to eos_token (ID: {tokenizer.eos_token_id})") | |
print(f"DEBUG: GPT2TokenizerFast loaded: {type(tokenizer)}, vocab_size: {tokenizer.vocab_size}") | |
print(f"DEBUG: Attempting to load model weights from {model_id_for_weights} using VisionLanguageModel.from_pretrained") | |
# Note: The custom VisionLanguageModel.from_pretrained in nanoVLM does not take trust_remote_code | |
model = VisionLanguageModel.from_pretrained(model_id_for_weights).to(device) | |
print(f"DEBUG: Model loaded successfully: {type(model)}") | |
model.eval() | |
print("DEBUG: Model set to evaluation mode (model.eval())") | |
# Optional: Print model's state_dict keys (can be very long) | |
# print("DEBUG: Model state_dict keys (first 10):", list(model.state_dict().keys())[:10]) | |
# print(f"DEBUG: Is model on device '{device}'? {next(model.parameters()).device}") | |
except Exception as e: | |
print(f"CRITICAL ERROR: Error loading model or processor components: {e}") | |
import traceback | |
traceback.print_exc() | |
# Reset to ensure generate_text_for_image knows they failed | |
image_processor = None | |
tokenizer = None | |
model = None | |
else: | |
print("CRITICAL ERROR: Custom VisionLanguageModel class not imported. Cannot load model.") | |
# --- Input Preparation Function --- | |
def prepare_inputs(text_list, image_input, image_processor_instance, tokenizer_instance, device_to_use): | |
print(f"DEBUG (prepare_inputs): Received text_list: {text_list}") | |
if image_processor_instance is None or tokenizer_instance is None: | |
print("ERROR (prepare_inputs): Image processor or tokenizer not initialized.") | |
raise ValueError("Image processor or tokenizer not initialized.") | |
# Process image | |
print(f"DEBUG (prepare_inputs): Processing image with {type(image_processor_instance)}") | |
processed_image_output = image_processor_instance(images=image_input, return_tensors="pt") | |
pixel_values = processed_image_output.pixel_values.to(device_to_use) | |
print(f"DEBUG (prepare_inputs): pixel_values shape: {pixel_values.shape}, dtype: {pixel_values.dtype}") | |
# Process text | |
print(f"DEBUG (prepare_inputs): Processing text with {type(tokenizer_instance)}") | |
# Using model_max_length from tokenizer, with a fallback. | |
max_len = getattr(tokenizer_instance, 'model_max_length', 512) | |
print(f"DEBUG (prepare_inputs): Tokenizer max_length: {max_len}") | |
processed_text_output = tokenizer_instance( | |
text=text_list, return_tensors="pt", padding=True, truncation=True, max_length=max_len | |
) | |
input_ids = processed_text_output.input_ids.to(device_to_use) | |
attention_mask = processed_text_output.attention_mask.to(device_to_use) | |
print(f"DEBUG (prepare_inputs): input_ids shape: {input_ids.shape}, dtype: {input_ids.dtype}, values: {input_ids}") | |
print(f"DEBUG (prepare_inputs): attention_mask shape: {attention_mask.shape}, dtype: {attention_mask.dtype}, values: {attention_mask}") | |
return {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask} | |
# --- Text Generation Function --- | |
def generate_text_for_image(image_input_pil: Optional[PILImage.Image], prompt_input_str: Optional[str]) -> str: | |
print(f"DEBUG (generate_text_for_image): Received prompt: '{prompt_input_str}'") | |
if model is None or image_processor is None or tokenizer is None: | |
print("ERROR (generate_text_for_image): Model or processor components not loaded.") | |
return "Error: Model or processor components not loaded correctly. Check application logs." | |
if image_input_pil is None: | |
print("WARN (generate_text_for_image): No image uploaded.") | |
return "Please upload an image." | |
if not prompt_input_str: | |
print("WARN (generate_text_for_image): No prompt provided.") | |
return "Please provide a prompt (e.g., 'a photo of a')." | |
try: | |
print("DEBUG (generate_text_for_image): Preparing image...") | |
current_pil_image = image_input_pil # Gradio provides PIL if type="pil" | |
if not isinstance(current_pil_image, PILImage.Image): | |
print(f"WARN (generate_text_for_image): Input image not PIL, type: {type(current_pil_image)}. Converting.") | |
current_pil_image = PILImage.fromarray(current_pil_image) # Fallback if not PIL | |
if current_pil_image.mode != "RGB": | |
print(f"DEBUG (generate_text_for_image): Converting image from mode {current_pil_image.mode} to RGB.") | |
current_pil_image = current_pil_image.convert("RGB") | |
print(f"DEBUG (generate_text_for_image): Image size: {current_pil_image.size}, mode: {current_pil_image.mode}") | |
print("DEBUG (generate_text_for_image): Preparing inputs for the model...") | |
inputs_dict = prepare_inputs( | |
text_list=[prompt_input_str], image_input=current_pil_image, | |
image_processor_instance=image_processor, tokenizer_instance=tokenizer, device_to_use=device | |
) | |
print(f"DEBUG (generate_text_for_image): Calling model.generate with input_ids (shape {inputs_dict['input_ids'].shape}), pixel_values (shape {inputs_dict['pixel_values'].shape}), attention_mask (shape {inputs_dict['attention_mask'].shape})") | |
# Match the signature: def generate(self, input_ids, image, attention_mask=None, max_new_tokens=...) | |
generated_ids_tensor = model.generate( | |
inputs_dict['input_ids'], # 1st argument: input_ids (text prompt) | |
inputs_dict['pixel_values'], # 2nd argument: image (pixel values) | |
inputs_dict['attention_mask'], # 3rd argument: attention_mask (for text) | |
max_new_tokens=30, # Using a smaller value for quicker debugging | |
temperature=0.8, # Slightly higher temperature to encourage diversity | |
top_k=50, # As per nanoVLM signature default | |
top_p=0.9, # As per nanoVLM signature default | |
greedy=False # As per nanoVLM signature default | |
) | |
print(f"DEBUG (generate_text_for_image): Raw generated_ids tensor: {generated_ids_tensor}") | |
# Decode the generated tokens | |
print("DEBUG (generate_text_for_image): Decoding generated tokens...") | |
generated_text_list_decoded = tokenizer.batch_decode(generated_ids_tensor, skip_special_tokens=True) | |
print(f"DEBUG (generate_text_for_image): Decoded text list (before join/cleanup): {generated_text_list_decoded}") | |
generated_text_str = generated_text_list_decoded[0] if generated_text_list_decoded else "" | |
# Optional: Clean up prompt if it's echoed by the model | |
cleaned_text_str = generated_text_str | |
if prompt_input_str and generated_text_str.startswith(prompt_input_str): | |
print("DEBUG (generate_text_for_image): Prompt found at the beginning of generation, removing it.") | |
cleaned_text_str = generated_text_str[len(prompt_input_str):].lstrip(" ,.:") | |
print(f"DEBUG (generate_text_for_image): Final cleaned text to be returned: '{cleaned_text_str}'") | |
return cleaned_text_str.strip() | |
except Exception as e: | |
print(f"CRITICAL ERROR (generate_text_for_image): An error occurred during generation: {e}") | |
import traceback | |
traceback.print_exc() # Print full traceback to logs | |
return f"An error occurred during text generation: {str(e)}. Check application logs." | |
# --- Gradio Interface Definition --- | |
description_md = """ | |
## Interactive nanoVLM-222M Demo | |
Upload an image and provide a text prompt (e.g., "What is in this image?", "Describe the animal in detail."). | |
The model will attempt to generate a textual response based on the visual content and your query. | |
This Space uses the `lusxvr/nanoVLM-222M` model with code from the original `huggingface/nanoVLM` repository. | |
""" | |
# example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # Not used currently | |
print("DEBUG: Defining Gradio interface...") | |
iface = None | |
try: | |
iface = gr.Interface( | |
fn=generate_text_for_image, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), # type="pil" ensures PIL.Image object | |
gr.Textbox(label="Your Prompt / Question", info="e.g., 'a photo of a', 'Describe this scene.'") | |
], | |
outputs=gr.Textbox(label="Generated Text", show_copy_button=True), | |
title="nanoVLM-222M Interactive Demo", | |
description=description_md, | |
# examples=[ # Examples commented out to simplify Gradio setup | |
# [example_image_url, "a photo of a"], | |
# [example_image_url, "Describe the image in detail."], | |
# ], | |
# cache_examples=False, # Explicitly False, or remove argument | |
allow_flagging="never" # Keep flagging disabled | |
) | |
print("DEBUG: Gradio interface defined successfully.") | |
except Exception as e: | |
print(f"CRITICAL ERROR: Error defining Gradio interface: {e}") | |
import traceback | |
traceback.print_exc() | |
# --- Launch Gradio App --- | |
if __name__ == "__main__": | |
print("DEBUG: Entered __main__ block.") | |
if VisionLanguageModel is None: | |
print("CRITICAL ERROR: VisionLanguageModel class was not imported. Cannot proceed.") | |
elif model is None or image_processor is None or tokenizer is None: | |
print("CRITICAL ERROR: Model, image_processor, or tokenizer failed to load. Gradio app might not be fully functional.") | |
if iface is not None: | |
print("DEBUG: Attempting to launch Gradio interface...") | |
try: | |
iface.launch(server_name="0.0.0.0", server_port=7860) # Standard for Spaces | |
print("DEBUG: Gradio launch command issued.") # This might not be reached if launch blocks or errors immediately | |
except Exception as e: | |
print(f"CRITICAL ERROR: Error launching Gradio interface: {e}") | |
import traceback | |
traceback.print_exc() | |
else: | |
print("CRITICAL ERROR: Gradio interface (iface) is None. Cannot launch.") | |