CCockrum commited on
Commit
6b56b5d
Β·
verified Β·
1 Parent(s): 93a4e54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -65
app.py CHANGED
@@ -5,26 +5,25 @@ from langchain_huggingface import HuggingFaceEndpoint
5
  from langchain_core.prompts import PromptTemplate
6
  from langchain_core.output_parsers import StrOutputParser
7
  from transformers import pipeline
8
- from config import NASA_API_KEY # Import the NASA API key
9
 
 
10
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
11
 
12
  # Initialize sentiment analysis pipeline
13
  sentiment_analyzer = pipeline("sentiment-analysis")
14
 
 
15
  def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
16
- llm = HuggingFaceEndpoint(
17
  repo_id=model_id,
18
  max_new_tokens=max_new_tokens,
19
  temperature=temperature,
20
- token=os.getenv("HF_TOKEN") # Hugging Face token from environment variable
21
  )
22
- return llm
23
 
 
24
  def get_nasa_apod():
25
- """
26
- Fetch NASA Astronomy Picture of the Day (APOD).
27
- """
28
  url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
29
  response = requests.get(url)
30
  if response.status_code == 200:
@@ -33,35 +32,27 @@ def get_nasa_apod():
33
  else:
34
  return "", "NASA Data Unavailable", "I couldn't fetch data from NASA right now. Please try again later."
35
 
 
36
  def analyze_sentiment(user_text):
37
- """
38
- Analyze sentiment of user input.
39
- """
40
  result = sentiment_analyzer(user_text)[0]
41
  return result['label']
42
 
 
43
  def predict_action(user_text):
44
- """
45
- Predicts user's intent based on input.
46
- """
47
  if "NASA" in user_text or "space" in user_text:
48
  return "nasa_info"
49
  return "general_query"
50
 
 
51
  def generate_follow_up(user_text):
52
- """
53
- Generates a follow-up question to continue the conversation.
54
- """
55
  prompt_text = (
56
  f"Based on the user's message: '{user_text}', suggest a natural follow-up question "
57
  "to keep the conversation engaging."
58
  )
59
-
60
  hf = get_llm_hf_inference(max_new_tokens=64, temperature=0.7)
61
- chat = hf.invoke(input=prompt_text)
62
-
63
- return chat.strip()
64
 
 
65
  def get_response(system_message, chat_history, user_text, max_new_tokens=256):
66
  sentiment = analyze_sentiment(user_text)
67
  action = predict_action(user_text)
@@ -81,7 +72,6 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
81
  prompt = PromptTemplate.from_template(
82
  "[INST] {system_message}\n\nCurrent Conversation:\n{chat_history}\n\nUser: {user_text}.\n [/INST]\nAI:"
83
  )
84
-
85
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
86
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
87
  response = response.split("AI:")[-1]
@@ -97,67 +87,53 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
97
 
98
  return response, follow_up, chat_history, None
99
 
100
- # Streamlit UI Setup
101
  st.set_page_config(page_title="NASA ChatBot", page_icon="πŸš€")
102
 
103
- st.title("πŸš€ HAL")
104
- # Chat Display with Updated Styling
105
- st.markdown("<div class='container'>", unsafe_allow_html=True)
106
 
107
- for message in st.session_state.chat_history:
108
- if message["role"] == "user":
109
- st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
110
- else:
111
- st.markdown(f"<div class='assistant-msg'><strong>Bot:</strong> {message['content']}</div>", unsafe_allow_html=True)
112
 
113
- st.markdown("</div>", unsafe_allow_html=True)
 
 
 
114
 
115
- # Custom CSS for chat styling
116
  st.markdown("""
117
  <style>
118
- /* Style for chat messages */
119
  .user-msg {
120
  background-color: #0078D7; /* Dark Blue */
121
- color: white; /* White text for contrast */
122
  padding: 10px;
123
  border-radius: 10px;
124
  margin-bottom: 5px;
125
  width: fit-content;
126
  max-width: 80%;
127
  }
128
-
129
  .assistant-msg {
130
  background-color: #333333; /* Dark Gray */
131
- color: white; /* White text for contrast */
132
  padding: 10px;
133
  border-radius: 10px;
134
  margin-bottom: 5px;
135
  width: fit-content;
136
  max-width: 80%;
137
  }
138
-
139
- /* Center messages for better appearance */
140
  .container {
141
  display: flex;
142
  flex-direction: column;
143
  align-items: flex-start;
144
  }
145
-
146
- /* Adjust messages on mobile */
147
  @media (max-width: 600px) {
148
- .user-msg, .assistant-msg {
149
- font-size: 16px;
150
- max-width: 100%;
151
- }
152
  }
153
  </style>
154
  """, unsafe_allow_html=True)
155
 
156
- # Initialize chat history
157
- # Initialize chat history in session state
158
- if "chat_history" not in st.session_state:
159
- st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
160
-
161
  # Chat Display
162
  st.markdown("<div class='container'>", unsafe_allow_html=True)
163
 
@@ -169,21 +145,7 @@ for message in st.session_state.chat_history:
169
 
170
  st.markdown("</div>", unsafe_allow_html=True)
171
 
