CCockrum commited on
Commit
627a6b9
Β·
verified Β·
1 Parent(s): ebc9593

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -35
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import re
3
  import requests
@@ -7,11 +9,11 @@ from langchain_huggingface import HuggingFaceEndpoint
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain_core.output_parsers import StrOutputParser
9
  from transformers import pipeline
10
- from langdetect import detect # Ensure this package is installed
11
 
12
- # βœ… Check for GPU or Default to CPU
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- print(f"βœ… Using device: {device}") # Debugging info
15
 
16
  # βœ… Environment Variables
17
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -22,25 +24,25 @@ NASA_API_KEY = os.getenv("NASA_API_KEY")
22
  if NASA_API_KEY is None:
23
  raise ValueError("NASA_API_KEY is not set. Please add it to your environment variables.")
24
 
25
- # βœ… Set Up Streamlit
26
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
27
 
28
- # βœ… Initialize Session State Variables (Ensuring Chat History Persists)
29
  if "chat_history" not in st.session_state:
30
  st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
31
 
32
- # βœ… Initialize Hugging Face Model (Explicitly Set to CPU/GPU)
33
- def get_llm_hf_inference(model_id="meta-llama/Llama-2-7b-chat-hf", max_new_tokens=800, temperature=0.3):
 
 
34
  return HuggingFaceEndpoint(
35
  repo_id=model_id,
36
  max_new_tokens=max_new_tokens,
37
- temperature=temperature, # πŸ”₯ Lowered temperature for more factual and structured responses
38
  token=HF_TOKEN,
39
- task="text-generation",
40
- device=-1 if device == "cpu" else 0 # βœ… Force CPU (-1) or GPU (0)
41
  )
42
 
43
- # βœ… Ensure English Responses
44
  def ensure_english(text):
45
  try:
46
  detected_lang = detect(text)
@@ -50,46 +52,41 @@ def ensure_english(text):
50
  return "⚠️ Language detection failed. Please ask your question again."
51
  return text
52
 
53
- # βœ… Main Response Function (Fixing Repetition & Context)
54
- def get_response(system_message, chat_history, user_text, max_new_tokens=800):
55
- # βœ… Ensure conversation history is included correctly
56
  filtered_history = "\n".join(
57
- f"{msg['role'].capitalize()}: {msg['content']}"
58
- for msg in chat_history[-5:] # βœ… Only keep the last 5 exchanges to prevent overflow
59
  )
60
-
61
  prompt = PromptTemplate.from_template(
62
- "[INST] You are a highly knowledgeable AI assistant. Answer concisely, avoid repetition, and structure responses well."
63
- "\n\nCONTEXT:\n{chat_history}\n"
64
- "\nLATEST USER INPUT:\nUser: {user_text}\n"
65
- "\n[END CONTEXT]\n"
66
- "Assistant:"
 
 
 
 
 
67
  )
68
 
69
- # βœ… Invoke Hugging Face Model
70
- hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.3) # πŸ”₯ Lowered temperature
71
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
72
 
73
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
74
-
75
- # Clean up the response - remove any "HAL:" prefix if present
76
- response = response.split("HAL:")[-1].strip() if "HAL:" in response else response.strip()
77
  response = ensure_english(response)
78
 
79
  if not response:
80
  response = "I'm sorry, but I couldn't generate a response. Can you rephrase your question?"
81
 
82
- # βœ… Update conversation history
83
  chat_history.append({'role': 'user', 'content': user_text})
84
  chat_history.append({'role': 'assistant', 'content': response})
85
 
86
- # βœ… Keep only last 10 exchanges to prevent unnecessary repetition
87
  return response, chat_history[-10:]
88
 
89
- # βœ… Streamlit UI
90
  st.title("πŸš€ HAL - NASA AI Assistant")
91
 
