Kareem94 commited on
Commit
0d168b4
·
verified ·
1 Parent(s): 6d0481d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +200 -71
main.py CHANGED
@@ -2,6 +2,7 @@ from flask import Flask, request, jsonify
2
  import torch
3
  from transformers import RobertaTokenizer, RobertaForSequenceClassification
4
  import os
 
5
  from functools import lru_cache
6
 
7
  app = Flask(__name__)
@@ -25,7 +26,11 @@ def load_tokenizer():
25
  return tokenizer
26
  except Exception as e:
27
  print(f"Error loading tokenizer: {e}")
28
- return RobertaTokenizer.from_pretrained('microsoft/codebert-base')
 
 
 
 
29
 
30
  def load_model():
31
  global device
@@ -60,27 +65,27 @@ def load_model():
60
 
61
  except Exception as e:
62
  print(f"Error loading model: {e}")
63
- raise e
64
 
65
- @lru_cache(maxsize=1000)
66
- def cached_tokenize(code_hash, max_length):
67
- code = code_hash
68
- return tokenizer(
69
- code,
70
- truncation=True,
71
- padding='max_length',
72
- max_length=max_length,
73
- return_tensors='pt'
74
- )
75
 
76
  try:
77
  print("Loading tokenizer...")
78
  tokenizer = load_tokenizer()
79
- print("Tokenizer loaded successfully!")
 
 
 
80
 
81
  print("Loading model...")
82
  model = load_model()
83
- print("Model loaded successfully!")
 
 
 
84
 
85
  except Exception as e:
86
  print(f"Error during initialization: {str(e)}")
@@ -111,75 +116,199 @@ def predict_batch():
111
  codes = data['codes']
112
  if not isinstance(codes, list) or len(codes) == 0:
113
  return jsonify({"error": "'codes' must be a non-empty array"}), 400
