Kareem94's picture
Update main.py
9374567 verified
from flask import Flask, request, jsonify
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import os
import gc
from functools import lru_cache
app = Flask(__name__)
model = None
tokenizer = None
device = None
def setup_device():
if torch.cuda.is_available():
return torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return torch.device('mps')
else:
return torch.device('cpu')
def load_tokenizer():
try:
tokenizer = RobertaTokenizer.from_pretrained('./tokenizer_readability')
tokenizer.model_max_length = 512
return tokenizer
except Exception as e:
print(f"Error loading tokenizer: {e}")
try:
return RobertaTokenizer.from_pretrained('microsoft/codebert-base')
except Exception as e2:
print(f"Fallback tokenizer failed: {e2}")
return None
def load_model():
global device
device = setup_device()
print(f"Using device: {device}")
try:
checkpoint = torch.load("codebert_readability_scorer.pth", map_location=device)
if 'config' in checkpoint:
from transformers import RobertaConfig
config = RobertaConfig.from_dict(checkpoint['config'])
model = RobertaForSequenceClassification(config)
else:
model = RobertaForSequenceClassification.from_pretrained(
'microsoft/codebert-base',
num_labels=1
)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
if device.type == 'cuda':
model.half()
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
def cleanup_gpu_memory():
if device and device.type == 'cuda':
torch.cuda.empty_cache()
gc.collect()
try:
print("Loading tokenizer...")
tokenizer = load_tokenizer()
if tokenizer:
print("Tokenizer loaded successfully!")
else:
print("Failed to load tokenizer!")
print("Loading model...")
model = load_model()
if model:
print("Model loaded successfully!")
else:
print("Failed to load model!")
except Exception as e:
print(f"Error during initialization: {str(e)}")
tokenizer = None
model = None
@app.route("/", methods=['GET'])
def home():
return jsonify({
"message": "CodeBERT Readability Evalutor API",
"status": "Model loaded" if model is not None else "Model not loaded",
"device": str(device) if device else "unknown",
"endpoints": {
"/predict": "POST with JSON body containing 'codes' array"
}
})
@app.route("/predict", methods=['POST'])
def predict_batch():
try:
if model is None or tokenizer is None:
return jsonify({"error": "Model not loaded properly"}), 500
data = request.get_json()
if not data or 'codes' not in data:
return jsonify({"error": "Missing 'codes' field in JSON body"}), 400
codes = data['codes']
if not isinstance(codes, list) or len(codes) == 0:
return jsonify({"error": "'codes' must be a non-empty array"}), 400
if len(codes) > 100:
return jsonify({"error": "Too many codes. Maximum 100 allowed."}), 400
validated_codes = []
for i, code in enumerate(codes):
if not isinstance(code, str):
return jsonify({"error": f"Code at index {i} must be a string"}), 400
if len(code.strip()) == 0:
validated_codes.append("# empty code")
elif len(code) > 50000:
return jsonify({"error": f"Code at index {i} too long. Maximum 50000 characters."}), 400
else:
validated_codes.append(code.strip())
if len(validated_codes) == 1:
score = predict_readability_with_chunking(validated_codes[0])
cleanup_gpu_memory()
return jsonify({"results": [{"score": score}]})
batch_size = min(len(validated_codes), 16)
results = []
try:
for i in range(0, len(validated_codes), batch_size):
batch = validated_codes[i:i+batch_size]
long_codes = []
short_codes = []
long_indices = []
short_indices = []
for idx, code in enumerate(batch):
try:
tokens = tokenizer.encode(code, add_special_tokens=False, max_length=1000, truncation=True)
if len(tokens) > 450:
long_codes.append(code)
long_indices.append(i + idx)
else:
short_codes.append(code)
short_indices.append(i + idx)
except Exception as e:
print(f"Tokenization error for code {i + idx}: {e}")
short_codes.append(code)
short_indices.append(i + idx)
batch_scores = [0.0] * len(batch)
if short_codes:
try:
short_scores = predict_readability_batch(short_codes)
for j, score in enumerate(short_scores):
local_idx = short_indices[j] - i
batch_scores[local_idx] = score
except Exception as e:
print(f"Batch prediction error: {e}")
for j in range(len(short_codes)):
local_idx = short_indices[j] - i
batch_scores[local_idx] = 0.0
for j, code in enumerate(long_codes):
try:
score = predict_readability_with_chunking(code)
local_idx = long_indices[j] - i
batch_scores[local_idx] = score
except Exception as e:
print(f"Chunking prediction error: {e}")
local_idx = long_indices[j] - i
batch_scores[local_idx] = 0.0
for score in batch_scores:
results.append({"score": score})
cleanup_gpu_memory()
except Exception as e:
cleanup_gpu_memory()
raise e
return jsonify({"results": results})
except Exception as e:
cleanup_gpu_memory()
return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500
def predict_readability_with_chunking(code):
try:
if not code or len(code.strip()) == 0:
return 0.0
tokens = tokenizer.encode(code, add_special_tokens=False, max_length=2000, truncation=True)
if len(tokens) <= 450:
return predict_readability(code)
chunk_size = 400
overlap = 50
max_score = 0.0
for start in range(0, len(tokens), chunk_size - overlap):
end = min(start + chunk_size, len(tokens))
chunk_tokens = tokens[start:end]
try:
chunk_code = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
if chunk_code.strip():
score = predict_readability(chunk_code)
max_score = max(max_score, score)
except Exception as e:
print(f"Chunk processing error: {e}")
continue
if end >= len(tokens):
break
return max_score
except Exception as e:
print(f"Chunking error: {e}")
return 0.0
def predict_readability(code):
try:
if not code or len(code.strip()) == 0:
return 0.0
dynamic_length = min(max(len(code.split()) * 2, 128), 512)
inputs = tokenizer(
code,
truncation=True,
padding='max_length',
max_length=dynamic_length,
return_tensors='pt'
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
if device.type == 'cuda':
with torch.cuda.amp.autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs)
amplified_logits = 3.0 * outputs.logits
score = torch.sigmoid(amplified_logits).cpu().item()
return round(max(0.0, min(1.0, score)), 4)
except Exception as e:
print(f"Single prediction error: {e}")
return 0.0
def predict_readability_batch(codes):
try:
if not codes or len(codes) == 0:
return []
filtered_codes = [code if code and code.strip() else "# empty" for code in codes]
max_len = max([len(code.split()) * 2 for code in filtered_codes if code])
dynamic_length = min(max(max_len, 128), 512)
inputs = tokenizer(
filtered_codes,
truncation=True,
padding='max_length',
max_length=dynamic_length,
return_tensors='pt'
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
if device.type == 'cuda':
with torch.cuda.amp.autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs)
amplified_logits = 3.0 * outputs.logits
scores = torch.sigmoid(amplified_logits).cpu().numpy()
return [round(max(0.0, min(1.0, float(score))), 4) for score in scores.flatten()]
except Exception as e:
print(f"Batch prediction error: {e}")
return [0.0] * len(codes)
@app.route("/health", methods=['GET'])
def health_check():
return jsonify({
"status": "healthy",
"model_loaded": model is not None,
"tokenizer_loaded": tokenizer is not None,
"device": str(device) if device else "unknown"
})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=False, threaded=True)