ElPremOoO commited on
Commit
133584b
·
verified ·
1 Parent(s): 7aa3527

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +32 -39
main.py CHANGED
@@ -4,12 +4,6 @@ from transformers import RobertaTokenizer
4
  import os
5
  from transformers import RobertaForSequenceClassification
6
  import torch.serialization
7
- import torch
8
- from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments
9
- from torch.utils.data import Dataset
10
- import pandas as pd
11
- from sklearn.model_selection import train_test_split
12
- import numpy as np
13
  # Initialize Flask app
14
  app = Flask(__name__)
15
 
@@ -31,39 +25,38 @@ def home():
31
  # @app.route("/predict", methods=["POST"])
32
  @app.route("/predict")
33
  def predict():
34
- print("Received code:", request.get_json()["code"])
35
- code = request.get_json()["code"]
36
- # Load saved weights and config
37
- checkpoint = torch.load("codebert_readability_scorer.pth")
38
- config = RobertaConfig.from_dict(checkpoint['config'])
39
-
40
- # Rebuild the model with correct architecture
41
- model = RobertaForSequenceClassification(config)
42
- model.load_state_dict(checkpoint['model_state_dict'])
43
- model.eval()
44
-
45
- # Load tokenizer
46
- tokenizer = RobertaTokenizer.from_pretrained('./tokenizer')
47
-
48
- # Prepare input
49
- inputs = tokenizer(
50
- code,
51
- truncation=True,
52
- padding='max_length',
53
- max_length=512,
54
- return_tensors='pt'
55
- )
56
-
57
- # Make prediction
58
- with torch.no_grad():
59
- outputs = model(**inputs)
60
-
61
- score = torch.sigmoid(outputs.logits).item()
62
- return score
63
-
64
-
65
-
66
-
67
  # Run the Flask app
68
  if __name__ == "__main__":
69
  app.run(host="0.0.0.0", port=7860)
 
4
  import os
5
  from transformers import RobertaForSequenceClassification
6
  import torch.serialization
 
 
 
 
 
 
7
  # Initialize Flask app
8
  app = Flask(__name__)
9
 
 
25
  # @app.route("/predict", methods=["POST"])
26
  @app.route("/predict")
27
  def predict():
28
+ try:
29
+ # Debugging: print input code to check if the request is received correctly
30
+ print("Received code:", request.get_json()["code"])
31
+
32
+ data = request.get_json()
33
+ if "code" not in data:
34
+ return jsonify({"error": "Missing 'code' parameter"}), 400
35
+
36
+ code_input = data["code"]
37
+
38
+ # Tokenize the input code using the CodeBERT tokenizer
39
+ inputs = tokenizer(
40
+ code_input,
41
+ return_tensors='pt',
42
+ truncation=True,
43
+ padding='max_length',
44
+ max_length=512
45
+ )
46
+
47
+ # Make prediction using the model
48
+ with torch.no_grad():
49
+ outputs = model(**inputs)
50
+ prediction = outputs.logits.squeeze().item() # Extract the predicted score (single float)
51
+
52
+ print(f"Predicted score: {prediction}") # Debugging: Print prediction
53
+
54
+ return jsonify({"predicted_score": prediction})
55
+
56
+ except Exception as e:
57
+ return jsonify({"error": str(e)}), 500
58
+
59
+
 
60
  # Run the Flask app
61
  if __name__ == "__main__":
62
  app.run(host="0.0.0.0", port=7860)