CCockrum commited on
Commit
94ac9e7
·
verified ·
1 Parent(s): ce325db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -129
app.py CHANGED
@@ -3,109 +3,42 @@ from langchain_huggingface import HuggingFaceEndpoint
3
  import streamlit as st
4
  from langchain_core.prompts import PromptTemplate
5
  from langchain_core.output_parsers import StrOutputParser
 
 
6
 
7
- model_id="mistralai/Mistral-7B-Instruct-v0.3"
8
 
9
  def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
10
- """
11
- Returns a language model for HuggingFace inference.
12
-
13
- Parameters:
14
- - model_id (str): The ID of the HuggingFace model repository.
15
- - max_new_tokens (int): The maximum number of new tokens to generate.
16
- - temperature (float): The temperature for sampling from the model.
17
-
18
- Returns:
19
- - llm (HuggingFaceEndpoint): The language model for HuggingFace inference.
20
- """
21
  llm = HuggingFaceEndpoint(
22
  repo_id=model_id,
23
  max_new_tokens=max_new_tokens,
24
  temperature=temperature,
25
- token = os.getenv("HF_TOKEN")
26
  )
27
  return llm
28
 
29
- # Configure the Streamlit app
30
- st.set_page_config(page_title="HuggingFace ChatBot", page_icon="🤗")
31
- st.title("Personal HuggingFace ChatBot")
32
- st.markdown(f"*This is a simple chatbot that uses the HuggingFace transformers library to generate responses to your text input. It uses the {model_id}.*")
33
-
34
- # Initialize session state for avatars
35
- if "avatars" not in st.session_state:
36
- st.session_state.avatars = {'user': None, 'assistant': None}
37
-
38
- # Initialize session state for user text input
39
- if 'user_text' not in st.session_state:
40
- st.session_state.user_text = None
41
-
42
- # Initialize session state for model parameters
43
- if "max_response_length" not in st.session_state:
44
- st.session_state.max_response_length = 256
45
-
46
- if "system_message" not in st.session_state:
47
- st.session_state.system_message = "friendly AI conversing with a human user"
48
-
49
- if "starter_message" not in st.session_state:
50
- st.session_state.starter_message = "Hello, there! How can I help you today?"
51
-
52
-
53
- # Sidebar for settings
54
- with st.sidebar:
55
- st.header("System Settings")
56
-
57
- # AI Settings
58
- st.session_state.system_message = st.text_area(
59
- "System Message", value="You are a friendly AI conversing with a human user."
60
- )
61
- st.session_state.starter_message = st.text_area(
62
- 'First AI Message', value="Hello, there! How can I help you today?"
63
- )
64
-
65
- # Model Settings
66
- st.session_state.max_response_length = st.number_input(
67
- "Max Response Length", value=128
68
- )
69
-
70
- # Avatar Selection
71
- st.markdown("*Select Avatars:*")
72
- col1, col2 = st.columns(2)
73
- with col1:
74
- st.session_state.avatars['assistant'] = st.selectbox(
75
- "AI Avatar", options=["🤗", "💬", "🤖"], index=0
76
- )
77
- with col2:
78
- st.session_state.avatars['user'] = st.selectbox(
79
- "User Avatar", options=["👤", "👱‍♂️", "👨🏾", "👩", "👧🏾"], index=0
80
- )
81
- # Reset Chat History
82
- reset_history = st.button("Reset Chat History")
83
-
84
- # Initialize or reset chat history
85
- if "chat_history" not in st.session_state or reset_history:
86
- st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message}]
87
 
88
  def get_response(system_message, chat_history, user_text,
89
  eos_token_id=['User'], max_new_tokens=256, get_llm_hf_kws={}):
90
- """
91
- Generates a response from the chatbot model.
 
 
 
92
 
93
- Args:
94
- system_message (str): The system message for the conversation.
95
- chat_history (list): The list of previous chat messages.
96
- user_text (str): The user's input text.
97
- model_id (str, optional): The ID of the HuggingFace model to use.
98
- eos_token_id (list, optional): The list of end-of-sentence token IDs.
99
- max_new_tokens (int, optional): The maximum number of new tokens to generate.
100
- get_llm_hf_kws (dict, optional): Additional keyword arguments for the get_llm_hf function.
101
-
102
- Returns:
103
- tuple: A tuple containing the generated response and the updated chat history.
104
- """
105
- # Set up the model
106
  hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
107
 
