Update main.py
Browse files
main.py
CHANGED
@@ -90,42 +90,15 @@ except Exception as e:
|
|
90 |
@app.route("/", methods=['GET'])
|
91 |
def home():
|
92 |
return jsonify({
|
93 |
-
"message": "CodeBERT Vulnerability
|
94 |
"status": "Model loaded" if model is not None else "Model not loaded",
|
95 |
"device": str(device) if device else "unknown",
|
96 |
"endpoints": {
|
97 |
-
"/predict": "POST with JSON body containing '
|
98 |
-
"/predict_batch": "POST with JSON body containing 'codes' array",
|
99 |
-
"/predict_get": "GET with 'code' URL parameter"
|
100 |
}
|
101 |
})
|
102 |
|
103 |
@app.route("/predict", methods=['POST'])
|
104 |
-
def predict_post():
|
105 |
-
try:
|
106 |
-
if model is None or tokenizer is None:
|
107 |
-
return jsonify({"error": "Model not loaded properly"}), 500
|
108 |
-
|
109 |
-
data = request.get_json()
|
110 |
-
if not data or 'code' not in data:
|
111 |
-
return jsonify({"error": "Missing 'code' field in JSON body"}), 400
|
112 |
-
|
113 |
-
code = data['code']
|
114 |
-
if not code or not isinstance(code, str):
|
115 |
-
return jsonify({"error": "'code' field must be a non-empty string"}), 400
|
116 |
-
|
117 |
-
score = predict_vulnerability(code)
|
118 |
-
|
119 |
-
return jsonify({
|
120 |
-
"score": score,
|
121 |
-
"vulnerability_level": get_vulnerability_level(score),
|
122 |
-
"code_preview": code[:200] + "..." if len(code) > 200 else code
|
123 |
-
})
|
124 |
-
|
125 |
-
except Exception as e:
|
126 |
-
return jsonify({"error": f"Prediction error: {str(e)}"}), 500
|
127 |
-
|
128 |
-
@app.route("/predict_batch", methods=['POST'])
|
129 |
def predict_batch():
|
130 |
try:
|
131 |
if model is None or tokenizer is None:
|
@@ -148,10 +121,7 @@ def predict_batch():
|
|
148 |
|
149 |
for j, score in enumerate(scores):
|
150 |
results.append({
|
151 |
-
"
|
152 |
-
"score": score,
|
153 |
-
"vulnerability_level": get_vulnerability_level(score),
|
154 |
-
"code_preview": batch[j][:100] + "..." if len(batch[j]) > 100 else batch[j]
|
155 |
})
|
156 |
|
157 |
return jsonify({"results": results})
|
@@ -159,26 +129,7 @@ def predict_batch():
|
|
159 |
except Exception as e:
|
160 |
return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500
|
161 |
|
162 |
-
|
163 |
-
def predict_get():
|
164 |
-
try:
|
165 |
-
if model is None or tokenizer is None:
|
166 |
-
return jsonify({"error": "Model not loaded properly"}), 500
|
167 |
-
|
168 |
-
code = request.args.get("code")
|
169 |
-
if not code:
|
170 |
-
return jsonify({"error": "Missing 'code' URL parameter"}), 400
|
171 |
-
|
172 |
-
score = predict_vulnerability(code)
|
173 |
-
|
174 |
-
return jsonify({
|
175 |
-
"score": score,
|
176 |
-
"vulnerability_level": get_vulnerability_level(score),
|
177 |
-
"code_preview": code[:200] + "..." if len(code) > 200 else code
|
178 |
-
})
|
179 |
-
|
180 |
-
except Exception as e:
|
181 |
-
return jsonify({"error": f"Prediction error: {str(e)}"}), 500
|
182 |
|
183 |
def predict_vulnerability(code):
|
184 |
dynamic_length = min(max(len(code.split()) * 2, 128), 512)
|
@@ -229,13 +180,6 @@ def predict_vulnerability_batch(codes):
|
|
229 |
|
230 |
return [round(float(score), 4) for score in scores.flatten()]
|
231 |
|
232 |
-
def get_vulnerability_level(score):
|
233 |
-
if score < 0.3:
|
234 |
-
return "Low"
|
235 |
-
elif score < 0.7:
|
236 |
-
return "Medium"
|
237 |
-
else:
|
238 |
-
return "High"
|
239 |
|
240 |
@app.route("/health", methods=['GET'])
|
241 |
def health_check():
|
|
|
90 |
@app.route("/", methods=['GET'])
|
91 |
def home():
|
92 |
return jsonify({
|
93 |
+
"message": "CodeBERT Vulnerability Evalutor API",
|
94 |
"status": "Model loaded" if model is not None else "Model not loaded",
|
95 |
"device": str(device) if device else "unknown",
|
96 |
"endpoints": {
|
97 |
+
"/predict": "POST with JSON body containing 'codes' array"
|
|
|
|
|
98 |
}
|
99 |
})
|
100 |
|
101 |
@app.route("/predict", methods=['POST'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def predict_batch():
|
103 |
try:
|
104 |
if model is None or tokenizer is None:
|
|
|
121 |
|
122 |
for j, score in enumerate(scores):
|
123 |
results.append({
|
124 |
+
"score": score
|
|
|
|
|
|
|
125 |
})
|
126 |
|
127 |
return jsonify({"results": results})
|
|
|
129 |
except Exception as e:
|
130 |
return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500
|
131 |
|
132 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
def predict_vulnerability(code):
|
135 |
dynamic_length = min(max(len(code.split()) * 2, 128), 512)
|
|
|
180 |
|
181 |
return [round(float(score), 4) for score in scores.flatten()]
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
@app.route("/health", methods=['GET'])
|
185 |
def health_check():
|