import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import os MODEL_NAME = "distilbert-base-uncased" SAVE_PATH = "./data/best_model" LABEL_MAP = {0: "low risk", 1: "medium risk", 2: "high risk"} # Load tokenizer from saved fine-tuned model if available tokenizer = AutoTokenizer.from_pretrained(SAVE_PATH if os.path.isdir(SAVE_PATH) else MODEL_NAME) _model = None # ← this stores the model in memory def load_model(): global _model if _model is not None: return _model # ✅ already loaded model_path = os.path.join(SAVE_PATH, "pytorch_model.bin") if os.path.exists(model_path): print("✅ Loading fine-tuned model...") _model = AutoModelForSequenceClassification.from_pretrained(SAVE_PATH) else: print("⚠️ Fine-tuned model not found. Falling back to base model:", MODEL_NAME) _model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) _model.eval() return _model def predict(input_data: str): model = load_model() with torch.no_grad(): inputs = tokenizer(input_data, return_tensors="pt", truncation=True, padding=True, max_length=256) outputs = model(**inputs) predicted_class = torch.argmax(outputs.logits, dim=1).item() return LABEL_MAP.get(predicted_class, "unknown") # import torch # from transformers import AutoTokenizer, AutoModelForSequenceClassification # import os # MODEL_NAME = "google/gemma-1.1-2b-it" # SAVE_PATH = "./data/best_model" # LABEL_MAP = {0: "low risk", 1: "medium risk", 2: "high risk"} # # Load tokenizer from saved model if available # if os.path.exists(SAVE_PATH): # tokenizer = AutoTokenizer.from_pretrained(SAVE_PATH) # else: # tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # def load_model(): # model_path = os.path.join(SAVE_PATH, "best_model") # if os.path.exists(model_path): # print("✅ Loading fine-tuned model from:", model_path) # model = AutoModelForSequenceClassification.from_pretrained(model_path) # else: # print("🔁 Loading base model (not yet fine-tuned):", MODEL_NAME) # model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) # return model # def predict(input_data: str): # model = load_model() # model.eval() # with torch.no_grad(): # inputs = tokenizer(input_data, return_tensors="pt", truncation=True, padding=True, max_length=256) # outputs = model(**inputs) # predicted_class = torch.argmax(outputs.logits, dim=1).item() # return LABEL_MAP.get(predicted_class, "unknown")