172
-
173
-
174
- # Sidebar for chat reset
175
- if st.sidebar.button("Reset Chat"):
176
- st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
177
- st.experimental_rerun()
178
-
179
- # Chat display
180
- for message in st.session_state.chat_history:
181
- if message["role"] == "user":
182
- st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
183
- else:
184
- st.markdown(f"<div class='assistant-msg'><strong>Bot:</strong> {message['content']}</div>", unsafe_allow_html=True)
185
-
186
- # User input
187
  user_input = st.text_area("Type your message:", height=100)
188
 
189
  if st.button("Send"):
@@ -201,7 +163,7 @@ if st.button("Send"):
201
  if image_url:
202
  st.image(image_url, caption="NASA Image of the Day")
203
 
204
- # Follow-up options
205
  follow_up_options = [follow_up, "Explain differently", "Give me an example"]
206
  selected_option = st.radio("What would you like to do next?", follow_up_options)
207
 
 
5
  from langchain_core.prompts import PromptTemplate
6
  from langchain_core.output_parsers import StrOutputParser
7
  from transformers import pipeline
8
+ from config import NASA_API_KEY # Ensure this file exists with your NASA API Key
9
 
10
+ # Model settings
11
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
12
 
13
  # Initialize sentiment analysis pipeline
14
  sentiment_analyzer = pipeline("sentiment-analysis")
15
 
16
+ # Function to initialize Hugging Face model
17
  def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
18
+ return HuggingFaceEndpoint(
19
  repo_id=model_id,
20
  max_new_tokens=max_new_tokens,
21
  temperature=temperature,
22
+ token=os.getenv("HF_TOKEN") # Hugging Face API Token
23
  )
 
24
 
25
+ # Function to get NASA Astronomy Picture of the Day
26
  def get_nasa_apod():
 
 
 
27
  url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
28
  response = requests.get(url)
29
  if response.status_code == 200:
 
32
  else:
33
  return "", "NASA Data Unavailable", "I couldn't fetch data from NASA right now. Please try again later."
34
 
35
+ # Function to analyze sentiment of user input
36
  def analyze_sentiment(user_text):
 
 
 
37
  result = sentiment_analyzer(user_text)[0]
38
  return result['label']
39
 
40
+ # Function to predict user intent
41
  def predict_action(user_text):
 
 
 
42
  if "NASA" in user_text or "space" in user_text:
43
  return "nasa_info"
44
  return "general_query"
45
 
46
+ # Function to generate a follow-up question
47
  def generate_follow_up(user_text):
 
 
 
48
  prompt_text = (
49
  f"Based on the user's message: '{user_text}', suggest a natural follow-up question "
50
  "to keep the conversation engaging."
51
  )
 
52
  hf = get_llm_hf_inference(max_new_tokens=64, temperature=0.7)
53
+ return hf.invoke(input=prompt_text).strip()
 
 
54
 
55
+ # Function to process user input and generate a response
56
  def get_response(system_message, chat_history, user_text, max_new_tokens=256):
57
  sentiment = analyze_sentiment(user_text)
58
  action = predict_action(user_text)
 
72
  prompt = PromptTemplate.from_template(
73
  "[INST] {system_message}\n\nCurrent Conversation:\n{chat_history}\n\nUser: {user_text}.\n [/INST]\nAI:"
74
  )
 
75
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
76
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
77
  response = response.split("AI:")[-1]
 
87
 
88
  return response, follow_up, chat_history, None
89
 
90
+ # --- Streamlit UI Setup ---
91
  st.set_page_config(page_title="NASA ChatBot", page_icon="πŸš€")
92
 
93
+ st.title("πŸš€ HAL - Your NASA AI Assistant")
94
+ st.markdown("🌌 *Ask me about space, NASA, and beyond!*")
 
95
 
96
+ # Ensure chat history is initialized
97
+ if "chat_history" not in st.session_state:
98
+ st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
 
 
99
 
100
+ # Sidebar for chat reset
101
+ if st.sidebar.button("Reset Chat"):
102
+ st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
103
+ st.experimental_rerun()
104
 
105
+ # Chat Display Styling
106
  st.markdown("""
107
  <style>
 
108
  .user-msg {
109
  background-color: #0078D7; /* Dark Blue */
110
+ color: white;
111
  padding: 10px;
112
  border-radius: 10px;
113
  margin-bottom: 5px;
114
  width: fit-content;
115
  max-width: 80%;
116
  }
 
117
  .assistant-msg {
118
  background-color: #333333; /* Dark Gray */
119
+ color: white;
120
  padding: 10px;
121
  border-radius: 10px;
122
  margin-bottom: 5px;
123
  width: fit-content;
124
  max-width: 80%;
125
  }
 
 
126
  .container {
127
  display: flex;
128
  flex-direction: column;
129
  align-items: flex-start;
130
  }
 
 
131
  @media (max-width: 600px) {
132
+ .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; }
 
 
 
133
  }
134
  </style>
135
  """, unsafe_allow_html=True)
136
 
 
 
 
 
 
137
  # Chat Display
138
  st.markdown("<div class='container'>", unsafe_allow_html=True)
139
 
 
145
 
146
  st.markdown("</div>", unsafe_allow_html=True)
147
 
148
+ # User Input Section
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  user_input = st.text_area("Type your message:", height=100)
150
 
151
  if st.button("Send"):
 
163
  if image_url:
164
  st.image(image_url, caption="NASA Image of the Day")
165
 
166
+ # Follow-up question suggestions
167
  follow_up_options = [follow_up, "Explain differently", "Give me an example"]
168
  selected_option = st.radio("What would you like to do next?", follow_up_options)
169