CCockrum commited on
Commit
546ff54
Β·
verified Β·
1 Parent(s): 6fe1d37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -131
app.py CHANGED
@@ -7,21 +7,21 @@ from langchain_huggingface import HuggingFaceEndpoint
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain_core.output_parsers import StrOutputParser
9
  from transformers import pipeline
10
- from langdetect import detect
11
 
12
- # Use environment variables for keys
13
  HF_TOKEN = os.getenv("HF_TOKEN")
14
  if HF_TOKEN is None:
15
- raise ValueError("HF_TOKEN environment variable not set. Please set it in your Hugging Face Space settings.")
16
 
17
  NASA_API_KEY = os.getenv("NASA_API_KEY")
18
  if NASA_API_KEY is None:
19
- raise ValueError("NASA_API_KEY environment variable not set. Please set it in your Hugging Face Space settings.")
20
 
21
- # Set up Streamlit UI
22
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
23
 
24
- # --- Initialize Session State Variables ---
25
  if "chat_history" not in st.session_state:
26
  st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
27
  if "response_ready" not in st.session_state:
@@ -29,14 +29,10 @@ if "response_ready" not in st.session_state:
29
  if "follow_up" not in st.session_state:
30
  st.session_state.follow_up = ""
31
 
32
- # --- Set Up Model & API Functions ---
33
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
34
- sentiment_analyzer = pipeline(
35
- "sentiment-analysis",
36
- model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
37
- revision="714eb0f"
38
- )
39
 
 
40
  def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.7):
41
  return HuggingFaceEndpoint(
42
  repo_id=model_id,
@@ -46,62 +42,40 @@ def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.7)
46
  task="text-generation"
47
  )
48
 
 
49
  def get_nasa_apod():
50
  url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
51
  response = requests.get(url)
52
  if response.status_code == 200:
53
  data = response.json()
54
  return data.get("url", ""), data.get("title", ""), data.get("explanation", "")
55
- else:
56
- return "", "NASA Data Unavailable", "I couldn't fetch data from NASA right now. Please try again later."
 
 
 
 
 
57
 
58
  def analyze_sentiment(user_text):
59
  result = sentiment_analyzer(user_text)[0]
60
  return result['label']
61
 
62
- def is_apod_query(user_text):
63
- """
64
- Checks if the user's question contains keywords indicating they are asking for
65
- the Astronomy Picture of the Day.
66
- """
67
- keywords = ["apod", "image", "picture", "photo", "astronomy picture"]
68
- return any(keyword in user_text.lower() for keyword in keywords)
69
-
70
  def predict_action(user_text):
71
- """
72
- Determines the type of response for the user's query.
73
- If the query contains "NASA" or "space" and also indicates an APOD query,
74
- it returns "nasa_info". Otherwise, it returns "general_query", even if "NASA" is present.
75
- """
76
- user_text_lower = user_text.lower()
77
- if "nasa" in user_text_lower or "space" in user_text_lower:
78
- if is_apod_query(user_text):
79
- return "nasa_info"
80
  return "general_query"
81
 
 
82
  def generate_follow_up(user_text):
83
- """
84
- Generates two variant follow-up questions and randomly selects one.
85
- It also cleans up any unwanted quotation marks or extra meta commentary.
86
- """
87
- prompt_text = (
88
- f"Based on the user's question: '{user_text}', generate two concise, friendly follow-up questions "
89
- "that invite further discussion. For example, one might be 'Would you like to know more about the six types of quarks?' "
90
- "and another might be 'Would you like to explore another aspect of quantum physics?' Do not include extra commentary ."
91
- "Answer exclusively in English, and do not include extra commentary."
92
- )
93
  hf = get_llm_hf_inference(max_new_tokens=80, temperature=0.9)
94
  output = hf.invoke(input=prompt_text).strip()
95
- variants = re.split(r"\n|[;]+", output)
96
- cleaned = [v.strip(' "\'') for v in variants if v.strip()]
97
- if not cleaned:
98
- cleaned = ["Would you like to explore this topic further?"]
99
- return random.choice(cleaned)
100
-
101
- from langdetect import detect
102
 
 
103
  def ensure_english(text):
104
- """Check if the model accidentally generated a non-English response."""
105
  try:
106
  detected_lang = detect(text)
107
  if detected_lang != "en":
@@ -110,26 +84,11 @@ def ensure_english(text):
110
  return "⚠️ Language detection failed. Please ask your question again."
111
  return text
112
 
 
113
  def get_response(system_message, chat_history, user_text, max_new_tokens=512):
114
- """
115
- Generates HAL's answer with depth and a follow-up question.
116
- The prompt instructs the model to provide a detailed explanation and then generate a follow-up.
117
- If the answer comes back empty, a fallback answer is used.
118
- """
119
-
120
- # πŸ” Determine the user's intent (NASA Info or General Query)
121
- action = predict_action(user_text) # πŸ”₯ Define 'action' here
122
-
123
- # Extract style instruction if present
124
- style_instruction = ""
125
- lower_text = user_text.lower()
126
- if "in the voice of" in lower_text or "speaking as" in lower_text:
127
- match = re.search(r"(in the voice of|speaking as)(.*)", lower_text)
128
- if match:
129
- style_instruction = match.group(2).strip().capitalize()
130
- style_instruction = f" Please respond in the voice of {style_instruction}."
131
-
132
- # πŸš€ Handle NASA-specific queries
133
  if action == "nasa_info":
