...
Browse files- llm_engine.py +16 -11
llm_engine.py
CHANGED
@@ -4,24 +4,29 @@ from prompts import SYMPTOM_PROMPT_TEMPLATE, QUESTION_PROMPT_TEMPLATE
|
|
4 |
|
5 |
MODEL_NAME = "GEMINI-Lab/MedicalGPT-LLAMA2-7B"
|
6 |
|
7 |
-
# Load
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
# Create generation pipeline
|
16 |
-
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
17 |
-
|
18 |
-
# Handler functions
|
19 |
def handle_symptoms(symptoms: str) -> str:
|
|
|
|
|
20 |
prompt = SYMPTOM_PROMPT_TEMPLATE.format(symptoms=symptoms)
|
21 |
output = generator(prompt, max_new_tokens=512, do_sample=True)[0]["generated_text"]
|
22 |
return output[len(prompt):].strip()
|
23 |
|
24 |
def handle_question(question: str) -> str:
|
|
|
|
|
25 |
prompt = QUESTION_PROMPT_TEMPLATE.format(question=question)
|
26 |
output = generator(prompt, max_new_tokens=512, do_sample=True)[0]["generated_text"]
|
27 |
return output[len(prompt):].strip()
|
|
|
4 |
|
5 |
MODEL_NAME = "GEMINI-Lab/MedicalGPT-LLAMA2-7B"
|
6 |
|
7 |
+
# Load model & tokenizer
|
8 |
+
try:
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
10 |
+
model = AutoModelForCausalLM.from_pretrained(
|
11 |
+
MODEL_NAME,
|
12 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
13 |
+
device_map="auto"
|
14 |
+
)
|
15 |
+
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
16 |
+
except Exception as e:
|
17 |
+
print(f"Model loading failed: {e}")
|
18 |
+
generator = None
|
19 |
|
|
|
|
|
|
|
|
|
20 |
def handle_symptoms(symptoms: str) -> str:
|
21 |
+
if generator is None:
|
22 |
+
return "Model failed to load."
|
23 |
prompt = SYMPTOM_PROMPT_TEMPLATE.format(symptoms=symptoms)
|
24 |
output = generator(prompt, max_new_tokens=512, do_sample=True)[0]["generated_text"]
|
25 |
return output[len(prompt):].strip()
|
26 |
|
27 |
def handle_question(question: str) -> str:
|
28 |
+
if generator is None:
|
29 |
+
return "Model failed to load."
|
30 |
prompt = QUESTION_PROMPT_TEMPLATE.format(question=question)
|
31 |
output = generator(prompt, max_new_tokens=512, do_sample=True)[0]["generated_text"]
|
32 |
return output[len(prompt):].strip()
|