CCockrum commited on
Commit
5c095c6
·
verified ·
1 Parent(s): 2825ff1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -89
app.py CHANGED
@@ -4,160 +4,184 @@ import streamlit as st
4
  from langchain_huggingface import HuggingFaceEndpoint
5
  from langchain_core.prompts import PromptTemplate
6
  from langchain_core.output_parsers import StrOutputParser
7
- from transformers import pipeline # for Sentiment Analysis
8
- NASA_API_KEY = os.getenv("NASA_API_KEY")
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
11
 
12
- # Initialize sentiment analysis pipeline
13
- sentiment_analyzer = pipeline("sentiment-analysis")
 
 
 
 
14
 
15
- def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
16
- llm = HuggingFaceEndpoint(
 
17
  repo_id=model_id,
18
  max_new_tokens=max_new_tokens,
19
  temperature=temperature,
20
- token=os.getenv("HF_TOKEN") # Hugging Face token from environment variable
 
21
  )
22
- return llm
23
 
24
  def get_nasa_apod():
25
- """
26
- Fetch the Astronomy Picture of the Day (APOD) from the NASA API.
27
- """
28
  url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
29
  response = requests.get(url)
30
  if response.status_code == 200:
31
  data = response.json()
32
- return f"Title: {data['title']}\nExplanation: {data['explanation']}\nURL: {data['url']}"
33
  else:
34
- return "I couldn't fetch data from NASA right now. Please try again later."
35
 
36
  def analyze_sentiment(user_text):
37
- """
38
- Analyzes the sentiment of the user's input to adjust responses.
39
- """
40
  result = sentiment_analyzer(user_text)[0]
41
- sentiment = result['label']
42
- return sentiment
43
 
44
  def predict_action(user_text):
45
- """
46
- Predicts actions based on user input (e.g., fetch space info or general knowledge).
47
- """
48
  if "NASA" in user_text or "space" in user_text:
49
  return "nasa_info"
50
- if "weather" in user_text:
51
- return "weather_info"
52
  return "general_query"
53
 
54
  def generate_follow_up(user_text):
55
  """
56
- Generates a relevant follow-up question based on the user's input.
57
  """
58
  prompt_text = (
59
- f"Given the user's message: '{user_text}', ask one natural follow-up question "
60
- "that suggests a related topic or offers user the opportunity to go in a new direction."
 
61
  )
 
 
62
 
63
- hf = get_llm_hf_inference(max_new_tokens=64, temperature=0.7)
64
- chat = hf.invoke(input=prompt_text)
65
-
66
- return chat.strip()
67
-
68
- def get_response(system_message, chat_history, user_text,
69
- eos_token_id=['User'], max_new_tokens=256, get_llm_hf_kws={}):
70
  sentiment = analyze_sentiment(user_text)
71
  action = predict_action(user_text)
72
 
73
  if action == "nasa_info":
74
- nasa_response = get_nasa_apod()
 
75
  chat_history.append({'role': 'user', 'content': user_text})
76
- chat_history.append({'role': 'assistant', 'content': nasa_response})
77
 
78
  follow_up = generate_follow_up(user_text)
79
  chat_history.append({'role': 'assistant', 'content': follow_up})
80
- return f"{nasa_response}\n\n{follow_up}", chat_history
81
 
82
- hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
83
 
84
  prompt = PromptTemplate.from_template(
85
  (
86
- "[INST] {system_message}"
87
- "\nCurrent Conversation:\n{chat_history}\n\n"
88
- "\nUser: {user_text}.\n [/INST]"
89
- "\nAI:"
90
  )
91
  )
92
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
93
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
94
- response = response.split("AI:")[-1]
95
 
96
  chat_history.append({'role': 'user', 'content': user_text})
97
  chat_history.append({'role': 'assistant', 'content': response})
98
 
99
- # Modify response based on sentiment analysis (e.g., offer help for negative sentiments)
100
  if sentiment == "NEGATIVE":
101
- response += "\nI'm sorry to hear that. How can I assist you further?"
102
 
103
  follow_up = generate_follow_up(user_text)
104
  chat_history.append({'role': 'assistant', 'content': follow_up})
105
 
106
- return f"{response}\n\n{follow_up}", chat_history
107
 
108
- # Streamlit setup
109
- st.set_page_config(page_title="HuggingFace ChatBot", page_icon="🤗")
110
- st.title("NASA Personal Assistant")
111
- st.markdown(f"*This chatbot uses {model_id} and NASA's APIs to provide information and responses.*")
112
 
