CCockrum commited on
Commit
a9e4498
Β·
verified Β·
1 Parent(s): b108c4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -125
app.py CHANGED
@@ -1,113 +1,95 @@
1
  import os
2
- import streamlit as st
3
- from langdetect import detect
4
  import torch
 
 
 
 
 
 
5
 
6
- # Check if GPU is available but don't load anything yet
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
8
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
9
 
10
- # Initialize session state variables
11
  if "chat_history" not in st.session_state:
12
- st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you with NASA-related information today?"}]
13
 
14
- if "model_loaded" not in st.session_state:
15
- st.session_state.model_loaded = False
 
 
 
 
 
 
 
 
16
 
17
- # Load environment variables
18
- def load_api_keys():
19
- hf_token = os.getenv("HF_TOKEN")
20
- nasa_api_key = os.getenv("NASA_API_KEY")
21
-
22
- missing_keys = []
23
- if not hf_token:
24
- missing_keys.append("HF_TOKEN")
25
- if not nasa_api_key:
26
- missing_keys.append("NASA_API_KEY")
27
-
28
- return hf_token, nasa_api_key, missing_keys
29
-
30
- # Lazy-load the model only when needed
31
- def load_model():
32
- with st.spinner("Loading AI model... This may take a moment."):
33
- try:
34
- from langchain_huggingface import HuggingFaceEndpoint
35
- from langchain_core.prompts import PromptTemplate
36
- from langchain_core.output_parsers import StrOutputParser
37
-
38
- hf_token, _, _ = load_api_keys()
39
-
40
- # Use a smaller model if you're having resource issues
41
- llm = HuggingFaceEndpoint(
42
- repo_id="meta-llama/Llama-2-7b-chat-hf", # Consider a smaller model like "distilroberta-base"
43
- max_new_tokens=800,
44
- temperature=0.3,
45
- token=hf_token,
46
- task="text-generation",
47
- device=-1 if device == "cpu" else 0
48
- )
49
- st.session_state.model_loaded = True
50
- st.session_state.llm = llm
51
- st.session_state.prompt = PromptTemplate.from_template(
52
- "[INST] You are HAL, a NASA AI assistant with deep knowledge of space, astronomy, and NASA missions. "
53
- "Answer concisely and accurately.\n\n"
54
- "CONTEXT:\n{chat_history}\n"
55
- "\nLATEST USER INPUT:\nUser: {user_text}\n"
56
- "[END CONTEXT]\n"
57
- "Assistant:"
58
- )
59
- return True
60
- except Exception as e:
61
- st.error(f"Error loading model: {str(e)}")
62
- return False
63
-
64
- # Ensure English responses
65
  def ensure_english(text):
66
  try:
67
- if text and len(text) > 5: # Only check if there's meaningful text
68
- detected_lang = detect(text)
69
- if detected_lang != "en":
70
- return "⚠️ Sorry, I only respond in English. Can you rephrase your question?"
71
- return text
72
  except:
73
- return text # Return original if detection fails
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Get response from the model
76
- def get_response(user_text):
77
- if not st.session_state.model_loaded:
78
- if not load_model():
79
- return "Sorry, I'm having trouble loading. Please try again or check your environment setup."
80
 
81
- try:
82
- # Prepare conversation history
83
- filtered_history = "\n".join(
84
- f"{msg['role'].capitalize()}: {msg['content']}"
85
- for msg in st.session_state.chat_history[-5:]
86
- )
87
-
88
- from langchain_core.output_parsers import StrOutputParser
89
-
90
- # Create and invoke the chat pipeline
91
- chat = st.session_state.prompt | st.session_state.llm.bind(skip_prompt=True) | StrOutputParser()
92
-
93
- response = chat.invoke({
94
- "user_text": user_text,
95
- "chat_history": filtered_history
96
- })
97
-
98
- # Clean up response
99
- response = response.split("HAL:")[-1].strip() if "HAL:" in response else response.strip()
100
- response = ensure_english(response)
101
-
102
- if not response:
103
- response = "I'm sorry, but I couldn't generate a response. Can you rephrase your question?"
104
-
105
- return response
106
-
107
- except Exception as e:
108
- return f"I encountered an error: {str(e)}. Please try again with a different question."
109
-
110
- # UI Styling
111
  st.markdown("""
112
  <style>
113
  .user-msg, .assistant-msg {
@@ -118,38 +100,29 @@ st.markdown("""
118
  max-width: 80%;
119
  text-align: justify;
120
  }
121
- .user-msg { background-color: #696969; color: white; margin-left: auto; }
122
  .assistant-msg { background-color: #333333; color: white; }
123
- .container { display: flex; flex-direction: column; }
124
  @media (max-width: 600px) { .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; } }
125
  </style>
126
  """, unsafe_allow_html=True)
127
 
128
- # Main UI
129
- st.title("πŸš€ HAL - NASA AI Assistant")
130
 
131
- # Check for API keys before allowing interaction
132
- hf_token, nasa_api_key, missing_keys = load_api_keys()
133
- if missing_keys:
134
- st.error(f"Missing environment variables: {', '.join(missing_keys)}. Please set them to use this application.")
135
- else:
136
- # Chat interface
137
- user_input = st.chat_input("Ask me about NASA, space missions, or astronomy...")
138
-
139
- if user_input:
140
- # Add user message to history
141
- st.session_state.chat_history.append({"role": "user", "content": user_input})
142
-
143
- # Get AI response
144
- with st.spinner("Thinking..."):
145
- response = get_response(user_input)
146
- st.session_state.chat_history.append({"role": "assistant", "content": response})
147
-
148
- # Display chat history
149
- st.markdown("<div class='container'>", unsafe_allow_html=True)
150
- for message in st.session_state.chat_history:
151
- if message["role"] == "user":
152
- st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
153
- else:
154
- st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
155
- st.markdown("</div>", unsafe_allow_html=True)
 
1
  import os
2
+ import re
3
+ import requests
4
  import torch
5
+ import streamlit as st
6
+ from langchain_huggingface import HuggingFaceEndpoint
7
+ from langchain_core.prompts import PromptTemplate
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from transformers import pipeline
10
+ from langdetect import detect # Ensure this package is installed
11
 
12
+ # βœ… Check for GPU or Default to CPU
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"βœ… Using device: {device}") # Debugging info
15
+
16
+ # βœ… Environment Variables
17
+ HF_TOKEN = os.getenv("HF_TOKEN")
18
+ if HF_TOKEN is None:
19
+ raise ValueError("HF_TOKEN is not set. Please add it to your environment variables.")
20
+
21
+ NASA_API_KEY = os.getenv("NASA_API_KEY")
22
+ if NASA_API_KEY is None:
23
+ raise ValueError("NASA_API_KEY is not set. Please add it to your environment variables.")
24
+
25
+ # βœ… Set Up Streamlit
26
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
27
 
28
+ # βœ… Initialize Session State Variables (Ensuring Chat History Persists)
29
  if "chat_history" not in st.session_state:
30
+ st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
31
 
32
+ # βœ… Initialize Hugging Face Model (Explicitly Set to CPU/GPU)
33
+ def get_llm_hf_inference(model_id="meta-llama/Llama-2-7b-chat-hf", max_new_tokens=800, temperature=0.3):
34
+ return HuggingFaceEndpoint(
35
+ repo_id=model_id,
36
+ max_new_tokens=max_new_tokens,
37
+ temperature=temperature, # πŸ”₯ Lowered temperature for more factual and structured responses
38
+ token=HF_TOKEN,
39
+ task="text-generation",
40
+ device=-1 if device == "cpu" else 0 # βœ… Force CPU (-1) or GPU (0)
41
+ )
42
 
43
+ # βœ… Ensure English Responses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def ensure_english(text):
45
  try:
46
+ detected_lang = detect(text)
47
+ if detected_lang != "en":
48
+ return "⚠️ Sorry, I only respond in English. Can you rephrase your question?"
 
 
49
  except:
50
+ return "⚠️ Language detection failed. Please ask your question again."
51
+ return text
52
+
53
+ # βœ… Main Response Function (Fixing Repetition & Context)
54
+ def get_response(system_message, chat_history, user_text, max_new_tokens=800):
55
+ # βœ… Ensure conversation history is included correctly
56
+ filtered_history = "\n".join(
57
+ f"{msg['role'].capitalize()}: {msg['content']}"
58
+ for msg in chat_history[-5:] # βœ… Only keep the last 5 exchanges to prevent overflow
59
+ )
60
+
61
+ prompt = PromptTemplate.from_template(
62
+ "[INST] You are a highly knowledgeable AI assistant. Answer concisely, avoid repetition, and structure responses well."
63
+ "\n\nCONTEXT:\n{chat_history}\n"
64
+ "\nLATEST USER INPUT:\nUser: {user_text}\n"
65
+ "\n[END CONTEXT]\n"
66
+ "Assistant:"
67
+ )
68
 
69
+ # βœ… Invoke Hugging Face Model
70
+ hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.3) # πŸ”₯ Lowered temperature
71
+ chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
72
+
73
+ response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
74
 
75
+ # Clean up the response - remove any "HAL:" prefix if present
76
+ response = response.split("HAL:")[-1].strip() if "HAL:" in response else response.strip()
77
+ response = ensure_english(response)
78
+
79
+ if not response:
80
+ response = "I'm sorry, but I couldn't generate a response. Can you rephrase your question?"
81
+
82
+ # βœ… Update conversation history
83
+ chat_history.append({'role': 'user', 'content': user_text})
84
+ chat_history.append({'role': 'assistant', 'content': response})
85
+
86
+ # βœ… Keep only last 10 exchanges to prevent unnecessary repetition
87
+ return response, chat_history[-10:]
88
+
89
+ # βœ… Streamlit UI
90
+ st.title("πŸš€ HAL - NASA AI Assistant")
91
+
92
+ # βœ… Justify all chatbot responses
 
 
 
 
 
 
 
 
 
 
 
 
93
  st.markdown("""
94
  <style>
95
  .user-msg, .assistant-msg {
 
100
  max-width: 80%;
101
  text-align: justify;
102
  }
103
+ .user-msg { background-color: #696969; color: white; }
104
  .assistant-msg { background-color: #333333; color: white; }
105
+ .container { display: flex; flex-direction: column; align-items: flex-start; }
106
  @media (max-width: 600px) { .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; } }
107
  </style>
108
  """, unsafe_allow_html=True)
109
 
110
+ # βœ… Chat UI
111
+ user_input = st.chat_input("Type your message here...")
112
 
113
+ if user_input:
114
+ # Get response and update chat history
115
+ response, st.session_state.chat_history = get_response(
116
+ system_message="You are a helpful AI assistant.",
117
+ user_text=user_input,
118
+ chat_history=st.session_state.chat_history
119
+ )
120
+
121
+ # βœ… Display chat history (ONLY display from history, not separately)
122
+ st.markdown("<div class='container'>", unsafe_allow_html=True)
123
+ for message in st.session_state.chat_history:
124
+ if message["role"] == "user":
125
+ st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
126
+ else:
127
+ st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
128
+ st.markdown("</div>", unsafe_allow_html=True)