ZealPyae commited on
Commit
609af13
·
verified ·
1 Parent(s): 30479ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -20
app.py CHANGED
@@ -8,7 +8,7 @@
8
 
9
  from fastapi import FastAPI, HTTPException
10
  from pydantic import BaseModel
11
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
  import torch
13
 
14
  app = FastAPI()
@@ -19,29 +19,17 @@ if torch.cuda.is_available():
19
  else:
20
  device = torch.device("cpu")
21
 
22
- # Load the tokenizer and model
23
- tokenizer = AutoTokenizer.from_pretrained("kmack/malicious-url-detection")
24
- model = AutoModelForSequenceClassification.from_pretrained("kmack/malicious-url-detection")
25
- model = model.to(device)
26
-
27
  # Define the request model
28
  class URLRequest(BaseModel):
29
  url: str
30
 
31
- # Prediction function
32
- def get_prediction(input_text: str) -> dict:
33
- label2id = model.config.label2id
34
- inputs = tokenizer(input_text, return_tensors='pt', truncation=True)
35
- inputs = inputs.to(device)
36
- outputs = model(**inputs)
37
- logits = outputs.logits
38
- sigmoid = torch.nn.Sigmoid()
39
- probs = sigmoid(logits.squeeze().cpu())
40
- probs = probs.detach().numpy()
41
- for i, k in enumerate(label2id.keys()):
42
- label2id[k] = probs[i]
43
- label2id = {k: float(v) for k, v in sorted(label2id.items(), key=lambda item: item[1].item(), reverse=True)}
44
- return label2id
45
 
46
  # Define the API endpoint for URL prediction
47
  @app.post("/predict")
@@ -54,3 +42,4 @@ async def predict(url_request: URLRequest):
54
  @app.get("/")
55
  async def read_root():
56
  return {"message": "API is up and running"}
 
 
8
 
9
  from fastapi import FastAPI, HTTPException
10
  from pydantic import BaseModel
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
12
  import torch
13
 
14
  app = FastAPI()
 
19
  else:
20
  device = torch.device("cpu")
21
 
 
 
 
 
 
22
  # Define the request model
23
  class URLRequest(BaseModel):
24
  url: str
25
 
26
+ # Load the tokenizer and model using pipeline
27
+ pipe = pipeline("text-classification", model="kmack/malicious-url-detection", device=device.index if torch.cuda.is_available() else -1)
28
+
29
+ # Define the prediction function
30
+ def get_prediction(url_to_check: str):
31
+ result = pipe(url_to_check)
32
+ return result
 
 
 
 
 
 
 
33
 
34
  # Define the API endpoint for URL prediction
35
  @app.post("/predict")
 
42
  @app.get("/")
43
  async def read_root():
44
  return {"message": "API is up and running"}
45
+