113
- # Initialize session state
114
- if "chat_history" not in st.session_state:
115
- st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
116
-
117
- # Sidebar for settings
118
  if st.sidebar.button("Reset Chat"):
119
  st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- # Main chat interface
122
- user_input = st.chat_input(placeholder="Type your message here...")
123
  if user_input:
124
- response, st.session_state.chat_history = get_response(
125
  system_message="You are a helpful AI assistant.",
126
  user_text=user_input,
127
- chat_history=st.session_state.chat_history,
128
- max_new_tokens=128
129
  )
130
- # Display messages
131
- for message in st.session_state.chat_history:
132
- st.chat_message(message["role"]).write(message["content"])
133
 
 
134
 
 
 
135
 
 
 
136
 
137
- if st.button("Send"):
138
- if user_input:
139
- response, follow_up, st.session_state.chat_history, image_url = get_response(
140
- system_message="You are a helpful AI assistant.",
141
- user_text=user_input,
142
- chat_history=st.session_state.chat_history
143
- )
144
-
145
- # Display response
146
- st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {response}</div>", unsafe_allow_html=True)
147
-
148
- # Display NASA image if available
149
- if image_url:
150
- st.image(image_url, caption="NASA Image of the Day")
151
-
152
- # Follow-up question suggestions
153
- follow_up_options = [follow_up, "Explain differently", "Give me an example"]
154
- selected_option = st.radio("What would you like to do next?", follow_up_options)
155
-
156
- if st.button("Continue"):
157
- if selected_option:
158
- response, _, st.session_state.chat_history, _ = get_response(
159
- system_message="You are a helpful AI assistant.",
160
- user_text=selected_option,
161
- chat_history=st.session_state.chat_history
162
- )
163
- st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {response}</div>", unsafe_allow_html=True)
 
4
  from langchain_huggingface import HuggingFaceEndpoint
5
  from langchain_core.prompts import PromptTemplate
6
  from langchain_core.output_parsers import StrOutputParser
7
+ from transformers import pipeline
8
+ from config import NASA_API_KEY # Ensure this file exists with your NASA API Key
9
 
10
+ # Set up Streamlit UI
11
+ st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="🚀")
12
+
13
+ # --- Ensure Session State Variables are Initialized ---
14
+ if "chat_history" not in st.session_state:
15
+ st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
16
+
17
+ if "response_ready" not in st.session_state:
18
+ st.session_state.response_ready = False # Tracks whether HAL has responded
19
+
20
+ if "follow_up" not in st.session_state:
21
+ st.session_state.follow_up = "" # Stores follow-up question
22
+
23
+ # --- Set Up Model & API Functions ---
24
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
25
 
26
+ # Initialize sentiment analysis pipeline with explicit model specification
27
+ sentiment_analyzer = pipeline(
28
+ "sentiment-analysis",
29
+ model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
30
+ revision="714eb0f"
31
+ )
32
 
33
+ def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.7):
34
+ # Explicitly specify task="text-generation" so that the endpoint knows which task to run
35
+ return HuggingFaceEndpoint(
36
  repo_id=model_id,
37
  max_new_tokens=max_new_tokens,
38
  temperature=temperature,
39
+ token=os.getenv("HF_TOKEN"),
40
+ task="text-generation"
41
  )
 
42
 
43
  def get_nasa_apod():
 
 
 
44
  url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
45
  response = requests.get(url)
46
  if response.status_code == 200:
47
  data = response.json()
48
+ return data.get("url", ""), data.get("title", ""), data.get("explanation", "")
49
  else:
50
+ return "", "NASA Data Unavailable", "I couldn't fetch data from NASA right now. Please try again later."
51
 
52
  def analyze_sentiment(user_text):
 
 
 
53
  result = sentiment_analyzer(user_text)[0]
54
+ return result['label']
 
55
 
56
  def predict_action(user_text):
 
 
 
57
  if "NASA" in user_text or "space" in user_text:
58
  return "nasa_info"
 
 
59
  return "general_query"
60
 
61
  def generate_follow_up(user_text):
62
  """
63
+ Generates a concise and conversational follow-up question related to the user's input.
64
  """
65
  prompt_text = (
66
+ f"Given the user's question: '{user_text}', generate a SHORT and SIMPLE follow-up question. "
67
+ "Make it conversational and friendly. Example: 'Would you like to learn more about the six types of quarks?' "
68
+ "Do NOT provide long explanations—just ask a friendly follow-up question."
69
  )
70
+ hf = get_llm_hf_inference(max_new_tokens=32, temperature=0.7)
71
+ return hf.invoke(input=prompt_text).strip()
72
 
73
+ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
74
+ """
75
+ Generates HAL's response, making it more conversational and engaging.
76
+ """
 
 
 
77
  sentiment = analyze_sentiment(user_text)
78
  action = predict_action(user_text)
79
 
80
  if action == "nasa_info":
81
+ nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
82
+ response = f"**{nasa_title}**\n\n{nasa_explanation}"
83
  chat_history.append({'role': 'user', 'content': user_text})
84
+ chat_history.append({'role': 'assistant', 'content': response})
85
 
86
  follow_up = generate_follow_up(user_text)
87
  chat_history.append({'role': 'assistant', 'content': follow_up})
88
+ return response, follow_up, chat_history, nasa_url
89
 
90
+ hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.9)
91
 
92
  prompt = PromptTemplate.from_template(
93
  (
94
+ "[INST] {system_message}\n\nCurrent Conversation:\n{chat_history}\n\nUser: {user_text}.\n [/INST]\n"
95
+ "AI: Keep responses conversational and engaging. Start with a friendly phrase like "
96
+ "'Certainly!', 'Of course!', or 'Great question!' before answering. "
97
+ "Keep responses concise but engaging.\nHAL:"
98
  )
99
  )
100
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
101
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
102
+ response = response.split("HAL:")[-1].strip()
103
 
104
  chat_history.append({'role': 'user', 'content': user_text})
105
  chat_history.append({'role': 'assistant', 'content': response})
106
 
 
107
  if sentiment == "NEGATIVE":
108
+ response = "I'm here to help. Let me know what I can do for you. 😊"
109
 
110
  follow_up = generate_follow_up(user_text)
111
  chat_history.append({'role': 'assistant', 'content': follow_up})
112
 
113
+ return response, follow_up, chat_history, None
114
 
115
+ # --- Chat UI ---
116
+ st.title("🚀 HAL - Your NASA AI Assistant")
117
+ st.markdown("🌌 *Ask me about space, NASA, and beyond!*")
 
118
 
119
+ # Sidebar: Reset Chat
 
 
 
 
120
  if st.sidebar.button("Reset Chat"):
121
  st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
122
+ st.session_state.response_ready = False
123
+ st.session_state.follow_up = ""
124
+ st.experimental_rerun()
125
+
126
+ # Custom Chat Styling
127
+ st.markdown("""
128
+ <style>
129
+ .user-msg {
130
+ background-color: #0078D7;
131
+ color: white;
132
+ padding: 10px;
133
+ border-radius: 10px;
134
+ margin-bottom: 5px;
135
+ width: fit-content;
136
+ max-width: 80%;
137
+ }
138
+ .assistant-msg {
139
+ background-color: #333333;
140
+ color: white;
141
+ padding: 10px;
142
+ border-radius: 10px;
143
+ margin-bottom: 5px;
144
+ width: fit-content;
145
+ max-width: 80%;
146
+ }
147
+ .container {
148
+ display: flex;
149
+ flex-direction: column;
150
+ align-items: flex-start;
151
+ }
152
+ @media (max-width: 600px) {
153
+ .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; }
154
+ }
155
+ </style>
156
+ """, unsafe_allow_html=True)
157
+
158
+ # Chat History Display
159
+ st.markdown("<div class='container'>", unsafe_allow_html=True)
160
+ for message in st.session_state.chat_history:
161
+ if message["role"] == "user":
162
+ st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
163
+ else:
164
+ st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
165
+ st.markdown("</div>", unsafe_allow_html=True)
166
+
167
+ # --- Single Input Box for Both Initial and Follow-Up Messages ---
168
+ user_input = st.chat_input("Type your message here...") # Only ONE chat_input()
169
 
 
 
170
  if user_input:
171
+ response, follow_up, st.session_state.chat_history, image_url = get_response(
172
  system_message="You are a helpful AI assistant.",
173
  user_text=user_input,
174
+ chat_history=st.session_state.chat_history
 
175
  )
 
 
 
176
 
177
+ st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {response}</div>", unsafe_allow_html=True)
178
 
179
+ if image_url:
180
+ st.image(image_url, caption="NASA Image of the Day")
181
 
182
+ st.session_state.follow_up = follow_up
183
+ st.session_state.response_ready = True
184
 
185
+ if st.session_state.response_ready and st.session_state.follow_up:
186
+ st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {st.session_state.follow_up}</div>", unsafe_allow_html=True)
187
+ st.session_state.response_ready = False