File size: 6,234 Bytes
a3e0475
fc5f1c7
d6f5773
fc5f1c7
a3e0475
 
fc5f1c7
2825ff1
073538f
94ac9e7
a3e0475
fc5f1c7
 
 
d6f5773
 
a3e0475
 
 
d6f5773
a3e0475
d6f5773
a3e0475
94ac9e7
d6f5773
 
 
94ac9e7
 
 
 
d6f5773
94ac9e7
d6f5773
073538f
fc5f1c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6f5773
 
fc5f1c7
 
 
 
d6f5773
2a239ae
d6f5773
fc5f1c7
 
 
 
2a239ae
d6f5773
2a239ae
 
ad0b8d6
d6f5773
 
 
 
ad0b8d6
2a239ae
 
 
d6f5773
2a239ae
 
 
fc5f1c7
 
 
 
 
 
 
 
 
2a239ae
d6f5773
 
 
 
2a239ae
d6f5773
 
 
 
 
2a239ae
 
f543f0b
d6f5773
 
ad0b8d6
d6f5773
ad0b8d6
 
d6f5773
 
ad0b8d6
d6f5773
 
 
b8a80ad
fc5f1c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import requests
import streamlit as st
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from transformers import pipeline  # for Sentiment Analysis
NASA_API_KEY = os.getenv("NASA_API_KEY")

model_id = "mistralai/Mistral-7B-Instruct-v0.3"

# Initialize sentiment analysis pipeline
sentiment_analyzer = pipeline("sentiment-analysis")

def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
    llm = HuggingFaceEndpoint(
        repo_id=model_id,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        token=os.getenv("HF_TOKEN")  # Hugging Face token from environment variable
    )
    return llm

def get_nasa_apod():
    """
    Fetch the Astronomy Picture of the Day (APOD) from the NASA API.
    """
    url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
    response = requests.get(url)
    if response.status_code == 200:
        data = response.json()
        return f"Title: {data['title']}\nExplanation: {data['explanation']}\nURL: {data['url']}"
    else:
        return "I couldn't fetch data from NASA right now. Please try again later."

def analyze_sentiment(user_text):
    """
    Analyzes the sentiment of the user's input to adjust responses.
    """
    result = sentiment_analyzer(user_text)[0]
    sentiment = result['label']
    return sentiment

def predict_action(user_text):
    """
    Predicts actions based on user input (e.g., fetch space info or general knowledge).
    """
    if "NASA" in user_text or "space" in user_text:
        return "nasa_info"
    if "weather" in user_text:
        return "weather_info"
    return "general_query"

def generate_follow_up(user_text):
    """
    Generates a relevant follow-up question based on the user's input.
    """
    prompt_text = (
        f"Given the user's message: '{user_text}', ask one natural follow-up question "
        "that suggests a related topic or offers user the opportunity to go in a new direction."
    )

    hf = get_llm_hf_inference(max_new_tokens=64, temperature=0.7)
    chat = hf.invoke(input=prompt_text)
    
    return chat.strip()

def get_response(system_message, chat_history, user_text, 
                 eos_token_id=['User'], max_new_tokens=256, get_llm_hf_kws={}):
    sentiment = analyze_sentiment(user_text)
    action = predict_action(user_text)

    if action == "nasa_info":
        nasa_response = get_nasa_apod()
        chat_history.append({'role': 'user', 'content': user_text})
        chat_history.append({'role': 'assistant', 'content': nasa_response})
        
        follow_up = generate_follow_up(user_text)
        chat_history.append({'role': 'assistant', 'content': follow_up})
        return f"{nasa_response}\n\n{follow_up}", chat_history

    hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)

    prompt = PromptTemplate.from_template(
        (
            "[INST] {system_message}"
            "\nCurrent Conversation:\n{chat_history}\n\n"
            "\nUser: {user_text}.\n [/INST]"
            "\nAI:"
        )
    )
    chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
    response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
    response = response.split("AI:")[-1]

    chat_history.append({'role': 'user', 'content': user_text})
    chat_history.append({'role': 'assistant', 'content': response})

    # Modify response based on sentiment analysis (e.g., offer help for negative sentiments)
    if sentiment == "NEGATIVE":
        response += "\nI'm sorry to hear that. How can I assist you further?"

    follow_up = generate_follow_up(user_text)
    chat_history.append({'role': 'assistant', 'content': follow_up})

    return f"{response}\n\n{follow_up}", chat_history

# Streamlit setup
st.set_page_config(page_title="HuggingFace ChatBot", page_icon="🤗")
st.title("NASA Personal Assistant")
st.markdown(f"*This chatbot uses {model_id} and NASA's APIs to provide information and responses.*")

# Initialize session state
if "chat_history" not in st.session_state:
    st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
    
# Sidebar for settings
if st.sidebar.button("Reset Chat"):
    st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]

# Main chat interface
user_input = st.chat_input(placeholder="Type your message here...")
if user_input:
    response, st.session_state.chat_history = get_response(
        system_message="You are a helpful AI assistant.",
        user_text=user_input,
        chat_history=st.session_state.chat_history,
        max_new_tokens=128
    )
    # Display messages
    for message in st.session_state.chat_history:
        st.chat_message(message["role"]).write(message["content"])




if st.button("Send"):
    if user_input:
        response, follow_up, st.session_state.chat_history, image_url = get_response(
            system_message="You are a helpful AI assistant.",
            user_text=user_input,
            chat_history=st.session_state.chat_history
        )

        # Display response
        st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {response}</div>", unsafe_allow_html=True)

        # Display NASA image if available
        if image_url:
            st.image(image_url, caption="NASA Image of the Day")

        # Follow-up question suggestions
        follow_up_options = [follow_up, "Explain differently", "Give me an example"]
        selected_option = st.radio("What would you like to do next?", follow_up_options)

        if st.button("Continue"):
            if selected_option:
                response, _, st.session_state.chat_history, _ = get_response(
                    system_message="You are a helpful AI assistant.",
                    user_text=selected_option,
                    chat_history=st.session_state.chat_history
                )
                st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {response}</div>", unsafe_allow_html=True)