CCockrum commited on
Commit
391ca85
Β·
verified Β·
1 Parent(s): 6b06c5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -45
app.py CHANGED
@@ -8,10 +8,52 @@ from langchain_core.prompts import PromptTemplate
8
  from langchain_core.output_parsers import StrOutputParser
9
  from transformers import pipeline
10
 
11
- # Must be the very first command!
12
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
13
 
14
- # Appearance adjustments (if any) could be added here as well
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Initialize session state variables
17
  if "chat_history" not in st.session_state:
@@ -21,9 +63,17 @@ if "response_ready" not in st.session_state:
21
  if "follow_up" not in st.session_state:
22
  st.session_state.follow_up = ""
23
  if "saved_conversations" not in st.session_state:
24
- st.session_state.saved_conversations = {} # Dictionary: conv_id -> chat_history
25
 
26
- # --- Set Up Model & API Functions ---
 
 
 
 
 
 
 
 
27
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
28
  sentiment_analyzer = pipeline(
29
  "sentiment-analysis",
@@ -36,34 +86,31 @@ def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.7)
36
  repo_id=model_id,
37
  max_new_tokens=max_new_tokens,
38
  temperature=temperature,
39
- token=os.getenv("HF_TOKEN"),
40
  task="text-generation"
41
  )
42
 
43
  def get_nasa_apod():
44
- url = f"https://api.nasa.gov/planetary/apod?api_key={os.getenv('NASA_API_KEY')}"
45
  response = requests.get(url)
46
  if response.status_code == 200:
47
  data = response.json()
48
  return data.get("url", ""), data.get("title", ""), data.get("explanation", "")
49
  else:
50
- return "", "NASA Data Unavailable", "I couldn't fetch data from NASA right now. Please try again later."
51
 
52
  def analyze_sentiment(user_text):
53
  result = sentiment_analyzer(user_text)[0]
54
  return result['label']
55
 
56
  def predict_action(user_text):
57
- if "NASA" in user_text or "space" in user_text:
58
- return "nasa_info"
59
- return "general_query"
60
 
61
  def generate_follow_up(user_text):
62
  prompt_text = (
63
- f"Based on the user's question: '{user_text}', generate two concise, friendly follow-up questions "
64
- "that invite further discussion. For example, one variant might ask, "
65
- "'Would you like to know more about the six types of quarks?' and another might ask, "
66
- "'Would you like to explore another aspect of quantum physics?'. Do not include extra commentary."
67
  )
68
  hf = get_llm_hf_inference(max_new_tokens=80, temperature=0.9)
69
  output = hf.invoke(input=prompt_text).strip()
@@ -84,6 +131,7 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
84
  style_instruction = match.group(2).strip().capitalize()
85
  style_instruction = f" Please respond in the voice of {style_instruction}."
86
 
 
87
  if action == "nasa_info":
88
  nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
89
  response = f"**{nasa_title}**\n\n{nasa_explanation}"
@@ -92,7 +140,7 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
92
  follow_up = generate_follow_up(user_text)
93
  chat_history.append({'role': 'assistant', 'content': follow_up})
94
  return response, follow_up, chat_history, nasa_url
95
-
96
  hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.9)
97
  filtered_history = ""
98
  for message in chat_history:
@@ -101,60 +149,51 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
101
  filtered_history += f"{message['role']}: {message['content']}\n"
102
 
103
  style_clause = style_instruction if style_instruction else ""
104
-
105
  prompt = PromptTemplate.from_template(
106
  (
107
  "[INST] {system_message}\n\nCurrent Conversation:\n{chat_history}\n\n"
108
  "User: {user_text}.\n [/INST]\n"
109
- "AI: Please answer the user's question in depth in a friendly, conversational tone, starting with a phrase like "
110
- "'Certainly!', 'Of course!', or 'Great question!'." + style_clause +
111
  "\nHAL:"
112
  )
113
  )
114
-
115
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
116
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
117
  response = response.split("HAL:")[-1].strip()
118
-
119
  if not response:
120
  response = "Certainly, here is an in-depth explanation: [Fallback explanation]."
121
-
122
  chat_history.append({'role': 'user', 'content': user_text})
123
  chat_history.append({'role': 'assistant', 'content': response})
124
-
125
  if sentiment == "NEGATIVE" and not user_text.strip().endswith("?"):
126
  response = "I'm sorry you're feeling this way. I'm here to help. What can I do to assist you further?"
127
  chat_history[-1]['content'] = response
