Spaces:
Sleeping
Sleeping
File size: 2,068 Bytes
c0be552 4de6b67 c0be552 4de6b67 c0be552 6f4cff5 4de6b67 6f4cff5 4de6b67 6f4cff5 4de6b67 6f4cff5 4de6b67 c0be552 6f4cff5 c0be552 6f4cff5 6ab96fe c0be552 133584b 4de6b67 b3d0c67 e29c9f9 6f4cff5 4de6b67 6f4cff5 4de6b67 133584b 4de6b67 133584b 4de6b67 133584b 6f4cff5 4de6b67 133584b 6f4cff5 4de6b67 6f4cff5 4de6b67 6f4cff5 133584b 6f4cff5 c0be552 4de6b67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from flask import Flask, request, jsonify
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification, RobertaConfig
import os
app = Flask(__name__)
# Load model and tokenizer
def load_model():
# Load saved config and weights
checkpoint = torch.load("codebert_readability_scorer.pth", map_location=torch.device('cpu'))
config = RobertaConfig.from_dict(checkpoint['config'])
# Initialize model with loaded config
model = RobertaForSequenceClassification(config)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
# Load components
try:
tokenizer = RobertaTokenizer.from_pretrained("./tokenizer_readability")
model = load_model()
print("Model and tokenizer loaded successfully!")
except Exception as e:
print(f"Error loading model: {str(e)}")
@app.route("/")
def home():
return request.url
@app.route("/predict")
def predict():
try:
# Get code from request
# data = request.get_json()
# if "code" not in data:
# return jsonify({"error": "Missing 'code' parameter"}), 400
data = request.get_json(force=True, silent=True)
if not data or "code" not in data:
return jsonify({"error": f"Missing 'code' parameter. data: {data}"}), 400
code = data["code"]
# Tokenize input
inputs = tokenizer(
code,
truncation=True,
padding='max_length',
max_length=512,
return_tensors='pt'
)
print("here")
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
# Apply sigmoid and format score
score = torch.sigmoid(outputs.logits).item()
return jsonify({
"readability_score": round(score, 4),
"processed_code": code[:500] + "..." if len(code) > 500 else code
})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860) |