...
Browse files- llm_engine.py +3 -0
llm_engine.py
CHANGED
@@ -4,6 +4,7 @@ from prompts import SYMPTOM_PROMPT_TEMPLATE, QUESTION_PROMPT_TEMPLATE
|
|
4 |
|
5 |
MODEL_NAME = "GEMINI-Lab/MedicalGPT-LLAMA2-7B"
|
6 |
|
|
|
7 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
8 |
model = AutoModelForCausalLM.from_pretrained(
|
9 |
MODEL_NAME,
|
@@ -11,8 +12,10 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
11 |
device_map="auto"
|
12 |
)
|
13 |
|
|
|
14 |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
15 |
|
|
|
16 |
def handle_symptoms(symptoms: str) -> str:
|
17 |
prompt = SYMPTOM_PROMPT_TEMPLATE.format(symptoms=symptoms)
|
18 |
output = generator(prompt, max_new_tokens=512, do_sample=True)[0]["generated_text"]
|
|
|
4 |
|
5 |
MODEL_NAME = "GEMINI-Lab/MedicalGPT-LLAMA2-7B"
|
6 |
|
7 |
+
# Load tokenizer and model
|
8 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
9 |
model = AutoModelForCausalLM.from_pretrained(
|
10 |
MODEL_NAME,
|
|
|
12 |
device_map="auto"
|
13 |
)
|
14 |
|
15 |
+
# Create generation pipeline
|
16 |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
17 |
|
18 |
+
# Handler functions
|
19 |
def handle_symptoms(symptoms: str) -> str:
|
20 |
prompt = SYMPTOM_PROMPT_TEMPLATE.format(symptoms=symptoms)
|
21 |
output = generator(prompt, max_new_tokens=512, do_sample=True)[0]["generated_text"]
|