ElPremOoO commited on
Commit
1b77b3b
·
verified ·
1 Parent(s): dbef60c

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +65 -0
main.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ import torch
3
+ from transformers import RobertaTokenizer, RobertaForSequenceClassification, RobertaConfig
4
+ import os
5
+
6
+ app = Flask(__name__)
7
+
8
+ # Load model and tokenizer
9
+ def load_model():
10
+ # Load saved config and weights
11
+ checkpoint = torch.load("codebert_readability_scorer.pth", map_location=torch.device('cpu'))
12
+ config = RobertaConfig.from_dict(checkpoint['config'])
13
+
14
+ # Initialize model with loaded config
15
+ model = RobertaForSequenceClassification(config)
16
+ model.load_state_dict(checkpoint['model_state_dict'])
17
+ model.eval()
18
+ return model
19
+
20
+ # Load components
21
+ try:
22
+ tokenizer = RobertaTokenizer.from_pretrained("./tokenizer_readability")
23
+ model = load_model()
24
+ print("Model and tokenizer loaded successfully!")
25
+ except Exception as e:
26
+ print(f"Error loading model: {str(e)}")
27
+
28
+ @app.route("/")
29
+ def home():
30
+ return request.url
31
+
32
+ @app.route("/predict")
33
+ def predict():
34
+ try:
35
+ # Get code from URL parameter
36
+ code = request.args.get("code")
37
+ if not code:
38
+ return jsonify({"error": "Missing 'code' URL parameter"}), 400
39
+
40
+ # Tokenize input
41
+ inputs = tokenizer(
42
+ code,
43
+ truncation=True,
44
+ padding='max_length',
45
+ max_length=512,
46
+ return_tensors='pt'
47
+ )
48
+
49
+ # Make prediction
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+
53
+ # Apply sigmoid and format score
54
+ score = torch.sigmoid(outputs.logits).item()
55
+
56
+ return jsonify({
57
+ "readability_score": round(score, 4),
58
+ "processed_code": code[:500] + "..." if len(code) > 500 else code
59
+ })
60
+
61
+ except Exception as e:
62
+ return jsonify({"error": str(e)}), 500
63
+
64
+ if __name__ == "__main__":
65
+ app.run(host="0.0.0.0", port=7860)