dashakoryakovskaya commited on
Commit
47c9424
Β·
verified Β·
1 Parent(s): 75c019d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import plotly.express as px
3
+ import pandas as pd
4
+ import logging
5
+ import whisper
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import pandas as pd
11
+ from torch.nn.functional import silu
12
+ from torch.nn.functional import softplus
13
+ from einops import rearrange, repeat, einsum
14
+ from transformers import AutoTokenizer, AutoModel
15
+ from torch import Tensor
16
+ from einops import rearrange
17
+
18
+ from model import Mamba
19
+
20
+ logging.basicConfig(level=logging.INFO)
21
+
22
+ def plotly_plot_text(text):
23
+ data = pd.DataFrame()
24
+ data['Emotion'] = ['😠 anger', '🀒 disgust', '😨 fear', 'πŸ˜„ joy/happiness', '😐 neutral', '😒 sadness', '😲 surprise/enthusiasm']
25
+ data['Probability'] = model.predict_proba([text])[0].tolist()
26
+ p = px.bar(data, x='Emotion', y='Probability', color="Probability")
27
+ return (
28
+ p,
29
+ f"πŸ—£οΈ Transcription:\n{text}",
30
+ f"## πŸ† Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}"
31
+ )
32
+
33
+ def transcribe_audio(audio_path):
34
+ whisper_model = whisper.load_model("base")
35
+ try:
36
+ result = whisper_model.transcribe(audio_path, fp16=False)
37
+ return result.get('text', '')
38
+ except Exception as e:
39
+ logging.error(f"Transcription failed: {e}")
40
+ return ""
41
+
42
+ def plotly_plot_audio(audio_path):
43
+ data = pd.DataFrame()
44
+ data['Emotion'] = ['😠 anger', '🀒 disgust', '😨 fear', 'πŸ˜„ joy/happiness', '😐 neutral', '😒 sadness', '😲 surprise/enthusiasm']
45
+ try:
46
+ text = transcribe_audio(audio_path)
47
+ data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0]
48
+ p = px.bar(data, x='Emotion', y='Probability', color="Probability")
49
+ return (
50
+ p,
51
+ f"πŸ—£οΈ Transcription:\n{text}",
52
+ f"## πŸ† Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}"
53
+ )
54
+
55
+ except Exception as e:
56
+ logging.error(f"Processing failed: {e}")
57
+ data['Probability'] = [0] * data.shape[0]
58
+ p = px.bar(data, x='Emotion', y='Probability', color="Probability")
59
+ return (
60
+ p,
61
+ "❌ Error processing audio",
62
+ "⚠️ Processing Error"
63
+ )
64
+
65
+ def create_demo():
66
+ with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection") as demo:
67
+ gr.Markdown("# Text-based bilingual emotion recognition")
68
+
69
+ with gr.Row():
70
+ with gr.Column():
71
+ audio_input = gr.Audio(
72
+ sources=["upload", "microphone"],
73
+ type="filepath",
74
+ label="Record or Upload Audio",
75
+ format="wav",
76
+ interactive=True
77
+ )
78
+ with gr.Column():
79
+ text_input = gr.Text(label="Write Text")
80
+
81
+ with gr.Row():
82
+ top_emotion = gr.Markdown("## πŸ† Dominant Emotion: Waiting for input ...",
83
+ elem_classes="dominant-emotion")
84
+
85
+ with gr.Row():
86
+ text_plot = gr.Plot(label="Text Analysis")
87
+
88
+ transcription = gr.Textbox(
89
+ label="πŸ“œ Transcription Results",
90
+ placeholder="Transcribed text will appear here...",
91
+ lines=3,
92
+ max_lines=6
93
+ )
94
+
95
+ if text_input is not None:
96
+ text_input.change(fn=plotly_plot_text, inputs=text_input, outputs=[text_plot, transcription, top_emotion])
97
+ elif audio_input is not None:
98
+ audio_input.change(fn=plotly_plot_audio, inputs=audio_input, outputs=[text_plot, transcription, top_emotion])
99
+ return demo
100
+
101
+
102
+ if __name__ == "__main__":
103
+ model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name='jina', pooling=None).to(device)
104
+ checkpoint = torch.load("Mamba_jina_checkpoint.pth"), map_location=torch.device('cpu')
105
+ model.load_state_dict(checkpoint['model_state_dict'])
106
+ demo = create_demo()
107
+ demo.launch()