|
import gradio as gr |
|
import torch |
|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
|
|
def process_input(image, text=""): |
|
"""Process image and optional text input to generate description""" |
|
try: |
|
|
|
if isinstance(image, np.ndarray): |
|
pil_image = Image.fromarray(image) |
|
else: |
|
return "Please provide a valid image" |
|
|
|
|
|
conditional_text = text if text else "a video of" |
|
|
|
|
|
inputs = processor( |
|
pil_image, |
|
text=conditional_text, |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
|
|
output = model.generate( |
|
**inputs, |
|
max_new_tokens=100, |
|
num_beams=5, |
|
length_penalty=1.0, |
|
repetition_penalty=1.5 |
|
) |
|
|
|
|
|
result = processor.decode(output[0], skip_special_tokens=True) |
|
|
|
return result.strip() |
|
|
|
except Exception as e: |
|
return f"Error processing input: {str(e)}" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=process_input, |
|
inputs=[ |
|
gr.Image(type="numpy", label="Upload Image"), |
|
gr.Textbox( |
|
label="Prompt (Optional)", |
|
placeholder="Guide the description or leave empty for automatic caption", |
|
lines=2 |
|
), |
|
], |
|
outputs=gr.Textbox(label="Generated Description", lines=6), |
|
title="Scene Description Generator", |
|
description="Upload an image and optionally add a prompt to guide the description. Created by <a href='https://justlab.ai'>Justlab.ai</a>", |
|
|
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |