CCockrum commited on
Commit
e673788
Β·
verified Β·
1 Parent(s): 4f642ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -98
app.py CHANGED
@@ -1,95 +1,113 @@
1
  import os
2
- import re
3
- import requests
4
- import torch
5
  import streamlit as st
6
- from langchain_huggingface import HuggingFaceEndpoint
7
- from langchain_core.prompts import PromptTemplate
8
- from langchain_core.output_parsers import StrOutputParser
9
- from transformers import pipeline
10
- from langdetect import detect # Ensure this package is installed
11
 
12
- # βœ… Check for GPU or Default to CPU
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- print(f"βœ… Using device: {device}") # Debugging info
15
-
16
- # βœ… Environment Variables
17
- HF_TOKEN = os.getenv("HF_TOKEN")
18
- if HF_TOKEN is None:
19
- raise ValueError("HF_TOKEN is not set. Please add it to your environment variables.")
20
-
21
- NASA_API_KEY = os.getenv("NASA_API_KEY")
22
- if NASA_API_KEY is None:
23
- raise ValueError("NASA_API_KEY is not set. Please add it to your environment variables.")
24
-
25
- # βœ… Set Up Streamlit
26
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
27
 
28
- # βœ… Initialize Session State Variables (Ensuring Chat History Persists)
29
  if "chat_history" not in st.session_state:
30
- st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
31
 
32
- # βœ… Initialize Hugging Face Model (Explicitly Set to CPU/GPU)
33
- def get_llm_hf_inference(model_id="meta-llama/Llama-2-7b-chat-hf", max_new_tokens=800, temperature=0.3):
34
- return HuggingFaceEndpoint(
35
- repo_id=model_id,
36
- max_new_tokens=max_new_tokens,
37
- temperature=temperature, # πŸ”₯ Lowered temperature for more factual and structured responses
38
- token=HF_TOKEN,
39
- task="text-generation",
40
- device=-1 if device == "cpu" else 0 # βœ… Force CPU (-1) or GPU (0)
41
- )
42
 
43
- # βœ… Ensure English Responses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def ensure_english(text):
45
  try:
46
- detected_lang = detect(text)
47
- if detected_lang != "en":
48
- return "⚠️ Sorry, I only respond in English. Can you rephrase your question?"
 
 
49
  except:
50
- return "⚠️ Language detection failed. Please ask your question again."
51
- return text
52
-
53
- # βœ… Main Response Function (Fixing Repetition & Context)
54
- def get_response(system_message, chat_history, user_text, max_new_tokens=800):
55
- # βœ… Ensure conversation history is included correctly
56
- filtered_history = "\n".join(
57
- f"{msg['role'].capitalize()}: {msg['content']}"
58
- for msg in chat_history[-5:] # βœ… Only keep the last 5 exchanges to prevent overflow
59
- )
60
-
61
- prompt = PromptTemplate.from_template(
62
- "[INST] You are a highly knowledgeable AI assistant. Answer concisely, avoid repetition, and structure responses well."
63
- "\n\nCONTEXT:\n{chat_history}\n"
64
- "\nLATEST USER INPUT:\nUser: {user_text}\n"
65
- "\n[END CONTEXT]\n"
66
- "Assistant:"
67
- )
68
 
69
- # βœ… Invoke Hugging Face Model
70
- hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.3) # πŸ”₯ Lowered temperature
71
- chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
72
-
73
- response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
74
 
75
- # Clean up the response - remove any "HAL:" prefix if present
76
- response = response.split("HAL:")[-1].strip() if "HAL:" in response else response.strip()
77
- response = ensure_english(response)
78
-
79
- if not response:
80
- response = "I'm sorry, but I couldn't generate a response. Can you rephrase your question?"
81
-
82
- # βœ… Update conversation history
83
- chat_history.append({'role': 'user', 'content': user_text})
84
- chat_history.append({'role': 'assistant', 'content': response})
85
-
86
- # βœ… Keep only last 10 exchanges to prevent unnecessary repetition
87
- return response, chat_history[-10:]
88
-
89
- # βœ… Streamlit UI
90
- st.title("πŸš€ HAL - NASA AI Assistant")
91
-
92
- # βœ… Justify all chatbot responses
 
 
 
 
 
 
 
 
 
 
 
 
93
  st.markdown("""
94
  <style>
95
  .user-msg, .assistant-msg {
@@ -100,29 +118,38 @@ st.markdown("""
100
  max-width: 80%;
101
  text-align: justify;
102
  }
