CCockrum commited on
Commit
ad0b8d6
·
verified ·
1 Parent(s): 8438304

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -45
app.py CHANGED
@@ -23,16 +23,21 @@ if "follow_up" not in st.session_state:
23
  # --- Set Up Model & API Functions ---
24
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
25
 
26
- # Initialize sentiment analysis pipeline
27
- sentiment_analyzer = pipeline("sentiment-analysis")
28
-
29
- def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
 
 
 
 
 
30
  return HuggingFaceEndpoint(
31
  repo_id=model_id,
32
- task="text-generation", # Specify the task explicitly
33
  max_new_tokens=max_new_tokens,
34
  temperature=temperature,
35
- token=os.getenv("HF_TOKEN") # Hugging Face API Token
 
36
  )
37
 
38
  def get_nasa_apod():
@@ -54,14 +59,21 @@ def predict_action(user_text):
54
  return "general_query"
55
 
56
  def generate_follow_up(user_text):
 
 
 
57
  prompt_text = (
58
- f"Based on the user's message: '{user_text}', suggest a natural follow-up question "
59
- "to keep the conversation engaging."
 
60
  )
61
- hf = get_llm_hf_inference(max_new_tokens=64, temperature=0.7)
62
  return hf.invoke(input=prompt_text).strip()
63
 
64
  def get_response(system_message, chat_history, user_text, max_new_tokens=256):
 
 
 
65
  sentiment = analyze_sentiment(user_text)
66
  action = predict_action(user_text)
67
 
@@ -75,20 +87,25 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
75
  chat_history.append({'role': 'assistant', 'content': follow_up})
76
  return response, follow_up, chat_history, nasa_url
77
 
78
- hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
79
 
80
  prompt = PromptTemplate.from_template(
81
- "[INST] {system_message}\n\nCurrent Conversation:\n{chat_history}\n\nUser: {user_text}.\n [/INST]\nAI:"
 
 
 
 
 
82
  )
83
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
84
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
85
- response = response.split("AI:")[-1]
86
 
87
  chat_history.append({'role': 'user', 'content': user_text})
88
  chat_history.append({'role': 'assistant', 'content': response})
89
 
90
  if sentiment == "NEGATIVE":
91
- response += "\n😞 I'm sorry to hear that. How can I assist you further?"
92
 
93
  follow_up = generate_follow_up(user_text)
94
  chat_history.append({'role': 'assistant', 'content': follow_up})
@@ -147,42 +164,24 @@ for message in st.session_state.chat_history:
147
  st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
148
  st.markdown("</div>", unsafe_allow_html=True)
149
 
150
- # --- Input & Button Handling ---
151
- user_input = st.text_area("Type your message:", height=100)
152
-
153
- send_button_placeholder = st.empty()
154
 
155
- if not st.session_state.response_ready:
156
- if send_button_placeholder.button("Send"):
157
- if user_input:
158
- response, follow_up, st.session_state.chat_history, image_url = get_response(
159
- system_message="You are a helpful AI assistant.",
160
- user_text=user_input,
161
- chat_history=st.session_state.chat_history
162
- )
163
 
164
- st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {response}</div>", unsafe_allow_html=True)
165
 
166
- if image_url:
167
- st.image(image_url, caption="NASA Image of the Day")
168
 
169
- # Store follow-up question
170
- st.session_state.follow_up = follow_up
171
- st.session_state.response_ready = True # Hide Send button after response
172
 
173
- # Conversational Follow-up
174
  if st.session_state.response_ready and st.session_state.follow_up:
175
  st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {st.session_state.follow_up}</div>", unsafe_allow_html=True)
176
-
177
- next_input = st.text_input("HAL is waiting for your response...")
178
-
179
- if next_input:
180
- response, _, st.session_state.chat_history, _ = get_response(
181
- system_message="You are a helpful AI assistant.",
182
- user_text=next_input,
183
- chat_history=st.session_state.chat_history
184
- )
185
- st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {response}</div>", unsafe_allow_html=True)
186
-
187
- st.session_state.response_ready = False
188
- st.session_state.follow_up = ""
 
