Spaces:
Running
on
Zero
Running
on
Zero
Thanush
commited on
Commit
·
000ab02
1
Parent(s):
6e237a4
Refactor app.py to extract user name and age from conversation history and improve response generation logic
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import spaces
|
|
3 |
import torch
|
4 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
from langchain.memory import ConversationBufferMemory
|
|
|
6 |
|
7 |
# Model configuration
|
8 |
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
|
@@ -93,6 +94,18 @@ def get_meditron_suggestions(patient_info):
|
|
93 |
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
94 |
return suggestion
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
@spaces.GPU
|
97 |
def generate_response(message, history):
|
98 |
"""Generate a response using both models, with full context."""
|
@@ -101,15 +114,21 @@ def generate_response(message, history):
|
|
101 |
memory.save_context({"input": history[-1][0]}, {"output": history[-1][1]})
|
102 |
memory.save_context({"input": message}, {"output": ""})
|
103 |
|
104 |
-
# Use the full message sequence from memory
|
105 |
messages = memory.chat_memory.messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
-
# Build the prompt with the full message sequence
|
108 |
prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
|
109 |
-
|
110 |
-
# Add summarization instruction after 4 turns (count human messages)
|
111 |
num_user_turns = sum(1 for m in messages if m.type == "human")
|
112 |
-
|
|
|
113 |
prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
|
114 |
|
115 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
@@ -130,15 +149,10 @@ def generate_response(message, history):
|
|
130 |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
131 |
llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
|
132 |
|
133 |
-
# After 4 turns, add medicine suggestions from Meditron
|
134 |
-
if num_user_turns
|
135 |
-
# Collect full patient conversation (all user messages)
|
136 |
full_patient_info = "\n".join([m.content for m in messages if m.type == "human"] + [message]) + "\n\nSummary: " + llama_response
|
137 |
-
|
138 |
-
# Get medicine suggestions
|
139 |
medicine_suggestions = get_meditron_suggestions(full_patient_info)
|
140 |
-
|
141 |
-
# Format final response
|
142 |
final_response = (
|
143 |
f"{llama_response}\n\n"
|
144 |
f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n"
|
|
|
3 |
import torch
|
4 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
from langchain.memory import ConversationBufferMemory
|
6 |
+
import re
|
7 |
|
8 |
# Model configuration
|
9 |
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
|
|
|
94 |
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
95 |
return suggestion
|
96 |
|
97 |
+
def extract_name_age(messages):
|
98 |
+
name, age = None, None
|
99 |
+
for msg in messages:
|
100 |
+
if msg.type == "human":
|
101 |
+
age_match = re.search(r"(?:I am|I'm|age is|aged|My age is)\s*(\d{1,3})", msg.content, re.IGNORECASE)
|
102 |
+
if age_match:
|
103 |
+
age = age_match.group(1)
|
104 |
+
name_match = re.search(r"(?:my name is|I'm|I am)\s*([A-Za-z]+)", msg.content, re.IGNORECASE)
|
105 |
+
if name_match:
|
106 |
+
name = name_match.group(1)
|
107 |
+
return name, age
|
108 |
+
|
109 |
@spaces.GPU
|
110 |
def generate_response(message, history):
|
111 |
"""Generate a response using both models, with full context."""
|
|
|
114 |
memory.save_context({"input": history[-1][0]}, {"output": history[-1][1]})
|
115 |
memory.save_context({"input": message}, {"output": ""})
|
116 |
|
|
|
117 |
messages = memory.chat_memory.messages
|
118 |
+
name, age = extract_name_age(messages)
|
119 |
+
missing_info = []
|
120 |
+
if not name:
|
121 |
+
missing_info.append("your name")
|
122 |
+
if not age:
|
123 |
+
missing_info.append("your age")
|
124 |
+
if missing_info:
|
125 |
+
ask = "Before we continue, could you please tell me " + " and ".join(missing_info) + "?"
|
126 |
+
return ask
|
127 |
|
|
|
128 |
prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
|
|
|
|
|
129 |
num_user_turns = sum(1 for m in messages if m.type == "human")
|
130 |
+
# Only add summarization ONCE, not on every turn after 4
|
131 |
+
if num_user_turns == 4:
|
132 |
prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
|
133 |
|
134 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
149 |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
150 |
llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
|
151 |
|
152 |
+
# After 4 turns, add medicine suggestions from Meditron, but only once
|
153 |
+
if num_user_turns == 4:
|
|
|
154 |
full_patient_info = "\n".join([m.content for m in messages if m.type == "human"] + [message]) + "\n\nSummary: " + llama_response
|
|
|
|
|
155 |
medicine_suggestions = get_meditron_suggestions(full_patient_info)
|
|
|
|
|
156 |
final_response = (
|
157 |
f"{llama_response}\n\n"
|
158 |
f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n"
|