103
- .user-msg { background-color: #696969; color: white; }
104
  .assistant-msg { background-color: #333333; color: white; }
105
- .container { display: flex; flex-direction: column; align-items: flex-start; }
106
  @media (max-width: 600px) { .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; } }
107
  </style>
108
  """, unsafe_allow_html=True)
109
 
110
- # βœ… Chat UI
111
- user_input = st.chat_input("Type your message here...")
112
-
113
- if user_input:
114
- # Get response and update chat history
115
- response, st.session_state.chat_history = get_response(
116
- system_message="You are a helpful AI assistant.",
117
- user_text=user_input,
118
- chat_history=st.session_state.chat_history
119
- )
120
 
121
- # βœ… Display chat history (ONLY display from history, not separately)
122
- st.markdown("<div class='container'>", unsafe_allow_html=True)
123
- for message in st.session_state.chat_history:
124
- if message["role"] == "user":
125
- st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
126
- else:
127
- st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
128
- st.markdown("</div>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
2
  import streamlit as st
3
+ from langdetect import detect
4
+ import torch
 
 
 
5
 
6
+ # Check if GPU is available but don't load anything yet
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
8
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
9
 
10
+ # Initialize session state variables
11
  if "chat_history" not in st.session_state:
12
+ st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you with NASA-related information today?"}]
13
 
14
+ if "model_loaded" not in st.session_state:
15
+ st.session_state.model_loaded = False
 
 
 
 
 
 
 
 
16
 
17
+ # Load environment variables
18
+ def load_api_keys():
19
+ hf_token = os.getenv("HF_TOKEN")
20
+ nasa_api_key = os.getenv("NASA_API_KEY")
21
+
22
+ missing_keys = []
23
+ if not hf_token:
24
+ missing_keys.append("HF_TOKEN")
25
+ if not nasa_api_key:
26
+ missing_keys.append("NASA_API_KEY")
27
+
28
+ return hf_token, nasa_api_key, missing_keys
29
+
30
+ # Lazy-load the model only when needed
31
+ def load_model():
32
+ with st.spinner("Loading AI model... This may take a moment."):
33
+ try:
34
+ from langchain_huggingface import HuggingFaceEndpoint
35
+ from langchain_core.prompts import PromptTemplate
36
+ from langchain_core.output_parsers import StrOutputParser
37
+
38
+ hf_token, _, _ = load_api_keys()
39
+
40
+ # Use a smaller model if you're having resource issues
41
+ llm = HuggingFaceEndpoint(
42
+ repo_id="meta-llama/Llama-2-7b-chat-hf", # Consider a smaller model like "distilroberta-base"
43
+ max_new_tokens=800,
44
+ temperature=0.3,
45
+ token=hf_token,
46
+ task="text-generation",
47
+ device=-1 if device == "cpu" else 0
48
+ )
49
+ st.session_state.model_loaded = True
50
+ st.session_state.llm = llm
51
+ st.session_state.prompt = PromptTemplate.from_template(
52
+ "[INST] You are HAL, a NASA AI assistant with deep knowledge of space, astronomy, and NASA missions. "
53
+ "Answer concisely and accurately.\n\n"
54
+ "CONTEXT:\n{chat_history}\n"
55
+ "\nLATEST USER INPUT:\nUser: {user_text}\n"
56
+ "[END CONTEXT]\n"
57
+ "Assistant:"
58
+ )
59
+ return True
60
+ except Exception as e:
61
+ st.error(f"Error loading model: {str(e)}")
62
+ return False
63
+
64
+ # Ensure English responses
65
  def ensure_english(text):
66
  try:
67
+ if text and len(text) > 5: # Only check if there's meaningful text
68
+ detected_lang = detect(text)
69
+ if detected_lang != "en":
70
+ return "⚠️ Sorry, I only respond in English. Can you rephrase your question?"
71
+ return text
72
  except:
73
+ return text # Return original if detection fails
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ # Get response from the model
76
+ def get_response(user_text):
77
+ if not st.session_state.model_loaded:
78
+ if not load_model():
79
+ return "Sorry, I'm having trouble loading. Please try again or check your environment setup."
80
 
81
+ try:
82
+ # Prepare conversation history
83
+ filtered_history = "\n".join(
84
+ f"{msg['role'].capitalize()}: {msg['content']}"
85
+ for msg in st.session_state.chat_history[-5:]
86
+ )
87
+
88
+ from langchain_core.output_parsers import StrOutputParser
89
+
90
+ # Create and invoke the chat pipeline
91
+ chat = st.session_state.prompt | st.session_state.llm.bind(skip_prompt=True) | StrOutputParser()
92
+
93
+ response = chat.invoke({
94
+ "user_text": user_text,
95
+ "chat_history": filtered_history
96
+ })
97
+
98
+ # Clean up response
99
+ response = response.split("HAL:")[-1].strip() if "HAL:" in response else response.strip()
100
+ response = ensure_english(response)
101
+
102
+ if not response:
103
+ response = "I'm sorry, but I couldn't generate a response. Can you rephrase your question?"
104
+
105
+ return response
106
+
107
+ except Exception as e:
108
+ return f"I encountered an error: {str(e)}. Please try again with a different question."
109
+
110
+ # UI Styling
111
  st.markdown("""
112
  <style>
113
  .user-msg, .assistant-msg {
 
118
  max-width: 80%;
119
  text-align: justify;
120
  }
121
+ .user-msg { background-color: #696969; color: white; margin-left: auto; }
122
  .assistant-msg { background-color: #333333; color: white; }
123
+ .container { display: flex; flex-direction: column; }
124
  @media (max-width: 600px) { .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; } }
125
  </style>
126
  """, unsafe_allow_html=True)
127
 
128
+ # Main UI
129
+ st.title("πŸš€ HAL - NASA AI Assistant")
 
 
 
 
 
 
 
 
130
 
131
+ # Check for API keys before allowing interaction
132
+ hf_token, nasa_api_key, missing_keys = load_api_keys()
133
+ if missing_keys:
134
+ st.error(f"Missing environment variables: {', '.join(missing_keys)}. Please set them to use this application.")
135
+ else:
136
+ # Chat interface
137
+ user_input = st.chat_input("Ask me about NASA, space missions, or astronomy...")
138
+
139
+ if user_input:
140
+ # Add user message to history
141
+ st.session_state.chat_history.append({"role": "user", "content": user_input})
142
+
143
+ # Get AI response
144
+ with st.spinner("Thinking..."):
145
+ response = get_response(user_input)
146
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
147
+
148
+ # Display chat history
149
+ st.markdown("<div class='container'>", unsafe_allow_html=True)
150
+ for message in st.session_state.chat_history:
151
+ if message["role"] == "user":
152
+ st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
153
+ else:
154
+ st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
155
+ st.markdown("</div>", unsafe_allow_html=True)