File size: 2,661 Bytes
3c1b5ae
fe68135
 
 
7e76660
cc0e29d
4d723eb
fe68135
7e76660
 
10687e3
7e76660
fe68135
 
7e76660
 
 
 
0ad67f6
7e76660
 
 
 
 
 
0ad67f6
7e76660
 
fe68135
 
 
 
4d723eb
fe68135
 
4d723eb
c161fc7
10687e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import os

MODEL_NAME = "distilbert-base-uncased"
SAVE_PATH = "./data/best_model"
LABEL_MAP = {0: "low risk", 1: "medium risk", 2: "high risk"}

# Load tokenizer from saved fine-tuned model if available
tokenizer = AutoTokenizer.from_pretrained(SAVE_PATH if os.path.isdir(SAVE_PATH) else MODEL_NAME)

_model = None  # ← this stores the model in memory

def load_model():
    global _model
    if _model is not None:
        return _model  # ✅ already loaded

    model_path = os.path.join(SAVE_PATH, "pytorch_model.bin")
    if os.path.exists(model_path):
        print("✅ Loading fine-tuned model...")
        _model = AutoModelForSequenceClassification.from_pretrained(SAVE_PATH)
    else:
        print("⚠️ Fine-tuned model not found. Falling back to base model:", MODEL_NAME)
        _model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)

    _model.eval()
    return _model

def predict(input_data: str):
    model = load_model()
    with torch.no_grad():
        inputs = tokenizer(input_data, return_tensors="pt", truncation=True, padding=True, max_length=256)
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()
        return LABEL_MAP.get(predicted_class, "unknown")



# import torch
# from transformers import AutoTokenizer, AutoModelForSequenceClassification
# import os

# MODEL_NAME = "google/gemma-1.1-2b-it"
# SAVE_PATH = "./data/best_model"
# LABEL_MAP = {0: "low risk", 1: "medium risk", 2: "high risk"}

# # Load tokenizer from saved model if available
# if os.path.exists(SAVE_PATH):
#     tokenizer = AutoTokenizer.from_pretrained(SAVE_PATH)
# else:
#     tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# def load_model():
#     model_path = os.path.join(SAVE_PATH, "best_model")
#     if os.path.exists(model_path):
#         print("✅ Loading fine-tuned model from:", model_path)
#         model = AutoModelForSequenceClassification.from_pretrained(model_path)
#     else:
#         print("🔁 Loading base model (not yet fine-tuned):", MODEL_NAME)
#         model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
#     return model

# def predict(input_data: str):
#     model = load_model()
#     model.eval()
#     with torch.no_grad():
#         inputs = tokenizer(input_data, return_tensors="pt", truncation=True, padding=True, max_length=256)
#         outputs = model(**inputs)
#         predicted_class = torch.argmax(outputs.logits, dim=1).item()
#         return LABEL_MAP.get(predicted_class, "unknown")