File size: 3,336 Bytes
a3e0475
b9e2074
1fddcd6
a3e0475
 
1fddcd6
290623e
a3e0475
94ac9e7
a3e0475
ed7c9b0
290623e
a3e0475
 
 
290623e
a3e0475
290623e
a3e0475
94ac9e7
290623e
 
 
94ac9e7
 
 
 
290623e
94ac9e7
290623e
a3e0475
290623e
 
1fddcd6
290623e
2a239ae
290623e
1fddcd6
2a239ae
290623e
2a239ae
 
 
 
 
 
290623e
2a239ae
 
 
 
290623e
2a239ae
 
 
1fddcd6
2a239ae
290623e
 
 
 
2a239ae
290623e
 
 
 
 
2a239ae
 
 
290623e
 
2a239ae
290623e
2a239ae
 
290623e
 
2a239ae
290623e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
from langchain_huggingface import HuggingFaceEndpoint
import streamlit as st
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
import requests
from config import NASA_API_KEY  # Import the NASA API key from the configuration file

model_id = "mistralai/Mistral-7B-Instruct-v0.3"

def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
    llm = HuggingFaceEndpoint(
        repo_id=model_id,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        token=os.getenv("HF_TOKEN")  # Hugging Face token from environment variable
    )
    return llm

def get_nasa_apod():
    """
    Fetch the Astronomy Picture of the Day (APOD) from the NASA API.
    """
    url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
    response = requests.get(url)
    if response.status_code == 200:
        data = response.json()
        return f"Title: {data['title']}\nExplanation: {data['explanation']}\nURL: {data['url']}"
    else:
        return "I couldn't fetch data from NASA right now. Please try again later."

def get_response(system_message, chat_history, user_text, 
                 eos_token_id=['User'], max_new_tokens=256, get_llm_hf_kws={}):
    if "NASA" in user_text or "space" in user_text:
        nasa_response = get_nasa_apod()
        chat_history.append({'role': 'user', 'content': user_text})
        chat_history.append({'role': 'assistant', 'content': nasa_response})
        return nasa_response, chat_history

    hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)

    prompt = PromptTemplate.from_template(
        (
            "[INST] {system_message}"
            "\nCurrent Conversation:\n{chat_history}\n\n"
            "\nUser: {user_text}.\n [/INST]"
            "\nAI:"
        )
    )
    chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
    response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
    response = response.split("AI:")[-1]

    chat_history.append({'role': 'user', 'content': user_text})
    chat_history.append({'role': 'assistant', 'content': response})
    return response, chat_history

# Streamlit setup
st.set_page_config(page_title="HuggingFace ChatBot", page_icon="🤗")
st.title("NASA Personal Assistant")
st.markdown(f"*This chatbot uses {model_id} and NASA's APIs to provide information and responses.*")

# Initialize session state
if "chat_history" not in st.session_state:
    st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
    
# Sidebar for settings
if st.sidebar.button("Reset Chat"):
    st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]

# Main chat interface
user_input = st.chat_input(placeholder="Type your message here...")
if user_input:
    response, st.session_state.chat_history = get_response(
        system_message="You are a helpful AI assistant.",
        user_text=user_input,
        chat_history=st.session_state.chat_history,
        max_new_tokens=128
    )
    # Display messages
    for message in st.session_state.chat_history:
        st.chat_message(message["role"]).write(message["content"])