THeaxxxxxxxx commited on
Commit
07498dc
·
verified ·
1 Parent(s): 51221ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -12
app.py CHANGED
@@ -2,12 +2,14 @@ import streamlit as st
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
 
5
  st.set_page_config(
6
  page_title="Review Assistant",
7
  page_icon="📝",
8
  layout="centered"
9
  )
10
 
 
11
  st.markdown("""
12
  <style>
13
  .main-header {
@@ -46,6 +48,7 @@ st.markdown("""
46
  font-weight: bold;
47
  margin-bottom: 0.5rem;
48
  }
 
49
  .stButton>button {
50
  background-color: #2563EB;
51
  color: white;
@@ -53,10 +56,17 @@ st.markdown("""
53
  padding: 0.5rem 2rem;
54
  border-radius: 6px;
55
  font-weight: 500;
 
 
 
56
  }
57
  .stButton>button:hover {
58
  background-color: #1D4ED8;
59
  }
 
 
 
 
60
  .footer {
61
  text-align: center;
62
  color: #9CA3AF;
@@ -69,12 +79,21 @@ st.markdown("""
69
  </style>
70
  """, unsafe_allow_html=True)
71
 
 
72
  st.markdown("<h1 class='main-header'>Smart Review Analysis Assistant</h1>", unsafe_allow_html=True)
73
  st.markdown("<p class='sub-header'>Topic Recognition, Sentiment Analysis, and Auto Reply in One Click</p>", unsafe_allow_html=True)
74
 
75
- # ------- Load Pipelines -------
76
  @st.cache_resource
77
  def load_pipelines():
 
 
 
 
 
 
 
 
78
  # Topic Classification Model (Zero-shot classification)
79
  topic_labels = [
80
  "billing", "account access", "customer service", "loans",
@@ -82,29 +101,33 @@ def load_pipelines():
82
  "branch service", "transaction delay", "account closure", "information error"
83
  ]
84
 
85
-
86
- dtype = torch.float32
87
 
 
88
  topic_classifier = pipeline(
89
  "zero-shot-classification",
90
  model="MoritzLaurer/deberta-v3-base-zeroshot-v1",
91
  )
92
 
93
- # Sentiment Analysis Model
94
  sentiment_classifier = pipeline(
95
  "sentiment-analysis",
96
  model="cardiffnlp/twitter-roberta-base-sentiment-latest",
97
  )
98
 
99
- # Reply Generation Model
100
  model_name = "Leo66277/finetuned-tinyllama-customer-replies"
101
  tokenizer = AutoTokenizer.from_pretrained(model_name)
102
  model = AutoModelForCausalLM.from_pretrained(model_name)
103
 
 
104
  def generate_reply(text):
 
105
  prompt_text = f"Please write a short, polite English customer service reply to the following customer comment:\n{text}"
106
  inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512)
107
 
 
108
  with torch.no_grad():
109
  gen_ids = model.generate(
110
  inputs.input_ids,
@@ -115,47 +138,65 @@ def load_pipelines():
115
  early_stopping=True
116
  )
117
 
 
118
  reply = tokenizer.decode(gen_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
119
  reply = reply.strip('"').replace('\n', ' ').replace(' ', ' ')
120
  return reply
121
 
122
  return topic_classifier, sentiment_classifier, generate_reply, topic_labels
123
 
124
-
125
  st.markdown("### Enter a review for instant analysis")
126
- example_review = "The people at the call center are inexperienced and lack proper training. I had to call multiple times to resolve a simple issue."
127
 
 
 
 
 
128
  user_review = st.text_area(
129
  "Please enter or paste a review below:",
130
  value=example_review,
131
  height=120
132
  )
133
 
134
- if st.button("Analyze"):
 
 
 
 
 
 
 
 
135
  if not user_review.strip():
 
136
  st.warning("Please enter a valid review!")
137
  else:
 
138
  with st.spinner("Analyzing your review..."):
 
139
  if "topic_pipe" not in st.session_state:
140
  st.session_state.topic_pipe, st.session_state.sentiment_pipe, st.session_state.reply_generator, st.session_state.topic_labels = load_pipelines()
141
 
142
- # Topic Classification
143
  topic_result = st.session_state.topic_pipe(user_review, st.session_state.topic_labels, multi_label=False)
144
  topic = topic_result['labels'][0]
145
 
146
- # Sentiment Analysis
147
  sentiment_result = st.session_state.sentiment_pipe(user_review)
148
  sentiment = sentiment_result[0]['label']
149
 
150
- # Auto Reply Generation
151
  reply_text = st.session_state.reply_generator(user_review)
152
 
 
153
  col1, col2 = st.columns(2)
154
  with col1:
155
  st.markdown(f"<div class='result-card topic-card'><p class='result-label'>Topic:</p>{topic}</div>", unsafe_allow_html=True)
156
  with col2:
157
  st.markdown(f"<div class='result-card sentiment-card'><p class='result-label'>Sentiment:</p>{sentiment}</div>", unsafe_allow_html=True)
158
 
 
159
  st.markdown(f"<div class='result-card reply-card'><p class='result-label'>Auto-reply Suggestion:</p>{reply_text}</div>", unsafe_allow_html=True)
160
 
161
- st.markdown("<div class='footer'>© 2024 Review AI Assistant</div>", unsafe_allow_html=True)
 
 
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Set page configuration
6
  st.set_page_config(
7
  page_title="Review Assistant",
8
  page_icon="📝",
9
  layout="centered"
10
  )
11
 
12
+ # Custom page styling with CSS
13
  st.markdown("""
14
  <style>
15
  .main-header {
 
48
  font-weight: bold;
49
  margin-bottom: 0.5rem;
50
  }
51
+ /* Updated button styling with icon */
52
  .stButton>button {
53
  background-color: #2563EB;
54
  color: white;
 
56
  padding: 0.5rem 2rem;
57
  border-radius: 6px;
58
  font-weight: 500;
59
+ display: inline-flex;
60
+ align-items: center;
61
+ justify-content: center;
62
  }
63
  .stButton>button:hover {
64
  background-color: #1D4ED8;
65
  }
66
+ /* Button icon styling */
67
+ .button-icon {
68
+ margin-right: 8px;
69
+ }
70
  .footer {
71
  text-align: center;
72
  color: #9CA3AF;
 
79
  </style>
80
  """, unsafe_allow_html=True)
81
 
82
+ # Main page header
83
  st.markdown("<h1 class='main-header'>Smart Review Analysis Assistant</h1>", unsafe_allow_html=True)
84
  st.markdown("<p class='sub-header'>Topic Recognition, Sentiment Analysis, and Auto Reply in One Click</p>", unsafe_allow_html=True)
85
 
86
+ # Function to load ML pipelines - cached to avoid reloading
87
  @st.cache_resource
88
  def load_pipelines():
89
+ """
90
+ Load all three machine learning pipelines:
91
+ 1. Topic classifier using zero-shot classification
92
+ 2. Sentiment analysis model
93
+ 3. Reply generator for customer service responses
94
+
95
+ Returns the models and topic labels for use in the app
96
+ """
97
  # Topic Classification Model (Zero-shot classification)
98
  topic_labels = [
99
  "billing", "account access", "customer service", "loans",
 
101
  "branch service", "transaction delay", "account closure", "information error"
102
  ]
103
 
104
+ # Use float32 for better CPU compatibility
105
+ dtype = torch.float32
106
 
107
+ # Load topic classification model
108
  topic_classifier = pipeline(
109
  "zero-shot-classification",
110
  model="MoritzLaurer/deberta-v3-base-zeroshot-v1",
111
  )
112
 
113
+ # Load sentiment analysis model
114
  sentiment_classifier = pipeline(
115
  "sentiment-analysis",
116
  model="cardiffnlp/twitter-roberta-base-sentiment-latest",
117
  )
118
 
119
+ # Load reply generation model
120
  model_name = "Leo66277/finetuned-tinyllama-customer-replies"
121
  tokenizer = AutoTokenizer.from_pretrained(model_name)
122
  model = AutoModelForCausalLM.from_pretrained(model_name)
123
 
124
+ # Function to generate customer service replies
125
  def generate_reply(text):
126
+ """Generate a customer service reply based on the input text"""
127
  prompt_text = f"Please write a short, polite English customer service reply to the following customer comment:\n{text}"
128
  inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512)
129
 
130
+ # Generate response with beam search for better quality
131
  with torch.no_grad():
132
  gen_ids = model.generate(
133
  inputs.input_ids,
 
138
  early_stopping=True
139
  )
140
 
141
+ # Clean up generated text
142
  reply = tokenizer.decode(gen_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
143
  reply = reply.strip('"').replace('\n', ' ').replace(' ', ' ')
144
  return reply
145
 
146
  return topic_classifier, sentiment_classifier, generate_reply, topic_labels
147
 
148
+ # Page layout and user input section
149
  st.markdown("### Enter a review for instant analysis")
 
150
 
151
+ # Updated example review as requested
152
+ example_review = "BOA states on their website that holds are 2-7 days. I made a deposit, and the receipt states funds would be available in 2 days. Now, 13 days later, I am still waiting on my funds, and BOA can't give me a straight answer."
153
+
154
+ # Text input area for user reviews
155
  user_review = st.text_area(
156
  "Please enter or paste a review below:",
157
  value=example_review,
158
  height=120
159
  )
160
 
161
+ # Custom button with icon
162
+ analyze_button = st.markdown("""
163
+ <button class="stButton primaryButton">
164
+ <span class="button-icon">📊</span> Analyze
165
+ </button>
166
+ """, unsafe_allow_html=True)
167
+
168
+ # Check if button is clicked (using the regular button for functionality)
169
+ if st.button("Analyze", key="hidden_button", help="Click to analyze the review"):
170
  if not user_review.strip():
171
+ # Validation check
172
  st.warning("Please enter a valid review!")
173
  else:
174
+ # Show loading spinner during analysis
175
  with st.spinner("Analyzing your review..."):
176
+ # Load models if not already loaded
177
  if "topic_pipe" not in st.session_state:
178
  st.session_state.topic_pipe, st.session_state.sentiment_pipe, st.session_state.reply_generator, st.session_state.topic_labels = load_pipelines()
179
 
180
+ # Perform topic classification
181
  topic_result = st.session_state.topic_pipe(user_review, st.session_state.topic_labels, multi_label=False)
182
  topic = topic_result['labels'][0]
183
 
184
+ # Perform sentiment analysis
185
  sentiment_result = st.session_state.sentiment_pipe(user_review)
186
  sentiment = sentiment_result[0]['label']
187
 
188
+ # Generate auto-reply
189
  reply_text = st.session_state.reply_generator(user_review)
190
 
191
+ # Display results in a two-column layout
192
  col1, col2 = st.columns(2)
193
  with col1:
194
  st.markdown(f"<div class='result-card topic-card'><p class='result-label'>Topic:</p>{topic}</div>", unsafe_allow_html=True)
195
  with col2:
196
  st.markdown(f"<div class='result-card sentiment-card'><p class='result-label'>Sentiment:</p>{sentiment}</div>", unsafe_allow_html=True)
197
 
198
+ # Display auto-reply suggestion
199
  st.markdown(f"<div class='result-card reply-card'><p class='result-label'>Auto-reply Suggestion:</p>{reply_text}</div>", unsafe_allow_html=True)
200
 
201
+ # Page footer
202
+ st.markdown("<div class='footer'>© 2025 Review AI Assistant</div>", unsafe_allow_html=True)