|
from flask import Flask, request, jsonify |
|
import torch |
|
from transformers import RobertaTokenizer, RobertaForSequenceClassification |
|
import os |
|
import gc |
|
from functools import lru_cache |
|
|
|
app = Flask(__name__) |
|
|
|
model = None |
|
tokenizer = None |
|
device = None |
|
|
|
def setup_device(): |
|
if torch.cuda.is_available(): |
|
return torch.device('cuda') |
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
|
return torch.device('mps') |
|
else: |
|
return torch.device('cpu') |
|
|
|
def load_tokenizer(): |
|
try: |
|
tokenizer = RobertaTokenizer.from_pretrained('./tokenizer_vulnerability') |
|
tokenizer.model_max_length = 512 |
|
return tokenizer |
|
except Exception as e: |
|
print(f"Error loading tokenizer: {e}") |
|
try: |
|
return RobertaTokenizer.from_pretrained('microsoft/codebert-base') |
|
except Exception as e2: |
|
print(f"Fallback tokenizer failed: {e2}") |
|
return None |
|
|
|
def load_model(): |
|
global device |
|
device = setup_device() |
|
print(f"Using device: {device}") |
|
|
|
try: |
|
checkpoint = torch.load("codebert_vulnerability_scorer.pth", map_location=device) |
|
|
|
if 'config' in checkpoint: |
|
from transformers import RobertaConfig |
|
config = RobertaConfig.from_dict(checkpoint['config']) |
|
model = RobertaForSequenceClassification(config) |
|
else: |
|
model = RobertaForSequenceClassification.from_pretrained( |
|
'microsoft/codebert-base', |
|
num_labels=1 |
|
) |
|
|
|
if 'model_state_dict' in checkpoint: |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
else: |
|
model.load_state_dict(checkpoint) |
|
|
|
model.to(device) |
|
model.eval() |
|
|
|
if device.type == 'cuda': |
|
model.half() |
|
|
|
return model |
|
|
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
return None |
|
|
|
def cleanup_gpu_memory(): |
|
if device and device.type == 'cuda': |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
try: |
|
print("Loading tokenizer...") |
|
tokenizer = load_tokenizer() |
|
if tokenizer: |
|
print("Tokenizer loaded successfully!") |
|
else: |
|
print("Failed to load tokenizer!") |
|
|
|
print("Loading model...") |
|
model = load_model() |
|
if model: |
|
print("Model loaded successfully!") |
|
else: |
|
print("Failed to load model!") |
|
|
|
except Exception as e: |
|
print(f"Error during initialization: {str(e)}") |
|
tokenizer = None |
|
model = None |
|
|
|
@app.route("/", methods=['GET']) |
|
def home(): |
|
return jsonify({ |
|
"message": "CodeBERT Vulnerability Evalutor API", |
|
"status": "Model loaded" if model is not None else "Model not loaded", |
|
"device": str(device) if device else "unknown", |
|
"endpoints": { |
|
"/predict": "POST with JSON body containing 'codes' array" |
|
} |
|
}) |
|
|
|
@app.route("/predict", methods=['POST']) |
|
def predict_batch(): |
|
try: |
|
if model is None or tokenizer is None: |
|
return jsonify({"error": "Model not loaded properly"}), 500 |
|
|
|
data = request.get_json() |
|
if not data or 'codes' not in data: |
|
return jsonify({"error": "Missing 'codes' field in JSON body"}), 400 |
|
|
|
codes = data['codes'] |
|
if not isinstance(codes, list) or len(codes) == 0: |
|
return jsonify({"error": "'codes' must be a non-empty array"}), 400 |
|
|
|
if len(codes) > 100: |
|
return jsonify({"error": "Too many codes. Maximum 100 allowed."}), 400 |
|
|
|
validated_codes = [] |
|
for i, code in enumerate(codes): |
|
if not isinstance(code, str): |
|
return jsonify({"error": f"Code at index {i} must be a string"}), 400 |
|
if len(code.strip()) == 0: |
|
validated_codes.append("# empty code") |
|
elif len(code) > 50000: |
|
return jsonify({"error": f"Code at index {i} too long. Maximum 50000 characters."}), 400 |
|
else: |
|
validated_codes.append(code.strip()) |
|
|
|
if len(validated_codes) == 1: |
|
score = predict_vulnerability_with_chunking(validated_codes[0]) |
|
cleanup_gpu_memory() |
|
return jsonify({"results": [{"score": 1.0 - score}]}) |
|
|
|
batch_size = min(len(validated_codes), 16) |
|
results = [] |
|
|
|
try: |
|
for i in range(0, len(validated_codes), batch_size): |
|
batch = validated_codes[i:i+batch_size] |
|
|
|
long_codes = [] |
|
short_codes = [] |
|
long_indices = [] |
|
short_indices = [] |
|
|
|
for idx, code in enumerate(batch): |
|
try: |
|
tokens = tokenizer.encode(code, add_special_tokens=False, max_length=1000, truncation=True) |
|
if len(tokens) > 450: |
|
long_codes.append(code) |
|
long_indices.append(i + idx) |
|
else: |
|
short_codes.append(code) |
|
short_indices.append(i + idx) |
|
except Exception as e: |
|
print(f"Tokenization error for code {i + idx}: {e}") |
|
short_codes.append(code) |
|
short_indices.append(i + idx) |
|
|
|
batch_scores = [0.0] * len(batch) |
|
|
|
if short_codes: |
|
try: |
|
short_scores = predict_vulnerability_batch(short_codes) |
|
for j, score in enumerate(short_scores): |
|
local_idx = short_indices[j] - i |
|
batch_scores[local_idx] = score |
|
except Exception as e: |
|
print(f"Batch prediction error: {e}") |
|
for j in range(len(short_codes)): |
|
local_idx = short_indices[j] - i |
|
batch_scores[local_idx] = 0.0 |
|
|
|
for j, code in enumerate(long_codes): |
|
try: |
|
score = predict_vulnerability_with_chunking(code) |
|
local_idx = long_indices[j] - i |
|
batch_scores[local_idx] = score |
|
except Exception as e: |
|
print(f"Chunking prediction error: {e}") |
|
local_idx = long_indices[j] - i |
|
batch_scores[local_idx] = 0.0 |
|
|
|
for score in batch_scores: |
|
results.append({"score": round(1.0 - score,4)}) |
|
|
|
cleanup_gpu_memory() |
|
|
|
except Exception as e: |
|
cleanup_gpu_memory() |
|
raise e |
|
|
|
return jsonify({"results": results}) |
|
|
|
except Exception as e: |
|
cleanup_gpu_memory() |
|
return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500 |
|
|
|
def predict_vulnerability_with_chunking(code): |
|
try: |
|
if not code or len(code.strip()) == 0: |
|
return 0.0 |
|
|
|
tokens = tokenizer.encode(code, add_special_tokens=False, max_length=2000, truncation=True) |
|
|
|
if len(tokens) <= 450: |
|
return predict_vulnerability(code) |
|
|
|
chunk_size = 400 |
|
overlap = 50 |
|
max_score = 0.0 |
|
|
|
for start in range(0, len(tokens), chunk_size - overlap): |
|
end = min(start + chunk_size, len(tokens)) |
|
chunk_tokens = tokens[start:end] |
|
|
|
try: |
|
chunk_code = tokenizer.decode(chunk_tokens, skip_special_tokens=True) |
|
if chunk_code.strip(): |
|
score = predict_vulnerability(chunk_code) |
|
max_score = max(max_score, score) |
|
except Exception as e: |
|
print(f"Chunk processing error: {e}") |
|
continue |
|
|
|
if end >= len(tokens): |
|
break |
|
|
|
return max_score |
|
|
|
except Exception as e: |
|
print(f"Chunking error: {e}") |
|
return 0.0 |
|
|
|
def predict_vulnerability(code): |
|
try: |
|
if not code or len(code.strip()) == 0: |
|
return 0.0 |
|
|
|
dynamic_length = min(max(len(code.split()) * 2, 128), 512) |
|
|
|
inputs = tokenizer( |
|
code, |
|
truncation=True, |
|
padding='max_length', |
|
max_length=dynamic_length, |
|
return_tensors='pt' |
|
) |
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
if device.type == 'cuda': |
|
with torch.cuda.amp.autocast(): |
|
outputs = model(**inputs) |
|
else: |
|
outputs = model(**inputs) |
|
|
|
amplified_logits = 2.0 * outputs.logits |
|
score = torch.sigmoid(amplified_logits).cpu().item() |
|
return round(max(0.0, min(1.0, score)), 4) |
|
|
|
except Exception as e: |
|
print(f"Single prediction error: {e}") |
|
return 0.0 |
|
|
|
def predict_vulnerability_batch(codes): |
|
try: |
|
if not codes or len(codes) == 0: |
|
return [] |
|
|
|
filtered_codes = [code if code and code.strip() else "# empty" for code in codes] |
|
|
|
max_len = max([len(code.split()) * 2 for code in filtered_codes if code]) |
|
dynamic_length = min(max(max_len, 128), 512) |
|
|
|
inputs = tokenizer( |
|
filtered_codes, |
|
truncation=True, |
|
padding='max_length', |
|
max_length=dynamic_length, |
|
return_tensors='pt' |
|
) |
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
if device.type == 'cuda': |
|
with torch.cuda.amp.autocast(): |
|
outputs = model(**inputs) |
|
else: |
|
outputs = model(**inputs) |
|
|
|
amplified_logits = 2.0 * outputs.logits |
|
scores = torch.sigmoid(amplified_logits).cpu().numpy() |
|
|
|
return [round(max(0.0, min(1.0, float(score))), 4) for score in scores.flatten()] |
|
|
|
except Exception as e: |
|
print(f"Batch prediction error: {e}") |
|
return [0.0] * len(codes) |
|
|
|
@app.route("/health", methods=['GET']) |
|
def health_check(): |
|
return jsonify({ |
|
"status": "healthy", |
|
"model_loaded": model is not None, |
|
"tokenizer_loaded": tokenizer is not None, |
|
"device": str(device) if device else "unknown" |
|
}) |
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", port=7860, debug=False, threaded=True) |