from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification import numpy as np import torch app = FastAPI() # Check if CUDA is available if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") # Load the tokenizer and model tokenizer = AutoTokenizer.from_pretrained("kmack/malicious-url-detection") model = AutoModelForSequenceClassification.from_pretrained("kmack/malicious-url-detection") model = model.to(device) # Define the request model class URLRequest(BaseModel): url: str # Prediction function def get_prediction(input_text: str) -> dict: label2id = model.config.label2id inputs = tokenizer(input_text, return_tensors='pt', truncation=True) inputs = inputs.to(device) outputs = model(**inputs) logits = outputs.logits sigmoid = torch.nn.Sigmoid() probs = sigmoid(logits.squeeze().cpu()) probs = probs.detach().numpy() for i, k in enumerate(label2id.keys()): label2id[k] = probs[i] label2id = {k: float(v) for k, v in sorted(label2id.items(), key=lambda item: item[1].item(), reverse=True)} return label2id # Define the API endpoint for URL prediction @app.post("/predict") async def predict(url_request: URLRequest): url_to_check = url_request.url result = get_prediction(url_to_check) return {"prediction": result} # Health check endpoint @app.get("/") async def read_root(): return {"message": "API is up and running"}