Kareem94 commited on
Commit
94752fb
·
verified ·
1 Parent(s): 621fa3a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +5 -10
main.py CHANGED
@@ -261,11 +261,8 @@ def predict_readability(code):
261
  else:
262
  outputs = model(**inputs)
263
 
264
- if hasattr(outputs, 'logits'):
265
- score = torch.sigmoid(outputs.logits).cpu().item()
266
- else:
267
- score = torch.sigmoid(outputs[0]).cpu().item()
268
-
269
  return round(max(0.0, min(1.0, score)), 4)
270
 
271
  except Exception as e:
@@ -299,11 +296,9 @@ def predict_readability_batch(codes):
299
  else:
300
  outputs = model(**inputs)
301
 
302
- if hasattr(outputs, 'logits'):
303
- scores = torch.sigmoid(outputs.logits).cpu().numpy()
304
- else:
305
- scores = torch.sigmoid(outputs[0]).cpu().numpy()
306
-
307
  return [round(max(0.0, min(1.0, float(score))), 4) for score in scores.flatten()]
308
 
309
  except Exception as e:
 
261
  else:
262
  outputs = model(**inputs)
263
 
264
+ amplified_logits = 4.0 * outputs.logits
265
+ score = torch.sigmoid(amplified_logits).cpu().item()
 
 
 
266
  return round(max(0.0, min(1.0, score)), 4)
267
 
268
  except Exception as e:
 
296
  else:
297
  outputs = model(**inputs)
298
 
299
+ amplified_logits = 4.0 * outputs.logits
300
+ scores = torch.sigmoid(amplified_logits).cpu().numpy()
301
+
 
 
302
  return [round(max(0.0, min(1.0, float(score))), 4) for score in scores.flatten()]
303
 
304
  except Exception as e: