CCockrum commited on
Commit
1755fdf
Β·
verified Β·
1 Parent(s): 4dfa210

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -61
app.py CHANGED
@@ -7,11 +7,11 @@ from langchain_huggingface import HuggingFaceEndpoint
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain_core.output_parsers import StrOutputParser
9
  from transformers import pipeline
10
- from langdetect import detect # Ensure this package is installed
11
 
12
  # βœ… Check for GPU or Default to CPU
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- print(f"βœ… Using device: {device}") # Debugging info
15
 
16
  # βœ… Environment Variables
17
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -25,7 +25,7 @@ if NASA_API_KEY is None:
25
  # βœ… Set Up Streamlit
26
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
27
 
28
- # βœ… Initialize Session State Variables (Ensuring Chat History Persists)
29
  if "chat_history" not in st.session_state:
30
  st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
31
  if "response_ready" not in st.session_state:
@@ -33,15 +33,15 @@ if "response_ready" not in st.session_state:
33
  if "follow_up" not in st.session_state:
34
  st.session_state.follow_up = ""
35
 
36
- # βœ… Initialize Hugging Face Model (Explicitly Set to CPU/GPU)
37
- def get_llm_hf_inference(model_id="meta-llama/Llama-2-7b-chat-hf", max_new_tokens=800, temperature=0.8):
38
  return HuggingFaceEndpoint(
39
  repo_id=model_id,
40
  max_new_tokens=max_new_tokens,
41
  temperature=temperature,
42
  token=HF_TOKEN,
43
  task="text-generation",
44
- device=-1 if device == "cpu" else 0 # βœ… Force CPU (-1) or GPU (0)
45
  )
46
 
47
  # βœ… NASA API Function
@@ -57,7 +57,7 @@ def get_nasa_apod():
57
  sentiment_analyzer = pipeline(
58
  "sentiment-analysis",
59
  model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
60
- device=-1 if device == "cpu" else 0 # βœ… Force CPU (-1) or GPU (0)
61
  )
62
 
63
  def analyze_sentiment(user_text):
@@ -70,17 +70,19 @@ def predict_action(user_text):
70
  return "nasa_info"
71
  return "general_query"
72
 
73
- # βœ… Ensure English Responses
74
  def ensure_english(text):
 
75
  try:
76
  detected_lang = detect(text)
77
- if detected_lang != "en":
78
- return "⚠️ Sorry, I only respond in English. Can you rephrase your question?"
79
  except:
80
- return "⚠️ Language detection failed. Please ask your question again."
81
- return text
 
82
 
83
- # βœ… Follow-Up Question Generation
84
  def generate_follow_up(user_text):
85
  """Generates a structured follow-up question in a concise format."""
86
 
@@ -90,63 +92,59 @@ def generate_follow_up(user_text):
90
  "Ensure it's concise and structured exactly as requested without extra commentary."
91
  )
92
 
93
- hf = get_llm_hf_inference(max_new_tokens=30, temperature=0.8) # πŸ”₯ Lower temp for consistency
94
  output = hf.invoke(input=prompt_text).strip()
95
 
96
- # βœ… Extract the relevant part using regex to remove unwanted symbols or truncations
97
  cleaned_output = re.sub(r"```|''|\"", "", output).strip()
98
 
99
- # βœ… Ensure output is formatted correctly
100
  if "Would you like to learn more about" not in cleaned_output:
101
  cleaned_output = "Would you like to explore another related topic or ask about something else?"
102
 
103
  return cleaned_output
104
 
105
- # βœ… Main Response Function
106
  def get_response(system_message, user_text, max_new_tokens=800):
107
- """
108
- Generates a response from the chatbot, ensures conversation history is updated, and includes a follow-up question.
109
- """
110
 
111
- chat_history = st.session_state.chat_history # βœ… Get Chat History Reference
 
 
 
112
 
113
- # βœ… Detect Intent (NASA query vs General AI chat)
114
  action = predict_action(user_text)
115
 
116
- # βœ… Handle NASA-Specific Queries
117
  if action == "nasa_info":
118
  nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
119
  response = f"**{nasa_title}**\n\n{nasa_explanation}"
120
  follow_up = generate_follow_up(user_text)
121
 
122
  # βœ… Append to chat history
123
- chat_history.append({'role': 'user', 'content': user_text})
124
  chat_history.append({'role': 'assistant', 'content': response})
125
  chat_history.append({'role': 'assistant', 'content': follow_up})
126
- st.session_state.chat_history = chat_history # βœ… Update Session History
127
  return response, follow_up, nasa_url
128
 
129
  # βœ… Format Conversation History for Model Input
