jason-moore commited on
Commit
ca4db09
·
1 Parent(s): 1296b9e

Add prompt

Browse files
Files changed (1) hide show
  1. app.py +21 -5
app.py CHANGED
@@ -13,21 +13,37 @@ def load_model():
13
  model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False)
14
 
15
  return model, tokenizer
16
- return model, tokenizer
17
 
18
  # Function to generate SOAP notes
19
  def generate_soap_note(doctor_patient_conversation):
20
  if not doctor_patient_conversation.strip():
21
  return "Please enter a doctor-patient conversation."
22
-
23
- # Tokenize and generate
24
- inputs = tokenizer(doctor_patient_conversation, return_tensors="pt")
25
 
26
- generate_ids = model.generate(inputs.input_ids, max_length=200, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
 
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Decode and extract the response part
30
  decoded_response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
 
 
 
 
31
  logger.debug(f"Decoded response: {decoded_response}")
32
  return decoded_response
33
 
 
13
  model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False)
14
 
15
  return model, tokenizer
 
16
 
17
  # Function to generate SOAP notes
18
  def generate_soap_note(doctor_patient_conversation):
19
  if not doctor_patient_conversation.strip():
20
  return "Please enter a doctor-patient conversation."
 
 
 
21
 
22
+ # Create a properly formatted prompt with instructions
23
+ prompt = f"""<|user|>
24
+ Please generate a structured SOAP (Subjective, Objective, Assessment, Plan) note based on the following doctor-patient conversation:
25
 
26
+ {doctor_patient_conversation}
27
+ <|assistant|>"""
28
+
29
+ # Tokenize and generate
30
+ inputs = tokenizer(prompt, return_tensors="pt")
31
+
32
+ generate_ids = model.generate(
33
+ inputs.input_ids,
34
+ max_length=2048,
35
+ num_beams=5,
36
+ no_repeat_ngram_size=2,
37
+ early_stopping=True
38
+ )
39
 
40
  # Decode and extract the response part
41
  decoded_response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
42
+
43
+ # Extract only the assistant's response (remove the prompt part)
44
+ if "<|assistant|>" in decoded_response:
45
+ decoded_response = decoded_response.split("<|assistant|>")[1].strip()
46
+
47
  logger.debug(f"Decoded response: {decoded_response}")
48
  return decoded_response
49