kovacsvi commited on
Commit
70be539
·
1 Parent(s): 6796ab1

minor - media (hierarchical)

Browse files
Files changed (1) hide show
  1. interfaces/cap_minor_media.py +66 -27
interfaces/cap_minor_media.py CHANGED
@@ -7,8 +7,10 @@ import pandas as pd
7
  from transformers import AutoModelForSequenceClassification
8
  from transformers import AutoTokenizer
9
  from huggingface_hub import HfApi
 
10
 
11
- from label_dicts import CAP_MEDIA_NUM_DICT, CAP_MEDIA_LABEL_NAMES
 
12
 
13
  from .utils import is_disk_full
14
 
@@ -31,46 +33,83 @@ def check_huggingface_path(checkpoint_path: str):
31
  return False
32
 
33
  def build_huggingface_path(language: str, domain: str):
34
- return "poltextlab/xlm-roberta-large-pooled-cap-media"
35
 
36
- def predict(text, model_id, tokenizer_id):
37
  device = torch.device("cpu")
38
- model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
40
 
41
- inputs = tokenizer(text,
42
- max_length=256,
43
- truncation=True,
44
- padding="do_not_pad",
45
- return_tensors="pt").to(device)
46
- model.eval()
47
 
 
 
48
  with torch.no_grad():
49
- logits = model(**inputs).logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
52
- output_pred = {f"[{CAP_MEDIA_NUM_DICT[i]}] {CAP_MEDIA_LABEL_NAMES[CAP_MEDIA_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
53
- output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
54
  return output_pred, output_info
55
 
56
  def predict_cap(text, language, domain):
57
  domain = domains[domain]
58
- model_id = build_huggingface_path(language, domain)
59
  tokenizer_id = "xlm-roberta-large"
60
 
61
  if is_disk_full():
62
  os.system('rm -rf /data/models*')
63
  os.system('rm -r ~/.cache/huggingface/hub')
64
 
65
- return predict(text, model_id, tokenizer_id)
66
-
67
- #demo = gr.Interface(
68
- # title="CAP Media Topics Babel Demo",
69
- # fn=predict_cap,
70
- # inputs=[gr.Textbox(lines=6, label="Input"),
71
- # gr.Dropdown(languages, label="Language"),
72
- # gr.Dropdown(domains.keys(), label="Domain")],
73
- # outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])
74
- with gr.Blocks() as demo:
75
- gr.Markdown("""# CAP Media Topics Babel Demo
76
- ## 🚧 Coming Soon 🚧""")
 
7
  from transformers import AutoModelForSequenceClassification
8
  from transformers import AutoTokenizer
9
  from huggingface_hub import HfApi
10
+ from collections import defaultdict
11
 
12
+ from label_dicts import (CAP_MEDIA_NUM_DICT, CAP_MEDIA_LABEL_NAMES,
13
+ CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES)
14
 
15
  from .utils import is_disk_full
16
 
 
33
  return False
34
 
35
  def build_huggingface_path(language: str, domain: str):
36
+ return ("poltextlab/xlm-roberta-large-pooled-cap-media", "poltextlab/xlm-roberta-large-pooled-cap-minor-v3")
37
 
38
+ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
39
  device = torch.device("cpu")
40
+
41
+ # Load major and minor models + tokenizer
42
+ major_model = AutoModelForSequenceClassification.from_pretrained(
43
+ major_model_id,
44
+ low_cpu_mem_usage=True,
45
+ device_map="auto",
46
+ offload_folder="offload",
47
+ token=HF_TOKEN
48
+ ).to(device)
49
+
50
+ minor_model = AutoModelForSequenceClassification.from_pretrained(
51
+ minor_model_id,
52
+ low_cpu_mem_usage=True,
53
+ device_map="auto",
54
+ offload_folder="offload",
55
+ token=HF_TOKEN
56
+ ).to(device)
57
+
58
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
59
 
60
+ # Tokenize input
61
+ inputs = tokenizer(text, max_length=256, truncation=True, padding="do_not_pad", return_tensors="pt").to(device)
 
 
 
 
62
 
63
+ # Predict major topic
64
+ major_model.eval()
65
  with torch.no_grad():
66
+ major_logits = major_model(**inputs).logits
67
+ major_probs = F.softmax(major_logits, dim=-1)
68
+ major_probs_np = major_probs.cpu().numpy().flatten()
69
+ top_major_index = int(np.argmax(major_probs_np))
70
+ top_major_id = major_index_to_id[top_major_index]
71
+
72
+ # Default: show major topic predictions
73
+ output_pred = {
74
+ f"[{major_index_to_id[i]}] {CAP_MEDIA_LABEL_NAMES[major_index_to_id[i]]}": float(major_probs_np[i])
75
+ for i in np.argsort(major_probs_np)[::-1]
76
+ }
77
+
78
+ # If eligible for minor prediction
79
+ if top_major_id in major_to_minor_map:
80
+ valid_minor_ids = major_to_minor_map[top_major_id]
81
+ minor_model.eval()
82
+ with torch.no_grad():
83
+ minor_logits = minor_model(**inputs).logits
84
+ minor_probs = F.softmax(minor_logits, dim=-1)
85
+
86
+ # Restrict to valid minor codes
87
+ valid_indices = [minor_id_to_index[mid] for mid in valid_minor_ids if mid in minor_id_to_index]
88
+ filtered_probs = {minor_index_to_id[i]: float(minor_probs[0][i]) for i in valid_indices}
89
+ output_pred = {
90
+ f"[{k}] {CAP_MIN_LABEL_NAMES[k]}": v
91
+ for k, v in sorted(filtered_probs.items(), key=lambda item: item[1], reverse=True)
92
+ }
93
+
94
+ output_info = f'<p style="text-align: center; display: block">Prediction used <a href="https://huggingface.co/{major_model_id}">{major_model_id}</a> and <a href="https://huggingface.co/{minor_model_id}">{minor_model_id}</a>.</p>'
95
 
 
 
 
96
  return output_pred, output_info
97
 
98
  def predict_cap(text, language, domain):
99
  domain = domains[domain]
100
+ major_model_id, minor_model_id = build_huggingface_path(language, domain)
101
  tokenizer_id = "xlm-roberta-large"
102
 
103
  if is_disk_full():
104
  os.system('rm -rf /data/models*')
105
  os.system('rm -r ~/.cache/huggingface/hub')
106
 
107
+ return predict(text, major_model_id, minor_model_id, tokenizer_id)
108
+
109
+ demo = gr.Interface(
110
+ title="CAP Media Topics Babel Demo",
111
+ fn=predict_cap,
112
+ inputs=[gr.Textbox(lines=6, label="Input"),
113
+ gr.Dropdown(languages, label="Language"),
114
+ gr.Dropdown(domains.keys(), label="Domain")],
115
+ outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])