File size: 1,730 Bytes
c0be552
 
4de6b67
c0be552
 
4de6b67
c0be552
4de6b67
 
 
 
 
6f4cff5
4de6b67
 
 
 
 
 
 
 
6f4cff5
4de6b67
 
 
 
c0be552
 
 
 
 
5783855
c0be552
133584b
5783855
 
 
 
 
6f4cff5
4de6b67
133584b
4de6b67
133584b
 
4de6b67
 
133584b
6f4cff5
4de6b67
133584b
 
6f4cff5
4de6b67
 
6f4cff5
5783855
6f4cff5
133584b
 
 
c0be552
4de6b67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from flask import Flask, request, jsonify
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification, RobertaConfig
import os

app = Flask(__name__)

# Load model and tokenizer
def load_model():
    # Load saved config and weights
    checkpoint = torch.load("codebert_readability_scorer.pth", map_location=torch.device('cpu'))
    config = RobertaConfig.from_dict(checkpoint['config'])

    # Initialize model with loaded config
    model = RobertaForSequenceClassification(config)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model

# Load components
try:
    tokenizer = RobertaTokenizer.from_pretrained("./tokenizer_readability")
    model = load_model()
    print("Model and tokenizer loaded successfully!")
except Exception as e:
    print(f"Error loading model: {str(e)}")

@app.route("/")
def home():
    return request.url

@app.route("/predict", methods=["POST"])
def predict():
    try:
        data = request.get_json()
        if not data or "code" not in data:
            return jsonify({"error": "Missing 'code' in request body"}), 400

        code = data["code"]

        # Tokenize input
        inputs = tokenizer(
            code,
            truncation=True,
            padding='max_length',
            max_length=512,
            return_tensors='pt'
        )

        # Make prediction
        with torch.no_grad():
            outputs = model(**inputs)

        # Apply sigmoid and format score
        score = torch.sigmoid(outputs.logits).item()

        return jsonify(round(score, 4)) 

    except Exception as e:
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)