128
-
129
  follow_up = generate_follow_up(user_text)
130
  chat_history.append({'role': 'assistant', 'content': follow_up})
131
-
132
  return response, follow_up, chat_history, None
133
 
134
- # --- Sidebar: Saved Conversations ---
135
- #st.sidebar.header("Saved Conversations")
136
- #if st.sidebar.button("Save Current Conversation"):
137
- #conv_id = f"Conv {len(st.session_state.saved_conversations) + 1}"
138
- # Save a copy of the current conversation history
139
- #st.session_state.saved_conversations[conv_id] = st.session_state.chat_history.copy()
140
- #st.sidebar.success(f"Conversation saved as {conv_id}.")
141
-
142
- # Display saved conversation links
143
- #if st.session_state.saved_conversations:
144
- #for conv_id in st.session_state.saved_conversations:
145
- #if st.sidebar.button(f"Load {conv_id}"):
146
- #st.session_state.chat_history = st.session_state.saved_conversations[conv_id].copy()
147
- #st.sidebar.info(f"Loaded {conv_id}.")
148
-
149
- # --- Chat UI Rendering ---
150
  st.title("πŸš€ HAL - Your NASA AI Assistant")
151
  st.markdown("🌌 *Ask me about space, NASA, and beyond!*")
152
 
153
- #if st.sidebar.button("Reset Chat"):
154
- #st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
155
- #st.session_state.response_ready = False
156
- #st.session_state.follow_up = ""
157
- #st.experimental_rerun()
158
 
159
  st.markdown("<div class='container'>", unsafe_allow_html=True)
160
  for message in st.session_state.chat_history:
 
8
  from langchain_core.output_parsers import StrOutputParser
9
  from transformers import pipeline
10
 
11
+ # Must be the first Streamlit command!
12
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
13
 
14
+ # Appearance settings (optional): you can modify these as needed
15
+ user_bg_color = "#0078D7"
16
+ assistant_bg_color = "#333333"
17
+ text_color = "#FFFFFF"
18
+ font_choice = "sans serif"
19
+
20
+ # Inject custom CSS for appearance
21
+ custom_css = f"""
22
+ <style>
23
+ .user-msg {{
24
+ background-color: {user_bg_color};
25
+ color: {text_color};
26
+ padding: 10px;
27
+ border-radius: 10px;
28
+ margin-bottom: 5px;
29
+ width: fit-content;
30
+ max-width: 80%;
31
+ font-family: {font_choice};
32
+ }}
33
+ .assistant-msg {{
34
+ background-color: {assistant_bg_color};
35
+ color: {text_color};
36
+ padding: 10px;
37
+ border-radius: 10px;
38
+ margin-bottom: 5px;
39
+ width: fit-content;
40
+ max-width: 80%;
41
+ font-family: {font_choice};
42
+ }}
43
+ .container {{
44
+ display: flex;
45
+ flex-direction: column;
46
+ align-items: flex-start;
47
+ }}
48
+ @media (max-width: 600px) {{
49
+ .user-msg, .assistant-msg {{
50
+ font-size: 16px;
51
+ max-width: 100%;
52
+ }}
53
+ }}
54
+ </style>
55
+ """
56
+ st.markdown(custom_css, unsafe_allow_html=True)
57
 
58
  # Initialize session state variables
59
  if "chat_history" not in st.session_state:
 
63
  if "follow_up" not in st.session_state:
64
  st.session_state.follow_up = ""
65
  if "saved_conversations" not in st.session_state:
66
+ st.session_state.saved_conversations = {} # dict mapping conv_id -> chat_history
67
 
68
+ # Set up keys from environment variables
69
+ HF_TOKEN = os.getenv("HF_TOKEN")
70
+ if not HF_TOKEN:
71
+ raise ValueError("HF_TOKEN environment variable not set.")
72
+ NASA_API_KEY = os.getenv("NASA_API_KEY")
73
+ if not NASA_API_KEY:
74
+ raise ValueError("NASA_API_KEY environment variable not set.")
75
+
76
+ # --- Model & API functions ---
77
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
78
  sentiment_analyzer = pipeline(
79
  "sentiment-analysis",
 
86
  repo_id=model_id,
87
  max_new_tokens=max_new_tokens,
88
  temperature=temperature,
89
+ token=HF_TOKEN,
90
  task="text-generation"
91
  )
92
 
93
  def get_nasa_apod():
94
+ url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
95
  response = requests.get(url)
