CCockrum commited on
Commit
70ccc87
Β·
verified Β·
1 Parent(s): f009e2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -47
app.py CHANGED
@@ -32,42 +32,16 @@ if "response_ready" not in st.session_state:
32
  st.session_state.response_ready = False
33
 
34
  # βœ… Initialize Hugging Face Model (Explicitly Set to CPU/GPU)
35
- def get_llm_hf_inference(model_id="meta-llama/Llama-2-7b-chat-hf", max_new_tokens=800, temperature=0.8):
36
  return HuggingFaceEndpoint(
37
  repo_id=model_id,
38
  max_new_tokens=max_new_tokens,
39
- temperature=temperature,
40
  token=HF_TOKEN,
41
  task="text-generation",
42
  device=-1 if device == "cpu" else 0 # βœ… Force CPU (-1) or GPU (0)
43
  )
44
 
45
- # βœ… NASA API Function
46
- def get_nasa_apod():
47
- url = f"https://api.nasa.gov/planetary/apod?api_key={NASA_API_KEY}"
48
- response = requests.get(url)
49
- if response.status_code == 200:
50
- data = response.json()
51
- return data.get("url", ""), data.get("title", ""), data.get("explanation", "")
52
- return "", "NASA Data Unavailable", "I couldn't fetch data from NASA right now."
53
-
54
- # βœ… Sentiment Analysis (Now Uses Explicit Device)
55
- sentiment_analyzer = pipeline(
56
- "sentiment-analysis",
57
- model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
58
- device=-1 if device == "cpu" else 0 # βœ… Force CPU (-1) or GPU (0)
59
- )
60
-
61
- def analyze_sentiment(user_text):
62
- result = sentiment_analyzer(user_text)[0]
63
- return result['label']
64
-
65
- # βœ… Intent Detection
66
- def predict_action(user_text):
67
- if "NASA" in user_text.lower() or "space" in user_text.lower():
68
- return "nasa_info"
69
- return "general_query"
70
-
71
  # βœ… Ensure English Responses
72
  def ensure_english(text):
73
  try:
@@ -78,33 +52,25 @@ def ensure_english(text):
78
  return "⚠️ Language detection failed. Please ask your question again."
79
  return text
80
 
81
- # βœ… Main Response Function (Follow-Up Question Removed)
82
  def get_response(system_message, chat_history, user_text, max_new_tokens=800):
83
- action = predict_action(user_text)
84
-
85
- # βœ… Handle NASA-Specific Queries
86
- if action == "nasa_info":
87
- nasa_url, nasa_title, nasa_explanation = get_nasa_apod()
88
- response = f"**{nasa_title}**\n\n{nasa_explanation}"
89
- chat_history.append({'role': 'user', 'content': user_text})
90
- chat_history.append({'role': 'assistant', 'content': response})
91
- return response, chat_history, nasa_url
92
-
93
- # βœ… Invoke Hugging Face Model
94
- hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.9)
95
-
96
  filtered_history = "\n".join(f"{msg['role']}: {msg['content']}" for msg in chat_history)
97
 
98
  prompt = PromptTemplate.from_template(
99
- "[INST] You are a helpful AI assistant.\n\nCurrent Conversation:\n{chat_history}\n\n"
 
 
100
  "User: {user_text}.\n [/INST]\n"
101
- "AI: Provide a detailed explanation with depth. Use a conversational tone. "
102
- "🚨 Answer **only in English**."
103
- "Ensure a friendly, engaging tone."
104
  "\nHAL:"
105
  )
106
 
 
 
107
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
 
108
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
109
  response = response.split("HAL:")[-1].strip() if "HAL:" in response else response.strip()
110
 
@@ -117,7 +83,7 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=800):
117
  st.session_state.chat_history.append({'role': 'user', 'content': user_text})
118
  st.session_state.chat_history.append({'role': 'assistant', 'content': response})
119
 
120
- return response, chat_history
121
 
122
  # βœ… Streamlit UI
123
  st.title("πŸš€ HAL - NASA AI Assistant")
 
32
  st.session_state.response_ready = False
33
 
34
  # βœ… Initialize Hugging Face Model (Explicitly Set to CPU/GPU)
35
+ def get_llm_hf_inference(model_id="meta-llama/Llama-2-7b-chat-hf", max_new_tokens=800, temperature=0.3):
36
  return HuggingFaceEndpoint(
37
  repo_id=model_id,
38
  max_new_tokens=max_new_tokens,
39
+ temperature=temperature, # πŸ”₯ Lowered temperature for more factual and structured responses
40
  token=HF_TOKEN,
41
  task="text-generation",
42
  device=-1 if device == "cpu" else 0 # βœ… Force CPU (-1) or GPU (0)
43
  )
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # βœ… Ensure English Responses
46
  def ensure_english(text):
47
  try:
 
52
  return "⚠️ Language detection failed. Please ask your question again."
53
  return text
54
 
55
+ # βœ… Main Response Function (Fixing Repetition & Context)
56
  def get_response(system_message, chat_history, user_text, max_new_tokens=800):
57
+ # βœ… Ensure conversation history is included correctly
 
 
 
 
 
 
 
 
 
 
 
 
58
  filtered_history = "\n".join(f"{msg['role']}: {msg['content']}" for msg in chat_history)
59
 
60
  prompt = PromptTemplate.from_template(
61
+ "[INST] You are a knowledgeable and formal AI assistant. Please provide detailed, structured answers "
62
+ "without unnecessary enthusiasm or emojis.\n\n"
63
+ "Current Conversation:\n{chat_history}\n\n"
64
  "User: {user_text}.\n [/INST]\n"
65
+ "AI: Provide a structured and informative response while maintaining a neutral and professional tone."
66
+ "Ensure your response is engaging yet clear."
 
67
  "\nHAL:"
68
  )
69
 
70
+ # βœ… Invoke Hugging Face Model
71
+ hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.3) # πŸ”₯ Lowered temperature
72
  chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
73
+
74
  response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=filtered_history))
75
  response = response.split("HAL:")[-1].strip() if "HAL:" in response else response.strip()
76
 
 
83
  st.session_state.chat_history.append({'role': 'user', 'content': user_text})
84
  st.session_state.chat_history.append({'role': 'assistant', 'content': response})
85
 
86
+ return response, st.session_state.chat_history
87
 
88
  # βœ… Streamlit UI
89
  st.title("πŸš€ HAL - NASA AI Assistant")