Spaces:
Paused
Paused
Commit
·
ca4db09
1
Parent(s):
1296b9e
Add prompt
Browse files
app.py
CHANGED
@@ -13,21 +13,37 @@ def load_model():
|
|
13 |
model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False)
|
14 |
|
15 |
return model, tokenizer
|
16 |
-
return model, tokenizer
|
17 |
|
18 |
# Function to generate SOAP notes
|
19 |
def generate_soap_note(doctor_patient_conversation):
|
20 |
if not doctor_patient_conversation.strip():
|
21 |
return "Please enter a doctor-patient conversation."
|
22 |
-
|
23 |
-
# Tokenize and generate
|
24 |
-
inputs = tokenizer(doctor_patient_conversation, return_tensors="pt")
|
25 |
|
26 |
-
|
|
|
|
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
# Decode and extract the response part
|
30 |
decoded_response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
|
|
|
|
|
|
|
|
|
31 |
logger.debug(f"Decoded response: {decoded_response}")
|
32 |
return decoded_response
|
33 |
|
|
|
13 |
model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False)
|
14 |
|
15 |
return model, tokenizer
|
|
|
16 |
|
17 |
# Function to generate SOAP notes
|
18 |
def generate_soap_note(doctor_patient_conversation):
|
19 |
if not doctor_patient_conversation.strip():
|
20 |
return "Please enter a doctor-patient conversation."
|
|
|
|
|
|
|
21 |
|
22 |
+
# Create a properly formatted prompt with instructions
|
23 |
+
prompt = f"""<|user|>
|
24 |
+
Please generate a structured SOAP (Subjective, Objective, Assessment, Plan) note based on the following doctor-patient conversation:
|
25 |
|
26 |
+
{doctor_patient_conversation}
|
27 |
+
<|assistant|>"""
|
28 |
+
|
29 |
+
# Tokenize and generate
|
30 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
31 |
+
|
32 |
+
generate_ids = model.generate(
|
33 |
+
inputs.input_ids,
|
34 |
+
max_length=2048,
|
35 |
+
num_beams=5,
|
36 |
+
no_repeat_ngram_size=2,
|
37 |
+
early_stopping=True
|
38 |
+
)
|
39 |
|
40 |
# Decode and extract the response part
|
41 |
decoded_response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
42 |
+
|
43 |
+
# Extract only the assistant's response (remove the prompt part)
|
44 |
+
if "<|assistant|>" in decoded_response:
|
45 |
+
decoded_response = decoded_response.split("<|assistant|>")[1].strip()
|
46 |
+
|
47 |
logger.debug(f"Decoded response: {decoded_response}")
|
48 |
return decoded_response
|
49 |
|