asmashayea commited on
Commit
681d8ae
·
1 Parent(s): f4612ba
Files changed (1) hide show
  1. inference.py +3 -1
inference.py CHANGED
@@ -68,7 +68,9 @@ def infer_araberta(text):
68
  predicted_ids = outputs['logits'][0].cpu().tolist()
69
 
70
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
71
- predicted_labels = [model.config.id2label.get(p, 'O') for p in predicted_ids]
 
 
72
 
73
  clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
74
  clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]
 
68
  predicted_ids = outputs['logits'][0].cpu().tolist()
69
 
70
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
71
+ # predicted_labels = [model.config.id2label.get(p, 'O') for p in predicted_ids]
72
+ predicted_labels = [model.id2label.get(p, 'O') for p in predicted_ids]
73
+
74
 
75
  clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
76
  clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]