ZealPyae commited on
Commit
e644b76
·
verified ·
1 Parent(s): 503d93b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -6
app.py CHANGED
@@ -1,25 +1,50 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import pipeline
 
 
4
 
5
- # Initialize the FastAPI app
6
  app = FastAPI()
7
 
8
- # Initialize the model pipeline
9
- pipe = pipeline("text-classification", model="kmack/malicious-url-detection")
 
 
 
 
 
 
 
 
10
 
11
  # Define the request model
12
  class URLRequest(BaseModel):
13
  url: str
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Define the API endpoint for URL prediction
16
  @app.post("/predict")
17
  async def predict(url_request: URLRequest):
18
  url_to_check = url_request.url
19
- result = pipe(url_to_check)
20
  return {"prediction": result}
21
 
22
  # Health check endpoint
23
  @app.get("/")
24
  async def read_root():
25
  return {"message": "API is up and running"}
 
 
1
+ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import numpy as np
5
+ import torch
6
 
 
7
  app = FastAPI()
8
 
9
+ # Check if CUDA is available
10
+ if torch.cuda.is_available():
11
+ device = torch.device("cuda:0")
12
+ else:
13
+ device = torch.device("cpu")
14
+
15
+ # Load the tokenizer and model
16
+ tokenizer = AutoTokenizer.from_pretrained("kmack/malicious-url-detection")
17
+ model = AutoModelForSequenceClassification.from_pretrained("kmack/malicious-url-detection")
18
+ model = model.to(device)
19
 
20
  # Define the request model
21
  class URLRequest(BaseModel):
22
  url: str
23
 
24
+ # Prediction function
25
+ def get_prediction(input_text: str) -> dict:
26
+ label2id = model.config.label2id
27
+ inputs = tokenizer(input_text, return_tensors='pt', truncation=True)
28
+ inputs = inputs.to(device)
29
+ outputs = model(**inputs)
30
+ logits = outputs.logits
31
+ sigmoid = torch.nn.Sigmoid()
32
+ probs = sigmoid(logits.squeeze().cpu())
33
+ probs = probs.detach().numpy()
34
+ for i, k in enumerate(label2id.keys()):
35
+ label2id[k] = probs[i]
36
+ label2id = {k: float(v) for k, v in sorted(label2id.items(), key=lambda item: item[1].item(), reverse=True)}
37
+ return label2id
38
+
39
  # Define the API endpoint for URL prediction
40
  @app.post("/predict")
41
  async def predict(url_request: URLRequest):
42
  url_to_check = url_request.url
43
+ result = get_prediction(url_to_check)
44
  return {"prediction": result}
45
 
46
  # Health check endpoint
47
  @app.get("/")
48
  async def read_root():
49
  return {"message": "API is up and running"}
50
+