134
  nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
135
  response = f"**{nasa_title}**\n\n{nasa_explanation}"
@@ -139,84 +98,53 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=512):
139
  chat_history.append({'role': 'assistant', 'content': follow_up})
140
  return response, follow_up, chat_history, nasa_url
141
 
 
142
  hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.9)
143
- filtered_history = ""
144
- for message in chat_history:
145
- if message["role"] == "assistant" and message["content"].strip() == "Hello! How can I assist you today?":
146
- continue
147
- filtered_history += f"{message['role']}: {message['content']}\n"
148
 
149
- style_clause = style_instruction if style_instruction else ""
 
150
 
151
- # Instruct the model to generate a detailed, in-depth answer.
152
  prompt = PromptTemplate.from_template(
153
- (
154
- "[INST] {system_message}\n\nCurrent Conversation:\n{chat_history}\n\n"
155
- "User: {user_text}.\n [/INST]\n"
156
- "AI: Please provide a detailed explanation in depth. "
157
- "Ensure your response covers the topic thoroughly and is written in a friendly, conversational style, "
158
- "starting with a phrase like 'Certainly!', 'Of course!', or 'Great question!'."
159
- "🚨 IMPORTANT: Answer exclusively in **English only**. Do not generate responses in any other language."
160
- "\nHAL:"
161
- )
162
  )
163
 
 
164
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
165
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
166
  response = response.split("HAL:")[-1].strip()
167
 
168
- # 🚨 Ensure the response is in English
169
  response = ensure_english(response)
170
 
171
- # Fallback in case the generated answer is empty
172
- if not response:
173
-
174
 
 
 
175
 
 
 
176
 
177
- # --- Chat UI ---
178
- st.title("πŸš€ HAL - Your NASA AI Assistant")
 
 
179
  st.markdown("🌌 *Ask me about space, NASA, and beyond!*")
180
 
181
- #Reset Button
182
  if st.sidebar.button("Reset Chat"):
183
  st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
184
  st.session_state.response_ready = False
185
  st.session_state.follow_up = ""
186
- st.experimental_rerun()
187
-
188
- #Style and Appearance
189
- st.markdown("""
190
- <style>
191
- .user-msg {
192
- background-color: #696969;
193
- color: white;
194
- padding: 10px;
195
- border-radius: 10px;
196
- margin-bottom: 5px;
197
- width: fit-content;
198
- max-width: 80%;
199
- }
200
- .assistant-msg {
201
- background-color: #333333;
202
- color: white;
203
- padding: 10px;
204
- border-radius: 10px;
205
- margin-bottom: 5px;
206
- width: fit-content;
207
- max-width: 80%;
208
- }
209
- .container {
210
- display: flex;
211
- flex-direction: column;
212
- align-items: flex-start;
213
- }
214
- @media (max-width: 600px) {
215
- .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; }
216
- }
217
- </style>
218
- """, unsafe_allow_html=True)
219
 
 
220
  user_input = st.chat_input("Type your message here...")
221
 
222
  if user_input:
@@ -227,13 +155,10 @@ if user_input:
227
  )
228
  if image_url:
229
  st.image(image_url, caption="NASA Image of the Day")
 
230
  st.session_state.follow_up = follow_up
231
  st.session_state.response_ready = True
232
 
233
- st.markdown("<div class='container'>", unsafe_allow_html=True)
234
- for message in st.session_state.chat_history:
235
- if message["role"] == "user":
236
- st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
237
- else:
238
- st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
239
- st.markdown("</div>", unsafe_allow_html=True)
 
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain_core.output_parsers import StrOutputParser
9
  from transformers import pipeline
10
+ from langdetect import detect # Ensure this package is installed
11
 
12
+ # βœ… Environment Variables
13
  HF_TOKEN = os.getenv("HF_TOKEN")
14
  if HF_TOKEN is None:
15
+ raise ValueError("HF_TOKEN is not set. Please add it to your environment variables.")
16
 
17
  NASA_API_KEY = os.getenv("NASA_API_KEY")
18
  if NASA_API_KEY is None:
19
+ raise ValueError("NASA_API_KEY is not set. Please add it to your environment variables.")
20
 
21
+ # βœ… Set Up Streamlit
22
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
23
 
24
+ # βœ… Ensure Session State Variables
25
  if "chat_history" not in st.session_state:
26
  st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
27
  if "response_ready" not in st.session_state:
 
29
  if "follow_up" not in st.session_state:
30
  st.session_state.follow_up = ""
31
 
32
+ # βœ… Model Configuration
33
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
 
 
 
 
 
