CaptionEmotion / app.py
HarshanaLF's picture
Update app.py
d37ad38 verified
import gradio as gr
from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image as PILImage
# Load the image captioning model and tokenizer
caption_model_name = "Salesforce/blip-image-captioning-large"
caption_processor = BlipProcessor.from_pretrained(caption_model_name)
caption_model = BlipForConditionalGeneration.from_pretrained(caption_model_name)
# Load the emotion analysis model
emotion_model_name = "SamLowe/roberta-base-go_emotions"
emotion_classifier = pipeline("text-classification", model=emotion_model_name, return_all_scores=True)
def generate_caption_and_analyze_emotions(image=None, text=None):
try:
if image is not None:
# Preprocess the image for caption generation
caption_inputs = caption_processor(images=image, return_tensors="pt")
# Generate caption using the caption model
caption_ids = caption_model.generate(**caption_inputs)
# Decode the output caption
decoded_caption = caption_processor.decode(caption_ids[0], skip_special_tokens=True)
else:
decoded_caption = text
# Perform emotion analysis on the generated caption or provided text
results = emotion_classifier(decoded_caption)
# Prepare data for visualization
labels = [result['label'] for result in results[0]]
scores = [result['score'] for result in results[0]]
# Plot the emotion visualization
plt.figure(figsize=(10, 5))
plt.bar(labels, scores, color='skyblue')
plt.xlabel('Emotions')
plt.ylabel('Scores')
plt.title('Emotion Analysis')
plt.xticks(rotation=45)
plt.tight_layout()
# Save the plot as an image
plt_path = "emotion_visualization.png"
plt.savefig(plt_path)
plt.close()
# Load the saved image for Gradio
vis_image = PILImage.open(plt_path)
sentiment_label = results[0][0]['label']
if sentiment_label == 'neutral':
sentiment_text = "Sentiment of the text is"
else:
sentiment_text = "Sentiment of the text shows"
caption_output = f"Caption: '{decoded_caption}'"
sentiment_output = f"{sentiment_text} {sentiment_label}."
return caption_output, sentiment_output, vis_image
except Exception as e:
return f"An error occurred: {e}", "", None
# Define the Gradio interface using the new API
image_input = gr.Image(label="Upload an image")
text_input = gr.Textbox(label="Or enter text", lines=2)
outputs = [gr.Textbox(label="Generated Caption"), gr.Textbox(label="Sentiment Analysis"), gr.Image(label="Emotion Visualization")]
# Create the Gradio app
app = gr.Interface(fn=generate_caption_and_analyze_emotions, inputs=[image_input, text_input], outputs=outputs)
# Launch the Gradio app
if __name__ == "__main__":
app.launch()