CCockrum commited on
Commit
7e6badf
Β·
verified Β·
1 Parent(s): fb39f22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -64
app.py CHANGED
@@ -20,6 +20,9 @@ if "response_ready" not in st.session_state:
20
  if "follow_up" not in st.session_state:
21
  st.session_state.follow_up = "" # Stores follow-up question
22
 
 
 
 
23
  # --- Set Up Model & API Functions ---
24
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
25
 
@@ -48,9 +51,17 @@ def analyze_sentiment(user_text):
48
  return result['label']
49
 
50
  def predict_action(user_text):
 
 
 
51
  if "NASA" in user_text or "space" in user_text:
52
  return "nasa_info"
53
- return "general_query"
 
 
 
 
 
54
 
55
  def generate_follow_up(user_text):
56
  """
@@ -69,7 +80,7 @@ def generate_follow_up(user_text):
69
 
70
  def get_response(system_message, chat_history, user_text, max_new_tokens=256):
71
  """
72
- Generates HAL's response, making it more conversational and engaging.
73
  """
74
  sentiment = analyze_sentiment(user_text)
75
  action = predict_action(user_text)
@@ -77,38 +88,43 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
77
  if action == "nasa_info":
78
  nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
79
  response = f"**{nasa_title}**\n\n{nasa_explanation}"
80
- chat_history.append({'role': 'user', 'content': user_text})
81
  chat_history.append({'role': 'assistant', 'content': response})
82
-
83
  follow_up = generate_follow_up(user_text)
84
- chat_history.append({'role': 'assistant', 'content': follow_up})
85
  return response, follow_up, chat_history, nasa_url
86
 
87
  hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.9)
88
 
89
  prompt = PromptTemplate.from_template(
90
- (
91
- "[INST] {system_message}"
92
- "\nCurrent Conversation:\n{chat_history}\n\n"
93
- "\nUser: {user_text}.\n [/INST]"
94
- "\nAI: Keep responses conversational and engaging. Start with a friendly phrase like "
95
- "'Certainly!', 'Of course!', or 'Great question!' before answering."
96
- " Keep responses concise but engaging."
97
- "\nHAL:"
98
- )
99
  )
100
 
101
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
102
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
103
  response = response.split("HAL:")[-1].strip()
104
 
105
- chat_history.append({'role': 'user', 'content': user_text})
 
 
 
106
  chat_history.append({'role': 'assistant', 'content': response})
107
 
108
- follow_up = generate_follow_up(user_text)
109
- chat_history.append({'role': 'assistant', 'content': follow_up})
 
 
 
 
 
 
 
 
110
 
111
- return response, follow_up, chat_history, None
112
 
113
  # --- Chat UI ---
114
  st.title("πŸš€ HAL - Your NASA AI Assistant")
@@ -120,41 +136,9 @@ if st.sidebar.button("Reset Chat"):
120
  st.session_state.response_ready = False
121
  st.session_state.follow_up = ""
122
  st.session_state.last_topic = ""
123
- st.rerun() # βœ… Correct method to reset the app in newer Streamlit versions
124
-
125
- # Custom Chat Styling
126
- st.markdown("""
127
- <style>
128
- .user-msg {
129
- background-color: #0078D7;
130
- color: white;
131
- padding: 10px;
132
- border-radius: 10px;
133
- margin-bottom: 5px;
134
- width: fit-content;
135
- max-width: 80%;
136
- }
137
- .assistant-msg {
138
- background-color: #333333;
139
- color: white;
140
- padding: 10px;
141
- border-radius: 10px;
142
- margin-bottom: 5px;
143
- width: fit-content;
144
- max-width: 80%;
145
- }
146
- .container {
147
- display: flex;
148
- flex-direction: column;
149
- align-items: flex-start;
150
- }
151
- @media (max-width: 600px) {
152
- .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; }
153
- }
154
- </style>
155
- """, unsafe_allow_html=True)
156
-
157
- # --- Chat History Display (Ensures All Messages Are Visible) ---
158
  st.markdown("<div class='container'>", unsafe_allow_html=True)
159
 
160
  for message in st.session_state.chat_history:
@@ -166,11 +150,12 @@ for message in st.session_state.chat_history:
166
  st.markdown("</div>", unsafe_allow_html=True)
167
 
168
  # --- Single Input Box for Both Initial and Follow-Up Messages ---
169
- user_input = st.chat_input("Type your message here...") # Uses Enter to submit
170
 
171
  if user_input:
172
- # Save user message in chat history
173
- st.session_state.chat_history.append({'role': 'user', 'content': user_input})
 
174
 
175
  # Generate HAL's response
176
  response, follow_up, st.session_state.chat_history, image_url = get_response(
@@ -180,21 +165,14 @@ if user_input:
180
  )
181
 
182
  st.session_state.chat_history.append({'role': 'assistant', 'content': response})
183
-
184
  st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {response}</div>", unsafe_allow_html=True)
185
 
186
  if image_url:
187
  st.image(image_url, caption="NASA Image of the Day")
188
 
189
  st.session_state.follow_up = follow_up
190
- st.session_state.response_ready = True # Enables follow-up response cycle
191
 
192
  if st.session_state.response_ready and st.session_state.follow_up:
193
- # βœ… Save HAL's follow-up question in chat history
194
- st.session_state.chat_history.append({'role': 'assistant', 'content': st.session_state.follow_up})
195
-
196
- # βœ… Display the follow-up question
197
  st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {st.session_state.follow_up}</div>", unsafe_allow_html=True)
198
-
199
- # βœ… Reset response state so user can type next input
200
  st.session_state.response_ready = False
 
20
  if "follow_up" not in st.session_state:
21
  st.session_state.follow_up = "" # Stores follow-up question
22
 
23
+ if "last_topic" not in st.session_state:
24
+ st.session_state.last_topic = "" # Tracks last discussed topic
25
+
26
  # --- Set Up Model & API Functions ---
27
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
28
 
 
51
  return result['label']
52
 
53
  def predict_action(user_text):
54
+ """
55
+ Determines the topic of the user's message.
56
+ """
57
  if "NASA" in user_text or "space" in user_text:
58
  return "nasa_info"
59
+ elif "quark" in user_text or "physics" in user_text or "quantum" in user_text:
60
+ return "physics"
61
+ elif "AI" in user_text or "machine learning" in user_text:
62
+ return "AI"
63
+ else:
64
+ return "general_query"
65
 
66
  def generate_follow_up(user_text):
67
  """
 