130
  formatted_chat_history = "\n".join(f"{msg['role']}: {msg['content']}" for msg in chat_history)
131
 
132
  # βœ… Invoke Hugging Face Model
133
- hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.9)
134
 
135
- # βœ… Define the Chat Prompt Template
136
  prompt = PromptTemplate.from_template(
137
  "[INST] You are a helpful AI assistant.\n\nCurrent Conversation:\n{chat_history}\n\n"
138
  "User: {user_text}.\n [/INST]\n"
139
- "AI: Provide a detailed explanation with depth. Use a conversational tone. "
140
  "🚨 Answer **only in English**."
141
- "Ensure a friendly, engaging tone."
142
  "\nHAL:"
143
  )
144
 
145
- # βœ… Generate AI Response
146
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
147
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=formatted_chat_history))
148
  response = response.split("HAL:")[-1].strip() if "HAL:" in response else response.strip()
149
 
 
150
  response = ensure_english(response)
151
 
152
  if not response:
@@ -154,43 +152,19 @@ def get_response(system_message, user_text, max_new_tokens=800):
154
 
155
  follow_up = generate_follow_up(user_text)
156
 
157
- # βœ… Append to Chat History
158
- chat_history.append({'role': 'user', 'content': user_text})
159
  chat_history.append({'role': 'assistant', 'content': response})
160
  chat_history.append({'role': 'assistant', 'content': follow_up})
161
- st.session_state.chat_history = chat_history # βœ… Persist History
162
 
163
  return response, follow_up, None
164
 
165
  # βœ… Streamlit UI
166
  st.title("πŸš€ HAL - NASA AI Assistant")
167
 
168
- # βœ… Justify all chatbot responses
169
- st.markdown("""
170
- <style>
171
- .user-msg, .assistant-msg {
172
- padding: 10px;
173
- border-radius: 10px;
174
- margin-bottom: 5px;
175
- width: fit-content;
176
- max-width: 80%;
177
- text-align: justify;
178
- }
179
- .user-msg { background-color: #696969; color: white; }
180
- .assistant-msg { background-color: #333333; color: white; }
181
- .container { display: flex; flex-direction: column; align-items: flex-start; }
182
- @media (max-width: 600px) { .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; } }
183
- </style>
184
- """, unsafe_allow_html=True)
185
-
186
  # βœ… Display Chat History
187
- st.markdown("<div class='container'>", unsafe_allow_html=True)
188
  for message in st.session_state.chat_history:
189
- if message["role"] == "user":
190
- st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
191
- else:
192
- st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
193
- st.markdown("</div>", unsafe_allow_html=True)
194
 
195
  # βœ… Chat Input
196
  user_input = st.chat_input("Type your message here...")
@@ -199,12 +173,10 @@ if user_input:
199
  response, follow_up, image_url = get_response("You are a helpful AI assistant.", user_input)
200
 
201
  if response:
202
- st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {response}</div>", unsafe_allow_html=True)
203
 
204
  if follow_up:
205
- st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {follow_up}</div>", unsafe_allow_html=True)
206
 
207
  if image_url:
208
  st.image(image_url, caption="NASA Image of the Day")
209
-
210
- st.session_state.response_ready = True
 
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain_core.output_parsers import StrOutputParser
9
  from transformers import pipeline
10
+ from langdetect import detect
11
 
12
  # βœ… Check for GPU or Default to CPU
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"βœ… Using device: {device}")
15
 
16
  # βœ… Environment Variables
17
  HF_TOKEN = os.getenv("HF_TOKEN")
 
25
  # βœ… Set Up Streamlit
26
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
27
 
28
+ # βœ… Ensure Session State Variables (Maintains Chat History)
29
  if "chat_history" not in st.session_state:
30
  st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
31
  if "response_ready" not in st.session_state:
 
33
  if "follow_up" not in st.session_state:
34
  st.session_state.follow_up = ""
35
 
36
+ # βœ… Initialize Hugging Face Model (CPU/GPU Compatible)
37
+ def get_llm_hf_inference(model_id="meta-llama/Llama-2-7b-chat-hf", max_new_tokens=800, temperature=0.6):
38
  return HuggingFaceEndpoint(
39
  repo_id=model_id,
40
  max_new_tokens=max_new_tokens,
41
  temperature=temperature,
42
  token=HF_TOKEN,
43
  task="text-generation",
44
+ device=-1 if device == "cpu" else 0
45
  )
46
 
47
  # βœ… NASA API Function
 
57
  sentiment_analyzer = pipeline(
58
  "sentiment-analysis",
59
  model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
60
+ device=-1 if device == "cpu" else 0
61
  )
