from flask import Flask, request, jsonify import torch from transformers import RobertaTokenizer, RobertaForSequenceClassification import os from functools import lru_cache app = Flask(__name__) model = None tokenizer = None device = None def setup_device(): if torch.cuda.is_available(): return torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return torch.device('mps') else: return torch.device('cpu') def load_tokenizer(): try: tokenizer = RobertaTokenizer.from_pretrained('./tokenizer_vulnerability') tokenizer.model_max_length = 512 return tokenizer except Exception as e: print(f"Error loading tokenizer: {e}") return RobertaTokenizer.from_pretrained('microsoft/codebert-base') def load_model(): global device device = setup_device() print(f"Using device: {device}") try: checkpoint = torch.load("codebert_vulnerability_scorer.pth", map_location=device) if 'config' in checkpoint: from transformers import RobertaConfig config = RobertaConfig.from_dict(checkpoint['config']) model = RobertaForSequenceClassification(config) else: model = RobertaForSequenceClassification.from_pretrained( 'microsoft/codebert-base', num_labels=1 ) if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) model.to(device) model.eval() if device.type == 'cuda': model.half() return model except Exception as e: print(f"Error loading model: {e}") raise e @lru_cache(maxsize=1000) def cached_tokenize(code_hash, max_length): code = code_hash return tokenizer( code, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt' ) try: print("Loading tokenizer...") tokenizer = load_tokenizer() print("Tokenizer loaded successfully!") print("Loading model...") model = load_model() print("Model loaded successfully!") except Exception as e: print(f"Error during initialization: {str(e)}") tokenizer = None model = None @app.route("/", methods=['GET']) def home(): return jsonify({ "message": "CodeBERT Vulnerability Scorer API", "status": "Model loaded" if model is not None else "Model not loaded", "device": str(device) if device else "unknown", "endpoints": { "/predict": "POST with JSON body containing 'code' field", "/predict_batch": "POST with JSON body containing 'codes' array", "/predict_get": "GET with 'code' URL parameter" } }) @app.route("/predict", methods=['POST']) def predict_post(): try: if model is None or tokenizer is None: return jsonify({"error": "Model not loaded properly"}), 500 data = request.get_json() if not data or 'code' not in data: return jsonify({"error": "Missing 'code' field in JSON body"}), 400 code = data['code'] if not code or not isinstance(code, str): return jsonify({"error": "'code' field must be a non-empty string"}), 400 score = predict_vulnerability(code) return jsonify({ "score": score, "vulnerability_level": get_vulnerability_level(score), "code_preview": code[:200] + "..." if len(code) > 200 else code }) except Exception as e: return jsonify({"error": f"Prediction error: {str(e)}"}), 500 @app.route("/predict_batch", methods=['POST']) def predict_batch(): try: if model is None or tokenizer is None: return jsonify({"error": "Model not loaded properly"}), 500 data = request.get_json() if not data or 'codes' not in data: return jsonify({"error": "Missing 'codes' field in JSON body"}), 400 codes = data['codes'] if not isinstance(codes, list) or len(codes) == 0: return jsonify({"error": "'codes' must be a non-empty array"}), 400 batch_size = min(len(codes), 16) results = [] for i in range(0, len(codes), batch_size): batch = codes[i:i+batch_size] scores = predict_vulnerability_batch(batch) for j, score in enumerate(scores): results.append({ "index": i + j, "score": score, "vulnerability_level": get_vulnerability_level(score), "code_preview": batch[j][:100] + "..." if len(batch[j]) > 100 else batch[j] }) return jsonify({"results": results}) except Exception as e: return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500 @app.route("/predict_get", methods=['GET']) def predict_get(): try: if model is None or tokenizer is None: return jsonify({"error": "Model not loaded properly"}), 500 code = request.args.get("code") if not code: return jsonify({"error": "Missing 'code' URL parameter"}), 400 score = predict_vulnerability(code) return jsonify({ "score": score, "vulnerability_level": get_vulnerability_level(score), "code_preview": code[:200] + "..." if len(code) > 200 else code }) except Exception as e: return jsonify({"error": f"Prediction error: {str(e)}"}), 500 def predict_vulnerability(code): dynamic_length = min(max(len(code.split()) * 2, 128), 512) inputs = tokenizer( code, truncation=True, padding='max_length', max_length=dynamic_length, return_tensors='pt' ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): with torch.cuda.amp.autocast() if device.type == 'cuda' else torch.no_grad(): outputs = model(**inputs) if hasattr(outputs, 'logits'): score = torch.sigmoid(outputs.logits).cpu().item() else: score = torch.sigmoid(outputs[0]).cpu().item() return round(score, 4) def predict_vulnerability_batch(codes): max_len = max([len(code.split()) * 2 for code in codes]) dynamic_length = min(max(max_len, 128), 512) inputs = tokenizer( codes, truncation=True, padding='max_length', max_length=dynamic_length, return_tensors='pt' ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): with torch.cuda.amp.autocast() if device.type == 'cuda' else torch.no_grad(): outputs = model(**inputs) if hasattr(outputs, 'logits'): scores = torch.sigmoid(outputs.logits).cpu().numpy() else: scores = torch.sigmoid(outputs[0]).cpu().numpy() return [round(float(score), 4) for score in scores.flatten()] def get_vulnerability_level(score): if score < 0.3: return "Low" elif score < 0.7: return "Medium" else: return "High" @app.route("/health", methods=['GET']) def health_check(): return jsonify({ "status": "healthy", "model_loaded": model is not None, "tokenizer_loaded": tokenizer is not None, "device": str(device) if device else "unknown" }) if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=False, threaded=True)