80
 
81
  def get_response(system_message, chat_history, user_text, max_new_tokens=256):
82
  """
83
+ Generates HAL's response and follow-up, ensuring no duplicate queries or misplaced follow-ups.
84
  """
85
  sentiment = analyze_sentiment(user_text)
86
  action = predict_action(user_text)
 
88
  if action == "nasa_info":
89
  nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
90
  response = f"**{nasa_title}**\n\n{nasa_explanation}"
 
91
  chat_history.append({'role': 'assistant', 'content': response})
 
92
  follow_up = generate_follow_up(user_text)
 
93
  return response, follow_up, chat_history, nasa_url
94
 
95
  hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.9)
96
 
97
  prompt = PromptTemplate.from_template(
98
+ "[INST] {system_message}\n\n"
99
+ "Current Conversation:\n{chat_history}\n\n"
100
+ "User: {user_text}.\n [/INST]\n"
101
+ "AI: Keep responses conversational and engaging. Start with a friendly phrase like "
102
+ "'Certainly!', 'Of course!', or 'Great question!' before answering."
103
+ " Keep responses concise but engaging.\nHAL:"
 
 
 
104
  )
105
 
106
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
107
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
108
  response = response.split("HAL:")[-1].strip()
109
 
110
+ # βœ… Avoid duplicate user messages in history
111
+ if not chat_history or chat_history[-1]["content"] != user_text:
112
+ chat_history.append({'role': 'user', 'content': user_text})
113
+
114
  chat_history.append({'role': 'assistant', 'content': response})
115
 
116
+ # βœ… Avoid repeating follow-ups when topic changes
117
+ current_topic = action
118
+ if current_topic != st.session_state.last_topic:
119
+ st.session_state.follow_up = ""
120
+ else:
121
+ follow_up = generate_follow_up(user_text)
122
+ chat_history.append({'role': 'assistant', 'content': follow_up})
123
+ st.session_state.follow_up = follow_up
124
+
125
+ st.session_state.last_topic = current_topic
126
 
127
+ return response, st.session_state.follow_up, chat_history, None
128
 
129
  # --- Chat UI ---
130
  st.title("πŸš€ HAL - Your NASA AI Assistant")
 
136
  st.session_state.response_ready = False
137
  st.session_state.follow_up = ""
138
  st.session_state.last_topic = ""
139
+ st.rerun()
140
+
141
+ # --- Chat History Display ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  st.markdown("<div class='container'>", unsafe_allow_html=True)
143
 
144
  for message in st.session_state.chat_history:
 
150
  st.markdown("</div>", unsafe_allow_html=True)
151
 
152
  # --- Single Input Box for Both Initial and Follow-Up Messages ---
153
+ user_input = st.chat_input("Type your message here...")
154
 
155
  if user_input:
156
+ # βœ… Prevent duplicate user messages
157
+ if not st.session_state.chat_history or st.session_state.chat_history[-1]["content"] != user_input:
158
+ st.session_state.chat_history.append({'role': 'user', 'content': user_input})
159
 
160
  # Generate HAL's response
161
  response, follow_up, st.session_state.chat_history, image_url = get_response(
 
165
  )
166
 
167
  st.session_state.chat_history.append({'role': 'assistant', 'content': response})
 
168
  st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {response}</div>", unsafe_allow_html=True)
169
 
170
  if image_url:
171
  st.image(image_url, caption="NASA Image of the Day")
172
 
173
  st.session_state.follow_up = follow_up
174
+ st.session_state.response_ready = True
175
 
176
  if st.session_state.response_ready and st.session_state.follow_up:
 
 
 
 
177
  st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {st.session_state.follow_up}</div>", unsafe_allow_html=True)
 
 
178
  st.session_state.response_ready = False