File size: 1,244 Bytes
3c1b5ae
fe68135
 
 
 
cc0e29d
4d723eb
fe68135
c161fc7
 
 
 
 
fe68135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d723eb
fe68135
 
4d723eb
c161fc7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import os

MODEL_NAME = "google/gemma-1.1-2b-it"
SAVE_PATH = "./data/best_model"
LABEL_MAP = {0: "low risk", 1: "medium risk", 2: "high risk"}

# Load tokenizer from saved model if available
if os.path.exists(SAVE_PATH):
    tokenizer = AutoTokenizer.from_pretrained(SAVE_PATH)
else:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def load_model():
    model_path = os.path.join(SAVE_PATH, "best_model")
    if os.path.exists(model_path):
        print("✅ Loading fine-tuned model from:", model_path)
        model = AutoModelForSequenceClassification.from_pretrained(model_path)
    else:
        print("🔁 Loading base model (not yet fine-tuned):", MODEL_NAME)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
    return model

def predict(input_data: str):
    model = load_model()
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(input_data, return_tensors="pt", truncation=True, padding=True, max_length=256)
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()
        return LABEL_MAP.get(predicted_class, "unknown")