114
-
115
- batch_size = min(len(codes), 16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  results = []
117
 
118
- for i in range(0, len(codes), batch_size):
119
- batch = codes[i:i+batch_size]
120
- scores = predict_vulnerability_batch(batch)
121
-
122
- for j, score in enumerate(scores):
123
- results.append({
124
- "score": score
125
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  return jsonify({"results": results})
128
 
129
  except Exception as e:
 
130
  return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500
131
 
132
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  def predict_vulnerability(code):
135
- dynamic_length = min(max(len(code.split()) * 2, 128), 512)
136
-
137
- inputs = tokenizer(
138
- code,
139
- truncation=True,
140
- padding='max_length',
141
- max_length=dynamic_length,
142
- return_tensors='pt'
143
- )
144
-
145
- inputs = {k: v.to(device) for k, v in inputs.items()}
146
-
147
- with torch.no_grad():
148
- with torch.cuda.amp.autocast() if device.type == 'cuda' else torch.no_grad():
149
- outputs = model(**inputs)
150
-
151
- if hasattr(outputs, 'logits'):
152
- score = torch.sigmoid(outputs.logits).cpu().item()
153
- else:
154
- score = torch.sigmoid(outputs[0]).cpu().item()
155
-
156
- return round(score, 4)
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  def predict_vulnerability_batch(codes):
159
- max_len = max([len(code.split()) * 2 for code in codes])
160
- dynamic_length = min(max(max_len, 128), 512)
161
-
162
- inputs = tokenizer(
163
- codes,
164
- truncation=True,
165
- padding='max_length',
166
- max_length=dynamic_length,
167
- return_tensors='pt'
168
- )
169
-
170
- inputs = {k: v.to(device) for k, v in inputs.items()}
171
-
172
- with torch.no_grad():
173
- with torch.cuda.amp.autocast() if device.type == 'cuda' else torch.no_grad():
174
- outputs = model(**inputs)
175
-
176
- if hasattr(outputs, 'logits'):
177
- scores = torch.sigmoid(outputs.logits).cpu().numpy()
178
- else:
179
- scores = torch.sigmoid(outputs[0]).cpu().numpy()
180
-
181
- return [round(float(score), 4) for score in scores.flatten()]
182
-
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  @app.route("/health", methods=['GET'])
185
  def health_check():
 
2
  import torch
3
  from transformers import RobertaTokenizer, RobertaForSequenceClassification
4
  import os
5
+ import gc
6
  from functools import lru_cache
7
 
8
  app = Flask(__name__)
 
26
  return tokenizer
27
  except Exception as e:
28
  print(f"Error loading tokenizer: {e}")
29
+ try:
30
+ return RobertaTokenizer.from_pretrained('microsoft/codebert-base')
31
+ except Exception as e2:
32
+ print(f"Fallback tokenizer failed: {e2}")
33
+ return None
34
 
35
  def load_model():
36
  global device
 
65
 
66
  except Exception as e:
67
  print(f"Error loading model: {e}")
68
+ return None
69
 
70
+ def cleanup_gpu_memory():
71
+ if device and device.type == 'cuda':
72
+ torch.cuda.empty_cache()
73
+ gc.collect()
 
 
 
 
 
 
74
 
75
  try:
76
  print("Loading tokenizer...")
77
  tokenizer = load_tokenizer()
78
+ if tokenizer:
79
+ print("Tokenizer loaded successfully!")
80
+ else:
81
+ print("Failed to load tokenizer!")
82
 
83
  print("Loading model...")
84
  model = load_model()
85
+ if model:
86
+ print("Model loaded successfully!")
87
+ else:
88
+ print("Failed to load model!")
89
 
90
  except Exception as e:
91
  print(f"Error during initialization: {str(e)}")
 
116
  codes = data['codes']
117
  if not isinstance(codes, list) or len(codes) == 0:
118
  return jsonify({"error": "'codes' must be a non-empty array"}), 400
119
+
120
+ if len(codes) > 100:
121
+ return jsonify({"error": "Too many codes. Maximum 100 allowed."}), 400
122
+
123
+ validated_codes = []
124
+ for i, code in enumerate(codes):
125
+ if not isinstance(code, str):
126
+ return jsonify({"error": f"Code at index {i} must be a string"}), 400
127
+ if len(code.strip()) == 0:
128
+ validated_codes.append("# empty code")
129
+ elif len(code) > 50000:
130
+ return jsonify({"error": f"Code at index {i} too long. Maximum 50000 characters."}), 400
131
+ else:
132
+ validated_codes.append(code.strip())
133
+
134
+ if len(validated_codes) == 1:
135
+ score = predict_vulnerability_with_chunking(validated_codes[0])
136
+ cleanup_gpu_memory()
137
+ return jsonify({"results": [{"score": score}]})
138
+
139
+ batch_size = min(len(validated_codes), 16)
140
  results = []
141
 
142
+ try:
143
+ for i in range(0, len(validated_codes), batch_size):
144
+ batch = validated_codes[i:i+batch_size]
145
+
146
+ long_codes = []
147
+ short_codes = []
148
+ long_indices = []
149
+ short_indices = []
150
+
151
+ for idx, code in enumerate(batch):
152
+ try:
153
+ tokens = tokenizer.encode(code, add_special_tokens=False, max_length=1000, truncation=True)
154
+ if len(tokens) > 450:
155
+ long_codes.append(code)
156
+ long_indices.append(i + idx)
157
+ else:
158
+ short_codes.append(code)
159
+ short_indices.append(i + idx)
160
+ except Exception as e:
161
+ print(f"Tokenization error for code {i + idx}: {e}")
162
+ short_codes.append(code)
163
+ short_indices.append(i + idx)
164
+
165
+ batch_scores = [0.0] * len(batch)
166
+
167
+ if short_codes:
168
+ try:
169
+ short_scores = predict_vulnerability_batch(short_codes)
170
+ for j, score in enumerate(short_scores):
171
+ local_idx = short_indices[j] - i
172
+ batch_scores[local_idx] = score
173
+ except Exception as e:
174
+ print(f"Batch prediction error: {e}")
175
+ for j in range(len(short_codes)):
176
+ local_idx = short_indices[j] - i
177
+ batch_scores[local_idx] = 0.0
178
+
179
+ for j, code in enumerate(long_codes):
180
+ try:
181
+ score = predict_vulnerability_with_chunking(code)
182
+ local_idx = long_indices[j] - i
183
+ batch_scores[local_idx] = score
184
+ except Exception as e:
185
+ print(f"Chunking prediction error: {e}")
186
+ local_idx = long_indices[j] - i
187
+ batch_scores[local_idx] = 0.0
188
+
189
+ for score in batch_scores:
190
+ results.append({"score": score})
191
+
192
+ cleanup_gpu_memory()
193
+
194
+ except Exception as e:
195
+ cleanup_gpu_memory()
196
+ raise e
197
 
198
  return jsonify({"results": results})
199
 
200
  except Exception as e:
201
+ cleanup_gpu_memory()
202
  return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500
203
 
204
+ def predict_vulnerability_with_chunking(code):
205
+ try:
206
+ if not code or len(code.strip()) == 0:
207
+ return 0.0
208
+
209
+ tokens = tokenizer.encode(code, add_special_tokens=False, max_length=2000, truncation=True)
210
+
211
+ if len(tokens) <= 450:
212
+ return predict_vulnerability(code)
213
+
214
+ chunk_size = 400
215
+ overlap = 50
216
+ max_score = 0.0
217
+
218
+ for start in range(0, len(tokens), chunk_size - overlap):
219
+ end = min(start + chunk_size, len(tokens))
220
+ chunk_tokens = tokens[start:end]
221
+
222
+ try:
223
+ chunk_code = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
224
+ if chunk_code.strip():
225
+ score = predict_vulnerability(chunk_code)
226
+ max_score = max(max_score, score)
227
+ except Exception as e:
228
+ print(f"Chunk processing error: {e}")
229
+ continue
230
+
231
+ if end >= len(tokens):
232
+ break
233
+
234
+ return max_score
235
+
236
+ except Exception as e:
237
+ print(f"Chunking error: {e}")
238
+ return 0.0
239
 
240
  def predict_vulnerability(code):
241
+ try:
242
+ if not code or len(code.strip()) == 0:
243
+ return 0.0
244
+
245
+ dynamic_length = min(max(len(code.split()) * 2, 128), 512)
246
+
247
+ inputs = tokenizer(
248
+ code,
249
+ truncation=True,
250
+ padding='max_length',
251
+ max_length=dynamic_length,
252
+ return_tensors='pt'
253
+ )
254
+
255
+ inputs = {k: v.to(device) for k, v in inputs.items()}
256
+
257
+ with torch.no_grad():
258
+ if device.type == 'cuda':
259
+ with torch.cuda.amp.autocast():
260
+ outputs = model(**inputs)
261
+ else:
262
+ outputs = model(**inputs)
263
+
264
+ if hasattr(outputs, 'logits'):
265
+ score = torch.sigmoid(outputs.logits).cpu().item()
266
+ else:
267
+ score = torch.sigmoid(outputs[0]).cpu().item()
268
+
269
+ return round(max(0.0, min(1.0, score)), 4)
270
+
271
+ except Exception as e:
272
+ print(f"Single prediction error: {e}")
273
+ return 0.0
274
 
275
  def predict_vulnerability_batch(codes):
276
+ try:
277
+ if not codes or len(codes) == 0:
278
+ return []
279
+
280
+ filtered_codes = [code if code and code.strip() else "# empty" for code in codes]
281
+
282
+ max_len = max([len(code.split()) * 2 for code in filtered_codes if code])
283
+ dynamic_length = min(max(max_len, 128), 512)
284
+
285
+ inputs = tokenizer(
286
+ filtered_codes,
287
+ truncation=True,
288
+ padding='max_length',
289
+ max_length=dynamic_length,
290
+ return_tensors='pt'
291
+ )
292
+
293
+ inputs = {k: v.to(device) for k, v in inputs.items()}
294
+
295
+ with torch.no_grad():
296
+ if device.type == 'cuda':
297
+ with torch.cuda.amp.autocast():
298
+ outputs = model(**inputs)
299
+ else:
300
+ outputs = model(**inputs)
301
+
302
+ if hasattr(outputs, 'logits'):
303
+ scores = torch.sigmoid(outputs.logits).cpu().numpy()
304
+ else:
305
+ scores = torch.sigmoid(outputs[0]).cpu().numpy()
306
+
307
+ return [round(max(0.0, min(1.0, float(score))), 4) for score in scores.flatten()]
308
+
309
+ except Exception as e:
310
+ print(f"Batch prediction error: {e}")
311
+ return [0.0] * len(codes)
312
 
313
  @app.route("/health", methods=['GET'])
314
  def health_check():