108
- # Create the prompt template
109
  prompt = PromptTemplate.from_template(
110
  (
111
  "[INST] {system_message}"
@@ -114,55 +47,36 @@ def get_response(system_message, chat_history, user_text,
114
  "\nAI:"
115
  )
116
  )
117
- # Make the chain and bind the prompt
118
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
119
-
120
- # Generate the response
121
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
122
  response = response.split("AI:")[-1]
123
 
124
- # Update the chat history
125
  chat_history.append({'role': 'user', 'content': user_text})
126
  chat_history.append({'role': 'assistant', 'content': response})
127
  return response, chat_history
128
 
129
- # Chat interface
130
- chat_interface = st.container(border=True)
131
- with chat_interface:
132
- output_container = st.container()
133
- st.session_state.user_text = st.chat_input(placeholder="Enter your text here.")
 
 
 
134
 
135
- # Display chat messages
136
- with output_container:
137
- # For every message in the history
 
 
 
 
 
 
 
 
 
 
 
138
  for message in st.session_state.chat_history:
139
- # Skip the system message
140
- if message['role'] == 'system':
141
- continue
142
-
143
- # Display the chat message using the correct avatar
144
- with st.chat_message(message['role'],
145
- avatar=st.session_state['avatars'][message['role']]):
146
- st.markdown(message['content'])
147
-
148
- # When the user enter new text:
149
- if st.session_state.user_text:
150
-
151
- # Display the user's new message immediately
152
- with st.chat_message("user",
153
- avatar=st.session_state.avatars['user']):
154
- st.markdown(st.session_state.user_text)
155
-
156
- # Display a spinner status bar while waiting for the response
157
- with st.chat_message("assistant",
158
- avatar=st.session_state.avatars['assistant']):
159
-
160
- with st.spinner("Thinking..."):
161
- # Call the Inference API with the system_prompt, user text, and history
162
- response, st.session_state.chat_history = get_response(
163
- system_message=st.session_state.system_message,
164
- user_text=st.session_state.user_text,
165
- chat_history=st.session_state.chat_history,
166
- max_new_tokens=st.session_state.max_response_length,
167
- )
168
- st.markdown(response)
 
3
  import streamlit as st
4
  from langchain_core.prompts import PromptTemplate
5
  from langchain_core.output_parsers import StrOutputParser
6
+ import requests
7
+ from config import NASA_API_KEY # Import the NASA API key from the configuration file
8
 
9
+ model_id = "mistralai/Mistral-7B-Instruct-v0.3"
10
 
11
  def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
 
 
 
 
 
 
 
 
 
 
 
12
  llm = HuggingFaceEndpoint(
13
  repo_id=model_id,
14
  max_new_tokens=max_new_tokens,
15
  temperature=temperature,
16
+ token=os.getenv("HF_TOKEN") # Hugging Face token from environment variable
17
  )
18
  return llm
19
 
20
+ def get_nasa_apod():
21
+ """
22
+ Fetch the Astronomy Picture of the Day (APOD) from the NASA API.
23
+ """
24
+ url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
25
+ response = requests.get(url)
26
+ if response.status_code == 200:
27
+ data = response.json()
28
+ return f"Title: {data['title']}\nExplanation: {data['explanation']}\nURL: {data['url']}"
29
+ else:
30
+ return "I couldn't fetch data from NASA right now. Please try again later."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def get_response(system_message, chat_history, user_text,
33
  eos_token_id=['User'], max_new_tokens=256, get_llm_hf_kws={}):
34
+ if "NASA" in user_text or "space" in user_text:
35
+ nasa_response = get_nasa_apod()
36
+ chat_history.append({'role': 'user', 'content': user_text})
37
+ chat_history.append({'role': 'assistant', 'content': nasa_response})
38
+ return nasa_response, chat_history
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
41
 
 
42
  prompt = PromptTemplate.from_template(
43
  (
44
  "[INST] {system_message}"
 
47
  "\nAI:"
48
  )
49
  )
 
50
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
 
 
51
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
52
  response = response.split("AI:")[-1]
53
 
 
54
  chat_history.append({'role': 'user', 'content': user_text})
55
  chat_history.append({'role': 'assistant', 'content': response})
56
  return response, chat_history
57
 
58
+ # Streamlit setup
59
+ st.set_page_config(page_title="HuggingFace ChatBot", page_icon="🤗")
60
+ st.title("Personal Assistant")
61
+ st.markdown(f"*This chatbot uses {model_id} and NASA's APIs to provide information and responses.*")
62
+
63
+ # Initialize session state
64
+ if "chat_history" not in st.session_state:
65
+ st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
66
 
67
+ # Sidebar for settings
68
+ if st.sidebar.button("Reset Chat"):
69
+ st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
70
+
71
+ # Main chat interface
72
+ user_input = st.chat_input(placeholder="Type your message here...")
73
+ if user_input:
74
+ response, st.session_state.chat_history = get_response(
75
+ system_message="You are a helpful AI assistant.",
76
+ user_text=user_input,
77
+ chat_history=st.session_state.chat_history,
78
+ max_new_tokens=128
79
+ )
80
+ # Display messages
81
  for message in st.session_state.chat_history:
82
+ st.chat_message(message["role"]).write(message["content"])