ElPremOoO commited on
Commit
6f4cff5
·
verified ·
1 Parent(s): e29c9f9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -8
main.py CHANGED
@@ -5,30 +5,34 @@ import os
5
 
6
  app = Flask(__name__)
7
 
 
8
  # Load model and tokenizer
9
  def load_model():
10
  # Load saved config and weights
11
  checkpoint = torch.load("codebert_readability_scorer.pth", map_location=torch.device('cpu'))
12
  config = RobertaConfig.from_dict(checkpoint['config'])
13
-
14
  # Initialize model with loaded config
15
  model = RobertaForSequenceClassification(config)
16
  model.load_state_dict(checkpoint['model_state_dict'])
17
  model.eval()
18
  return model
19
 
 
20
  # Load components
21
  try:
22
- tokenizer = RobertaTokenizer.from_pretrained("./tokenizer")
23
  model = load_model()
24
  print("Model and tokenizer loaded successfully!")
25
  except Exception as e:
26
  print(f"Error loading model: {str(e)}")
27
 
 
28
  @app.route("/")
29
  def home():
30
  return request.url
31
 
 
32
  @app.route("/predict")
33
  def predict():
34
  try:
@@ -39,9 +43,9 @@ def predict():
39
  data = request.get_json(force=True, silent=True)
40
  if not data or "code" not in data:
41
  return jsonify({"error": f"Missing 'code' parameter. data: {data}"}), 400
42
-
43
  code = data["code"]
44
-
45
  # Tokenize input
46
  inputs = tokenizer(
47
  code,
@@ -50,21 +54,23 @@ def predict():
50
  max_length=512,
51
  return_tensors='pt'
52
  )
53
-
 
54
  # Make prediction
55
  with torch.no_grad():
56
  outputs = model(**inputs)
57
-
58
  # Apply sigmoid and format score
59
  score = torch.sigmoid(outputs.logits).item()
60
-
61
  return jsonify({
62
  "readability_score": round(score, 4),
63
  "processed_code": code[:500] + "..." if len(code) > 500 else code
64
  })
65
-
66
  except Exception as e:
67
  return jsonify({"error": str(e)}), 500
68
 
 
69
  if __name__ == "__main__":
70
  app.run(host="0.0.0.0", port=7860)
 
5
 
6
  app = Flask(__name__)
7
 
8
+
9
  # Load model and tokenizer
10
  def load_model():
11
  # Load saved config and weights
12
  checkpoint = torch.load("codebert_readability_scorer.pth", map_location=torch.device('cpu'))
13
  config = RobertaConfig.from_dict(checkpoint['config'])
14
+
15
  # Initialize model with loaded config
16
  model = RobertaForSequenceClassification(config)
17
  model.load_state_dict(checkpoint['model_state_dict'])
18
  model.eval()
19
  return model
20
 
21
+
22
  # Load components
23
  try:
24
+ tokenizer = RobertaTokenizer.from_pretrained("./tokenizer_readability")
25
  model = load_model()
26
  print("Model and tokenizer loaded successfully!")
27
  except Exception as e:
28
  print(f"Error loading model: {str(e)}")
29
 
30
+
31
  @app.route("/")
32
  def home():
33
  return request.url
34
 
35
+
36
  @app.route("/predict")
37
  def predict():
38
  try:
 
43
  data = request.get_json(force=True, silent=True)
44
  if not data or "code" not in data:
45
  return jsonify({"error": f"Missing 'code' parameter. data: {data}"}), 400
46
+
47
  code = data["code"]
48
+
49
  # Tokenize input
50
  inputs = tokenizer(
51
  code,
 
54
  max_length=512,
55
  return_tensors='pt'
56
  )
57
+ print("here")
58
+
59
  # Make prediction
60
  with torch.no_grad():
61
  outputs = model(**inputs)
62
+
63
  # Apply sigmoid and format score
64
  score = torch.sigmoid(outputs.logits).item()
65
+
66
  return jsonify({
67
  "readability_score": round(score, 4),
68
  "processed_code": code[:500] + "..." if len(code) > 500 else code
69
  })
70
+
71
  except Exception as e:
72
  return jsonify({"error": str(e)}), 500
73
 
74
+
75
  if __name__ == "__main__":
76
  app.run(host="0.0.0.0", port=7860)