CCockrum commited on
Commit
fb983b2
·
verified ·
1 Parent(s): 6eb97b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -65
app.py CHANGED
@@ -1,24 +1,28 @@
1
  import os
2
  import re
3
  import random
 
4
  import requests
5
  import streamlit as st
 
6
  from langchain_huggingface import HuggingFaceEndpoint
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain_core.output_parsers import StrOutputParser
9
  from transformers import pipeline
10
 
11
- # Use environment variables for keys
12
- HF_TOKEN = os.getenv("HF_TOKEN")
13
- if HF_TOKEN is None:
14
- raise ValueError("HF_TOKEN environment variable not set. Please set it in your Hugging Face Space settings.")
15
 
16
- NASA_API_KEY = os.getenv("NASA_API_KEY")
17
- if NASA_API_KEY is None:
18
- raise ValueError("NASA_API_KEY environment variable not set. Please set it in your Hugging Face Space settings.")
 
 
 
 
 
19
 
20
- # Set up Streamlit UI
21
- st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="🚀")
22
 
23
  # --- Initialize Session State Variables ---
24
  if "chat_history" not in st.session_state:
@@ -28,6 +32,38 @@ if "response_ready" not in st.session_state:
28
  if "follow_up" not in st.session_state:
29
  st.session_state.follow_up = ""
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # --- Set Up Model & API Functions ---
32
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
33
  sentiment_analyzer = pipeline(
@@ -41,12 +77,12 @@ def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.7)
41
  repo_id=model_id,
42
  max_new_tokens=max_new_tokens,
43
  temperature=temperature,
44
- token=HF_TOKEN,
45
  task="text-generation"
46
  )
47
 
48
  def get_nasa_apod():
49
- url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
50
  response = requests.get(url)
51
  if response.status_code == 200:
52
  data = response.json()
@@ -59,19 +95,24 @@ def analyze_sentiment(user_text):
59
  return result['label']
60
 
61
  def predict_action(user_text):
62
- if "NASA" in user_text or "space" in user_text:
63
  return "nasa_info"
64
  return "general_query"
65
 
 
 
 
 
 
 
66
  def generate_follow_up(user_text):
67
  """
68
  Generates two variant follow-up questions and randomly selects one.
69
- It also cleans up any unwanted quotation marks or extra meta commentary.
70
  """
71
  prompt_text = (
72
- f"Based on the user's question: '{user_text}', generate two concise, friendly follow-up questions "
73
- "that invite further discussion. For example, one might be 'Would you like to know more about the six types of quarks?' "
74
- "and another might be 'Would you like to explore another aspect of quantum physics?' Do not include extra commentary."
75
  )
76
  hf = get_llm_hf_inference(max_new_tokens=80, temperature=0.9)
77
  output = hf.invoke(input=prompt_text).strip()
@@ -83,14 +124,17 @@ def generate_follow_up(user_text):
83
 
84
  def get_response(system_message, chat_history, user_text, max_new_tokens=1024):
85
  """
86
- Generates HAL's answer with depth and a follow-up question.
87
- The prompt instructs the model to provide a detailed explanation and then generate a follow-up.
88
- If the answer comes back empty, a fallback answer is used.
89
  """
90
  sentiment = analyze_sentiment(user_text)
91
  action = predict_action(user_text)
92
 
93
- # Extract style instruction if present
 
 
 
 
94
  style_instruction = ""
95
  lower_text = user_text.lower()
96
  if "in the voice of" in lower_text or "speaking as" in lower_text:
@@ -99,6 +143,9 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=1024):
99
  style_instruction = match.group(2).strip().capitalize()
100
  style_instruction = f" Please respond in the voice of {style_instruction}."
101
 
 
 
 
102
  if action == "nasa_info":
103
  nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
104
  response = f"**{nasa_title}**\n\n{nasa_explanation}"
@@ -117,23 +164,25 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=1024):
117
 
118
  style_clause = style_instruction if style_instruction else ""
119
 
120
- # Instruct the model to generate a detailed, in-depth answer.
121
  prompt = PromptTemplate.from_template(
122
  (
123
  "[INST] {system_message}\n\nCurrent Conversation:\n{chat_history}\n\n"
124
  "User: {user_text}.\n [/INST]\n"
125
- "AI: Please provide a detailed explanation in depth. "
126
- "Ensure your response covers the topic thoroughly and is written in a friendly, conversational style, "
127
- "starting with a phrase like 'Certainly!', 'Of course!', or 'Great question!'." + style_clause +
128
  "\nHAL:"
129
  )
130
  )
