Kareem94 commited on
Commit
621fa3a
·
verified ·
1 Parent(s): 368d5f1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +295 -37
main.py CHANGED
@@ -1,65 +1,323 @@
1
  from flask import Flask, request, jsonify
2
  import torch
3
- from transformers import RobertaTokenizer, RobertaForSequenceClassification, RobertaConfig
4
  import os
 
 
5
 
6
  app = Flask(__name__)
7
 
8
- # Load model and tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def load_model():
10
- # Load saved config and weights
11
- checkpoint = torch.load("codebert_readability_scorer.pth", map_location=torch.device('cpu'))
12
- config = RobertaConfig.from_dict(checkpoint['config'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Initialize model with loaded config
15
- model = RobertaForSequenceClassification(config)
16
- model.load_state_dict(checkpoint['model_state_dict'])
17
- model.eval()
18
- return model
19
 
20
- # Load components
21
  try:
22
- tokenizer = RobertaTokenizer.from_pretrained("./tokenizer_readability")
 
 
 
 
 
 
 
23
  model = load_model()
24
- print("Model and tokenizer loaded successfully!")
 
 
 
 
25
  except Exception as e:
26
- print(f"Error loading model: {str(e)}")
 
 
27
 
28
- @app.route("/")
29
  def home():
30
- return request.url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- @app.route("/predict")
33
- def predict():
34
  try:
35
- # Get code from URL parameter
36
- code = request.args.get("code")
37
- if not code:
38
- return jsonify({"error": "Missing 'code' URL parameter"}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # Tokenize input
 
 
 
 
 
 
41
  inputs = tokenizer(
42
  code,
43
  truncation=True,
44
  padding='max_length',
45
- max_length=512,
46
  return_tensors='pt'
47
  )
48
-
49
- # Make prediction
 
50
  with torch.no_grad():
51
- outputs = model(**inputs)
52
-
53
- # Apply sigmoid and format score
54
- score = torch.sigmoid(outputs.logits).item()
55
-
56
- return jsonify({
57
- "score": round(score, 4),
58
- "given_code": code[:500] + "..." if len(code) > 500 else code
59
- })
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  except Exception as e:
62
- return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
 
 
 
 
63
 
64
  if __name__ == "__main__":
65
- app.run(host="0.0.0.0", port=7860)
 
1
  from flask import Flask, request, jsonify
2
  import torch
3
+ from transformers import RobertaTokenizer, RobertaForSequenceClassification
4
  import os
5
+ import gc
6
+ from functools import lru_cache
7
 
8
  app = Flask(__name__)
9
 
10
+ model = None
11
+ tokenizer = None
12
+ device = None
13
+
14
+ def setup_device():
15
+ if torch.cuda.is_available():
16
+ return torch.device('cuda')
17
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
18
+ return torch.device('mps')
19
+ else:
20
+ return torch.device('cpu')
21
+
22
+ def load_tokenizer():
23
+ try:
24
+ tokenizer = RobertaTokenizer.from_pretrained('./tokenizer_readability')
25
+ tokenizer.model_max_length = 512
26
+ return tokenizer
27
+ except Exception as e:
28
+ print(f"Error loading tokenizer: {e}")
29
+ try:
30
+ return RobertaTokenizer.from_pretrained('microsoft/codebert-base')
31
+ except Exception as e2:
32
+ print(f"Fallback tokenizer failed: {e2}")
33
+ return None
34
+
35
  def load_model():
36
+ global device
37
+ device = setup_device()
38
+ print(f"Using device: {device}")
39
+
40
+ try:
41
+ checkpoint = torch.load("codebert_readability_scorer.pth", map_location=device)
42
+
43
+ if 'config' in checkpoint:
44
+ from transformers import RobertaConfig
45
+ config = RobertaConfig.from_dict(checkpoint['config'])
46
+ model = RobertaForSequenceClassification(config)
47
+ else:
48
+ model = RobertaForSequenceClassification.from_pretrained(
49
+ 'microsoft/codebert-base',
50
+ num_labels=1
51
+ )
52
+
53
+ if 'model_state_dict' in checkpoint:
54
+ model.load_state_dict(checkpoint['model_state_dict'])
55
+ else:
56
+ model.load_state_dict(checkpoint)
57
+
58
+ model.to(device)
59
+ model.eval()
60
+
61
+ if device.type == 'cuda':
62
+ model.half()
63
+
64
+ return model
65
+
66
+ except Exception as e:
67
+ print(f"Error loading model: {e}")
68
+ return None
69
 
70
+ def cleanup_gpu_memory():
71
+ if device and device.type == 'cuda':
72
+ torch.cuda.empty_cache()
73
+ gc.collect()
 
74
 
 
75
  try:
76
+ print("Loading tokenizer...")
77
+ tokenizer = load_tokenizer()
78
+ if tokenizer:
79
+ print("Tokenizer loaded successfully!")
80
+ else:
81
+ print("Failed to load tokenizer!")
82
+
83
+ print("Loading model...")
84
  model = load_model()
85
+ if model:
86
+ print("Model loaded successfully!")
87
+ else:
88
+ print("Failed to load model!")
89
+
90
  except Exception as e:
91
+ print(f"Error during initialization: {str(e)}")
92
+ tokenizer = None
93
+ model = None
94
 
95
+ @app.route("/", methods=['GET'])
96
  def home():
97
+ return jsonify({
98
+ "message": "CodeBERT Readability Evalutor API",
99
+ "status": "Model loaded" if model is not None else "Model not loaded",
100
+ "device": str(device) if device else "unknown",
101
+ "endpoints": {
102
+ "/predict": "POST with JSON body containing 'codes' array"
103
+ }
104
+ })
105
+
106
+ @app.route("/predict", methods=['POST'])
107
+ def predict_batch():
108
+ try:
109
+ if model is None or tokenizer is None:
110
+ return jsonify({"error": "Model not loaded properly"}), 500
111
+
112
+ data = request.get_json()
113
+ if not data or 'codes' not in data:
114
+ return jsonify({"error": "Missing 'codes' field in JSON body"}), 400
115
+
116
+ codes = data['codes']
117
+ if not isinstance(codes, list) or len(codes) == 0:
118
+ return jsonify({"error": "'codes' must be a non-empty array"}), 400
119
+
120
+ if len(codes) > 100:
121
+ return jsonify({"error": "Too many codes. Maximum 100 allowed."}), 400
122
+
123
+ validated_codes = []
124
+ for i, code in enumerate(codes):
125
+ if not isinstance(code, str):
126
+ return jsonify({"error": f"Code at index {i} must be a string"}), 400
127
+ if len(code.strip()) == 0:
128
+ validated_codes.append("# empty code")
129
+ elif len(code) > 50000:
130
+ return jsonify({"error": f"Code at index {i} too long. Maximum 50000 characters."}), 400
131
+ else:
132
+ validated_codes.append(code.strip())
133
+
134
+ if len(validated_codes) == 1:
135
+ score = predict_readability_with_chunking(validated_codes[0])
136
+ cleanup_gpu_memory()
137
+ return jsonify({"results": [{"score": score}]})
138
+
139
+ batch_size = min(len(validated_codes), 16)
140
+ results = []
141
+
142
+ try:
143
+ for i in range(0, len(validated_codes), batch_size):
144
+ batch = validated_codes[i:i+batch_size]
145
+
146
+ long_codes = []
147
+ short_codes = []
148
+ long_indices = []
149
+ short_indices = []
150
+
151
+ for idx, code in enumerate(batch):
152
+ try:
153
+ tokens = tokenizer.encode(code, add_special_tokens=False, max_length=1000, truncation=True)
154
+ if len(tokens) > 450:
155
+ long_codes.append(code)
156
+ long_indices.append(i + idx)
157
+ else:
158
+ short_codes.append(code)
159
+ short_indices.append(i + idx)
160
+ except Exception as e:
161
+ print(f"Tokenization error for code {i + idx}: {e}")
162
+ short_codes.append(code)
163
+ short_indices.append(i + idx)
164
+
165
+ batch_scores = [0.0] * len(batch)
166
+
167
+ if short_codes:
168
+ try:
169
+ short_scores = predict_readability_batch(short_codes)
170
+ for j, score in enumerate(short_scores):
171
+ local_idx = short_indices[j] - i
172
+ batch_scores[local_idx] = score
173
+ except Exception as e:
174
+ print(f"Batch prediction error: {e}")
175
+ for j in range(len(short_codes)):
176
+ local_idx = short_indices[j] - i
177
+ batch_scores[local_idx] = 0.0
178
+
179
+ for j, code in enumerate(long_codes):
180
+ try:
181
+ score = predict_readability_with_chunking(code)
182
+ local_idx = long_indices[j] - i
183
+ batch_scores[local_idx] = score
184
+ except Exception as e:
185
+ print(f"Chunking prediction error: {e}")
186
+ local_idx = long_indices[j] - i
187
+ batch_scores[local_idx] = 0.0
188
+
189
+ for score in batch_scores:
190
+ results.append({"score": score})
191
+
192
+ cleanup_gpu_memory()
193
+
194
+ except Exception as e:
195
+ cleanup_gpu_memory()
196
+ raise e
197
+
198
+ return jsonify({"results": results})
199
+
200
+ except Exception as e:
201
+ cleanup_gpu_memory()
202
+ return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500
203
 
204
+ def predict_readability_with_chunking(code):
 
205
  try:
206
+ if not code or len(code.strip()) == 0:
207
+ return 0.0
208
+
209
+ tokens = tokenizer.encode(code, add_special_tokens=False, max_length=2000, truncation=True)
210
+
211
+ if len(tokens) <= 450:
212
+ return predict_readability(code)
213
+
214
+ chunk_size = 400
215
+ overlap = 50
216
+ max_score = 0.0
217
+
218
+ for start in range(0, len(tokens), chunk_size - overlap):
219
+ end = min(start + chunk_size, len(tokens))
220
+ chunk_tokens = tokens[start:end]
221
+
222
+ try:
223
+ chunk_code = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
224
+ if chunk_code.strip():
225
+ score = predict_readability(chunk_code)
226
+ max_score = max(max_score, score)
227
+ except Exception as e:
228
+ print(f"Chunk processing error: {e}")
229
+ continue
230
+
231
+ if end >= len(tokens):
232
+ break
233
+
234
+ return max_score
235
+
236
+ except Exception as e:
237
+ print(f"Chunking error: {e}")
238
+ return 0.0
239
 
240
+ def predict_readability(code):
241
+ try:
242
+ if not code or len(code.strip()) == 0:
243
+ return 0.0
244
+
245
+ dynamic_length = min(max(len(code.split()) * 2, 128), 512)
246
+
247
  inputs = tokenizer(
248
  code,
249
  truncation=True,
250
  padding='max_length',
251
+ max_length=dynamic_length,
252
  return_tensors='pt'
253
  )
254
+
255
+ inputs = {k: v.to(device) for k, v in inputs.items()}
256
+
257
  with torch.no_grad():
258
+ if device.type == 'cuda':
259
+ with torch.cuda.amp.autocast():
260
+ outputs = model(**inputs)
261
+ else:
262
+ outputs = model(**inputs)
263
+
264
+ if hasattr(outputs, 'logits'):
265
+ score = torch.sigmoid(outputs.logits).cpu().item()
266
+ else:
267
+ score = torch.sigmoid(outputs[0]).cpu().item()
268
+
269
+ return round(max(0.0, min(1.0, score)), 4)
270
+
271
+ except Exception as e:
272
+ print(f"Single prediction error: {e}")
273
+ return 0.0
274
 
275
+ def predict_readability_batch(codes):
276
+ try:
277
+ if not codes or len(codes) == 0:
278
+ return []
279
+
280
+ filtered_codes = [code if code and code.strip() else "# empty" for code in codes]
281
+
282
+ max_len = max([len(code.split()) * 2 for code in filtered_codes if code])
283
+ dynamic_length = min(max(max_len, 128), 512)
284
+
285
+ inputs = tokenizer(
286
+ filtered_codes,
287
+ truncation=True,
288
+ padding='max_length',
289
+ max_length=dynamic_length,
290
+ return_tensors='pt'
291
+ )
292
+
293
+ inputs = {k: v.to(device) for k, v in inputs.items()}
294
+
295
+ with torch.no_grad():
296
+ if device.type == 'cuda':
297
+ with torch.cuda.amp.autocast():
298
+ outputs = model(**inputs)
299
+ else:
300
+ outputs = model(**inputs)
301
+
302
+ if hasattr(outputs, 'logits'):
303
+ scores = torch.sigmoid(outputs.logits).cpu().numpy()
304
+ else:
305
+ scores = torch.sigmoid(outputs[0]).cpu().numpy()
306
+
307
+ return [round(max(0.0, min(1.0, float(score))), 4) for score in scores.flatten()]
308
+
309
  except Exception as e:
310
+ print(f"Batch prediction error: {e}")
311
+ return [0.0] * len(codes)
312
+
313
+ @app.route("/health", methods=['GET'])
314
+ def health_check():
315
+ return jsonify({
316
+ "status": "healthy",
317
+ "model_loaded": model is not None,
318
+ "tokenizer_loaded": tokenizer is not None,
319
+ "device": str(device) if device else "unknown"
320
+ })
321
 
322
  if __name__ == "__main__":
323
+ app.run(host="0.0.0.0", port=7860, debug=False, threaded=True)