import gradio as gr from PIL import Image import torch from transformers import BlipProcessor, BlipForConditionalGeneration # Load the BLIP model and processor processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") # Define the function to generate a caption and validate it def generate_caption(image): # Process the image and generate a caption inputs = processor(images=image, return_tensors="pt") out = model.generate(**inputs) caption = processor.decode(out[0], skip_special_tokens=True) # Keywords related to dal dal_keywords = ["dal", "dals", "lent", "pulses", "legume", "lents", "lentils", "beans", "bean", "peanuts", "peanut"] # Check if any of the keywords are in the caption if any(keyword in caption.lower() for keyword in dal_keywords): return f"The image is related to dal" else: return None # Return None if the image is not related to dal # Define the function for the Gradio interface def captioning_interface(image): # Generate caption and check if related to dal result = generate_caption(image) if result: return image, result # If related to dal, show the image and the message else: raise gr.Error("This image is not related to dal. Please upload an image related to dal.") # Trigger error # Create the Gradio interface interface = gr.Interface( fn=captioning_interface, inputs=gr.Image(type="pil", label="Upload Image", image_mode='RGB'), # Removed the 'tool' argument outputs=[gr.Image(type="pil", label="Image Preview"), gr.Textbox(label="Description")], # Show image only if related to dal title="Dal Detection", description="Only images related to dal will be accepted!.", allow_flagging="never" # Disable flagging for rejected images ) # Launch the interface interface.launch(share=True)