File size: 4,862 Bytes
627a6b9
 
a3e0475
a9e4498
 
e673788
a9e4498
42a0358
59d18ff
a9e4498
 
59d18ff
627a6b9
e653ea8
39ae770
 
 
 
 
 
 
627a6b9
e653ea8
627a6b9
a9e4498
 
 
 
 
 
 
 
 
 
627a6b9
b256ef1
5c095c6
a53e6ab
a9e4498
a53e6ab
39ae770
 
 
 
 
 
627a6b9
59d18ff
 
 
627a6b9
a9e4498
 
 
627a6b9
a9e4498
627a6b9
 
a9e4498
a53e6ab
 
 
a9e4498
 
 
a53e6ab
a9e4498
 
 
627a6b9
a9e4498
627a6b9
a9e4498
627a6b9
a9e4498
627a6b9
 
 
 
 
 
 
 
 
 
a9e4498
a53e6ab
627a6b9
a9e4498
 
 
627a6b9
a9e4498
 
 
 
 
 
 
 
 
 
 
 
a53e6ab
 
 
 
 
 
 
 
 
 
a9e4498
a53e6ab
a9e4498
a53e6ab
 
 
 
a9e4498
a53e6ab
a9e4498
 
 
 
 
 
 
 
 
 
 
 
 
627a6b9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# hal_bot.py

import os
import re
import requests
import torch
import streamlit as st
from langchain_community.llms import HuggingFaceEndpoint
from langchain.llms import HuggingFacePipeline
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langdetect import detect

# βœ… Switched to Flan-T5 Model
MODEL_ID = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)


# βœ… Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"βœ… Using device: {device}")

# βœ… Environment Variables
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
    raise ValueError("HF_TOKEN is not set. Please add it to your environment variables.")

NASA_API_KEY = os.getenv("NASA_API_KEY")
if NASA_API_KEY is None:
    raise ValueError("NASA_API_KEY is not set. Please add it to your environment variables.")

# βœ… Streamlit Setup
st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")

if "chat_history" not in st.session_state:
    st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]

def load_local_llm(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
    return pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)




llm = HuggingFacePipeline(pipeline=pipe)

def get_llm_hf_inference(model_id=MODEL_ID, max_new_tokens=500, temperature=0.3):
    return HuggingFaceEndpoint(
        repo_id=model_id,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        token=HF_TOKEN,
        task="text2text-generation",
        device=-1 if device == "cpu" else 0
    )

def ensure_english(text):
    try:
        detected_lang = detect(text)
        if detected_lang != "en":
            return "⚠️ Sorry, I only respond in English. Can you rephrase your question?"
    except:
        return "⚠️ Language detection failed. Please ask your question again."
    return text

def get_response(system_message, chat_history, user_text, max_new_tokens=500):
    filtered_history = "\n".join(
        f"{msg['role'].capitalize()}: {msg['content']}" for msg in chat_history[-5:]
    )

    prompt = PromptTemplate.from_template(
        """
        You are a helpful NASA AI assistant.
        Answer concisely and clearly based on the conversation history and the user's latest message.

        Conversation History:
        {chat_history}

        User: {user_text}
        Assistant:
        """
    )

    hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.3)
    chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')

    response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
    response = response.strip()
    response = ensure_english(response)

    if not response:
        response = "I'm sorry, but I couldn't generate a response. Can you rephrase your question?"

    chat_history.append({'role': 'user', 'content': user_text})
    chat_history.append({'role': 'assistant', 'content': response})

    return response, chat_history[-10:]

st.title("πŸš€ HAL - NASA AI Assistant")

st.markdown("""
    <style>
    .user-msg, .assistant-msg {
        padding: 11px;
        border-radius: 10px;
        margin-bottom: 5px;
        width: fit-content;
        max-width: 80%;
        text-align: justify;
    }
    .user-msg { background-color: #696969; color: white; }
    .assistant-msg { background-color: #333333; color: white; }
    .container { display: flex; flex-direction: column; align-items: flex-start; }
    @media (max-width: 600px) { .user-msg, .assistant-msg { font-size: 16px; max-width: 100%; } }
    </style>
""", unsafe_allow_html=True)

user_input = st.chat_input("Type your message here...")

if user_input:
    response, st.session_state.chat_history = get_response(
        system_message="You are a helpful AI assistant.",
        user_text=user_input,
        chat_history=st.session_state.chat_history
    )

st.markdown("<div class='container'>", unsafe_allow_html=True)
for message in st.session_state.chat_history:
    if message["role"] == "user":
        st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
    else:
        st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)