92
- # βœ… Justify all chatbot responses
93
  st.markdown("""
94
  <style>
95
  .user-msg, .assistant-msg {
@@ -107,22 +104,19 @@ st.markdown("""
107
  </style>
108
  """, unsafe_allow_html=True)
109
 
110
- # βœ… Chat UI
111
  user_input = st.chat_input("Type your message here...")
112
 
113
  if user_input:
114
- # Get response and update chat history
115
  response, st.session_state.chat_history = get_response(
116
  system_message="You are a helpful AI assistant.",
117
  user_text=user_input,
118
  chat_history=st.session_state.chat_history
119
  )
120
 
121
- # βœ… Display chat history (ONLY display from history, not separately)
122
  st.markdown("<div class='container'>", unsafe_allow_html=True)
123
  for message in st.session_state.chat_history:
124
  if message["role"] == "user":
125
  st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
126
  else:
127
  st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
128
- st.markdown("</div>", unsafe_allow_html=True)
 
1
+ # hal_bot.py
2
+
3
  import os
4
  import re
5
  import requests
 
9
  from langchain_core.prompts import PromptTemplate
10
  from langchain_core.output_parsers import StrOutputParser
11
  from transformers import pipeline
12
+ from langdetect import detect
13
 
14
+ # βœ… Device setup
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ print(f"βœ… Using device: {device}")
17
 
18
  # βœ… Environment Variables
19
  HF_TOKEN = os.getenv("HF_TOKEN")
 
24
  if NASA_API_KEY is None:
25
  raise ValueError("NASA_API_KEY is not set. Please add it to your environment variables.")
26
 
27
+ # βœ… Streamlit Setup
28
  st.set_page_config(page_title="HAL - NASA ChatBot", page_icon="πŸš€")
29
 
 
30
  if "chat_history" not in st.session_state:
31
  st.session_state.chat_history = [{"role": "assistant", "content": "Hello! How can I assist you today?"}]
32
 
33
+ # βœ… Switched to Flan-T5 Model
34
+ MODEL_ID = "google/flan-t5-large"
35
+
36
+ def get_llm_hf_inference(model_id=MODEL_ID, max_new_tokens=500, temperature=0.3):
37
  return HuggingFaceEndpoint(
38
  repo_id=model_id,
39
  max_new_tokens=max_new_tokens,
40
+ temperature=temperature,
41
  token=HF_TOKEN,
42
+ task="text2text-generation",
43
+ device=-1 if device == "cpu" else 0
44
  )
45
 
 
46
  def ensure_english(text):
47
  try:
48
  detected_lang = detect(text)
 
52
  return "⚠️ Language detection failed. Please ask your question again."
53
  return text
54
 
55
+ def get_response(system_message, chat_history, user_text, max_new_tokens=500):
 
 
56
  filtered_history = "\n".join(
57
+ f"{msg['role'].capitalize()}: {msg['content']}" for msg in chat_history[-5:]
 
58
  )
59
+
60
  prompt = PromptTemplate.from_template(
61
+ """
62
+ You are a helpful NASA AI assistant.
63
+ Answer concisely and clearly based on the conversation history and the user's latest message.
64
+
65
+ Conversation History:
66
+ {chat_history}
67
+
68
+ User: {user_text}
69
+ Assistant:
70
+ """
71
  )
72
 
73
+ hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.3)
 
74
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
75
 
76
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
77
+ response = response.strip()
 
 
78
  response = ensure_english(response)
79
 
80
  if not response:
81
  response = "I'm sorry, but I couldn't generate a response. Can you rephrase your question?"
82
 
 
83
  chat_history.append({'role': 'user', 'content': user_text})
84
  chat_history.append({'role': 'assistant', 'content': response})
85
 
 
86
  return response, chat_history[-10:]
87
 
 
88
  st.title("πŸš€ HAL - NASA AI Assistant")
89
 
 
90
  st.markdown("""
91
  <style>
92
  .user-msg, .assistant-msg {
 
104
  </style>
105
  """, unsafe_allow_html=True)
106
 
 
107
  user_input = st.chat_input("Type your message here...")
108
 
109
  if user_input:
 
110
  response, st.session_state.chat_history = get_response(
111
  system_message="You are a helpful AI assistant.",
112
  user_text=user_input,
113
  chat_history=st.session_state.chat_history
114
  )
115
 
 
116
  st.markdown("<div class='container'>", unsafe_allow_html=True)
117
  for message in st.session_state.chat_history:
118
  if message["role"] == "user":
119
  st.markdown(f"<div class='user-msg'><strong>You:</strong> {message['content']}</div>", unsafe_allow_html=True)
120
  else:
121
  st.markdown(f"<div class='assistant-msg'><strong>HAL:</strong> {message['content']}</div>", unsafe_allow_html=True)
122
+ st.markdown("</div>", unsafe_allow_html=True)