96
  if response.status_code == 200:
97
  data = response.json()
98
  return data.get("url", ""), data.get("title", ""), data.get("explanation", "")
99
  else:
100
+ return "", "NASA Data Unavailable", "I couldn't fetch data from NASA right now."
101
 
102
  def analyze_sentiment(user_text):
103
  result = sentiment_analyzer(user_text)[0]
104
  return result['label']
105
 
106
  def predict_action(user_text):
107
+ return "nasa_info" if ("nasa" in user_text.lower() or "space" in user_text.lower()) else "general_query"
 
 
108
 
109
  def generate_follow_up(user_text):
110
  prompt_text = (
111
+ f"Based on the user's question: '{user_text}', generate two concise, friendly follow-up questions that invite further discussion. "
112
+ "For example, one could be 'Would you like to know more about the six types of quarks?' and another 'Would you like to explore another aspect of quantum physics?'. "
113
+ "Return only the questions, separated by a newline."
 
114
  )
115
  hf = get_llm_hf_inference(max_new_tokens=80, temperature=0.9)
116
  output = hf.invoke(input=prompt_text).strip()
 
131
  style_instruction = match.group(2).strip().capitalize()
132
  style_instruction = f" Please respond in the voice of {style_instruction}."
133
 
134
+ # Handle NASA queries separately
135
  if action == "nasa_info":
136
  nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
137
  response = f"**{nasa_title}**\n\n{nasa_explanation}"
 
140
  follow_up = generate_follow_up(user_text)
141
  chat_history.append({'role': 'assistant', 'content': follow_up})
142
  return response, follow_up, chat_history, nasa_url
143
+
144
  hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.9)
145
  filtered_history = ""
146
  for message in chat_history:
 
149
  filtered_history += f"{message['role']}: {message['content']}\n"
150
 
151
  style_clause = style_instruction if style_instruction else ""
 
152
  prompt = PromptTemplate.from_template(
153
  (
154
  "[INST] {system_message}\n\nCurrent Conversation:\n{chat_history}\n\n"
155
  "User: {user_text}.\n [/INST]\n"
156
+ "AI: Please provide a detailed, in-depth answer in a friendly, conversational tone. "
157
+ "Begin with a phrase like 'Certainly!', 'Of course!', or 'Great question!'." + style_clause +
158
  "\nHAL:"
159
  )
160
  )
 
161
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
162
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
163
  response = response.split("HAL:")[-1].strip()
 
164
  if not response:
165
  response = "Certainly, here is an in-depth explanation: [Fallback explanation]."
 
166
  chat_history.append({'role': 'user', 'content': user_text})
167
  chat_history.append({'role': 'assistant', 'content': response})
 
168
  if sentiment == "NEGATIVE" and not user_text.strip().endswith("?"):
169
  response = "I'm sorry you're feeling this way. I'm here to help. What can I do to assist you further?"
170
  chat_history[-1]['content'] = response
 
171
  follow_up = generate_follow_up(user_text)
172
  chat_history.append({'role': 'assistant', 'content': follow_up})
 
173
  return response, follow_up, chat_history, None
174
 
175
+ # --- Sidebar: Save/Load Conversations ---
176
+ st.sidebar.header("Saved Conversations")
177
+ if st.sidebar.button("Save Current Conversation"):
178
+ conv_id = f"Conv {len(st.session_state.saved_conversations) + 1}"
179
+ st.session_state.saved_conversations[conv_id] = st.session_state.chat_history.copy()
180
+ st.sidebar.success(f"Conversation saved as {conv_id}.")
181
+
182
+ if st.session_state.saved_conversations:
183
+ for conv_id in st.session_state.saved_conversations:
184
+ if st.sidebar.button(f"Load {conv_id}"):
185
+ st.session_state.chat_history = st.session_state.saved_conversations[conv_id].copy()
186
+ st.sidebar.info(f"Loaded {conv_id}.")
187
+
188
+ # --- Main Chat UI ---
 
 
189
  st.title("πŸš€ HAL - Your NASA AI Assistant")
190
  st.markdown("🌌 *Ask me about space, NASA, and beyond!*")
191
 
192
+ if st.sidebar.button("Reset Chat"):
193
+ st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
194
+ st.session_state.response_ready = False
195
+ st.session_state.follow_up = ""
196
+ st.experimental_rerun()
197
 
198
  st.markdown("<div class='container'>", unsafe_allow_html=True)
199
  for message in st.session_state.chat_history: