# hal_bot.py import os import re import requests import torch import streamlit as st from langchain_community.llms import HuggingFaceEndpoint from langchain.llms import HuggingFacePipeline from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from langdetect import detect # ✅ Switched to Flan-T5 Model MODEL_ID = "google/flan-t5-large" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1) # ✅ Device setup device = "cuda" if torch.cuda.is_available() else "cpu" print(f"✅ Using device: {device}") # ✅ Environment Variables HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN is None: raise ValueError("HF_TOKEN is not set. Please add it to your environment variables.") NASA_API_KEY = os.getenv("NASA_API_KEY") if NASA_API_KEY is None: raise ValueError("NASA_API_KEY is not set. Please add it to your environment variables.") # ✅ Streamlit Setup st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="🚀") if "chat_history" not in st.session_state: st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}] def load_local_llm(model_id): tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSeq2SeqLM.from_pretrained(model_id) return pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1) llm = HuggingFacePipeline(pipeline=pipe) def get_llm_hf_inference(model_id=MODEL_ID, max_new_tokens=500, temperature=0.3): return HuggingFaceEndpoint( repo_id=model_id, max_new_tokens=max_new_tokens, temperature=temperature, token=HF_TOKEN, task="text2text-generation", device=-1 if device == "cpu" else 0 ) def ensure_english(text): try: detected_lang = detect(text) if detected_lang != "en": return "⚠️ Sorry, I only respond in English. Can you rephrase your question?" except: return "⚠️ Language detection failed. Please ask your question again." return text def get_response(system_message, chat_history, user_text, max_new_tokens=500): filtered_history = "\n".join( f"{msg['role'].capitalize()}: {msg['content']}" for msg in chat_history[-5:] ) prompt = PromptTemplate.from_template( """ You are a helpful NASA AI assistant. Answer concisely and clearly based on the conversation history and the user's latest message. Conversation History: {chat_history} User: {user_text} Assistant: """ ) hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.3) chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content') response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history)) response = response.strip() response = ensure_english(response) if not response: response = "I'm sorry, but I couldn't generate a response. Can you rephrase your question?" chat_history.append({'role': 'user', 'content': user_text}) chat_history.append({'role': 'assistant', 'content': response}) return response, chat_history[-10:] st.title("🚀 HAL - NASA AI Assistant") st.markdown(""" """, unsafe_allow_html=True) user_input = st.chat_input("Type your message here...") if user_input: response, st.session_state.chat_history = get_response( system_message="You are a helpful AI assistant.", user_text=user_input, chat_history=st.session_state.chat_history ) st.markdown("
", unsafe_allow_html=True) for message in st.session_state.chat_history: if message["role"] == "user": st.markdown(f"
You: {message['content']}
", unsafe_allow_html=True) else: st.markdown(f"
HAL: {message['content']}
", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True)