File size: 1,253 Bytes
ad03d38
 
 
 
 
 
 
 
 
 
 
 
6cd8eab
 
ad03d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cd8eab
ad03d38
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "Writer/Palmyra-Med-70B-32k"

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)

messages = [
    {
        "role": "system",
        "content": "You are a highly knowledgeable and experienced expert in the healthcare and biomedical field, possessing extensive medical knowledge and practical expertise.",
    },
    {
        "role": "user",
        "content": "Does danzhi Xiaoyao San ameliorate depressive-like behavior by shifting toward serotonin via the downregulation of hippocampal indoleamine 2,3-dioxygenase?",
    },
]

input_ids = tokenizer.apply_chat_template(
    messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)

gen_conf = {
    "max_new_tokens": 256,
    "eos_token_id": [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")],
    "temperature": 0.0,
    "top_p": 0.9,
}

with torch.inference_mode():
    output_id = model.generate(input_ids, **gen_conf)

output_text = tokenizer.decode(output_id[0][input_ids.shape[1] :])

print(output_text)