Kareem94 commited on
Commit
76cad5e
·
verified ·
1 Parent(s): 2008b00

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +224 -39
main.py CHANGED
@@ -1,65 +1,250 @@
1
  from flask import Flask, request, jsonify
2
  import torch
3
- from transformers import RobertaTokenizer, RobertaForSequenceClassification, RobertaConfig
4
  import os
 
5
 
6
  app = Flask(__name__)
7
 
8
- # Load model and tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def load_model():
10
- # Load saved config and weights
11
- checkpoint = torch.load("codebert_vulnerability_scorer.pth", map_location=torch.device('cpu'))
12
- config = RobertaConfig.from_dict(checkpoint['config'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Initialize model with loaded config
15
- model = RobertaForSequenceClassification(config)
16
- model.load_state_dict(checkpoint['model_state_dict'])
17
- model.eval()
18
- return model
 
 
 
 
 
19
 
20
- # Load components
21
  try:
22
- tokenizer = RobertaTokenizer.from_pretrained("./tokenizer_vulnerability")
 
 
 
 
23
  model = load_model()
24
- print("Model and tokenizer loaded successfully!")
 
25
  except Exception as e:
26
- print(f"Error loading model: {str(e)}")
 
 
27
 
28
- @app.route("/")
29
  def home():
30
- return request.url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- @app.route("/predict")
33
- def predict():
34
  try:
35
- # Get code from URL parameter
 
 
36
  code = request.args.get("code")
37
  if not code:
38
  return jsonify({"error": "Missing 'code' URL parameter"}), 400
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # Tokenize input
41
- inputs = tokenizer(
42
- code,
43
- truncation=True,
44
- padding='max_length',
45
- max_length=512,
46
- return_tensors='pt'
47
- )
48
-
49
- # Make prediction
50
- with torch.no_grad():
 
 
 
 
51
  outputs = model(**inputs)
 
 
 
 
 
 
 
52
 
53
- # Apply sigmoid and format score
54
- score = torch.sigmoid(outputs.logits).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- return jsonify({
57
- "score": round(score, 4),
58
- "given_code": code[:500] + "..." if len(code) > 500 else code
59
- })
 
 
 
60
 
61
- except Exception as e:
62
- return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
63
 
64
  if __name__ == "__main__":
65
- app.run(host="0.0.0.0", port=7860)
 
1
  from flask import Flask, request, jsonify
2
  import torch
3
+ from transformers import RobertaTokenizer, RobertaForSequenceClassification
4
  import os
5
+ from functools import lru_cache
6
 
7
  app = Flask(__name__)
8
 
9
+ model = None
10
+ tokenizer = None
11
+ device = None
12
+
13
+ def setup_device():
14
+ if torch.cuda.is_available():
15
+ return torch.device('cuda')
16
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
17
+ return torch.device('mps')
18
+ else:
19
+ return torch.device('cpu')
20
+
21
+ def load_tokenizer():
22
+ try:
23
+ tokenizer = RobertaTokenizer.from_pretrained('./tokenizer_vulnerability')
24
+ tokenizer.model_max_length = 512
25
+ return tokenizer
26
+ except Exception as e:
27
+ print(f"Error loading tokenizer: {e}")
28
+ return RobertaTokenizer.from_pretrained('microsoft/codebert-base')
29
+
30
  def load_model():
31
+ global device
32
+ device = setup_device()
33
+ print(f"Using device: {device}")
34
+
35
+ try:
36
+ checkpoint = torch.load("codebert_vulnerability_scorer.pth", map_location=device)
37
+
38
+ if 'config' in checkpoint:
39
+ from transformers import RobertaConfig
40
+ config = RobertaConfig.from_dict(checkpoint['config'])
41
+ model = RobertaForSequenceClassification(config)
42
+ else:
43
+ model = RobertaForSequenceClassification.from_pretrained(
44
+ 'microsoft/codebert-base',
45
+ num_labels=1
46
+ )
47
+
48
+ if 'model_state_dict' in checkpoint:
49
+ model.load_state_dict(checkpoint['model_state_dict'])
50
+ else:
51
+ model.load_state_dict(checkpoint)
52
+
53
+ model.to(device)
54
+ model.eval()
55
+
56
+ if device.type == 'cuda':
57
+ model.half()
58
+
59
+ return model
60
+
61
+ except Exception as e:
62
+ print(f"Error loading model: {e}")
63
+ raise e
64
 
65
+ @lru_cache(maxsize=1000)
66
+ def cached_tokenize(code_hash, max_length):
67
+ code = code_hash
68
+ return tokenizer(
69
+ code,
70
+ truncation=True,
71
+ padding='max_length',
72
+ max_length=max_length,
73
+ return_tensors='pt'
74
+ )
75
 
 
76
  try:
77
+ print("Loading tokenizer...")
78
+ tokenizer = load_tokenizer()
79
+ print("Tokenizer loaded successfully!")
80
+
81
+ print("Loading model...")
82
  model = load_model()
83
+ print("Model loaded successfully!")
84
+
85
  except Exception as e:
86
+ print(f"Error during initialization: {str(e)}")
87
+ tokenizer = None
88
+ model = None
89
 
90
+ @app.route("/", methods=['GET'])
91
  def home():
92
+ return jsonify({
93
+ "message": "CodeBERT Vulnerability Scorer API",
94
+ "status": "Model loaded" if model is not None else "Model not loaded",
95
+ "device": str(device) if device else "unknown",
96
+ "endpoints": {
97
+ "/predict": "POST with JSON body containing 'code' field",
98
+ "/predict_batch": "POST with JSON body containing 'codes' array",
99
+ "/predict_get": "GET with 'code' URL parameter"
100
+ }
101
+ })
102
+
103
+ @app.route("/predict", methods=['POST'])
104
+ def predict_post():
105
+ try:
106
+ if model is None or tokenizer is None:
107
+ return jsonify({"error": "Model not loaded properly"}), 500
108
+
109
+ data = request.get_json()
110
+ if not data or 'code' not in data:
111
+ return jsonify({"error": "Missing 'code' field in JSON body"}), 400
112
+
113
+ code = data['code']
114
+ if not code or not isinstance(code, str):
115
+ return jsonify({"error": "'code' field must be a non-empty string"}), 400
116
+
117
+ score = predict_vulnerability(code)
118
+
119
+ return jsonify({
120
+ "score": score,
121
+ "vulnerability_level": get_vulnerability_level(score),
122
+ "code_preview": code[:200] + "..." if len(code) > 200 else code
123
+ })
124
+
125
+ except Exception as e:
126
+ return jsonify({"error": f"Prediction error: {str(e)}"}), 500
127
+
128
+ @app.route("/predict_batch", methods=['POST'])
129
+ def predict_batch():
130
+ try:
131
+ if model is None or tokenizer is None:
132
+ return jsonify({"error": "Model not loaded properly"}), 500
133
+
134
+ data = request.get_json()
135
+ if not data or 'codes' not in data:
136
+ return jsonify({"error": "Missing 'codes' field in JSON body"}), 400
137
+
138
+ codes = data['codes']
139
+ if not isinstance(codes, list) or len(codes) == 0:
140
+ return jsonify({"error": "'codes' must be a non-empty array"}), 400
141
+
142
+ batch_size = min(len(codes), 16)
143
+ results = []
144
+
145
+ for i in range(0, len(codes), batch_size):
146
+ batch = codes[i:i+batch_size]
147
+ scores = predict_vulnerability_batch(batch)
148
+
149
+ for j, score in enumerate(scores):
150
+ results.append({
151
+ "index": i + j,
152
+ "score": score,
153
+ "vulnerability_level": get_vulnerability_level(score),
154
+ "code_preview": batch[j][:100] + "..." if len(batch[j]) > 100 else batch[j]
155
+ })
156
+
157
+ return jsonify({"results": results})
158
+
159
+ except Exception as e:
160
+ return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500
161
 
162
+ @app.route("/predict_get", methods=['GET'])
163
+ def predict_get():
164
  try:
165
+ if model is None or tokenizer is None:
166
+ return jsonify({"error": "Model not loaded properly"}), 500
167
+
168
  code = request.args.get("code")
169
  if not code:
170
  return jsonify({"error": "Missing 'code' URL parameter"}), 400
171
+
172
+ score = predict_vulnerability(code)
173
+
174
+ return jsonify({
175
+ "score": score,
176
+ "vulnerability_level": get_vulnerability_level(score),
177
+ "code_preview": code[:200] + "..." if len(code) > 200 else code
178
+ })
179
+
180
+ except Exception as e:
181
+ return jsonify({"error": f"Prediction error: {str(e)}"}), 500
182
 
183
+ def predict_vulnerability(code):
184
+ dynamic_length = min(max(len(code.split()) * 2, 128), 512)
185
+
186
+ inputs = tokenizer(
187
+ code,
188
+ truncation=True,
189
+ padding='max_length',
190
+ max_length=dynamic_length,
191
+ return_tensors='pt'
192
+ )
193
+
194
+ inputs = {k: v.to(device) for k, v in inputs.items()}
195
+
196
+ with torch.no_grad():
197
+ with torch.cuda.amp.autocast() if device.type == 'cuda' else torch.no_grad():
198
  outputs = model(**inputs)
199
+
200
+ if hasattr(outputs, 'logits'):
201
+ score = torch.sigmoid(outputs.logits).cpu().item()
202
+ else:
203
+ score = torch.sigmoid(outputs[0]).cpu().item()
204
+
205
+ return round(score, 4)
206
 
207
+ def predict_vulnerability_batch(codes):
208
+ max_len = max([len(code.split()) * 2 for code in codes])
209
+ dynamic_length = min(max(max_len, 128), 512)
210
+
211
+ inputs = tokenizer(
212
+ codes,
213
+ truncation=True,
214
+ padding='max_length',
215
+ max_length=dynamic_length,
216
+ return_tensors='pt'
217
+ )
218
+
219
+ inputs = {k: v.to(device) for k, v in inputs.items()}
220
+
221
+ with torch.no_grad():
222
+ with torch.cuda.amp.autocast() if device.type == 'cuda' else torch.no_grad():
223
+ outputs = model(**inputs)
224
+
225
+ if hasattr(outputs, 'logits'):
226
+ scores = torch.sigmoid(outputs.logits).cpu().numpy()
227
+ else:
228
+ scores = torch.sigmoid(outputs[0]).cpu().numpy()
229
+
230
+ return [round(float(score), 4) for score in scores.flatten()]
231
 
232
+ def get_vulnerability_level(score):
233
+ if score < 0.3:
234
+ return "Low"
235
+ elif score < 0.7:
236
+ return "Medium"
237
+ else:
238
+ return "High"
239
 
240
+ @app.route("/health", methods=['GET'])
241
+ def health_check():
242
+ return jsonify({
243
+ "status": "healthy",
244
+ "model_loaded": model is not None,
245
+ "tokenizer_loaded": tokenizer is not None,
246
+ "device": str(device) if device else "unknown"
247
+ })
248
 
249
  if __name__ == "__main__":
250
+ app.run(host="0.0.0.0", port=7860, debug=False, threaded=True)