62
 
63
  def analyze_sentiment(user_text):
 
70
  return "nasa_info"
71
  return "general_query"
72
 
73
+ # βœ… Ensure English Responses (Fixed Detection Error)
74
  def ensure_english(text):
75
+ """Ensures the response is in English, preventing false language detection errors."""
76
  try:
77
  detected_lang = detect(text)
78
+ if detected_lang == "en":
79
+ return text # βœ… It's in English, return as-is
80
  except:
81
+ pass # πŸ”₯ Ignore detection errors, assume English
82
+
83
+ return "⚠️ Sorry, I only respond in English. Can you rephrase your question?"
84
 
85
+ # βœ… Follow-Up Question Generation (Ensures Proper Formatting)
86
  def generate_follow_up(user_text):
87
  """Generates a structured follow-up question in a concise format."""
88
 
 
92
  "Ensure it's concise and structured exactly as requested without extra commentary."
93
  )
94
 
95
+ hf = get_llm_hf_inference(max_new_tokens=30, temperature=0.8)
96
  output = hf.invoke(input=prompt_text).strip()
97
 
98
+ # βœ… Extract relevant part, removing unwanted symbols
99
  cleaned_output = re.sub(r"```|''|\"", "", output).strip()
100
 
 
101
  if "Would you like to learn more about" not in cleaned_output:
102
  cleaned_output = "Would you like to explore another related topic or ask about something else?"
103
 
104
  return cleaned_output
105
 
106
+ # βœ… Main Response Function (Fixed History & Language Issues)
107
  def get_response(system_message, user_text, max_new_tokens=800):
108
+ """Generates a response and ensures conversation history is updated."""
 
 
109
 
110
+ chat_history = st.session_state.chat_history # βœ… Get Chat History
111
+
112
+ # βœ… Store User Input in Chat History BEFORE Generating Response
113
+ chat_history.append({'role': 'user', 'content': user_text})
114
 
115
+ # βœ… Detect Intent (NASA vs General AI chat)
116
  action = predict_action(user_text)
117
 
 
118
  if action == "nasa_info":
119
  nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
120
  response = f"**{nasa_title}**\n\n{nasa_explanation}"
121
  follow_up = generate_follow_up(user_text)
122
 
123
  # βœ… Append to chat history
 
124
  chat_history.append({'role': 'assistant', 'content': response})
125
  chat_history.append({'role': 'assistant', 'content': follow_up})
126
+ st.session_state.chat_history = chat_history
127
  return response, follow_up, nasa_url
128
 
129
  # βœ… Format Conversation History for Model Input
130
  formatted_chat_history = "\n".join(f"{msg['role']}: {msg['content']}" for msg in chat_history)
131
 
132
  # βœ… Invoke Hugging Face Model
133
+ hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.7)
134
 
 
135
  prompt = PromptTemplate.from_template(
136
  "[INST] You are a helpful AI assistant.\n\nCurrent Conversation:\n{chat_history}\n\n"
137
  "User: {user_text}.\n [/INST]\n"
138
+ "AI: Provide a detailed explanation with depth. Use a conversational tone."
139
  "🚨 Answer **only in English**."
 
140
  "\nHAL:"
141
  )
142
 
 
143
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
144
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=formatted_chat_history))
145
  response = response.split("HAL:")[-1].strip() if "HAL:" in response else response.strip()
146
 
147
+ # βœ… Prevent False Language Errors
148
  response = ensure_english(response)
149
 
150
  if not response:
 
152
 
153
  follow_up = generate_follow_up(user_text)
154
 
155
+ # βœ… Append Responses to Chat History
 
156
  chat_history.append({'role': 'assistant', 'content': response})
157
  chat_history.append({'role': 'assistant', 'content': follow_up})
158
+ st.session_state.chat_history = chat_history
159
 
160
  return response, follow_up, None
161
 
162
  # βœ… Streamlit UI
163
  st.title("πŸš€ HAL - NASA AI Assistant")
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  # βœ… Display Chat History
 
166
  for message in st.session_state.chat_history:
167
+ st.markdown(f"**{message['role'].capitalize()}**: {message['content']}")
 
 
 
 
168
 
169
  # βœ… Chat Input
170
  user_input = st.chat_input("Type your message here...")
 
173
  response, follow_up, image_url = get_response("You are a helpful AI assistant.", user_input)
174
 
175
  if response:
176
+ st.markdown(f"**HAL**: {response}")
177
 
178
  if follow_up:
179
+ st.markdown(f"**HAL**: {follow_up}")
180
 
181
  if image_url:
182
  st.image(image_url, caption="NASA Image of the Day")