Spaces:
Running
Running
from flask import Flask, request, jsonify | |
import torch | |
from transformers import RobertaTokenizer, RobertaForSequenceClassification | |
import os | |
import gc | |
from functools import lru_cache | |
app = Flask(__name__) | |
model = None | |
tokenizer = None | |
device = None | |
def setup_device(): | |
if torch.cuda.is_available(): | |
return torch.device('cuda') | |
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
return torch.device('mps') | |
else: | |
return torch.device('cpu') | |
def load_tokenizer(): | |
try: | |
tokenizer = RobertaTokenizer.from_pretrained('./tokenizer_readability') | |
tokenizer.model_max_length = 512 | |
return tokenizer | |
except Exception as e: | |
print(f"Error loading tokenizer: {e}") | |
try: | |
return RobertaTokenizer.from_pretrained('microsoft/codebert-base') | |
except Exception as e2: | |
print(f"Fallback tokenizer failed: {e2}") | |
return None | |
def load_model(): | |
global device | |
device = setup_device() | |
print(f"Using device: {device}") | |
try: | |
checkpoint = torch.load("codebert_readability_scorer.pth", map_location=device) | |
if 'config' in checkpoint: | |
from transformers import RobertaConfig | |
config = RobertaConfig.from_dict(checkpoint['config']) | |
model = RobertaForSequenceClassification(config) | |
else: | |
model = RobertaForSequenceClassification.from_pretrained( | |
'microsoft/codebert-base', | |
num_labels=1 | |
) | |
if 'model_state_dict' in checkpoint: | |
model.load_state_dict(checkpoint['model_state_dict']) | |
else: | |
model.load_state_dict(checkpoint) | |
model.to(device) | |
model.eval() | |
if device.type == 'cuda': | |
model.half() | |
return model | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
return None | |
def cleanup_gpu_memory(): | |
if device and device.type == 'cuda': | |
torch.cuda.empty_cache() | |
gc.collect() | |
try: | |
print("Loading tokenizer...") | |
tokenizer = load_tokenizer() | |
if tokenizer: | |
print("Tokenizer loaded successfully!") | |
else: | |
print("Failed to load tokenizer!") | |
print("Loading model...") | |
model = load_model() | |
if model: | |
print("Model loaded successfully!") | |
else: | |
print("Failed to load model!") | |
except Exception as e: | |
print(f"Error during initialization: {str(e)}") | |
tokenizer = None | |
model = None | |
def home(): | |
return jsonify({ | |
"message": "CodeBERT Readability Evalutor API", | |
"status": "Model loaded" if model is not None else "Model not loaded", | |
"device": str(device) if device else "unknown", | |
"endpoints": { | |
"/predict": "POST with JSON body containing 'codes' array" | |
} | |
}) | |
def predict_batch(): | |
try: | |
if model is None or tokenizer is None: | |
return jsonify({"error": "Model not loaded properly"}), 500 | |
data = request.get_json() | |
if not data or 'codes' not in data: | |
return jsonify({"error": "Missing 'codes' field in JSON body"}), 400 | |
codes = data['codes'] | |
if not isinstance(codes, list) or len(codes) == 0: | |
return jsonify({"error": "'codes' must be a non-empty array"}), 400 | |
if len(codes) > 100: | |
return jsonify({"error": "Too many codes. Maximum 100 allowed."}), 400 | |
validated_codes = [] | |
for i, code in enumerate(codes): | |
if not isinstance(code, str): | |
return jsonify({"error": f"Code at index {i} must be a string"}), 400 | |
if len(code.strip()) == 0: | |
validated_codes.append("# empty code") | |
elif len(code) > 50000: | |
return jsonify({"error": f"Code at index {i} too long. Maximum 50000 characters."}), 400 | |
else: | |
validated_codes.append(code.strip()) | |
if len(validated_codes) == 1: | |
score = predict_readability_with_chunking(validated_codes[0]) | |
cleanup_gpu_memory() | |
return jsonify({"results": [{"score": score}]}) | |
batch_size = min(len(validated_codes), 16) | |
results = [] | |
try: | |
for i in range(0, len(validated_codes), batch_size): | |
batch = validated_codes[i:i+batch_size] | |
long_codes = [] | |
short_codes = [] | |
long_indices = [] | |
short_indices = [] | |
for idx, code in enumerate(batch): | |
try: | |
tokens = tokenizer.encode(code, add_special_tokens=False, max_length=1000, truncation=True) | |
if len(tokens) > 450: | |
long_codes.append(code) | |
long_indices.append(i + idx) | |
else: | |
short_codes.append(code) | |
short_indices.append(i + idx) | |
except Exception as e: | |
print(f"Tokenization error for code {i + idx}: {e}") | |
short_codes.append(code) | |
short_indices.append(i + idx) | |
batch_scores = [0.0] * len(batch) | |
if short_codes: | |
try: | |
short_scores = predict_readability_batch(short_codes) | |
for j, score in enumerate(short_scores): | |
local_idx = short_indices[j] - i | |
batch_scores[local_idx] = score | |
except Exception as e: | |
print(f"Batch prediction error: {e}") | |
for j in range(len(short_codes)): | |
local_idx = short_indices[j] - i | |
batch_scores[local_idx] = 0.0 | |
for j, code in enumerate(long_codes): | |
try: | |
score = predict_readability_with_chunking(code) | |
local_idx = long_indices[j] - i | |
batch_scores[local_idx] = score | |
except Exception as e: | |
print(f"Chunking prediction error: {e}") | |
local_idx = long_indices[j] - i | |
batch_scores[local_idx] = 0.0 | |
for score in batch_scores: | |
results.append({"score": score}) | |
cleanup_gpu_memory() | |
except Exception as e: | |
cleanup_gpu_memory() | |
raise e | |
return jsonify({"results": results}) | |
except Exception as e: | |
cleanup_gpu_memory() | |
return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500 | |
def predict_readability_with_chunking(code): | |
try: | |
if not code or len(code.strip()) == 0: | |
return 0.0 | |
tokens = tokenizer.encode(code, add_special_tokens=False, max_length=2000, truncation=True) | |
if len(tokens) <= 450: | |
return predict_readability(code) | |
chunk_size = 400 | |
overlap = 50 | |
max_score = 0.0 | |
for start in range(0, len(tokens), chunk_size - overlap): | |
end = min(start + chunk_size, len(tokens)) | |
chunk_tokens = tokens[start:end] | |
try: | |
chunk_code = tokenizer.decode(chunk_tokens, skip_special_tokens=True) | |
if chunk_code.strip(): | |
score = predict_readability(chunk_code) | |
max_score = max(max_score, score) | |
except Exception as e: | |
print(f"Chunk processing error: {e}") | |
continue | |
if end >= len(tokens): | |
break | |
return max_score | |
except Exception as e: | |
print(f"Chunking error: {e}") | |
return 0.0 | |
def predict_readability(code): | |
try: | |
if not code or len(code.strip()) == 0: | |
return 0.0 | |
dynamic_length = min(max(len(code.split()) * 2, 128), 512) | |
inputs = tokenizer( | |
code, | |
truncation=True, | |
padding='max_length', | |
max_length=dynamic_length, | |
return_tensors='pt' | |
) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
if device.type == 'cuda': | |
with torch.cuda.amp.autocast(): | |
outputs = model(**inputs) | |
else: | |
outputs = model(**inputs) | |
amplified_logits = 3.0 * outputs.logits | |
score = torch.sigmoid(amplified_logits).cpu().item() | |
return round(max(0.0, min(1.0, score)), 4) | |
except Exception as e: | |
print(f"Single prediction error: {e}") | |
return 0.0 | |
def predict_readability_batch(codes): | |
try: | |
if not codes or len(codes) == 0: | |
return [] | |
filtered_codes = [code if code and code.strip() else "# empty" for code in codes] | |
max_len = max([len(code.split()) * 2 for code in filtered_codes if code]) | |
dynamic_length = min(max(max_len, 128), 512) | |
inputs = tokenizer( | |
filtered_codes, | |
truncation=True, | |
padding='max_length', | |
max_length=dynamic_length, | |
return_tensors='pt' | |
) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
if device.type == 'cuda': | |
with torch.cuda.amp.autocast(): | |
outputs = model(**inputs) | |
else: | |
outputs = model(**inputs) | |
amplified_logits = 3.0 * outputs.logits | |
scores = torch.sigmoid(amplified_logits).cpu().numpy() | |
return [round(max(0.0, min(1.0, float(score))), 4) for score in scores.flatten()] | |
except Exception as e: | |
print(f"Batch prediction error: {e}") | |
return [0.0] * len(codes) | |
def health_check(): | |
return jsonify({ | |
"status": "healthy", | |
"model_loaded": model is not None, | |
"tokenizer_loaded": tokenizer is not None, | |
"device": str(device) if device else "unknown" | |
}) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860, debug=False, threaded=True) |