File size: 5,010 Bytes
a3e0475
b9e2074
a3e0475
b9e2074
a3e0475
 
b9e2074
6351a04
a3e0475
94ac9e7
a3e0475
b9e2074
 
 
a3e0475
 
 
 
 
94ac9e7
a3e0475
 
 
94ac9e7
 
 
 
 
 
 
 
 
 
 
a3e0475
b9e2074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68577cc
 
 
 
 
 
 
 
 
 
 
 
 
 
a3e0475
 
b9e2074
 
 
 
94ac9e7
 
 
68577cc
 
 
 
a3e0475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9e2074
 
 
68577cc
 
 
 
b9e2074
68577cc
a3e0475
94ac9e7
 
3b1c938
94ac9e7
 
 
 
 
a3e0475
94ac9e7
 
 
 
 
 
 
 
 
 
 
 
 
 
a3e0475
94ac9e7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import requests
import streamlit as st
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from transformers import pipeline  # for Sentiment Analysis
from config import NASA_API_KEY  # Import the NASA API key from the configuration file

model_id = "mistralai/Mistral-7B-Instruct-v0.3"

# Initialize sentiment analysis pipeline
sentiment_analyzer = pipeline("sentiment-analysis")

def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
    llm = HuggingFaceEndpoint(
        repo_id=model_id,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        token=os.getenv("HF_TOKEN")  # Hugging Face token from environment variable
    )
    return llm

def get_nasa_apod():
    """
    Fetch the Astronomy Picture of the Day (APOD) from the NASA API.
    """
    url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
    response = requests.get(url)
    if response.status_code == 200:
        data = response.json()
        return f"Title: {data['title']}\nExplanation: {data['explanation']}\nURL: {data['url']}"
    else:
        return "I couldn't fetch data from NASA right now. Please try again later."

def analyze_sentiment(user_text):
    """
    Analyzes the sentiment of the user's input to adjust responses.
    """
    result = sentiment_analyzer(user_text)[0]
    sentiment = result['label']
    return sentiment

def predict_action(user_text):
    """
    Predicts actions based on user input (e.g., fetch space info or general knowledge).
    """
    if "NASA" in user_text or "space" in user_text:
        return "nasa_info"
    if "weather" in user_text:
        return "weather_info"
    return "general_query"

def generate_follow_up(user_text):
    """
    Generates a relevant follow-up question based on the user's input.
    """
    prompt_text = (
        f"Given the user's message: '{user_text}', suggest a natural follow-up question "
        "that continues the conversation and encourages engagement."
    )

    hf = get_llm_hf_inference(max_new_tokens=64, temperature=0.7)
    chat = hf.invoke(input=prompt_text)
    
    return chat.strip()

def get_response(system_message, chat_history, user_text, 
                 eos_token_id=['User'], max_new_tokens=256, get_llm_hf_kws={}):
    sentiment = analyze_sentiment(user_text)
    action = predict_action(user_text)

    if action == "nasa_info":
        nasa_response = get_nasa_apod()
        chat_history.append({'role': 'user', 'content': user_text})
        chat_history.append({'role': 'assistant', 'content': nasa_response})
        
        follow_up = generate_follow_up(user_text)
        chat_history.append({'role': 'assistant', 'content': follow_up})
        return f"{nasa_response}\n\n{follow_up}", chat_history

    hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)

    prompt = PromptTemplate.from_template(
        (
            "[INST] {system_message}"
            "\nCurrent Conversation:\n{chat_history}\n\n"
            "\nUser: {user_text}.\n [/INST]"
            "\nAI:"
        )
    )
    chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
    response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
    response = response.split("AI:")[-1]

    chat_history.append({'role': 'user', 'content': user_text})
    chat_history.append({'role': 'assistant', 'content': response})

    # Modify response based on sentiment analysis (e.g., offer help for negative sentiments)
    if sentiment == "NEGATIVE":
        response += "\nI'm sorry to hear that. How can I assist you further?"

    follow_up = generate_follow_up(user_text)
    chat_history.append({'role': 'assistant', 'content': follow_up})

    return f"{response}\n\n{follow_up}", chat_history

# Streamlit setup
st.set_page_config(page_title="HuggingFace ChatBot", page_icon="🤗")
st.title("NASA Personal Assistant")
st.markdown(f"*This chatbot uses {model_id} and NASA's APIs to provide information and responses.*")

# Initialize session state
if "chat_history" not in st.session_state:
    st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
    
# Sidebar for settings
if st.sidebar.button("Reset Chat"):
    st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]

# Main chat interface
user_input = st.chat_input(placeholder="Type your message here...")
if user_input:
    response, st.session_state.chat_history = get_response(
        system_message="You are a helpful AI assistant.",
        user_text=user_input,
        chat_history=st.session_state.chat_history,
        max_new_tokens=128
    )
    # Display messages
    for message in st.session_state.chat_history:
        st.chat_message(message["role"]).write(message["content"])