Kareem94 commited on
Commit
9681739
·
verified ·
1 Parent(s): 0e1e673

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +4 -4
main.py CHANGED
@@ -261,9 +261,9 @@ def predict_vulnerability(code):
261
  else:
262
  outputs = model(**inputs)
263
 
264
- amplified_logits = 3.0 * outputs.logits
265
  score = torch.sigmoid(amplified_logits).cpu().item()
266
- return round(max(0.0, min(1.0, score)), 4)
267
 
268
  except Exception as e:
269
  print(f"Single prediction error: {e}")
@@ -296,10 +296,10 @@ def predict_vulnerability_batch(codes):
296
  else:
297
  outputs = model(**inputs)
298
 
299
- amplified_logits = 3.0 * outputs.logits
300
  scores = torch.sigmoid(amplified_logits).cpu().numpy()
301
 
302
- return [round(max(0.0, min(1.0, float(score))), 4) for score in scores.flatten()]
303
 
304
  except Exception as e:
305
  print(f"Batch prediction error: {e}")
 
261
  else:
262
  outputs = model(**inputs)
263
 
264
+ amplified_logits = 5.0 * outputs.logits
265
  score = torch.sigmoid(amplified_logits).cpu().item()
266
+ return 1.0 - round(max(0.0, min(1.0, score)), 4)
267
 
268
  except Exception as e:
269
  print(f"Single prediction error: {e}")
 
296
  else:
297
  outputs = model(**inputs)
298
 
299
+ amplified_logits = 5.0 * outputs.logits
300
  scores = torch.sigmoid(amplified_logits).cpu().numpy()
301
 
302
+ return 1.0 - [round(max(0.0, min(1.0, float(score))), 4) for score in scores.flatten()]
303
 
304
  except Exception as e:
305
  print(f"Batch prediction error: {e}")