File size: 1,810 Bytes
c0be552
 
4de6b67
c0be552
 
4de6b67
c0be552
4de6b67
 
 
 
 
6f4cff5
4de6b67
 
 
 
 
 
 
 
6f4cff5
4de6b67
 
 
 
c0be552
 
 
 
 
6ab96fe
c0be552
133584b
4d61b06
 
 
 
6f4cff5
4de6b67
133584b
4de6b67
133584b
 
4de6b67
 
133584b
6f4cff5
4de6b67
133584b
 
6f4cff5
4de6b67
 
6f4cff5
4de6b67
368d5f1
 
4de6b67
6f4cff5
133584b
 
 
c0be552
4de6b67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from flask import Flask, request, jsonify
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification, RobertaConfig
import os

app = Flask(__name__)

# Load model and tokenizer
def load_model():
    # Load saved config and weights
    checkpoint = torch.load("codebert_readability_scorer.pth", map_location=torch.device('cpu'))
    config = RobertaConfig.from_dict(checkpoint['config'])

    # Initialize model with loaded config
    model = RobertaForSequenceClassification(config)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model

# Load components
try:
    tokenizer = RobertaTokenizer.from_pretrained("./tokenizer_readability")
    model = load_model()
    print("Model and tokenizer loaded successfully!")
except Exception as e:
    print(f"Error loading model: {str(e)}")

@app.route("/")
def home():
    return request.url

@app.route("/predict")
def predict():
    try:
        # Get code from URL parameter
        code = request.args.get("code")
        if not code:
            return jsonify({"error": "Missing 'code' URL parameter"}), 400

        # Tokenize input
        inputs = tokenizer(
            code,
            truncation=True,
            padding='max_length',
            max_length=512,
            return_tensors='pt'
        )

        # Make prediction
        with torch.no_grad():
            outputs = model(**inputs)

        # Apply sigmoid and format score
        score = torch.sigmoid(outputs.logits).item()

        return jsonify({
            "score": round(score, 4),
            "given_code": code[:500] + "..." if len(code) > 500 else code
        })

    except Exception as e:
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)