Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| # Add the cloned nanoVLM directory to Python's system path | |
| NANOVLM_REPO_PATH = "/app/nanoVLM" | |
| if NANOVLM_REPO_PATH not in sys.path: | |
| sys.path.insert(0, NANOVLM_REPO_PATH) | |
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from transformers import CLIPImageProcessor, GPT2TokenizerFast | |
| try: | |
| from models.vision_language_model import VisionLanguageModel | |
| print("Successfully imported VisionLanguageModel from nanoVLM clone.") | |
| except ImportError as e: | |
| print(f"Error importing VisionLanguageModel from nanoVLM clone: {e}.") | |
| VisionLanguageModel = None | |
| device_choice = os.environ.get("DEVICE", "auto") | |
| if device_choice == "auto": | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| else: | |
| device = device_choice | |
| print(f"Using device: {device}") | |
| model_id_for_weights = "lusxvr/nanoVLM-222M" | |
| image_processor_id = "openai/clip-vit-base-patch32" | |
| tokenizer_id = "gpt2" | |
| image_processor = None | |
| tokenizer = None | |
| model = None | |
| if VisionLanguageModel: | |
| try: | |
| print(f"Attempting to load CLIPImageProcessor from: {image_processor_id}") | |
| image_processor = CLIPImageProcessor.from_pretrained(image_processor_id) # Removed trust_remote_code if not strictly needed by processor | |
| print("CLIPImageProcessor loaded.") | |
| print(f"Attempting to load GPT2TokenizerFast from: {tokenizer_id}") | |
| tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_id) # Removed trust_remote_code if not strictly needed by tokenizer | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Set tokenizer pad_token to eos_token.") | |
| print("GPT2TokenizerFast loaded.") | |
| print(f"Attempting to load model weights from {model_id_for_weights} using VisionLanguageModel.from_pretrained") | |
| model = VisionLanguageModel.from_pretrained(model_id_for_weights).to(device) | |
| print("Model loaded successfully.") | |
| model.eval() | |
| except Exception as e: | |
| print(f"Error loading model or processor components: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| image_processor = None; tokenizer = None; model = None | |
| else: | |
| print("Custom VisionLanguageModel class not imported, cannot load model.") | |
| def prepare_inputs(text_list, image_input, image_processor_instance, tokenizer_instance, device_to_use): | |
| if image_processor_instance is None or tokenizer_instance is None: | |
| raise ValueError("Image processor or tokenizer not initialized.") | |
| processed_image = image_processor_instance(images=image_input, return_tensors="pt").pixel_values.to(device_to_use) | |
| processed_text = tokenizer_instance( | |
| text=text_list, return_tensors="pt", padding=True, truncation=True, max_length=getattr(tokenizer_instance, 'model_max_length', 512) | |
| ) | |
| input_ids = processed_text.input_ids.to(device_to_use) | |
| attention_mask = processed_text.attention_mask.to(device_to_use) | |
| return {"pixel_values": processed_image, "input_ids": input_ids, "attention_mask": attention_mask} | |
| def generate_text_for_image(image_input, prompt_input): | |
| if model is None or image_processor is None or tokenizer is None: | |
| return "Error: Model or processor components not loaded correctly. Check logs." | |
| if image_input is None: return "Please upload an image." | |
| if not prompt_input: return "Please provide a prompt." | |
| try: | |
| if not isinstance(image_input, Image.Image): | |
| pil_image = Image.fromarray(image_input) | |
| else: | |
| pil_image = image_input | |
| if pil_image.mode != "RGB": pil_image = pil_image.convert("RGB") | |
| inputs = prepare_inputs( | |
| text_list=[prompt_input], image_input=pil_image, | |
| image_processor_instance=image_processor, tokenizer_instance=tokenizer, device_to_use=device | |
| ) | |
| generated_ids = model.generate( | |
| pixel_values=inputs['pixel_values'], input_ids=inputs['input_ids'], | |
| attention_mask=inputs['attention_mask'], max_new_tokens=150, num_beams=3, | |
| no_repeat_ngram_size=2, early_stopping=True, pad_token_id=tokenizer.pad_token_id | |
| ) | |
| generated_text_list = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
| generated_text = generated_text_list[0] if generated_text_list else "" | |
| if prompt_input and generated_text.startswith(prompt_input): | |
| cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:") | |
| else: | |
| cleaned_text = generated_text | |
| return cleaned_text.strip() | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| import traceback; traceback.print_exc() | |
| return f"An error occurred during text generation: {str(e)}" | |
| description = "Interactive demo for lusxvr/nanoVLM-222M." | |
| # example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # Not used for now | |
| print("Defining Gradio interface...") | |
| try: | |
| iface = gr.Interface( | |
| fn=generate_text_for_image, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image"), | |
| gr.Textbox(label="Your Prompt/Question") | |
| ], | |
| outputs=gr.Textbox(label="Generated Text", show_copy_button=True), | |
| title="Interactive nanoVLM-222M Demo", | |
| description=description, | |
| # examples=[ # <<<< REMOVED EXAMPLES | |
| # [example_image_url, "a photo of a"], | |
| # [example_image_url, "Describe the image in detail."], | |
| # ], | |
| allow_flagging="never" | |
| ) | |
| print("Gradio interface defined.") | |
| except Exception as e: | |
| print(f"Error defining Gradio interface: {e}") | |
| import traceback; traceback.print_exc() | |
| iface = None | |
| if __name__ == "__main__": | |
| if model is None or image_processor is None or tokenizer is None: | |
| print("CRITICAL: Model or processor components failed to load. Gradio might not work.") | |
| if iface is not None: | |
| print("Launching Gradio interface...") | |
| try: | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |
| except Exception as e: | |
| print(f"Error launching Gradio interface: {e}") | |
| import traceback; traceback.print_exc() | |
| # This is where the ValueError: When localhost is not accessible... usually comes from | |
| # if the underlying TypeError has already happened during iface setup. | |
| else: | |
| print("Gradio interface could not be defined due to earlier errors.") |