File size: 5,948 Bytes
a3e0475
d6f5773
e673788
 
e653ea8
e673788
e653ea8
b256ef1
5c095c6
e673788
a53e6ab
e673788
a53e6ab
e673788
 
a53e6ab
e673788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a53e6ab
 
e673788
 
 
 
 
a53e6ab
e673788
a53e6ab
e673788
 
 
 
 
1084147
e673788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a53e6ab
 
 
 
 
 
 
 
 
 
e673788
a53e6ab
e673788
a53e6ab
 
 
 
e673788
 
a53e6ab
e673788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import streamlit as st
from langdetect import detect
import torch

# Check if GPU is available but don't load anything yet
device = "cuda" if torch.cuda.is_available() else "cpu"
st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="๐Ÿš€")

# Initialize session state variables
if "chat_history" not in st.session_state:
    st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you with NASA-related information today?"}]

if "model_loaded" not in st.session_state:
    st.session_state.model_loaded = False

# Load environment variables
def load_api_keys():
    hf_token = os.getenv("HF_TOKEN")
    nasa_api_key = os.getenv("NASA_API_KEY")
    
    missing_keys = []
    if not hf_token:
        missing_keys.append("HF_TOKEN")
    if not nasa_api_key:
        missing_keys.append("NASA_API_KEY")
    
    return hf_token, nasa_api_key, missing_keys

# Lazy-load the model only when needed
def load_model():
    with st.spinner("Loading AI model... This may take a moment."):
        try:
            from langchain_huggingface import HuggingFaceEndpoint
            from langchain_core.prompts import PromptTemplate
            from langchain_core.output_parsers import StrOutputParser
            
            hf_token, _, _ = load_api_keys()
            
            # Use a smaller model if you're having resource issues
            llm = HuggingFaceEndpoint(
                repo_id="meta-llama/Llama-2-7b-chat-hf",  # Consider a smaller model like "distilroberta-base"
                max_new_tokens=800,
                temperature=0.3,
                token=hf_token,
                task="text-generation",
                device=-1 if device == "cpu" else 0
            )
            st.session_state.model_loaded = True
            st.session_state.llm = llm
            st.session_state.prompt = PromptTemplate.from_template(
                "[INST] You are HAL, a NASA AI assistant with deep knowledge of space, astronomy, and NASA missions. "
                "Answer concisely and accurately.\n\n"
                "CONTEXT:\n{chat_history}\n"
                "\nLATEST USER INPUT:\nUser: {user_text}\n"
                "[END CONTEXT]\n"
                "Assistant:"
            )
            return True
        except Exception as e:
            st.error(f"Error loading model: {str(e)}")
            return False

# Ensure English responses
def ensure_english(text):
    try:
        if text and len(text) > 5:  # Only check if there's meaningful text
            detected_lang = detect(text)
            if detected_lang != "en":
                return "โš ๏ธ Sorry, I only respond in English. Can you rephrase your question?"
        return text
    except:
        return text  # Return original if detection fails

# Get response from the model
def get_response(user_text):
    if not st.session_state.model_loaded:
        if not load_model():
            return "Sorry, I'm having trouble loading. Please try again or check your environment setup."
    
    try:
        # Prepare conversation history
        filtered_history = "\n".join(
            f"{msg['role'].capitalize()}: {msg['content']}"
            for msg in st.session_state.chat_history[-5:]
        )
        
        from langchain_core.output_parsers import StrOutputParser
        
        # Create and invoke the chat pipeline
        chat = st.session_state.prompt | st.session_state.llm.bind(skip_prompt=True) | StrOutputParser()
        
        response = chat.invoke({
            "user_text": user_text,
            "chat_history": filtered_history
        })
        
        # Clean up response
        response = response.split("HAL:")[-1].strip() if "HAL:" in response else response.strip()
        response = ensure_english(response)
        
        if not response:
            response = "I'm sorry, but I couldn't generate a response. Can you rephrase your question?"
        
        return response
        
    except Exception as e:
        return f"I encountered an error: {str(e)}. Please try again with a different question."

# UI Styling
st.markdown("""
    <style>
    .user-msg, .assistant-msg {
        padding: 11px;
        border-radius: 10px;
        margin-bottom: 5px;
        width: fit-content;
        max-width: 80%;
        text-align: justify;
    }
    .user-msg { background-color: #696969; color: white; margin-left: auto; }
    .assistant-msg { background-color: #333333; color: white; }
    .container { display: flex; flex-direction: column; }
    @media (max-width: 600px) { .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; } }
    </style>
""", unsafe_allow_html=True)

# Main UI
st.title("๐Ÿš€ HAL - NASA AI Assistant")

# Check for API keys before allowing interaction
hf_token, nasa_api_key, missing_keys = load_api_keys()
if missing_keys:
    st.error(f"Missing environment variables: {', '.join(missing_keys)}. Please set them to use this application.")
else:
    # Chat interface
    user_input = st.chat_input("Ask me about NASA, space missions, or astronomy...")
    
    if user_input:
        # Add user message to history
        st.session_state.chat_history.append({"role": "user", "content": user_input})
        
        # Get AI response
        with st.spinner("Thinking..."):
            response = get_response(user_input)
            st.session_state.chat_history.append({"role": "assistant", "content": response})
    
    # Display chat history
    st.markdown("<div class='container'>", unsafe_allow_html=True)
    for message in st.session_state.chat_history:
        if message["role"] == "user":
            st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
        else:
            st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
    st.markdown("</div>", unsafe_allow_html=True)