23
  # --- Set Up Model & API Functions ---
24
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
25
 
26
+ # Initialize sentiment analysis pipeline with explicit model specification
27
+ sentiment_analyzer = pipeline(
28
+ "sentiment-analysis",
29
+ model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
30
+ revision="714eb0f"
31
+ )
32
+
33
+ def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.7):
34
+ # Explicitly specify task="text-generation" so that the endpoint knows which task to run
35
  return HuggingFaceEndpoint(
36
  repo_id=model_id,
 
37
  max_new_tokens=max_new_tokens,
38
  temperature=temperature,
39
+ token=os.getenv("HF_TOKEN"),
40
+ task="text-generation"
41
  )
42
 
43
  def get_nasa_apod():
 
59
  return "general_query"
60
 
61
  def generate_follow_up(user_text):
62
+ """
63
+ Generates a concise and conversational follow-up question related to the user's input.
64
+ """
65
  prompt_text = (
66
+ f"Given the user's question: '{user_text}', generate a SHORT and SIMPLE follow-up question. "
67
+ "Make it conversational and friendly. Example: 'Would you like to learn more about the six types of quarks?' "
68
+ "Do NOT provide long explanations—just ask a friendly follow-up question."
69
  )
70
+ hf = get_llm_hf_inference(max_new_tokens=32, temperature=0.7)
71
  return hf.invoke(input=prompt_text).strip()
72
 
73
  def get_response(system_message, chat_history, user_text, max_new_tokens=256):
74
+ """
75
+ Generates HAL's response, making it more conversational and engaging.
76
+ """
77
  sentiment = analyze_sentiment(user_text)
78
  action = predict_action(user_text)
79
 
 
87
  chat_history.append({'role': 'assistant', 'content': follow_up})
88
  return response, follow_up, chat_history, nasa_url
89
 
90
+ hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.9)
91
 
92
  prompt = PromptTemplate.from_template(
93
+ (
94
+ "[INST] {system_message}\n\nCurrent Conversation:\n{chat_history}\n\nUser: {user_text}.\n [/INST]\n"
95
+ "AI: Keep responses conversational and engaging. Start with a friendly phrase like "
96
+ "'Certainly!', 'Of course!', or 'Great question!' before answering. "
97
+ "Keep responses concise but engaging.\nHAL:"
98
+ )
99
  )
100
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
101
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
102
+ response = response.split("HAL:")[-1].strip()
103
 
104
  chat_history.append({'role': 'user', 'content': user_text})
105
  chat_history.append({'role': 'assistant', 'content': response})
106
 
107
  if sentiment == "NEGATIVE":
108
+ response = "I'm here to help. Let me know what I can do for you. 😊"
109
 
110
  follow_up = generate_follow_up(user_text)
111
  chat_history.append({'role': 'assistant', 'content': follow_up})
 
164
  st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
165
  st.markdown("</div>", unsafe_allow_html=True)
166
 
167
+ # --- Single Input Box for Both Initial and Follow-Up Messages ---
168
+ user_input = st.chat_input("Type your message here...") # Only ONE chat_input()
 
 
169
 
170
+ if user_input:
171
+ response, follow_up, st.session_state.chat_history, image_url = get_response(
172
+ system_message="You are a helpful AI assistant.",
173
+ user_text=user_input,
174
+ chat_history=st.session_state.chat_history
175
+ )
 
 
176
 
177
+ st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {response}</div>", unsafe_allow_html=True)
178
 
179
+ if image_url:
180
+ st.image(image_url, caption="NASA Image of the Day")
181
 
182
+ st.session_state.follow_up = follow_up
183
+ st.session_state.response_ready = True
 
184
 
 
185
  if st.session_state.response_ready and st.session_state.follow_up:
186
  st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {st.session_state.follow_up}</div>", unsafe_allow_html=True)
187
+ st.session_state.response_ready = False