ElPremOoO commited on
Commit
c0be552
·
verified ·
1 Parent(s): 45214cd

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +69 -0
main.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ import torch
3
+ from transformers import RobertaTokenizer
4
+ import os
5
+ from transformers import RobertaForSequenceClassification
6
+ import torch.serialization
7
+ import torch
8
+ from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments
9
+ from torch.utils.data import Dataset
10
+ import pandas as pd
11
+ from sklearn.model_selection import train_test_split
12
+ import numpy as np
13
+ # Initialize Flask app
14
+ app = Flask(__name__)
15
+
16
+ # Load the trained model and tokenizer
17
+ tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
18
+ torch.serialization.add_safe_globals([RobertaForSequenceClassification])
19
+
20
+ model = torch.load("model.pth", map_location=torch.device('cpu'), weights_only=False) # Load the trained model
21
+
22
+ # Ensure the model is in evaluation mode
23
+ model.eval()
24
+
25
+
26
+ @app.route("/")
27
+ def home():
28
+ return request.url
29
+
30
+
31
+ # @app.route("/predict", methods=["POST"])
32
+ @app.route("/predict")
33
+ def predict():
34
+ print("Received code:", request.get_json()["code"])
35
+ code = request.get_json()["code"]
36
+ # Load saved weights and config
37
+ checkpoint = torch.load("codebert_vulnerability_scorer.pth")
38
+ config = RobertaConfig.from_dict(checkpoint['config'])
39
+
40
+ # Rebuild the model with correct architecture
41
+ model = RobertaForSequenceClassification(config)
42
+ model.load_state_dict(checkpoint['model_state_dict'])
43
+ model.eval()
44
+
45
+ # Load tokenizer
46
+ tokenizer = RobertaTokenizer.from_pretrained('./tokenizer')
47
+
48
+ # Prepare input
49
+ inputs = tokenizer(
50
+ code,
51
+ truncation=True,
52
+ padding='max_length',
53
+ max_length=512,
54
+ return_tensors='pt'
55
+ )
56
+
57
+ # Make prediction
58
+ with torch.no_grad():
59
+ outputs = model(**inputs)
60
+
61
+ score = torch.sigmoid(outputs.logits).item()
62
+ return score
63
+
64
+
65
+
66
+
67
+ # Run the Flask app
68
+ if __name__ == "__main__":
69
+ app.run(host="0.0.0.0", port=7860)