Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import torch | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
# Load the BLIP model and processor | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
# Define the function to generate a caption and validate it | |
def generate_caption(image): | |
# Process the image and generate a caption | |
inputs = processor(images=image, return_tensors="pt") | |
out = model.generate(**inputs) | |
caption = processor.decode(out[0], skip_special_tokens=True) | |
# Keywords related to dal | |
dal_keywords = ["dal", "dals", "lent", "pulses", "legume", "lents", "lentils", "beans", "bean", "peanuts", "peanut"] | |
# Check if any of the keywords are in the caption | |
if any(keyword in caption.lower() for keyword in dal_keywords): | |
return f"The image is related to dal" | |
else: | |
return None # Return None if the image is not related to dal | |
# Define the function for the Gradio interface | |
def captioning_interface(image): | |
# Generate caption and check if related to dal | |
result = generate_caption(image) | |
if result: | |
return image, result # If related to dal, show the image and the message | |
else: | |
raise gr.Error("This image is not related to dal. Please upload an image related to dal.") # Trigger error | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=captioning_interface, | |
inputs=gr.Image(type="pil", label="Upload Image", image_mode='RGB'), # Removed the 'tool' argument | |
outputs=[gr.Image(type="pil", label="Image Preview"), gr.Textbox(label="Description")], # Show image only if related to dal | |
title="Dal Detection", | |
description="Only images related to dal will be accepted!.", | |
allow_flagging="never" # Disable flagging for rejected images | |
) | |
# Launch the interface | |
interface.launch(share=True) | |