Kareem94 commited on
Commit
6d0481d
·
verified ·
1 Parent(s): 76cad5e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +4 -60
main.py CHANGED
@@ -90,42 +90,15 @@ except Exception as e:
90
  @app.route("/", methods=['GET'])
91
  def home():
92
  return jsonify({
93
- "message": "CodeBERT Vulnerability Scorer API",
94
  "status": "Model loaded" if model is not None else "Model not loaded",
95
  "device": str(device) if device else "unknown",
96
  "endpoints": {
97
- "/predict": "POST with JSON body containing 'code' field",
98
- "/predict_batch": "POST with JSON body containing 'codes' array",
99
- "/predict_get": "GET with 'code' URL parameter"
100
  }
101
  })
102
 
103
  @app.route("/predict", methods=['POST'])
104
- def predict_post():
105
- try:
106
- if model is None or tokenizer is None:
107
- return jsonify({"error": "Model not loaded properly"}), 500
108
-
109
- data = request.get_json()
110
- if not data or 'code' not in data:
111
- return jsonify({"error": "Missing 'code' field in JSON body"}), 400
112
-
113
- code = data['code']
114
- if not code or not isinstance(code, str):
115
- return jsonify({"error": "'code' field must be a non-empty string"}), 400
116
-
117
- score = predict_vulnerability(code)
118
-
119
- return jsonify({
120
- "score": score,
121
- "vulnerability_level": get_vulnerability_level(score),
122
- "code_preview": code[:200] + "..." if len(code) > 200 else code
123
- })
124
-
125
- except Exception as e:
126
- return jsonify({"error": f"Prediction error: {str(e)}"}), 500
127
-
128
- @app.route("/predict_batch", methods=['POST'])
129
  def predict_batch():
130
  try:
131
  if model is None or tokenizer is None:
@@ -148,10 +121,7 @@ def predict_batch():
148
 
149
  for j, score in enumerate(scores):
150
  results.append({
151
- "index": i + j,
152
- "score": score,
153
- "vulnerability_level": get_vulnerability_level(score),
154
- "code_preview": batch[j][:100] + "..." if len(batch[j]) > 100 else batch[j]
155
  })
156
 
157
  return jsonify({"results": results})
@@ -159,26 +129,7 @@ def predict_batch():
159
  except Exception as e:
160
  return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500
161
 
162
- @app.route("/predict_get", methods=['GET'])
163
- def predict_get():
164
- try:
165
- if model is None or tokenizer is None:
166
- return jsonify({"error": "Model not loaded properly"}), 500
167
-
168
- code = request.args.get("code")
169
- if not code:
170
- return jsonify({"error": "Missing 'code' URL parameter"}), 400
171
-
172
- score = predict_vulnerability(code)
173
-
174
- return jsonify({
175
- "score": score,
176
- "vulnerability_level": get_vulnerability_level(score),
177
- "code_preview": code[:200] + "..." if len(code) > 200 else code
178
- })
179
-
180
- except Exception as e:
181
- return jsonify({"error": f"Prediction error: {str(e)}"}), 500
182
 
183
  def predict_vulnerability(code):
184
  dynamic_length = min(max(len(code.split()) * 2, 128), 512)
@@ -229,13 +180,6 @@ def predict_vulnerability_batch(codes):
229
 
230
  return [round(float(score), 4) for score in scores.flatten()]
231
 
232
- def get_vulnerability_level(score):
233
- if score < 0.3:
234
- return "Low"
235
- elif score < 0.7:
236
- return "Medium"
237
- else:
238
- return "High"
239
 
240
  @app.route("/health", methods=['GET'])
241
  def health_check():
 
90
  @app.route("/", methods=['GET'])
91
  def home():
92
  return jsonify({
93
+ "message": "CodeBERT Vulnerability Evalutor API",
94
  "status": "Model loaded" if model is not None else "Model not loaded",
95
  "device": str(device) if device else "unknown",
96
  "endpoints": {
97
+ "/predict": "POST with JSON body containing 'codes' array"
 
 
98
  }
99
  })
100
 
101
  @app.route("/predict", methods=['POST'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def predict_batch():
103
  try:
104
  if model is None or tokenizer is None:
 
121
 
122
  for j, score in enumerate(scores):
123
  results.append({
124
+ "score": score
 
 
 
125
  })
126
 
127
  return jsonify({"results": results})
 
129
  except Exception as e:
130
  return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500
131
 
132
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  def predict_vulnerability(code):
135
  dynamic_length = min(max(len(code.split()) * 2, 128), 512)
 
180
 
181
  return [round(float(score), 4) for score in scores.flatten()]
182
 
 
 
 
 
 
 
 
183
 
184
  @app.route("/health", methods=['GET'])
185
  def health_check():