abidkh commited on
Commit
fed900c
·
1 Parent(s): 4f9669e
Files changed (1) hide show
  1. llm_engine.py +16 -11
llm_engine.py CHANGED
@@ -4,24 +4,29 @@ from prompts import SYMPTOM_PROMPT_TEMPLATE, QUESTION_PROMPT_TEMPLATE
4
 
5
  MODEL_NAME = "GEMINI-Lab/MedicalGPT-LLAMA2-7B"
6
 
7
- # Load tokenizer and model
8
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
- model = AutoModelForCausalLM.from_pretrained(
10
- MODEL_NAME,
11
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
12
- device_map="auto"
13
- )
 
 
 
 
 
14
 
15
- # Create generation pipeline
16
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
17
-
18
- # Handler functions
19
  def handle_symptoms(symptoms: str) -> str:
 
 
20
  prompt = SYMPTOM_PROMPT_TEMPLATE.format(symptoms=symptoms)
21
  output = generator(prompt, max_new_tokens=512, do_sample=True)[0]["generated_text"]
22
  return output[len(prompt):].strip()
23
 
24
  def handle_question(question: str) -> str:
 
 
25
  prompt = QUESTION_PROMPT_TEMPLATE.format(question=question)
26
  output = generator(prompt, max_new_tokens=512, do_sample=True)[0]["generated_text"]
27
  return output[len(prompt):].strip()
 
4
 
5
  MODEL_NAME = "GEMINI-Lab/MedicalGPT-LLAMA2-7B"
6
 
7
+ # Load model & tokenizer
8
+ try:
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ MODEL_NAME,
12
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
13
+ device_map="auto"
14
+ )
15
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
16
+ except Exception as e:
17
+ print(f"Model loading failed: {e}")
18
+ generator = None
19
 
 
 
 
 
20
  def handle_symptoms(symptoms: str) -> str:
21
+ if generator is None:
22
+ return "Model failed to load."
23
  prompt = SYMPTOM_PROMPT_TEMPLATE.format(symptoms=symptoms)
24
  output = generator(prompt, max_new_tokens=512, do_sample=True)[0]["generated_text"]
25
  return output[len(prompt):].strip()
26
 
27
  def handle_question(question: str) -> str:
28
+ if generator is None:
29
+ return "Model failed to load."
30
  prompt = QUESTION_PROMPT_TEMPLATE.format(question=question)
31
  output = generator(prompt, max_new_tokens=512, do_sample=True)[0]["generated_text"]
32
  return output[len(prompt):].strip()