131
 
 
 
 
132
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
133
- response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
134
- # Remove any extra markers if present.
 
135
 
136
- # Fallback in case the generated answer is empty
137
  if not response:
138
  response = "Certainly, here is an in-depth explanation: [Fallback explanation]."
139
 
@@ -147,6 +196,8 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=1024):
147
  follow_up = generate_follow_up(user_text)
148
  chat_history.append({'role': 'assistant', 'content': follow_up})
149
 
 
 
150
  return response, follow_up, chat_history, None
151
 
152
  # --- Chat UI ---
@@ -159,36 +210,13 @@ if st.sidebar.button("Reset Chat"):
159
  st.session_state.follow_up = ""
160
  st.experimental_rerun()
161
 
162
- st.markdown("""
163
- <style>
164
- .user-msg {
165
- background-color: #696969;
166
- color: white;
167
- padding: 10px;
168
- border-radius: 10px;
169
- margin-bottom: 5px;
170
- width: fit-content;
171
- max-width: 80%;
172
- }
173
- .assistant-msg {
174
- background-color: #333333;
175
- color: white;
176
- padding: 10px;
177
- border-radius: 10px;
178
- margin-bottom: 5px;
179
- width: fit-content;
180
- max-width: 80%;
181
- }
182
- .container {
183
- display: flex;
184
- flex-direction: column;
185
- align-items: flex-start;
186
- }
187
- @media (max-width: 600px) {
188
- .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; }
189
- }
190
- </style>
191
- """, unsafe_allow_html=True)
192
 
193
  user_input = st.chat_input("Type your message here...")
194
 
@@ -202,11 +230,3 @@ if user_input:
202
  st.image(image_url, caption="NASA Image of the Day")
203
  st.session_state.follow_up = follow_up
204
  st.session_state.response_ready = True
205
-
206
- st.markdown("<div class='container'>", unsafe_allow_html=True)
207
- for message in st.session_state.chat_history:
208
- if message["role"] == "user":
209
- st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
210
- else:
211
- st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
212
- st.markdown("</div>", unsafe_allow_html=True)
 
1
  import os
2
  import re
3
  import random
4
+ import subprocess
5
  import requests
6
  import streamlit as st
7
+ import spacy # for additional NLP processing
8
  from langchain_huggingface import HuggingFaceEndpoint
9
  from langchain_core.prompts import PromptTemplate
10
  from langchain_core.output_parsers import StrOutputParser
11
  from transformers import pipeline
12
 
13
+ # Must be the first Streamlit command!
14
+ st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="🚀")
 
 
15
 
16
+ # --- Helper to load spaCy model with fallback ---
17
+ def load_spacy_model():
18
+ try:
19
+ return spacy.load("en_core_web_sm")
20
+ except OSError:
21
+ st.warning("Downloading spaCy model en_core_web_sm... This may take a moment.")
22
+ subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
23
+ return spacy.load("en_core_web_sm")
24
 
25
+ nlp_spacy = load_spacy_model()
 
26
 
27
  # --- Initialize Session State Variables ---
28
  if "chat_history" not in st.session_state:
 
32
  if "follow_up" not in st.session_state:
33
  st.session_state.follow_up = ""
34
 
35
+ # --- Appearance CSS ---
36
+ st.markdown("""
37
+ <style>
38
+ .user-msg {
39
+ background-color: #696969;
40
+ color: white;
41
+ padding: 10px;
42
+ border-radius: 10px;
43
+ margin-bottom: 5px;
44
+ width: fit-content;
45
+ max-width: 80%;
46
+ }
47
+ .assistant-msg {
48
+ background-color: #333333;
49
+ color: white;
50
+ padding: 10px;
51
+ border-radius: 10px;
52
+ margin-bottom: 5px;
53
+ width: fit-content;
54
+ max-width: 80%;
55
+ }
56
+ .container {
57
+ display: flex;
58
+ flex-direction: column;
59
+ align-items: flex-start;
60
+ }
61
+ @media (max-width: 600px) {
62
+ .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; }
63
+ }
64
+ </style>
65
+ """, unsafe_allow_html=True)
66
+
67
  # --- Set Up Model & API Functions ---
68
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
69
  sentiment_analyzer = pipeline(
 
77
  repo_id=model_id,
78
  max_new_tokens=max_new_tokens,
79
  temperature=temperature,
80
+ token=os.getenv("HF_TOKEN"),
81
  task="text-generation"
82
  )
