File size: 3,330 Bytes
a3e0475
 
 
 
 
94ac9e7
 
a3e0475
94ac9e7
a3e0475
 
 
 
 
 
94ac9e7
a3e0475
 
 
94ac9e7
 
 
 
 
 
 
 
 
 
 
a3e0475
 
 
94ac9e7
 
 
 
 
a3e0475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94ac9e7
 
 
 
 
 
 
 
a3e0475
94ac9e7
 
 
 
 
 
 
 
 
 
 
 
 
 
a3e0475
94ac9e7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from langchain_huggingface import HuggingFaceEndpoint
import streamlit as st
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
import requests
from config import NASA_API_KEY  # Import the NASA API key from the configuration file

model_id = "mistralai/Mistral-7B-Instruct-v0.3"

def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
    llm = HuggingFaceEndpoint(
        repo_id=model_id,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        token=os.getenv("HF_TOKEN")  # Hugging Face token from environment variable
    )
    return llm

def get_nasa_apod():
    """
    Fetch the Astronomy Picture of the Day (APOD) from the NASA API.
    """
    url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
    response = requests.get(url)
    if response.status_code == 200:
        data = response.json()
        return f"Title: {data['title']}\nExplanation: {data['explanation']}\nURL: {data['url']}"
    else:
        return "I couldn't fetch data from NASA right now. Please try again later."

def get_response(system_message, chat_history, user_text, 
                 eos_token_id=['User'], max_new_tokens=256, get_llm_hf_kws={}):
    if "NASA" in user_text or "space" in user_text:
        nasa_response = get_nasa_apod()
        chat_history.append({'role': 'user', 'content': user_text})
        chat_history.append({'role': 'assistant', 'content': nasa_response})
        return nasa_response, chat_history

    hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)

    prompt = PromptTemplate.from_template(
        (
            "[INST] {system_message}"
            "\nCurrent Conversation:\n{chat_history}\n\n"
            "\nUser: {user_text}.\n [/INST]"
            "\nAI:"
        )
    )
    chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
    response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
    response = response.split("AI:")[-1]

    chat_history.append({'role': 'user', 'content': user_text})
    chat_history.append({'role': 'assistant', 'content': response})
    return response, chat_history

# Streamlit setup
st.set_page_config(page_title="HuggingFace ChatBot", page_icon="🤗")
st.title("Personal Assistant")
st.markdown(f"*This chatbot uses {model_id} and NASA's APIs to provide information and responses.*")

# Initialize session state
if "chat_history" not in st.session_state:
    st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
    
# Sidebar for settings
if st.sidebar.button("Reset Chat"):
    st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]

# Main chat interface
user_input = st.chat_input(placeholder="Type your message here...")
if user_input:
    response, st.session_state.chat_history = get_response(
        system_message="You are a helpful AI assistant.",
        user_text=user_input,
        chat_history=st.session_state.chat_history,
        max_new_tokens=128
    )
    # Display messages
    for message in st.session_state.chat_history:
        st.chat_message(message["role"]).write(message["content"])