Thanush commited on
Commit
000ab02
·
1 Parent(s): 6e237a4

Refactor app.py to extract user name and age from conversation history and improve response generation logic

Browse files
Files changed (1) hide show
  1. app.py +26 -12
app.py CHANGED
@@ -3,6 +3,7 @@ import spaces
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from langchain.memory import ConversationBufferMemory
 
6
 
7
  # Model configuration
8
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
@@ -93,6 +94,18 @@ def get_meditron_suggestions(patient_info):
93
  suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
94
  return suggestion
95
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  @spaces.GPU
97
  def generate_response(message, history):
98
  """Generate a response using both models, with full context."""
@@ -101,15 +114,21 @@ def generate_response(message, history):
101
  memory.save_context({"input": history[-1][0]}, {"output": history[-1][1]})
102
  memory.save_context({"input": message}, {"output": ""})
103
 
104
- # Use the full message sequence from memory
105
  messages = memory.chat_memory.messages
 
 
 
 
 
 
 
 
 
106
 
107
- # Build the prompt with the full message sequence
108
  prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
109
-
110
- # Add summarization instruction after 4 turns (count human messages)
111
  num_user_turns = sum(1 for m in messages if m.type == "human")
112
- if num_user_turns >= 4:
 
113
  prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
114
 
115
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
@@ -130,15 +149,10 @@ def generate_response(message, history):
130
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
131
  llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
132
 
133
- # After 4 turns, add medicine suggestions from Meditron
134
- if num_user_turns >= 4:
135
- # Collect full patient conversation (all user messages)
136
  full_patient_info = "\n".join([m.content for m in messages if m.type == "human"] + [message]) + "\n\nSummary: " + llama_response
137
-
138
- # Get medicine suggestions
139
  medicine_suggestions = get_meditron_suggestions(full_patient_info)
140
-
141
- # Format final response
142
  final_response = (
143
  f"{llama_response}\n\n"
144
  f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n"
 
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from langchain.memory import ConversationBufferMemory
6
+ import re
7
 
8
  # Model configuration
9
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
 
94
  suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
95
  return suggestion
96
 
97
+ def extract_name_age(messages):
98
+ name, age = None, None
99
+ for msg in messages:
100
+ if msg.type == "human":
101
+ age_match = re.search(r"(?:I am|I'm|age is|aged|My age is)\s*(\d{1,3})", msg.content, re.IGNORECASE)
102
+ if age_match:
103
+ age = age_match.group(1)
104
+ name_match = re.search(r"(?:my name is|I'm|I am)\s*([A-Za-z]+)", msg.content, re.IGNORECASE)
105
+ if name_match:
106
+ name = name_match.group(1)
107
+ return name, age
108
+
109
  @spaces.GPU
110
  def generate_response(message, history):
111
  """Generate a response using both models, with full context."""
 
114
  memory.save_context({"input": history[-1][0]}, {"output": history[-1][1]})
115
  memory.save_context({"input": message}, {"output": ""})
116
 
 
117
  messages = memory.chat_memory.messages
118
+ name, age = extract_name_age(messages)
119
+ missing_info = []
120
+ if not name:
121
+ missing_info.append("your name")
122
+ if not age:
123
+ missing_info.append("your age")
124
+ if missing_info:
125
+ ask = "Before we continue, could you please tell me " + " and ".join(missing_info) + "?"
126
+ return ask
127
 
 
128
  prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
 
 
129
  num_user_turns = sum(1 for m in messages if m.type == "human")
130
+ # Only add summarization ONCE, not on every turn after 4
131
+ if num_user_turns == 4:
132
  prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
133
 
134
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
149
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
150
  llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
151
 
152
+ # After 4 turns, add medicine suggestions from Meditron, but only once
153
+ if num_user_turns == 4:
 
154
  full_patient_info = "\n".join([m.content for m in messages if m.type == "human"] + [message]) + "\n\nSummary: " + llama_response
 
 
155
  medicine_suggestions = get_meditron_suggestions(full_patient_info)
 
 
156
  final_response = (
157
  f"{llama_response}\n\n"
158
  f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n"