34
 
35
+ # βœ… Initialize Hugging Face Model
36
  def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.7):
37
  return HuggingFaceEndpoint(
38
  repo_id=model_id,
 
42
  task="text-generation"
43
  )
44
 
45
+ # βœ… NASA API Function
46
  def get_nasa_apod():
47
  url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
48
  response = requests.get(url)
49
  if response.status_code == 200:
50
  data = response.json()
51
  return data.get("url", ""), data.get("title", ""), data.get("explanation", "")
52
+ return "", "NASA Data Unavailable", "I couldn't fetch data from NASA right now."
53
+
54
+ # βœ… Sentiment Analysis
55
+ sentiment_analyzer = pipeline(
56
+ "sentiment-analysis",
57
+ model="distilbert/distilbert-base-uncased-finetuned-sst-2-english"
58
+ )
59
 
60
  def analyze_sentiment(user_text):
61
  result = sentiment_analyzer(user_text)[0]
62
  return result['label']
63
 
64
+ # βœ… Intent Detection
 
 
 
 
 
 
 
65
  def predict_action(user_text):
66
+ if "NASA" in user_text or "space" in user_text:
67
+ return "nasa_info"
 
 
 
 
 
 
 
68
  return "general_query"
69
 
70
+ # βœ… Follow-Up Question Generation
71
  def generate_follow_up(user_text):
72
+ prompt_text = f"Based on: '{user_text}', generate a concise, friendly follow-up."
 
 
 
 
 
 
 
 
 
73
  hf = get_llm_hf_inference(max_new_tokens=80, temperature=0.9)
74
  output = hf.invoke(input=prompt_text).strip()
75
+ return output if output else "Would you like to explore this topic further?"
 
 
 
 
 
 
76
 
77
+ # βœ… Ensure English Responses
78
  def ensure_english(text):
 
79
  try:
80
  detected_lang = detect(text)
81
  if detected_lang != "en":
 
84
  return "⚠️ Language detection failed. Please ask your question again."
85
  return text
86
 
87
+ # βœ… Main Response Function
88
  def get_response(system_message, chat_history, user_text, max_new_tokens=512):
89
+ action = predict_action(user_text) # πŸ”₯ Fix: Define 'action'
90
+
91
+ # βœ… Handle NASA-Specific Queries
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  if action == "nasa_info":
93
  nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
94
  response = f"**{nasa_title}**\n\n{nasa_explanation}"
 
98
  chat_history.append({'role': 'assistant', 'content': follow_up})
99
  return response, follow_up, chat_history, nasa_url
100
 
101
+ # βœ… Set Up LLM Request
102
  hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.9)
 
 
 
 
 
103
 
104
+ # βœ… Format Chat History
105
+ filtered_history = "\n".join(f"{msg['role']}: {msg['content']}" for msg in chat_history)
106
 
107
+ # βœ… Prompt Engineering
108
  prompt = PromptTemplate.from_template(
109
+ "[INST] {system_message}\n\nCurrent Conversation:\n{chat_history}\n\n"
110
+ "User: {user_text}.\n [/INST]\n"
111
+ "AI: Provide a detailed explanation with depth. "
112
+ "Use a conversational style, starting with 'Certainly!', 'Of course!', or 'Great question!'."
113
+ "🚨 Answer **only in English**."
114
+ "\nHAL:"
 
 
 
115
  )
116
 
117
+ # βœ… Invoke LLM Model
118
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
119
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
120
  response = response.split("HAL:")[-1].strip()
121
 
122
+ # βœ… Ensure English
123
  response = ensure_english(response)
124
 
125
+ # βœ… Fallback Response
126
+ if not response:
127
+ response = "I'm sorry, but I couldn't generate a response. Can you rephrase your question?"
128
 
129
+ chat_history.append({'role': 'user', 'content': user_text})
130
+ chat_history.append({'role': 'assistant', 'content': response})
131
 
132
+ follow_up = generate_follow_up(user_text)
133
+ chat_history.append({'role': 'assistant', 'content': follow_up})
134
 
135
+ return response, follow_up, chat_history, None
136
+
137
+ # βœ… Streamlit UI
138
+ st.title("πŸš€ HAL - NASA AI Assistant")
139
  st.markdown("🌌 *Ask me about space, NASA, and beyond!*")
140
 
141
+ # βœ… Reset Chat Button
142
  if st.sidebar.button("Reset Chat"):
143
  st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
144
  st.session_state.response_ready = False
145
  st.session_state.follow_up = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ # βœ… Chat UI
148
  user_input = st.chat_input("Type your message here...")
149
 
150
  if user_input:
 
155
  )
156
  if image_url:
157
  st.image(image_url, caption="NASA Image of the Day")
158
+
159
  st.session_state.follow_up = follow_up
160
  st.session_state.response_ready = True
161
 
162
+ if st.session_state.response_ready and st.session_state.follow_up:
163
+ st.write(f"**HAL:** {st.session_state.follow_up}")
164
+ st.session_state.response_ready = False