File size: 3,000 Bytes
71dfc01
 
d37ad38
 
 
71dfc01
 
 
 
 
 
 
 
d37ad38
71dfc01
d37ad38
71dfc01
b48183e
 
 
71dfc01
b48183e
 
71dfc01
b48183e
 
 
 
 
 
71dfc01
d37ad38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71dfc01
b48183e
71dfc01
b48183e
71dfc01
b48183e
 
 
d37ad38
71dfc01
d37ad38
71dfc01
 
d37ad38
 
b48183e
d37ad38
71dfc01
 
b48183e
71dfc01
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image as PILImage

# Load the image captioning model and tokenizer
caption_model_name = "Salesforce/blip-image-captioning-large"
caption_processor = BlipProcessor.from_pretrained(caption_model_name)
caption_model = BlipForConditionalGeneration.from_pretrained(caption_model_name)

# Load the emotion analysis model
emotion_model_name = "SamLowe/roberta-base-go_emotions"
emotion_classifier = pipeline("text-classification", model=emotion_model_name, return_all_scores=True)

def generate_caption_and_analyze_emotions(image=None, text=None):
    try:
        if image is not None:
            # Preprocess the image for caption generation
            caption_inputs = caption_processor(images=image, return_tensors="pt")

            # Generate caption using the caption model
            caption_ids = caption_model.generate(**caption_inputs)

            # Decode the output caption
            decoded_caption = caption_processor.decode(caption_ids[0], skip_special_tokens=True)
        else:
            decoded_caption = text
        
        # Perform emotion analysis on the generated caption or provided text
        results = emotion_classifier(decoded_caption)
        
        # Prepare data for visualization
        labels = [result['label'] for result in results[0]]
        scores = [result['score'] for result in results[0]]

        # Plot the emotion visualization
        plt.figure(figsize=(10, 5))
        plt.bar(labels, scores, color='skyblue')
        plt.xlabel('Emotions')
        plt.ylabel('Scores')
        plt.title('Emotion Analysis')
        plt.xticks(rotation=45)
        plt.tight_layout()
        
        # Save the plot as an image
        plt_path = "emotion_visualization.png"
        plt.savefig(plt_path)
        plt.close()

        # Load the saved image for Gradio
        vis_image = PILImage.open(plt_path)

        sentiment_label = results[0][0]['label']
        if sentiment_label == 'neutral':
            sentiment_text = "Sentiment of the text is"
        else:
            sentiment_text = "Sentiment of the text shows"

        caption_output = f"Caption: '{decoded_caption}'"
        sentiment_output = f"{sentiment_text} {sentiment_label}."
        
        return caption_output, sentiment_output, vis_image
    except Exception as e:
        return f"An error occurred: {e}", "", None

# Define the Gradio interface using the new API
image_input = gr.Image(label="Upload an image")
text_input = gr.Textbox(label="Or enter text", lines=2)

outputs = [gr.Textbox(label="Generated Caption"), gr.Textbox(label="Sentiment Analysis"), gr.Image(label="Emotion Visualization")]

# Create the Gradio app
app = gr.Interface(fn=generate_caption_and_analyze_emotions, inputs=[image_input, text_input], outputs=outputs)

# Launch the Gradio app
if __name__ == "__main__":
    app.launch()