83
 
84
  def get_nasa_apod():
85
+ url = f"https://api.nasa.gov/planetary/apod?api_key={os.getenv('NASA_API_KEY')}"
86
  response = requests.get(url)
87
  if response.status_code == 200:
88
  data = response.json()
 
95
  return result['label']
96
 
97
  def predict_action(user_text):
98
+ if "nasa" in user_text.lower() or "space" in user_text.lower():
99
  return "nasa_info"
100
  return "general_query"
101
 
102
+ def extract_context(text):
103
+ """Extract key entities using spaCy."""
104
+ doc = nlp_spacy(text)
105
+ entities = [ent.text for ent in doc.ents]
106
+ return ", ".join(entities) if entities else ""
107
+
108
  def generate_follow_up(user_text):
109
  """
110
  Generates two variant follow-up questions and randomly selects one.
 
111
  """
112
  prompt_text = (
113
+ f"Based on the user's question: '{user_text}', generate two concise, friendly follow-up questions that invite further discussion. "
114
+ "For example, one might be 'Would you like to know more about the six types of quarks?' and another 'Would you like to explore another aspect of quantum physics?'. "
115
+ "Answer exclusively in English, and do not include extra commentary."
116
  )
117
  hf = get_llm_hf_inference(max_new_tokens=80, temperature=0.9)
118
  output = hf.invoke(input=prompt_text).strip()
 
124
 
125
  def get_response(system_message, chat_history, user_text, max_new_tokens=1024):
126
  """
127
+ Generates HAL's detailed, in-depth response and a follow-up question.
128
+ It incorporates sentiment analysis, additional NLP context, and style instructions.
 
129
  """
130
  sentiment = analyze_sentiment(user_text)
131
  action = predict_action(user_text)
132
 
133
+ # Extract extra context (e.g., named entities)
134
+ context_info = extract_context(user_text)
135
+ context_clause = f" The key topics here are: {context_info}." if context_info else ""
136
+
137
+ # Extract style instruction if provided.
138
  style_instruction = ""
139
  lower_text = user_text.lower()
140
  if "in the voice of" in lower_text or "speaking as" in lower_text:
 
143
  style_instruction = match.group(2).strip().capitalize()
144
  style_instruction = f" Please respond in the voice of {style_instruction}."
145
 
146
+ # Force output in English.
147
+ language_clause = " Answer exclusively in English."
148
+
149
  if action == "nasa_info":
150
  nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
151
  response = f"**{nasa_title}**\n\n{nasa_explanation}"
 
164
 
165
  style_clause = style_instruction if style_instruction else ""
166
 
 
167
  prompt = PromptTemplate.from_template(
168
  (
169
  "[INST] {system_message}\n\nCurrent Conversation:\n{chat_history}\n\n"
170
  "User: {user_text}.\n [/INST]\n"
171
+ "AI: Please provide a detailed, in-depth answer in a friendly, conversational tone that thoroughly covers the topic."
172
+ + style_clause + context_clause + language_clause +
 
173
  "\nHAL:"
174
  )
175
  )
176
 
177
+ st.write("DEBUG: Prompt sent to model:")
178
+ st.write(prompt.format(system_message=system_message, chat_history=filtered_history, user_text=user_text))
179
+
180
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
181
+ raw_output = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
182
+ st.write("DEBUG: Raw model output:")
183
+ st.write(raw_output)
184
 
185
+ response = raw_output # Using the full raw output without splitting
186
  if not response:
187
  response = "Certainly, here is an in-depth explanation: [Fallback explanation]."
188
 
 
196
  follow_up = generate_follow_up(user_text)
197
  chat_history.append({'role': 'assistant', 'content': follow_up})
198
 
199
+ st.write("DEBUG: Generated follow-up question:", follow_up)
200
+
201
  return response, follow_up, chat_history, None
202
 
203
  # --- Chat UI ---
 
210
  st.session_state.follow_up = ""
211
  st.experimental_rerun()
212
 
213
+ st.markdown("<div class='container'>", unsafe_allow_html=True)
214
+ for message in st.session_state.chat_history:
215
+ if message["role"] == "user":
216
+ st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
217
+ else:
218
+ st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
219
+ st.markdown("</div>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  user_input = st.chat_input("Type your message here...")
222
 
 
230
  st.image(image_url, caption="NASA Image of the Day")
231
  st.session_state.follow_up = follow_up
232
  st.session_state.response_ready = True