Thanush commited on
Commit
a0597d0
·
1 Parent(s): c4447f4

Refactor prompt building in app.py to utilize full message sequence and enhance response generation context

Browse files
Files changed (1) hide show
  1. app.py +18 -27
app.py CHANGED
@@ -55,17 +55,15 @@ print("Meditron model loaded successfully!")
55
  # Initialize LangChain memory
56
  memory = ConversationBufferMemory(return_messages=True)
57
 
58
- def build_llama2_prompt(system_prompt, history, user_input):
59
- """Format the conversation history and user input for Llama-2 chat models."""
60
  prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
61
-
62
- # Add conversation history
63
- for user_msg, assistant_msg in history:
64
- prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] "
65
-
66
- # Add the current user input
67
  prompt += f"{user_input} [/INST] "
68
-
69
  return prompt
70
 
71
  def get_meditron_suggestions(patient_info):
@@ -88,28 +86,21 @@ def get_meditron_suggestions(patient_info):
88
 
89
  @spaces.GPU
90
  def generate_response(message, history):
91
- """Generate a response using both models."""
92
  # Save the latest user message and last assistant response to memory
93
  if history and len(history[-1]) == 2:
94
  memory.save_context({"input": history[-1][0]}, {"output": history[-1][1]})
95
  memory.save_context({"input": message}, {"output": ""})
96
 
97
- # Build conversation history from memory
98
- lc_history = []
99
- user_msg = None
100
- for msg in memory.chat_memory.messages:
101
- if msg.type == "human":
102
- user_msg = msg.content
103
- elif msg.type == "ai" and user_msg is not None:
104
- assistant_msg = msg.content
105
- lc_history.append((user_msg, assistant_msg))
106
- user_msg = None
107
 
108
- # Build the prompt with LangChain memory history
109
- prompt = build_llama2_prompt(SYSTEM_PROMPT, lc_history, message)
110
 
111
- # Add summarization instruction after 4 turns
112
- if len(lc_history) >= 4:
 
113
  prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
114
 
115
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
@@ -131,9 +122,9 @@ def generate_response(message, history):
131
  llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
132
 
133
  # After 4 turns, add medicine suggestions from Meditron
134
- if len(lc_history) >= 4:
135
- # Collect full patient conversation
136
- full_patient_info = "\n".join([h[0] for h in lc_history] + [message]) + "\n\nSummary: " + llama_response
137
 
138
  # Get medicine suggestions
139
  medicine_suggestions = get_meditron_suggestions(full_patient_info)
 
55
  # Initialize LangChain memory
56
  memory = ConversationBufferMemory(return_messages=True)
57
 
58
+ def build_llama2_prompt(system_prompt, messages, user_input):
59
+ """Format the conversation history and user input for Llama-2 chat models, using the full message sequence."""
60
  prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
61
+ for msg in messages:
62
+ if msg.type == "human":
63
+ prompt += f"{msg.content} [/INST] "
64
+ elif msg.type == "ai":
65
+ prompt += f"{msg.content} </s><s>[INST] "
 
66
  prompt += f"{user_input} [/INST] "
 
67
  return prompt
68
 
69
  def get_meditron_suggestions(patient_info):
 
86
 
87
  @spaces.GPU
88
  def generate_response(message, history):
89
+ """Generate a response using both models, with full context."""
90
  # Save the latest user message and last assistant response to memory
91
  if history and len(history[-1]) == 2:
92
  memory.save_context({"input": history[-1][0]}, {"output": history[-1][1]})
93
  memory.save_context({"input": message}, {"output": ""})
94
 
95
+ # Use the full message sequence from memory
96
+ messages = memory.chat_memory.messages
 
 
 
 
 
 
 
 
97
 
98
+ # Build the prompt with the full message sequence
99
+ prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
100
 
101
+ # Add summarization instruction after 4 turns (count human messages)
102
+ num_user_turns = sum(1 for m in messages if m.type == "human")
103
+ if num_user_turns >= 4:
104
  prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
105
 
106
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
122
  llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
123
 
124
  # After 4 turns, add medicine suggestions from Meditron
125
+ if num_user_turns >= 4:
126
+ # Collect full patient conversation (all user messages)
127
+ full_patient_info = "\n".join([m.content for m in messages if m.type == "human"] + [message]) + "\n\nSummary: " + llama_response
128
 
129
  # Get medicine suggestions
130
  medicine_suggestions = get_meditron_suggestions(full_patient_info)