|
from flask import Flask, request, jsonify |
|
import torch |
|
from transformers import RobertaTokenizer, RobertaForSequenceClassification |
|
import os |
|
from functools import lru_cache |
|
|
|
app = Flask(__name__) |
|
|
|
model = None |
|
tokenizer = None |
|
device = None |
|
|
|
def setup_device(): |
|
if torch.cuda.is_available(): |
|
return torch.device('cuda') |
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
|
return torch.device('mps') |
|
else: |
|
return torch.device('cpu') |
|
|
|
def load_tokenizer(): |
|
try: |
|
tokenizer = RobertaTokenizer.from_pretrained('./tokenizer_vulnerability') |
|
tokenizer.model_max_length = 512 |
|
return tokenizer |
|
except Exception as e: |
|
print(f"Error loading tokenizer: {e}") |
|
return RobertaTokenizer.from_pretrained('microsoft/codebert-base') |
|
|
|
def load_model(): |
|
global device |
|
device = setup_device() |
|
print(f"Using device: {device}") |
|
|
|
try: |
|
checkpoint = torch.load("codebert_vulnerability_scorer.pth", map_location=device) |
|
|
|
if 'config' in checkpoint: |
|
from transformers import RobertaConfig |
|
config = RobertaConfig.from_dict(checkpoint['config']) |
|
model = RobertaForSequenceClassification(config) |
|
else: |
|
model = RobertaForSequenceClassification.from_pretrained( |
|
'microsoft/codebert-base', |
|
num_labels=1 |
|
) |
|
|
|
if 'model_state_dict' in checkpoint: |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
else: |
|
model.load_state_dict(checkpoint) |
|
|
|
model.to(device) |
|
model.eval() |
|
|
|
if device.type == 'cuda': |
|
model.half() |
|
|
|
return model |
|
|
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
raise e |
|
|
|
@lru_cache(maxsize=1000) |
|
def cached_tokenize(code_hash, max_length): |
|
code = code_hash |
|
return tokenizer( |
|
code, |
|
truncation=True, |
|
padding='max_length', |
|
max_length=max_length, |
|
return_tensors='pt' |
|
) |
|
|
|
try: |
|
print("Loading tokenizer...") |
|
tokenizer = load_tokenizer() |
|
print("Tokenizer loaded successfully!") |
|
|
|
print("Loading model...") |
|
model = load_model() |
|
print("Model loaded successfully!") |
|
|
|
except Exception as e: |
|
print(f"Error during initialization: {str(e)}") |
|
tokenizer = None |
|
model = None |
|
|
|
@app.route("/", methods=['GET']) |
|
def home(): |
|
return jsonify({ |
|
"message": "CodeBERT Vulnerability Scorer API", |
|
"status": "Model loaded" if model is not None else "Model not loaded", |
|
"device": str(device) if device else "unknown", |
|
"endpoints": { |
|
"/predict": "POST with JSON body containing 'code' field", |
|
"/predict_batch": "POST with JSON body containing 'codes' array", |
|
"/predict_get": "GET with 'code' URL parameter" |
|
} |
|
}) |
|
|
|
@app.route("/predict", methods=['POST']) |
|
def predict_post(): |
|
try: |
|
if model is None or tokenizer is None: |
|
return jsonify({"error": "Model not loaded properly"}), 500 |
|
|
|
data = request.get_json() |
|
if not data or 'code' not in data: |
|
return jsonify({"error": "Missing 'code' field in JSON body"}), 400 |
|
|
|
code = data['code'] |
|
if not code or not isinstance(code, str): |
|
return jsonify({"error": "'code' field must be a non-empty string"}), 400 |
|
|
|
score = predict_vulnerability(code) |
|
|
|
return jsonify({ |
|
"score": score, |
|
"vulnerability_level": get_vulnerability_level(score), |
|
"code_preview": code[:200] + "..." if len(code) > 200 else code |
|
}) |
|
|
|
except Exception as e: |
|
return jsonify({"error": f"Prediction error: {str(e)}"}), 500 |
|
|
|
@app.route("/predict_batch", methods=['POST']) |
|
def predict_batch(): |
|
try: |
|
if model is None or tokenizer is None: |
|
return jsonify({"error": "Model not loaded properly"}), 500 |
|
|
|
data = request.get_json() |
|
if not data or 'codes' not in data: |
|
return jsonify({"error": "Missing 'codes' field in JSON body"}), 400 |
|
|
|
codes = data['codes'] |
|
if not isinstance(codes, list) or len(codes) == 0: |
|
return jsonify({"error": "'codes' must be a non-empty array"}), 400 |
|
|
|
batch_size = min(len(codes), 16) |
|
results = [] |
|
|
|
for i in range(0, len(codes), batch_size): |
|
batch = codes[i:i+batch_size] |
|
scores = predict_vulnerability_batch(batch) |
|
|
|
for j, score in enumerate(scores): |
|
results.append({ |
|
"index": i + j, |
|
"score": score, |
|
"vulnerability_level": get_vulnerability_level(score), |
|
"code_preview": batch[j][:100] + "..." if len(batch[j]) > 100 else batch[j] |
|
}) |
|
|
|
return jsonify({"results": results}) |
|
|
|
except Exception as e: |
|
return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500 |
|
|
|
@app.route("/predict_get", methods=['GET']) |
|
def predict_get(): |
|
try: |
|
if model is None or tokenizer is None: |
|
return jsonify({"error": "Model not loaded properly"}), 500 |
|
|
|
code = request.args.get("code") |
|
if not code: |
|
return jsonify({"error": "Missing 'code' URL parameter"}), 400 |
|
|
|
score = predict_vulnerability(code) |
|
|
|
return jsonify({ |
|
"score": score, |
|
"vulnerability_level": get_vulnerability_level(score), |
|
"code_preview": code[:200] + "..." if len(code) > 200 else code |
|
}) |
|
|
|
except Exception as e: |
|
return jsonify({"error": f"Prediction error: {str(e)}"}), 500 |
|
|
|
def predict_vulnerability(code): |
|
dynamic_length = min(max(len(code.split()) * 2, 128), 512) |
|
|
|
inputs = tokenizer( |
|
code, |
|
truncation=True, |
|
padding='max_length', |
|
max_length=dynamic_length, |
|
return_tensors='pt' |
|
) |
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
with torch.cuda.amp.autocast() if device.type == 'cuda' else torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
if hasattr(outputs, 'logits'): |
|
score = torch.sigmoid(outputs.logits).cpu().item() |
|
else: |
|
score = torch.sigmoid(outputs[0]).cpu().item() |
|
|
|
return round(score, 4) |
|
|
|
def predict_vulnerability_batch(codes): |
|
max_len = max([len(code.split()) * 2 for code in codes]) |
|
dynamic_length = min(max(max_len, 128), 512) |
|
|
|
inputs = tokenizer( |
|
codes, |
|
truncation=True, |
|
padding='max_length', |
|
max_length=dynamic_length, |
|
return_tensors='pt' |
|
) |
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
with torch.cuda.amp.autocast() if device.type == 'cuda' else torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
if hasattr(outputs, 'logits'): |
|
scores = torch.sigmoid(outputs.logits).cpu().numpy() |
|
else: |
|
scores = torch.sigmoid(outputs[0]).cpu().numpy() |
|
|
|
return [round(float(score), 4) for score in scores.flatten()] |
|
|
|
def get_vulnerability_level(score): |
|
if score < 0.3: |
|
return "Low" |
|
elif score < 0.7: |
|
return "Medium" |
|
else: |
|
return "High" |
|
|
|
@app.route("/health", methods=['GET']) |
|
def health_check(): |
|
return jsonify({ |
|
"status": "healthy", |
|
"model_loaded": model is not None, |
|
"tokenizer_loaded": tokenizer is not None, |
|
"device": str(device) if device else "unknown" |
|
}) |
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", port=7860, debug=False, threaded=True) |