import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel, PeftConfig from PIL import Image import requests from io import BytesIO import torchvision.datasets as datasets import numpy as np # Load SigLIP for image embeddings from model.siglip import SigLIPModel def get_cifar_examples(): # Load CIFAR10 test set cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True) # CIFAR10 classes classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] # Get one example from each class examples = [] used_classes = set() for idx in range(len(cifar10_test)): img, label = cifar10_test[idx] if classes[label] not in used_classes: # Save the image temporarily img_path = f"examples/{classes[label]}_example.jpg" img.save(img_path) examples.append(img_path) used_classes.add(classes[label]) if len(used_classes) == 10: # We have one example from each class break return examples def load_models(): # Load SigLIP model siglip = SigLIPModel() # Load base Phi model base_model = AutoModelForCausalLM.from_pretrained( "microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True, device_map="auto", torch_dtype=torch.float32 ) # Load our fine-tuned LoRA adapter model = PeftModel.from_pretrained( base_model, "jatingocodeo/phi-vlm", # Your uploaded model device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained("jatingocodeo/phi-vlm") return siglip, model, tokenizer def generate_description(image, siglip, model, tokenizer): # Convert image to RGB if needed if image.mode != "RGB": image = image.convert("RGB") # Resize image to match SigLIP's expected size image = image.resize((32, 32)) # Get image embedding from SigLIP image_embedding = siglip.encode_image(image) # Prepare prompt prompt = """Below is an image. Please describe it in detail. Image: Description: """ # Tokenize input inputs = tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=128 ).to(model.device) # Generate description with torch.no_grad(): outputs = model( **inputs, image_embeddings=image_embedding.unsqueeze(0), max_new_tokens=100, temperature=0.7, do_sample=True, top_p=0.9 ) # Decode and return the generated text generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text.split("Description: ")[-1].strip() # Load models print("Loading models...") siglip, model, tokenizer = load_models() # Create Gradio interface def process_image(image): description = generate_description(image, siglip, model, tokenizer) return description # Get CIFAR10 examples examples = get_cifar_examples() # Define interface iface = gr.Interface( fn=process_image, inputs=gr.Image(type="pil"), outputs=gr.Textbox(label="Generated Description"), title="Image Description Generator", description="""Upload an image and get a detailed description generated by our fine-tuned VLM model. Below are sample images from CIFAR10 dataset that you can try.""", examples=[[ex] for ex in examples] # Format examples for Gradio ) # Launch the interface if __name__ == "__main__": iface.launch()