Spaces:
Sleeping
Sleeping
File size: 5,948 Bytes
a3e0475 d6f5773 e673788 e653ea8 e673788 e653ea8 b256ef1 5c095c6 e673788 a53e6ab e673788 a53e6ab e673788 a53e6ab e673788 a53e6ab e673788 a53e6ab e673788 a53e6ab e673788 1084147 e673788 a53e6ab e673788 a53e6ab e673788 a53e6ab e673788 a53e6ab e673788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import streamlit as st
from langdetect import detect
import torch
# Check if GPU is available but don't load anything yet
device = "cuda" if torch.cuda.is_available() else "cpu"
st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="๐")
# Initialize session state variables
if "chat_history" not in st.session_state:
st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you with NASA-related information today?"}]
if "model_loaded" not in st.session_state:
st.session_state.model_loaded = False
# Load environment variables
def load_api_keys():
hf_token = os.getenv("HF_TOKEN")
nasa_api_key = os.getenv("NASA_API_KEY")
missing_keys = []
if not hf_token:
missing_keys.append("HF_TOKEN")
if not nasa_api_key:
missing_keys.append("NASA_API_KEY")
return hf_token, nasa_api_key, missing_keys
# Lazy-load the model only when needed
def load_model():
with st.spinner("Loading AI model... This may take a moment."):
try:
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
hf_token, _, _ = load_api_keys()
# Use a smaller model if you're having resource issues
llm = HuggingFaceEndpoint(
repo_id="meta-llama/Llama-2-7b-chat-hf", # Consider a smaller model like "distilroberta-base"
max_new_tokens=800,
temperature=0.3,
token=hf_token,
task="text-generation",
device=-1 if device == "cpu" else 0
)
st.session_state.model_loaded = True
st.session_state.llm = llm
st.session_state.prompt = PromptTemplate.from_template(
"[INST] You are HAL, a NASA AI assistant with deep knowledge of space, astronomy, and NASA missions. "
"Answer concisely and accurately.\n\n"
"CONTEXT:\n{chat_history}\n"
"\nLATEST USER INPUT:\nUser: {user_text}\n"
"[END CONTEXT]\n"
"Assistant:"
)
return True
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return False
# Ensure English responses
def ensure_english(text):
try:
if text and len(text) > 5: # Only check if there's meaningful text
detected_lang = detect(text)
if detected_lang != "en":
return "โ ๏ธ Sorry, I only respond in English. Can you rephrase your question?"
return text
except:
return text # Return original if detection fails
# Get response from the model
def get_response(user_text):
if not st.session_state.model_loaded:
if not load_model():
return "Sorry, I'm having trouble loading. Please try again or check your environment setup."
try:
# Prepare conversation history
filtered_history = "\n".join(
f"{msg['role'].capitalize()}: {msg['content']}"
for msg in st.session_state.chat_history[-5:]
)
from langchain_core.output_parsers import StrOutputParser
# Create and invoke the chat pipeline
chat = st.session_state.prompt | st.session_state.llm.bind(skip_prompt=True) | StrOutputParser()
response = chat.invoke({
"user_text": user_text,
"chat_history": filtered_history
})
# Clean up response
response = response.split("HAL:")[-1].strip() if "HAL:" in response else response.strip()
response = ensure_english(response)
if not response:
response = "I'm sorry, but I couldn't generate a response. Can you rephrase your question?"
return response
except Exception as e:
return f"I encountered an error: {str(e)}. Please try again with a different question."
# UI Styling
st.markdown("""
<style>
.user-msg, .assistant-msg {
padding: 11px;
border-radius: 10px;
margin-bottom: 5px;
width: fit-content;
max-width: 80%;
text-align: justify;
}
.user-msg { background-color: #696969; color: white; margin-left: auto; }
.assistant-msg { background-color: #333333; color: white; }
.container { display: flex; flex-direction: column; }
@media (max-width: 600px) { .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; } }
</style>
""", unsafe_allow_html=True)
# Main UI
st.title("๐ HAL - NASA AI Assistant")
# Check for API keys before allowing interaction
hf_token, nasa_api_key, missing_keys = load_api_keys()
if missing_keys:
st.error(f"Missing environment variables: {', '.join(missing_keys)}. Please set them to use this application.")
else:
# Chat interface
user_input = st.chat_input("Ask me about NASA, space missions, or astronomy...")
if user_input:
# Add user message to history
st.session_state.chat_history.append({"role": "user", "content": user_input})
# Get AI response
with st.spinner("Thinking..."):
response = get_response(user_input)
st.session_state.chat_history.append({"role": "assistant", "content": response})
# Display chat history
st.markdown("<div class='container'>", unsafe_allow_html=True)
for message in st.session_state.chat_history:
if message["role"] == "user":
st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
else:
st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True) |