Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| # Add the cloned nanoVLM directory to Python's system path | |
| NANOVLM_REPO_PATH = "/app/nanoVLM" | |
| if NANOVLM_REPO_PATH not in sys.path: | |
| sys.path.insert(0, NANOVLM_REPO_PATH) | |
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| # Import specific processor components | |
| from transformers import CLIPImageProcessor, GPT2TokenizerFast | |
| # Import the custom VisionLanguageModel class | |
| try: | |
| from models.vision_language_model import VisionLanguageModel | |
| print("Successfully imported VisionLanguageModel from nanoVLM clone.") | |
| except ImportError as e: | |
| print(f"Error importing VisionLanguageModel from nanoVLM clone: {e}.") | |
| VisionLanguageModel = None | |
| # Determine the device to use | |
| device_choice = os.environ.get("DEVICE", "auto") | |
| if device_choice == "auto": | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| else: | |
| device = device_choice | |
| print(f"Using device: {device}") | |
| # Load the model and processor components | |
| model_id = "lusxvr/nanoVLM-222M" | |
| image_processor = None | |
| tokenizer = None | |
| model = None | |
| if VisionLanguageModel: | |
| try: | |
| print(f"Attempting to load specific processor components for {model_id}") | |
| # Load the image processor | |
| image_processor = CLIPImageProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| print("CLIPImageProcessor loaded.") | |
| # Load the tokenizer | |
| tokenizer = GPT2TokenizerFast.from_pretrained(model_id, trust_remote_code=True) | |
| # Add a padding token if it's not already there (common for GPT2) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Set tokenizer pad_token to eos_token.") | |
| print("GPT2TokenizerFast loaded.") | |
| print(f"Attempting to load model {model_id} using VisionLanguageModel.from_pretrained") | |
| model = VisionLanguageModel.from_pretrained( | |
| model_id, | |
| trust_remote_code=True # Allows custom model code to run | |
| # The VisionLanguageModel might need image_processor and tokenizer passed during init, | |
| # or it might retrieve them from its config. Check its __init__ if issues persist. | |
| # For now, assume it gets them from config or they are not strictly needed at init. | |
| ).to(device) | |
| print("Model loaded successfully.") | |
| model.eval() | |
| except Exception as e: | |
| print(f"Error loading model or processor components: {e}") | |
| image_processor = None | |
| tokenizer = None | |
| model = None | |
| else: | |
| print("Custom VisionLanguageModel class not imported, cannot load model.") | |
| # Define a simple processor-like function for preparing inputs | |
| def prepare_inputs(text, image, image_processor_instance, tokenizer_instance, device_to_use): | |
| if image_processor_instance is None or tokenizer_instance is None: | |
| raise ValueError("Image processor or tokenizer not initialized.") | |
| # Process image | |
| processed_image = image_processor_instance(images=image, return_tensors="pt").pixel_values.to(device_to_use) | |
| # Process text | |
| # Ensure padding is handled correctly for batching (even if batch size is 1) | |
| processed_text = tokenizer_instance( | |
| text=text, return_tensors="pt", padding=True, truncation=True | |
| ) | |
| input_ids = processed_text.input_ids.to(device_to_use) | |
| attention_mask = processed_text.attention_mask.to(device_to_use) | |
| return {"pixel_values": processed_image, "input_ids": input_ids, "attention_mask": attention_mask} | |
| def generate_text_for_image(image_input, prompt_input): | |
| if model is None or image_processor is None or tokenizer is None: | |
| return "Error: Model or processor components not loaded correctly. Check logs." | |
| if image_input is None: | |
| return "Please upload an image." | |
| if not prompt_input: | |
| return "Please provide a prompt." | |
| try: | |
| if not isinstance(image_input, Image.Image): | |
| pil_image = Image.fromarray(image_input) | |
| else: | |
| pil_image = image_input | |
| if pil_image.mode != "RGB": | |
| pil_image = pil_image.convert("RGB") | |
| # Use our custom input preparation function | |
| inputs = prepare_inputs( | |
| text=[prompt_input], # Expects a list of text prompts | |
| image=pil_image, # Expects a single PIL image or list | |
| image_processor_instance=image_processor, | |
| tokenizer_instance=tokenizer, | |
| device_to_use=device | |
| ) | |
| # Generate text using the model's generate method | |
| generated_ids = model.generate( | |
| pixel_values=inputs['pixel_values'], | |
| input_ids=inputs['input_ids'], | |
| attention_mask=inputs['attention_mask'], | |
| max_new_tokens=150, | |
| num_beams=3, | |
| no_repeat_ngram_size=2, | |
| early_stopping=True, | |
| pad_token_id=tokenizer.pad_token_id # Important for generation | |
| ) | |
| # Decode the generated tokens | |
| # skip_special_tokens=True removes special tokens like <|endoftext|> | |
| generated_text_list = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
| generated_text = generated_text_list[0] if generated_text_list else "" | |
| # Basic cleaning of the prompt if the model includes it in the output | |
| if prompt_input and generated_text.startswith(prompt_input): | |
| cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:") | |
| else: | |
| cleaned_text = generated_text | |
| return cleaned_text.strip() | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| import traceback | |
| traceback.print_exc() # Print full traceback for debugging | |
| return f"An error occurred during text generation: {str(e)}" | |
| description = "Interactive demo for lusxvr/nanoVLM-222M." | |
| example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" | |
| gradio_cache_dir = os.environ.get("GRADIO_TEMP_DIR", "/tmp/gradio_tmp") | |
| iface = gr.Interface( | |
| fn=generate_text_for_image, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image"), | |
| gr.Textbox(label="Your Prompt/Question") | |
| ], | |
| outputs=gr.Textbox(label="Generated Text", show_copy_button=True), | |
| title="Interactive nanoVLM-222M Demo", | |
| description=description, | |
| examples=[ | |
| [example_image_url, "a photo of a"], | |
| [example_image_url, "Describe the image in detail."], | |
| ], | |
| cache_examples=True, | |
| examples_cache_folder=gradio_cache_dir, | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| if model is None or image_processor is None or tokenizer is None: | |
| print("CRITICAL: Model or processor components failed to load.") | |
| else: | |
| print("Launching Gradio interface...") | |
| iface.launch(server_name="0.0.0.0", server_port=7860) |