Estherrr777 commited on
Commit
4d723eb
·
verified ·
1 Parent(s): 78a7529

Update backend/app/model.py

Browse files
Files changed (1) hide show
  1. backend/app/model.py +3 -2
backend/app/model.py CHANGED
@@ -4,6 +4,7 @@ import os
4
 
5
  MODEL_NAME = "google/gemma-1.1-2b-it"
6
  SAVE_PATH = "./backend/app/checkpoints"
 
7
 
8
  # Load tokenizer once
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
@@ -22,7 +23,7 @@ def predict(input_data: str):
22
  model = load_model()
23
  model.eval()
24
  with torch.no_grad():
25
- inputs = tokenizer(input_data, return_tensors="pt")
26
  outputs = model(**inputs)
27
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
28
- return predicted_class
 
4
 
5
  MODEL_NAME = "google/gemma-1.1-2b-it"
6
  SAVE_PATH = "./backend/app/checkpoints"
7
+ LABEL_MAP = {0: "low risk", 1: "medium risk", 2: "high risk"}
8
 
9
  # Load tokenizer once
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
23
  model = load_model()
24
  model.eval()
25
  with torch.no_grad():
26
+ inputs = tokenizer(input_data, return_tensors="pt", truncation=True, padding=True, max_length=256)
27
  outputs = model(**inputs)
28
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
29
+ return LABEL_MAP.get(predicted_class, "unknown")