Spaces:
Sleeping
Sleeping
import sys | |
import os | |
# Add the cloned nanoVLM directory to Python's system path | |
NANOVLM_REPO_PATH = "/app/nanoVLM" | |
if NANOVLM_REPO_PATH not in sys.path: | |
sys.path.insert(0, NANOVLM_REPO_PATH) | |
import gradio as gr | |
from PIL import Image | |
import torch | |
from transformers import AutoProcessor # AutoProcessor should still be fine | |
# Import the custom VisionLanguageModel class from the cloned nanoVLM repository | |
try: | |
from models.vision_language_model import VisionLanguageModel | |
print("Successfully imported VisionLanguageModel from nanoVLM clone.") | |
except ImportError as e: | |
print(f"Error importing VisionLanguageModel from nanoVLM clone: {e}. Check NANOVLM_REPO_PATH and ensure nanoVLM cloned correctly.") | |
VisionLanguageModel = None | |
# Determine the device to use | |
device_choice = os.environ.get("DEVICE", "auto") | |
if device_choice == "auto": | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
else: | |
device = device_choice | |
print(f"Using device: {device}") | |
# Load the model and processor | |
model_id = "lusxvr/nanoVLM-222M" | |
processor = None | |
model = None | |
if VisionLanguageModel: | |
try: | |
print(f"Attempting to load processor for {model_id}") | |
# trust_remote_code=True might be beneficial if the processor config itself refers to custom code, | |
# though less likely for processors. | |
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
print("Processor loaded.") | |
print(f"Attempting to load model {model_id} using VisionLanguageModel.from_pretrained") | |
# The VisionLanguageModel.from_pretrained method should handle its own configuration loading | |
# from the model_id repository (which includes config.json). | |
# trust_remote_code=True here allows the custom VisionLanguageModel code to run. | |
model = VisionLanguageModel.from_pretrained(model_id, trust_remote_code=True).to(device) | |
print("Model loaded successfully.") | |
model.eval() # Set to evaluation mode | |
except Exception as e: | |
print(f"Error loading model or processor: {e}") | |
processor = None | |
model = None | |
else: | |
print("Custom VisionLanguageModel class not imported, cannot load model.") | |
def generate_text_for_image(image_input, prompt_input): | |
if model is None or processor is None: | |
return "Error: Model or processor not loaded correctly. Check logs." | |
if image_input is None: | |
return "Please upload an image." | |
if not prompt_input: | |
return "Please provide a prompt." | |
try: | |
if not isinstance(image_input, Image.Image): | |
pil_image = Image.fromarray(image_input) | |
else: | |
pil_image = image_input | |
if pil_image.mode != "RGB": | |
pil_image = pil_image.convert("RGB") | |
inputs = processor(text=[prompt_input], images=[pil_image], return_tensors="pt").to(device) | |
# Call the generate method of the VisionLanguageModel instance | |
# Check the definition of generate in nanoVLM/models/vision_language_model.py for exact signature if issues persist | |
# It likely expects pixel_values and input_ids directly or as part of a dictionary | |
generated_ids = model.generate( | |
pixel_values=inputs.get('pixel_values'), | |
input_ids=inputs.get('input_ids'), | |
attention_mask=inputs.get('attention_mask'), | |
max_new_tokens=150, | |
num_beams=3, | |
no_repeat_ngram_size=2, | |
early_stopping=True | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
if prompt_input and generated_text.startswith(prompt_input): | |
cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:") | |
else: | |
cleaned_text = generated_text | |
return cleaned_text.strip() | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
return f"An error occurred during text generation: {str(e)}" | |
description = "Interactive demo for lusxvr/nanoVLM-222M." | |
example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" | |
gradio_cache_dir = os.environ.get("GRADIO_TEMP_DIR", "/tmp/gradio_tmp") | |
iface = gr.Interface( | |
fn=generate_text_for_image, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Textbox(label="Your Prompt/Question") | |
], | |
outputs=gr.Textbox(label="Generated Text", show_copy_button=True), | |
title="Interactive nanoVLM-222M Demo", | |
description=description, | |
examples=[ | |
[example_image_url, "a photo of a"], | |
[example_image_url, "Describe the image in detail."], | |
], | |
cache_examples=True, | |
examples_cache_folder=gradio_cache_dir, | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
if model is None or processor is None: | |
print("CRITICAL: Model or processor failed to load. Gradio interface may not function correctly.") | |
else: | |
print("Launching Gradio interface...") | |
iface.launch(server_name="0.0.0.0", server_port=7860) |