Spaces:
Sleeping
Sleeping
File size: 6,837 Bytes
e198913 97c8139 e198913 4670dfa 16bf2d1 e198913 16bf2d1 e198913 97c8139 e198913 16bf2d1 e198913 4670dfa 16bf2d1 4670dfa 16bf2d1 fbe5121 97c8139 e198913 16bf2d1 e198913 97c8139 16bf2d1 97c8139 16bf2d1 e198913 16bf2d1 e198913 97c8139 e198913 16bf2d1 4670dfa 16bf2d1 4670dfa e198913 4670dfa 16bf2d1 fbe5121 16bf2d1 4670dfa 16bf2d1 fbe5121 4670dfa 16bf2d1 4670dfa 16bf2d1 fbe5121 4670dfa 16bf2d1 fbe5121 4670dfa e198913 fbe5121 4670dfa e198913 4670dfa 1792bb4 fbe5121 4670dfa 16bf2d1 fbe5121 e198913 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import sys
import os
# Add the cloned nanoVLM directory to Python's system path
NANOVLM_REPO_PATH = "/app/nanoVLM"
if NANOVLM_REPO_PATH not in sys.path:
sys.path.insert(0, NANOVLM_REPO_PATH)
import gradio as gr
from PIL import Image
import torch
# Import specific processor components
from transformers import CLIPImageProcessor, GPT2TokenizerFast
# Import the custom VisionLanguageModel class
try:
from models.vision_language_model import VisionLanguageModel
print("Successfully imported VisionLanguageModel from nanoVLM clone.")
except ImportError as e:
print(f"Error importing VisionLanguageModel from nanoVLM clone: {e}.")
VisionLanguageModel = None
# Determine the device to use
device_choice = os.environ.get("DEVICE", "auto")
if device_choice == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = device_choice
print(f"Using device: {device}")
# Load the model and processor components
model_id = "lusxvr/nanoVLM-222M"
image_processor = None
tokenizer = None
model = None
if VisionLanguageModel:
try:
print(f"Attempting to load specific processor components for {model_id}")
# Load the image processor
image_processor = CLIPImageProcessor.from_pretrained(model_id, trust_remote_code=True)
print("CLIPImageProcessor loaded.")
# Load the tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, trust_remote_code=True)
# Add a padding token if it's not already there (common for GPT2)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Set tokenizer pad_token to eos_token.")
print("GPT2TokenizerFast loaded.")
print(f"Attempting to load model {model_id} using VisionLanguageModel.from_pretrained")
model = VisionLanguageModel.from_pretrained(
model_id,
trust_remote_code=True # Allows custom model code to run
# The VisionLanguageModel might need image_processor and tokenizer passed during init,
# or it might retrieve them from its config. Check its __init__ if issues persist.
# For now, assume it gets them from config or they are not strictly needed at init.
).to(device)
print("Model loaded successfully.")
model.eval()
except Exception as e:
print(f"Error loading model or processor components: {e}")
image_processor = None
tokenizer = None
model = None
else:
print("Custom VisionLanguageModel class not imported, cannot load model.")
# Define a simple processor-like function for preparing inputs
def prepare_inputs(text, image, image_processor_instance, tokenizer_instance, device_to_use):
if image_processor_instance is None or tokenizer_instance is None:
raise ValueError("Image processor or tokenizer not initialized.")
# Process image
processed_image = image_processor_instance(images=image, return_tensors="pt").pixel_values.to(device_to_use)
# Process text
# Ensure padding is handled correctly for batching (even if batch size is 1)
processed_text = tokenizer_instance(
text=text, return_tensors="pt", padding=True, truncation=True
)
input_ids = processed_text.input_ids.to(device_to_use)
attention_mask = processed_text.attention_mask.to(device_to_use)
return {"pixel_values": processed_image, "input_ids": input_ids, "attention_mask": attention_mask}
def generate_text_for_image(image_input, prompt_input):
if model is None or image_processor is None or tokenizer is None:
return "Error: Model or processor components not loaded correctly. Check logs."
if image_input is None:
return "Please upload an image."
if not prompt_input:
return "Please provide a prompt."
try:
if not isinstance(image_input, Image.Image):
pil_image = Image.fromarray(image_input)
else:
pil_image = image_input
if pil_image.mode != "RGB":
pil_image = pil_image.convert("RGB")
# Use our custom input preparation function
inputs = prepare_inputs(
text=[prompt_input], # Expects a list of text prompts
image=pil_image, # Expects a single PIL image or list
image_processor_instance=image_processor,
tokenizer_instance=tokenizer,
device_to_use=device
)
# Generate text using the model's generate method
generated_ids = model.generate(
pixel_values=inputs['pixel_values'],
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_new_tokens=150,
num_beams=3,
no_repeat_ngram_size=2,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id # Important for generation
)
# Decode the generated tokens
# skip_special_tokens=True removes special tokens like <|endoftext|>
generated_text_list = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
generated_text = generated_text_list[0] if generated_text_list else ""
# Basic cleaning of the prompt if the model includes it in the output
if prompt_input and generated_text.startswith(prompt_input):
cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:")
else:
cleaned_text = generated_text
return cleaned_text.strip()
except Exception as e:
print(f"Error during generation: {e}")
import traceback
traceback.print_exc() # Print full traceback for debugging
return f"An error occurred during text generation: {str(e)}"
description = "Interactive demo for lusxvr/nanoVLM-222M."
example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
gradio_cache_dir = os.environ.get("GRADIO_TEMP_DIR", "/tmp/gradio_tmp")
iface = gr.Interface(
fn=generate_text_for_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(label="Your Prompt/Question")
],
outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
title="Interactive nanoVLM-222M Demo",
description=description,
examples=[
[example_image_url, "a photo of a"],
[example_image_url, "Describe the image in detail."],
],
cache_examples=True,
examples_cache_folder=gradio_cache_dir,
allow_flagging="never"
)
if __name__ == "__main__":
if model is None or image_processor is None or tokenizer is None:
print("CRITICAL: Model or processor components failed to load.")
else:
print("Launching Gradio interface...")
iface.launch(server_name="0.0.0.0", server_port=7860) |