CCockrum commited on
Commit
f009e2d
Β·
verified Β·
1 Parent(s): 9cb1276

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -59
app.py CHANGED
@@ -7,11 +7,11 @@ from langchain_huggingface import HuggingFaceEndpoint
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain_core.output_parsers import StrOutputParser
9
  from transformers import pipeline
10
- from langdetect import detect
11
 
12
  # βœ… Check for GPU or Default to CPU
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- print(f"βœ… Using device: {device}")
15
 
16
  # βœ… Environment Variables
17
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -25,23 +25,21 @@ if NASA_API_KEY is None:
25
  # βœ… Set Up Streamlit
26
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
27
 
28
- # βœ… Ensure Session State Variables (Maintains Chat History)
29
  if "chat_history" not in st.session_state:
30
  st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
31
  if "response_ready" not in st.session_state:
32
  st.session_state.response_ready = False
33
- if "follow_up" not in st.session_state:
34
- st.session_state.follow_up = ""
35
 
36
- # βœ… Initialize Hugging Face Model (CPU/GPU Compatible)
37
- def get_llm_hf_inference(model_id="meta-llama/Llama-2-7b-chat-hf", max_new_tokens=800, temperature=0.3):
38
  return HuggingFaceEndpoint(
39
  repo_id=model_id,
40
  max_new_tokens=max_new_tokens,
41
  temperature=temperature,
42
  token=HF_TOKEN,
43
  task="text-generation",
44
- device=-1 if device == "cpu" else 0
45
  )
46
 
47
  # βœ… NASA API Function
@@ -57,7 +55,7 @@ def get_nasa_apod():
57
  sentiment_analyzer = pipeline(
58
  "sentiment-analysis",
59
  model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
60
- device=-1 if device == "cpu" else 0
61
  )
62
 
63
  def analyze_sentiment(user_text):
@@ -70,73 +68,56 @@ def predict_action(user_text):
70
  return "nasa_info"
71
  return "general_query"
72
 
73
- # βœ… Ensure English Responses (Fixed Detection Error)
74
  def ensure_english(text):
75
- """Ensures the response is in English, preventing false language detection errors."""
76
  try:
77
  detected_lang = detect(text)
78
- if detected_lang == "en":
79
- return text # βœ… It's in English, return as-is
80
  except:
81
- pass # πŸ”₯ Ignore detection errors, assume English
82
-
83
- return "⚠️ Sorry, I only respond in English. Can you rephrase your question?"
84
 
85
- # βœ… Main Response Function (Fixed History & Language Issues)
86
- def get_response(system_message, user_text, max_new_tokens=800):
87
- """Generates a response and ensures conversation history is updated."""
88
-
89
- chat_history = st.session_state.chat_history # βœ… Get Chat History
90
-
91
- # βœ… Store User Input in Chat History BEFORE Generating Response
92
- chat_history.append({'role': 'user', 'content': user_text})
93
-
94
- # βœ… Detect Intent (NASA vs General AI chat)
95
  action = predict_action(user_text)
96
 
 
97
  if action == "nasa_info":
98
  nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
99
  response = f"**{nasa_title}**\n\n{nasa_explanation}"
100
- follow_up = generate_follow_up(user_text)
101
-
102
- # βœ… Append to chat history
103
  chat_history.append({'role': 'assistant', 'content': response})
104
- chat_history.append({'role': 'assistant', 'content': follow_up})
105
- st.session_state.chat_history = chat_history
106
- return response, follow_up, nasa_url
107
-
108
- # βœ… Format Conversation History for Model Input
109
- formatted_chat_history = "\n".join(f"{msg['role']}: {msg['content']}" for msg in chat_history)
110
 
111
  # βœ… Invoke Hugging Face Model
112
- hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.3)
 
 
113
 
114
  prompt = PromptTemplate.from_template(
115
  "[INST] You are a helpful AI assistant.\n\nCurrent Conversation:\n{chat_history}\n\n"
116
  "User: {user_text}.\n [/INST]\n"
117
- "AI: Provide a detailed explanation with depth. Use a conversational tone."
118
  "🚨 Answer **only in English**."
 
119
  "\nHAL:"
120
  )
121
 
122
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
123
- response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=formatted_chat_history))
124
  response = response.split("HAL:")[-1].strip() if "HAL:" in response else response.strip()
125
 
126
- # βœ… Prevent False Language Errors
127
  response = ensure_english(response)
128
 
129
  if not response:
130
  response = "I'm sorry, but I couldn't generate a response. Can you rephrase your question?"
131
 
132
- follow_up = generate_follow_up(user_text)
133
-
134
- # βœ… Append Responses to Chat History
135
- chat_history.append({'role': 'assistant', 'content': response})
136
- chat_history.append({'role': 'assistant', 'content': follow_up})
137
- st.session_state.chat_history = chat_history
138
 
139
- return response, follow_up, None
140
 
141
  # βœ… Streamlit UI
142
  st.title("πŸš€ HAL - NASA AI Assistant")
@@ -145,7 +126,7 @@ st.title("πŸš€ HAL - NASA AI Assistant")
145
  st.markdown("""
146
  <style>
147
  .user-msg, .assistant-msg {
148
- padding: 10px;
149
  border-radius: 10px;
150
  margin-bottom: 5px;
151
  width: fit-content;
@@ -159,22 +140,29 @@ st.markdown("""
159
  </style>
160
  """, unsafe_allow_html=True)
161
 
 
 
 
 
162
 
163
- # βœ… Display Chat History
164
- for message in st.session_state.chat_history:
165
- st.markdown(f"**{message['role'].capitalize()}**: {message['content']}")
166
-
167
- # βœ… Chat Input
168
  user_input = st.chat_input("Type your message here...")
169
 
170
  if user_input:
171
- response, follow_up, image_url = get_response("You are a helpful AI assistant.", user_input)
 
 
 
 
172
 
173
  if response:
174
- st.markdown(f"**HAL**: {response}")
175
-
176
- if follow_up:
177
- st.markdown(f"**HAL**: {follow_up}")
178
 
179
- if image_url:
180
- st.image(image_url, caption="NASA Image of the Day")
 
 
 
 
 
 
 
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain_core.output_parsers import StrOutputParser
9
  from transformers import pipeline
10
+ from langdetect import detect # Ensure this package is installed
11
 
12
  # βœ… Check for GPU or Default to CPU
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"βœ… Using device: {device}") # Debugging info
15
 
16
  # βœ… Environment Variables
17
  HF_TOKEN = os.getenv("HF_TOKEN")
 
25
  # βœ… Set Up Streamlit
26
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
27
 
28
+ # βœ… Initialize Session State Variables (Ensuring Chat History Persists)
29
  if "chat_history" not in st.session_state:
30
  st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
31
  if "response_ready" not in st.session_state:
32
  st.session_state.response_ready = False
 
 
33
 
34
+ # βœ… Initialize Hugging Face Model (Explicitly Set to CPU/GPU)
35
+ def get_llm_hf_inference(model_id="meta-llama/Llama-2-7b-chat-hf", max_new_tokens=800, temperature=0.8):
36
  return HuggingFaceEndpoint(
37
  repo_id=model_id,
38
  max_new_tokens=max_new_tokens,
39
  temperature=temperature,
40
  token=HF_TOKEN,
41
  task="text-generation",
42
+ device=-1 if device == "cpu" else 0 # βœ… Force CPU (-1) or GPU (0)
43
  )
44
 
45
  # βœ… NASA API Function
 
55
  sentiment_analyzer = pipeline(
56
  "sentiment-analysis",
57
  model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
58
+ device=-1 if device == "cpu" else 0 # βœ… Force CPU (-1) or GPU (0)
59
  )
60
 
61
  def analyze_sentiment(user_text):
 
68
  return "nasa_info"
69
  return "general_query"
70
 
71
+ # βœ… Ensure English Responses
72
  def ensure_english(text):
 
73
  try:
74
  detected_lang = detect(text)
75
+ if detected_lang != "en":
76
+ return "⚠️ Sorry, I only respond in English. Can you rephrase your question?"
77
  except:
78
+ return "⚠️ Language detection failed. Please ask your question again."
79
+ return text
 
80
 
81
+ # βœ… Main Response Function (Follow-Up Question Removed)
82
+ def get_response(system_message, chat_history, user_text, max_new_tokens=800):
 
 
 
 
 
 
 
 
83
  action = predict_action(user_text)
84
 
85
+ # βœ… Handle NASA-Specific Queries
86
  if action == "nasa_info":
87
  nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
88
  response = f"**{nasa_title}**\n\n{nasa_explanation}"
89
+ chat_history.append({'role': 'user', 'content': user_text})
 
 
90
  chat_history.append({'role': 'assistant', 'content': response})
91
+ return response, chat_history, nasa_url
 
 
 
 
 
92
 
93
  # βœ… Invoke Hugging Face Model
94
+ hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.9)
95
+
96
+ filtered_history = "\n".join(f"{msg['role']}: {msg['content']}" for msg in chat_history)
97
 
98
  prompt = PromptTemplate.from_template(
99
  "[INST] You are a helpful AI assistant.\n\nCurrent Conversation:\n{chat_history}\n\n"
100
  "User: {user_text}.\n [/INST]\n"
101
+ "AI: Provide a detailed explanation with depth. Use a conversational tone. "
102
  "🚨 Answer **only in English**."
103
+ "Ensure a friendly, engaging tone."
104
  "\nHAL:"
105
  )
106
 
107
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
108
+ response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
109
  response = response.split("HAL:")[-1].strip() if "HAL:" in response else response.strip()
110
 
 
111
  response = ensure_english(response)
112
 
113
  if not response:
114
  response = "I'm sorry, but I couldn't generate a response. Can you rephrase your question?"
115
 
116
+ # βœ… Preserve conversation history
117
+ st.session_state.chat_history.append({'role': 'user', 'content': user_text})
118
+ st.session_state.chat_history.append({'role': 'assistant', 'content': response})
 
 
 
119
 
120
+ return response, chat_history
121
 
122
  # βœ… Streamlit UI
123
  st.title("πŸš€ HAL - NASA AI Assistant")
 
126
  st.markdown("""
127
  <style>
128
  .user-msg, .assistant-msg {
129
+ padding: 11px;
130
  border-radius: 10px;
131
  margin-bottom: 5px;
132
  width: fit-content;
 
140
  </style>
141
  """, unsafe_allow_html=True)
142
 
143
+ # βœ… Reset Chat Button
144
+ if st.sidebar.button("Reset Chat"):
145
+ st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
146
+ st.session_state.response_ready = False
147
 
148
+ # βœ… Chat UI
 
 
 
 
149
  user_input = st.chat_input("Type your message here...")
150
 
151
  if user_input:
152
+ response, st.session_state.chat_history = get_response(
153
+ system_message="You are a helpful AI assistant.",
154
+ user_text=user_input,
155
+ chat_history=st.session_state.chat_history
156
+ )
157
 
158
  if response:
159
+ st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {response}</div>", unsafe_allow_html=True)
 
 
 
160
 
161
+ # βœ… Display chat history
162
+ st.markdown("<div class='container'>", unsafe_allow_html=True)
163
+ for message in st.session_state.chat_history:
164
+ if message["role"] == "user":
165
+ st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
166
+ else:
167
+ st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
168
+ st.markdown("</div>", unsafe_allow_html=True)