kovacsvi commited on
Commit
3a6eb20
·
1 Parent(s): 0a394ee

config jit issue

Browse files
interfaces/emotion.py CHANGED
@@ -7,7 +7,7 @@ from transformers import AutoModelForSequenceClassification
7
  from transformers import AutoTokenizer
8
  from huggingface_hub import HfApi
9
 
10
- from label_dicts import MANIFESTO_LABEL_NAMES
11
 
12
  from .utils import is_disk_full, release_model
13
 
@@ -54,7 +54,7 @@ def predict(text, model_id, tokenizer_id):
54
  release_model(model, model_id)
55
 
56
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
57
- output_pred = {model.config.id2label[i]: probs[i] for i in np.argsort(probs)[::-1]}
58
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
59
  return output_pred, output_info
60
 
 
7
  from transformers import AutoTokenizer
8
  from huggingface_hub import HfApi
9
 
10
+ from label_dicts import EMOTION6_LABEL_NAMES
11
 
12
  from .utils import is_disk_full, release_model
13
 
 
54
  release_model(model, model_id)
55
 
56
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
57
+ output_pred = {EMOTION6_LABEL_NAMES[i]: probs[i] for i in np.argsort(probs)[::-1]}
58
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
59
  return output_pred, output_info
60
 
interfaces/manifesto.py CHANGED
@@ -7,7 +7,7 @@ from transformers import AutoModelForSequenceClassification
7
  from transformers import AutoTokenizer
8
  from huggingface_hub import HfApi
9
 
10
- from label_dicts import MANIFESTO_LABEL_NAMES
11
 
12
  from .utils import is_disk_full, release_model
13
 
@@ -53,7 +53,7 @@ def predict(text, model_id, tokenizer_id):
53
  release_model(model, model_id)
54
 
55
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
56
- output_pred = {f"[{model.config.id2label[i]}] {MANIFESTO_LABEL_NAMES[int(model.config.id2label[i])]}": probs[i] for i in np.argsort(probs)[::-1]}
57
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
58
  return output_pred, output_info
59
 
 
7
  from transformers import AutoTokenizer
8
  from huggingface_hub import HfApi
9
 
10
+ from label_dicts import MANIFESTO_LABEL_NAMES, MANIFESTO_NUM_DICT
11
 
12
  from .utils import is_disk_full, release_model
13
 
 
53
  release_model(model, model_id)
54
 
55
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
56
+ output_pred = {f"[{model.config.id2label[i]}] {MANIFESTO_LABEL_NAMES[int(MANIFESTO_NUM_DICT[i])]}": probs[i] for i in np.argsort(probs)[::-1]}
57
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
58
  return output_pred, output_info
59
 
label_dicts.py CHANGED
@@ -549,6 +549,66 @@ CAP_MIN_MEDIA_NUM_DICT = {0: 100,
549
  217: 30,
550
  218: 31,
551
  219: 99}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
  CAP_MIN_LABEL_NAMES = {
554
  # 1. Macroeconomics
@@ -905,6 +965,16 @@ ONTOLISST_LABEL_NAMES = {
905
  14: 'Administration',
906
  15: 'COVID19'
907
  }
 
 
 
 
 
 
 
 
 
 
908
  EMOTION9_LABEL_NAMES = {
909
  0: "Anger",
910
  1: "Fear",
 
549
  217: 30,
550
  218: 31,
551
  219: 99}
552
+
553
+ MANIFESTO_NUM_DICT = {
554
+ 0: 0,
555
+ 1: 101,
556
+ 2: 102,
557
+ 3: 103,
558
+ 4: 104,
559
+ 5: 105,
560
+ 6: 106,
561
+ 7: 107,
562
+ 8: 108,
563
+ 9: 109,
564
+ 10: 110,
565
+ 11: 201,
566
+ 12: 202,
567
+ 13: 203,
568
+ 14: 204,
569
+ 15: 301,
570
+ 16: 302,
571
+ 17: 303,
572
+ 18: 304,
573
+ 19: 305,
574
+ 20: 401,
575
+ 21: 402,
576
+ 22: 403,
577
+ 23: 404,
578
+ 24: 405,
579
+ 25: 406,
580
+ 26: 407,
581
+ 27: 408,
582
+ 28: 409,
583
+ 29: 410,
584
+ 30: 411,
585
+ 31: 412,
586
+ 32: 413,
587
+ 33: 414,
588
+ 34: 415,
589
+ 35: 416,
590
+ 36: 501,
591
+ 37: 502,
592
+ 38: 503,
593
+ 39: 504,
594
+ 40: 505,
595
+ 41: 506,
596
+ 42: 507,
597
+ 43: 601,
598
+ 44: 602,
599
+ 45: 603,
600
+ 46: 604,
601
+ 47: 605,
602
+ 48: 606,
603
+ 49: 607,
604
+ 50: 608,
605
+ 51: 701,
606
+ 52: 702,
607
+ 53: 703,
608
+ 54: 704,
609
+ 55: 705,
610
+ 56: 706
611
+ }
612
 
613
  CAP_MIN_LABEL_NAMES = {
614
  # 1. Macroeconomics
 
965
  14: 'Administration',
966
  15: 'COVID19'
967
  }
968
+
969
+ EMOTION6_LABEL_NAMES = {
970
+ 0: "Anger",
971
+ 1: "Fear",
972
+ 2: "Disgust",
973
+ 3: "Sadness",
974
+ 4: "Joy",
975
+ 5: "None of Them"
976
+ }
977
+
978
  EMOTION9_LABEL_NAMES = {
979
  0: "Anger",
980
  1: "Fear",