Spaces:
Runtime error
Runtime error
Update backend/app/model.py
Browse files- backend/app/model.py +3 -2
backend/app/model.py
CHANGED
@@ -4,6 +4,7 @@ import os
|
|
4 |
|
5 |
MODEL_NAME = "google/gemma-1.1-2b-it"
|
6 |
SAVE_PATH = "./backend/app/checkpoints"
|
|
|
7 |
|
8 |
# Load tokenizer once
|
9 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
@@ -22,7 +23,7 @@ def predict(input_data: str):
|
|
22 |
model = load_model()
|
23 |
model.eval()
|
24 |
with torch.no_grad():
|
25 |
-
inputs = tokenizer(input_data, return_tensors="pt")
|
26 |
outputs = model(**inputs)
|
27 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|
28 |
-
return predicted_class
|
|
|
4 |
|
5 |
MODEL_NAME = "google/gemma-1.1-2b-it"
|
6 |
SAVE_PATH = "./backend/app/checkpoints"
|
7 |
+
LABEL_MAP = {0: "low risk", 1: "medium risk", 2: "high risk"}
|
8 |
|
9 |
# Load tokenizer once
|
10 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
23 |
model = load_model()
|
24 |
model.eval()
|
25 |
with torch.no_grad():
|
26 |
+
inputs = tokenizer(input_data, return_tensors="pt", truncation=True, padding=True, max_length=256)
|
27 |
outputs = model(**inputs)
|
28 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|
29 |
+
return LABEL_MAP.get(predicted_class, "unknown")
|