Sina1138 commited on
Commit
6fe7180
·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. .gitignore +2 -0
  3. .gitmodules +3 -0
  4. README.md +13 -0
  5. glimpse-ui/.gitignore +362 -0
  6. glimpse-ui/LICENSE +21 -0
  7. glimpse-ui/alternative_polarity/deberta/deberta_v3_base_polarity.py +95 -0
  8. glimpse-ui/alternative_polarity/deberta/deberta_v3_base_polarity_train.py +98 -0
  9. glimpse-ui/alternative_polarity/manual_polarity_tester.py +65 -0
  10. glimpse-ui/alternative_polarity/scideberta/scideberta_full_polarity.py +79 -0
  11. glimpse-ui/alternative_polarity/scideberta/scideberta_full_polarity_train.py +108 -0
  12. glimpse-ui/alternative_topic/debetra/deberta_topic.py +92 -0
  13. glimpse-ui/alternative_topic/debetra/deberta_topic_train.py +80 -0
  14. glimpse-ui/alternative_topic/scideberta/scideberta_topic.py +92 -0
  15. glimpse-ui/alternative_topic/scideberta/scideberta_topic_train.py +80 -0
  16. glimpse-ui/data/ExtractDISAPEREData.py +106 -0
  17. glimpse-ui/glimpse/.gitignore +203 -0
  18. glimpse-ui/glimpse/Readme.md +69 -0
  19. glimpse-ui/glimpse/examples/RSA Sum tests.ipynb +189 -0
  20. glimpse-ui/glimpse/examples/reviews/reviews_app.py +274 -0
  21. glimpse-ui/glimpse/examples/reviews/reviews_latex_generation.py +272 -0
  22. glimpse-ui/glimpse/glimpse/baselines/generate_llm_summaries.py +112 -0
  23. glimpse-ui/glimpse/glimpse/baselines/sumy_baselines.py +129 -0
  24. glimpse-ui/glimpse/glimpse/data_loading/Glimpse_tokenizer.py +74 -0
  25. glimpse-ui/glimpse/glimpse/data_loading/data_processing.py +15 -0
  26. glimpse-ui/glimpse/glimpse/data_loading/generate_abstractive_candidates.py +230 -0
  27. glimpse-ui/glimpse/glimpse/data_loading/generate_extractive_candidates.py +129 -0
  28. glimpse-ui/glimpse/glimpse/evaluate/Evaluate informativeness.ipynb +258 -0
  29. glimpse-ui/glimpse/glimpse/evaluate/evaluate_bartbert_metrics.py +110 -0
  30. glimpse-ui/glimpse/glimpse/evaluate/evaluate_common_metrics_samples.py +122 -0
  31. glimpse-ui/glimpse/glimpse/evaluate/evaluate_seahorse_metrics_samples.py +150 -0
  32. glimpse-ui/glimpse/glimpse/src/beam_rsa_decoding.py +207 -0
  33. glimpse-ui/glimpse/glimpse/src/compute_rsa.py +137 -0
  34. glimpse-ui/glimpse/glimpse/src/rsa_merge_into_single.py +52 -0
  35. glimpse-ui/glimpse/glimpse/src/rsa_reranking.py +127 -0
  36. glimpse-ui/glimpse/mds/Single summaries expes.ipynb +587 -0
  37. glimpse-ui/glimpse/mds/Template summaries.ipynb +531 -0
  38. glimpse-ui/glimpse/mds/discriminative_classification.py +113 -0
  39. glimpse-ui/glimpse/pyproject.toml +21 -0
  40. glimpse-ui/glimpse/requirements +10 -0
  41. glimpse-ui/glimpse/rsasumm/__init__.py +0 -0
  42. glimpse-ui/glimpse/rsasumm/beam_search.py +430 -0
  43. glimpse-ui/glimpse/rsasumm/rsa_reranker.py +280 -0
  44. glimpse-ui/glimpse/scripts/abstractive.sh +37 -0
  45. glimpse-ui/glimpse/scripts/extractive.sh +31 -0
  46. glimpse-ui/glimpse_pk_csv_converter.py +92 -0
  47. glimpse-ui/interface/Demo.py +800 -0
  48. glimpse-ui/scibert/scibert_polarity/final_model/config.json +35 -0
  49. glimpse-ui/scibert/scibert_polarity/final_model/model.safetensors +3 -0
  50. glimpse-ui/scibert/scibert_polarity/final_model/special_tokens_map.json +7 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ glimpse-ui/data/preprocessed_scored_reviews.csv filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ./scibert/*
2
+ ./alternative_*/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "glimpse-ui"]
2
+ path = glimpse-ui
3
+ url = https://github.com/Sina1138/glimpse-ui.git
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: GlimpSys
3
+ emoji: 📊
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.35.0
8
+ app_file: glimpse-ui/interface/Demo.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
glimpse-ui/.gitignore ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # project ignores
2
+ glimpse/
3
+ data/DISAPERE-main/
4
+ *checkpoints/
5
+ .gradio/
6
+ test.py
7
+ data/*
8
+ final_model/
9
+ alternative_polarity/llama/
10
+ !data/ExtractDISAPEREData.py
11
+ !data/preprocessed_scored_reviews.csv
12
+
13
+ # Byte-compiled / optimized / DLL files
14
+ __pycache__/
15
+ *.py[cod]
16
+ *$py.class
17
+
18
+ # C extensions
19
+ *.so
20
+
21
+ # Distribution / packaging
22
+ .Python
23
+ build/
24
+ develop-eggs/
25
+ dist/
26
+ downloads/
27
+ eggs/
28
+ .eggs/
29
+ lib/
30
+ lib64/
31
+ parts/
32
+ sdist/
33
+ var/
34
+ wheels/
35
+ share/python-wheels/
36
+ *.egg-info/
37
+ .installed.cfg
38
+ *.egg
39
+ MANIFEST
40
+
41
+ # PyInstaller
42
+ # Usually these files are written by a python script from a template
43
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
44
+ *.manifest
45
+ *.spec
46
+
47
+ # Installer logs
48
+ pip-log.txt
49
+ pip-delete-this-directory.txt
50
+
51
+ # Unit test / coverage reports
52
+ htmlcov/
53
+ .tox/
54
+ .nox/
55
+ .coverage
56
+ .coverage.*
57
+ .cache
58
+ nosetests.xml
59
+ coverage.xml
60
+ *.cover
61
+ *.py,cover
62
+ .hypothesis/
63
+ .pytest_cache/
64
+ cover/
65
+
66
+ # Translations
67
+ *.mo
68
+ *.pot
69
+
70
+ # Django stuff:
71
+ *.log
72
+ local_settings.py
73
+ db.sqlite3
74
+ db.sqlite3-journal
75
+
76
+ # Flask stuff:
77
+ instance/
78
+ .webassets-cache
79
+
80
+ # Scrapy stuff:
81
+ .scrapy
82
+
83
+ # Sphinx documentation
84
+ docs/_build/
85
+
86
+ # PyBuilder
87
+ .pybuilder/
88
+ target/
89
+
90
+ # Jupyter Notebook
91
+ .ipynb_checkpoints
92
+
93
+ # IPython
94
+ profile_default/
95
+ ipython_config.py
96
+
97
+ # pyenv
98
+ # For a library or package, you might want to ignore these files since the code is
99
+ # intended to run in multiple environments; otherwise, check them in:
100
+ # .python-version
101
+
102
+ # pipenv
103
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
104
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
105
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
106
+ # install all needed dependencies.
107
+ #Pipfile.lock
108
+
109
+ # UV
110
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
111
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
112
+ # commonly ignored for libraries.
113
+ #uv.lock
114
+
115
+ # poetry
116
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
117
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
118
+ # commonly ignored for libraries.
119
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
120
+ #poetry.lock
121
+
122
+ # pdm
123
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
124
+ #pdm.lock
125
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
126
+ # in version control.
127
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
128
+ .pdm.toml
129
+ .pdm-python
130
+ .pdm-build/
131
+
132
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
133
+ __pypackages__/
134
+
135
+ # Celery stuff
136
+ celerybeat-schedule
137
+ celerybeat.pid
138
+
139
+ # SageMath parsed files
140
+ *.sage.py
141
+
142
+ # Environments
143
+ .env
144
+ .venv
145
+ env/
146
+ venv/
147
+ ENV/
148
+ env.bak/
149
+ venv.bak/
150
+
151
+ # Spyder project settings
152
+ .spyderproject
153
+ .spyproject
154
+
155
+ # Rope project settings
156
+ .ropeproject
157
+
158
+ # mkdocs documentation
159
+ /site
160
+
161
+ # mypy
162
+ .mypy_cache/
163
+ .dmypy.json
164
+ dmypy.json
165
+
166
+ # Pyre type checker
167
+ .pyre/
168
+
169
+ # pytype static type analyzer
170
+ .pytype/
171
+
172
+ # Cython debug symbols
173
+ cython_debug/
174
+
175
+ # PyCharm
176
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
177
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
178
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
179
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
180
+ #.idea/
181
+
182
+ # Ruff stuff:
183
+ .ruff_cache/
184
+
185
+ # PyPI configuration file
186
+ .pypirc
187
+
188
+ # Byte-compiled / optimized / DLL files
189
+ __pycache__/
190
+ *.py[cod]
191
+ *$py.class
192
+
193
+ # C extensions
194
+ *.so
195
+
196
+ # Distribution / packaging
197
+ .Python
198
+ build/
199
+ develop-eggs/
200
+ dist/
201
+ downloads/
202
+ eggs/
203
+ .eggs/
204
+ lib/
205
+ lib64/
206
+ parts/
207
+ sdist/
208
+ var/
209
+ wheels/
210
+ share/python-wheels/
211
+ *.egg-info/
212
+ .installed.cfg
213
+ *.egg
214
+ MANIFEST
215
+
216
+ # PyInstaller
217
+ # Usually these files are written by a python script from a template
218
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
219
+ *.manifest
220
+ *.spec
221
+
222
+ # Installer logs
223
+ pip-log.txt
224
+ pip-delete-this-directory.txt
225
+
226
+ # Unit test / coverage reports
227
+ htmlcov/
228
+ .tox/
229
+ .nox/
230
+ .coverage
231
+ .coverage.*
232
+ .cache
233
+ nosetests.xml
234
+ coverage.xml
235
+ *.cover
236
+ *.py,cover
237
+ .hypothesis/
238
+ .pytest_cache/
239
+ cover/
240
+
241
+ # Translations
242
+ *.mo
243
+ *.pot
244
+
245
+ # Django stuff:
246
+ *.log
247
+ local_settings.py
248
+ db.sqlite3
249
+ db.sqlite3-journal
250
+
251
+ # Flask stuff:
252
+ instance/
253
+ .webassets-cache
254
+
255
+ # Scrapy stuff:
256
+ .scrapy
257
+
258
+ # Sphinx documentation
259
+ docs/_build/
260
+
261
+ # PyBuilder
262
+ .pybuilder/
263
+ target/
264
+
265
+ # Jupyter Notebook
266
+ .ipynb_checkpoints
267
+
268
+ # IPython
269
+ profile_default/
270
+ ipython_config.py
271
+
272
+ # pyenv
273
+ # For a library or package, you might want to ignore these files since the code is
274
+ # intended to run in multiple environments; otherwise, check them in:
275
+ # .python-version
276
+
277
+ # pipenv
278
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
279
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
280
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
281
+ # install all needed dependencies.
282
+ #Pipfile.lock
283
+
284
+ # UV
285
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
286
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
287
+ # commonly ignored for libraries.
288
+ #uv.lock
289
+
290
+ # poetry
291
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
292
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
293
+ # commonly ignored for libraries.
294
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
295
+ #poetry.lock
296
+
297
+ # pdm
298
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
299
+ #pdm.lock
300
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
301
+ # in version control.
302
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
303
+ .pdm.toml
304
+ .pdm-python
305
+ .pdm-build/
306
+
307
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
308
+ __pypackages__/
309
+
310
+ # Celery stuff
311
+ celerybeat-schedule
312
+ celerybeat.pid
313
+
314
+ # SageMath parsed files
315
+ *.sage.py
316
+
317
+ # Environments
318
+ .env
319
+ .venv
320
+ env/
321
+ venv/
322
+ ENV/
323
+ env.bak/
324
+ venv.bak/
325
+
326
+ # Spyder project settings
327
+ .spyderproject
328
+ .spyproject
329
+
330
+ # Rope project settings
331
+ .ropeproject
332
+
333
+ # mkdocs documentation
334
+ /site
335
+
336
+ # mypy
337
+ .mypy_cache/
338
+ .dmypy.json
339
+ dmypy.json
340
+
341
+ # Pyre type checker
342
+ .pyre/
343
+
344
+ # pytype static type analyzer
345
+ .pytype/
346
+
347
+ # Cython debug symbols
348
+ cython_debug/
349
+
350
+ # PyCharm
351
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
352
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
353
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
354
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
355
+ #.idea/
356
+
357
+ # Ruff stuff:
358
+ .ruff_cache/
359
+
360
+ # PyPI configuration file
361
+ .pypirc
362
+ data/DISAPERE_test.py
glimpse-ui/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Sina Salmannia
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
glimpse-ui/alternative_polarity/deberta/deberta_v3_base_polarity.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from pathlib import Path
5
+ import nltk
6
+ from tqdm import tqdm
7
+ import sys, os.path
8
+ from torch.nn import functional as F
9
+
10
+ nltk.download('punkt')
11
+
12
+ BASE_DIR = Path(__file__).resolve().parent.parent.parent
13
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
14
+
15
+ from glimpse.glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
16
+
17
+ # === CONFIGURATION ===
18
+
19
+ MODEL_DIR = BASE_DIR / "alternative_polarity" / "deberta" / "deberta_v3_base_polarity_final_model"
20
+ DATA_DIR = BASE_DIR / "glimpse" / "data" / "processed"
21
+ OUTPUT_DIR = BASE_DIR / "data" / "polarity_scored"
22
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
23
+
24
+ # === Load model and tokenizer ===
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
26
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
27
+ model.eval()
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ model.to(device)
30
+
31
+ # === Tokenize like GLIMPSE ===
32
+ # def tokenize_sentences(text: str) -> list:
33
+ # # same tokenization as in the original glimpse code
34
+ # text = text.replace('-----', '\n')
35
+ # sentences = nltk.sent_tokenize(text)
36
+ # sentences = [sentence for sentence in sentences if sentence != ""]
37
+ # return sentences
38
+
39
+
40
+ # def predict_polarity(sentences):
41
+ # inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
42
+ # with torch.no_grad():
43
+ # outputs = model(**inputs)
44
+ # logits = outputs.logits
45
+ # temperature = 2.7 # Adjust temperature for scaling logits
46
+ # probs = F.softmax(logits / temperature, dim=-1)
47
+ # # Get probability of positive class
48
+ # polarity_scores = probs[:, 1]
49
+ # # Rescale: 0 → -1 (very negative), 1 → +1 (very positive)
50
+ # polarity_scores = (polarity_scores * 2) - 1
51
+ # return polarity_scores.cpu().tolist()
52
+
53
+ def predict_polarity(sentences):
54
+ inputs = tokenizer(
55
+ sentences,
56
+ return_tensors="pt",
57
+ padding=True,
58
+ truncation=True,
59
+ max_length=512
60
+ ).to(device)
61
+
62
+ with torch.no_grad():
63
+ logits = model(**inputs).logits # (batch, 2)
64
+ logit_diff = logits[:,1] - logits[:,0]
65
+ alpha = 2.1 # tweak
66
+ scores = torch.tanh(alpha * logit_diff) # in [-1,1]
67
+ return scores.cpu().tolist()
68
+
69
+
70
+ def find_polarity(start_year=2017, end_year=2021):
71
+ for year in range(start_year, end_year + 1):
72
+ print(f"Processing {year}...")
73
+ input_path = DATA_DIR / f"all_reviews_{year}.csv"
74
+ output_path = OUTPUT_DIR / f"polarity_scored_reviews_{year}.csv"
75
+
76
+ df = pd.read_csv(input_path)
77
+
78
+ all_rows = []
79
+ for _, row in tqdm(df.iterrows(), total=len(df)):
80
+ review_id = row["id"]
81
+ text = row["text"]
82
+ sentences = glimpse_tokenizer(text)
83
+ if not sentences:
84
+ continue
85
+ labels = predict_polarity(sentences)
86
+ for sentence, polarity in zip(sentences, labels):
87
+ all_rows.append({"id": review_id, "sentence": sentence, "polarity": polarity})
88
+
89
+ output_df = pd.DataFrame(all_rows)
90
+ output_df.to_csv(output_path, index=False)
91
+ print(f"Saved polarity-scored data to {output_path}")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ find_polarity()
glimpse-ui/alternative_polarity/deberta/deberta_v3_base_polarity_train.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from datasets import Dataset
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
4
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
5
+ import numpy as np
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from transformers import Trainer
10
+
11
+
12
+ # Load data
13
+ train_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_polarity_train.csv")
14
+ dev_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_polarity_dev.csv")
15
+ test_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_polarity_test.csv")
16
+
17
+ # Convert to HuggingFace Datasets
18
+ train_ds = Dataset.from_pandas(train_df)
19
+ dev_ds = Dataset.from_pandas(dev_df)
20
+ test_ds = Dataset.from_pandas(test_df)
21
+
22
+ # Tokenize
23
+ model_name = "microsoft/deberta-v3-base"
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ def tokenize(batch):
26
+ return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=512)
27
+
28
+ train_ds = train_ds.map(tokenize, batched=True)
29
+ dev_ds = dev_ds.map(tokenize, batched=True)
30
+ test_ds = test_ds.map(tokenize, batched=True)
31
+
32
+ # Set format for PyTorch
33
+ train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
34
+ dev_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
35
+ test_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
36
+
37
+ # Load model
38
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
39
+
40
+ # Compute class weights
41
+ label_counts = train_df['label'].value_counts()
42
+ total_samples = len(train_df)
43
+ class_weights = torch.tensor([total_samples / (len(label_counts) * count) for count in label_counts.sort_index().values])
44
+ class_weights = class_weights.to(dtype=torch.float32)
45
+ print("Class weights:", class_weights)
46
+
47
+ class WeightedTrainer(Trainer):
48
+ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
49
+ labels = inputs.pop("labels")
50
+ outputs = model(**inputs)
51
+ logits = outputs.logits
52
+ weights = class_weights.to(logits.device)
53
+ loss = F.cross_entropy(logits, labels, weight=weights)
54
+ return (loss, outputs) if return_outputs else loss
55
+
56
+
57
+ # Metrics
58
+ def compute_metrics(eval_pred):
59
+ logits, labels = eval_pred
60
+ preds = np.argmax(logits, axis=1)
61
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
62
+ acc = accuracy_score(labels, preds)
63
+ return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
64
+
65
+ # Training arguments
66
+ args = TrainingArguments(
67
+ output_dir="./alternative_polarity/deberta/checkpoints",
68
+ eval_strategy="epoch",
69
+ save_strategy="epoch",
70
+ learning_rate=2e-5,
71
+ per_device_train_batch_size=4,
72
+ per_device_eval_batch_size=8,
73
+ num_train_epochs=4,
74
+ weight_decay=0.01,
75
+ load_best_model_at_end=True,
76
+ metric_for_best_model="f1"
77
+ )
78
+
79
+ # Trainer
80
+ trainer = WeightedTrainer(
81
+ model=model,
82
+ args=args,
83
+ train_dataset=train_ds,
84
+ eval_dataset=dev_ds,
85
+ tokenizer=tokenizer,
86
+ compute_metrics=compute_metrics
87
+ )
88
+
89
+ # Train
90
+ trainer.train()
91
+
92
+ # Evaluate on test
93
+ results = trainer.evaluate(test_ds)
94
+ print("Test results:", results)
95
+
96
+ # Save the model and tokenizer
97
+ model.save_pretrained("./alternative_polarity/deberta/deberta_v3_base_polarity_final_model")
98
+ tokenizer.save_pretrained("./alternative_polarity/deberta/deberta_v3_base_polarity_final_model")
glimpse-ui/alternative_polarity/manual_polarity_tester.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from pathlib import Path
5
+ import sys, os
6
+
7
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
8
+ from glimpse.glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
9
+
10
+ # === CONFIGURATION ===
11
+ BASE_DIR = Path(__file__).resolve().parent.parent
12
+ MODEL_DIR = BASE_DIR / "alternative_polarity" / "deberta" / "deberta_v3_large_polarity_final_model"
13
+ # MODEL_DIR = BASE_DIR / "alternative_polarity" / "llama" / "final_model"
14
+ # MODEL_DIR = BASE_DIR / "alternative_polarity" / "scideberta" / "scideberta_full_polarity_final_model"
15
+
16
+ # --> Best so far: deberta_v3 (passes "pros" test)
17
+
18
+
19
+ # === Load model and tokenizer ===
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
21
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
22
+ model.eval()
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ model.to(device)
25
+
26
+ # === Prediction function with confidence ===
27
+ def predict_polarity(sentences):
28
+ inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
29
+ inputs = {k: v.to(device) for k, v in inputs.items()}
30
+
31
+ with torch.no_grad():
32
+ outputs = model(**inputs)
33
+ probs = F.softmax(outputs.logits, dim=1)
34
+ confidences, preds = torch.max(probs, dim=1)
35
+
36
+ results = []
37
+ for sentence, pred, conf, prob in zip(sentences, preds, confidences, probs):
38
+ results.append({
39
+ "sentence": sentence,
40
+ "label": "Positive" if pred.item() == 1 else "Negative",
41
+ "confidence": conf.item(),
42
+ "probs": prob.cpu().numpy().tolist()
43
+ })
44
+ return results
45
+
46
+ # === Example: test a multi-sentence peer review ===
47
+ if __name__ == "__main__":
48
+ # Replace this with your review
49
+ full_review = """
50
+ Pros:
51
+ Con: The experiments lack comparison with prior work.
52
+ The authors clearly explain their methodology, which is a strong point.
53
+ """
54
+
55
+ # Use glimpse tokenizer to split into sentences
56
+ sentences = glimpse_tokenizer(full_review)
57
+
58
+ # Run polarity prediction
59
+ results = predict_polarity(sentences)
60
+
61
+ # Display results
62
+ for res in results:
63
+ print(f"\nSentence: {res['sentence']}")
64
+ print(f" → Prediction: {res['label']} (Confidence: {res['confidence']:.3f})")
65
+ print(f" Probabilities: [Negative: {res['probs'][0]:.3f}, Positive: {res['probs'][1]:.3f}]")
glimpse-ui/alternative_polarity/scideberta/scideberta_full_polarity.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from pathlib import Path
5
+ import nltk
6
+ from tqdm import tqdm
7
+ import sys, os.path
8
+ from torch.nn import functional as F
9
+
10
+ nltk.download('punkt')
11
+
12
+ BASE_DIR = Path(__file__).resolve().parent.parent.parent
13
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
14
+
15
+ from glimpse.glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
16
+
17
+ # === CONFIGURATION ===
18
+
19
+ MODEL_DIR = BASE_DIR / "alternative_polarity" / "scideberta" / "scideberta_full_polarity_final_model"
20
+ DATA_DIR = BASE_DIR / "glimpse" / "data" / "processed"
21
+ OUTPUT_DIR = BASE_DIR / "data" / "polarity_scored"
22
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
23
+
24
+ # === Load model and tokenizer ===
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
26
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
27
+ model.eval()
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ model.to(device)
30
+
31
+ # === Tokenize like GLIMPSE ===
32
+ # def tokenize_sentences(text: str) -> list:
33
+ # # same tokenization as in the original glimpse code
34
+ # text = text.replace('-----', '\n')
35
+ # sentences = nltk.sent_tokenize(text)
36
+ # sentences = [sentence for sentence in sentences if sentence != ""]
37
+ # return sentences
38
+
39
+
40
+ def predict_polarity(sentences):
41
+ inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
42
+ with torch.no_grad():
43
+ outputs = model(**inputs)
44
+ logits = outputs.logits
45
+ temperature = 2.7
46
+ probs = F.softmax(logits / temperature, dim=-1)
47
+ # Get probability of positive class
48
+ polarity_scores = probs[:, 1]
49
+ # Rescale: 0 → -1 (very negative), 1 → +1 (very positive)
50
+ polarity_scores = (polarity_scores * 2) - 1
51
+ return polarity_scores.cpu().tolist()
52
+
53
+
54
+ def find_polarity(start_year=2017, end_year=2021):
55
+ for year in range(start_year, end_year + 1):
56
+ print(f"Processing {year}...")
57
+ input_path = DATA_DIR / f"all_reviews_{year}.csv"
58
+ output_path = OUTPUT_DIR / f"polarity_scored_reviews_{year}.csv"
59
+
60
+ df = pd.read_csv(input_path)
61
+
62
+ all_rows = []
63
+ for _, row in tqdm(df.iterrows(), total=len(df)):
64
+ review_id = row["id"]
65
+ text = row["text"]
66
+ sentences = glimpse_tokenizer(text)
67
+ if not sentences:
68
+ continue
69
+ labels = predict_polarity(sentences)
70
+ for sentence, polarity in zip(sentences, labels):
71
+ all_rows.append({"id": review_id, "sentence": sentence, "polarity": polarity})
72
+
73
+ output_df = pd.DataFrame(all_rows)
74
+ output_df.to_csv(output_path, index=False)
75
+ print(f"Saved polarity-scored data to {output_path}")
76
+
77
+
78
+ if __name__ == "__main__":
79
+ find_polarity()
glimpse-ui/alternative_polarity/scideberta/scideberta_full_polarity_train.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from datasets import Dataset
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
4
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
5
+ import numpy as np
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from transformers import Trainer
10
+
11
+ class WeightedTrainer(Trainer):
12
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
13
+ labels = inputs.pop("labels")
14
+ outputs = model(**inputs)
15
+ logits = outputs.logits
16
+ weights = class_weights.to(logits.device)
17
+ loss = F.cross_entropy(logits, labels, weight=weights)
18
+ return (loss, outputs) if return_outputs else loss
19
+
20
+
21
+
22
+ # Load data
23
+ train_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_polarity_train.csv")
24
+ dev_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_polarity_dev.csv")
25
+ test_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_polarity_test.csv")
26
+
27
+ # Convert to HuggingFace Datasets
28
+ train_ds = Dataset.from_pandas(train_df)
29
+ dev_ds = Dataset.from_pandas(dev_df)
30
+ test_ds = Dataset.from_pandas(test_df)
31
+
32
+ model_name = "KISTI-AI/Scideberta-full"
33
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
34
+
35
+ def tokenize(batch):
36
+ return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=512)
37
+
38
+ train_ds = train_ds.map(tokenize, batched=True)
39
+ dev_ds = dev_ds.map(tokenize, batched=True)
40
+ test_ds = test_ds.map(tokenize, batched=True)
41
+
42
+ # Set format for PyTorch
43
+ train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
44
+ dev_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
45
+ test_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
46
+
47
+ # Load model
48
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
49
+
50
+ # Compute class weights
51
+ label_counts = train_df['label'].value_counts()
52
+ total_samples = len(train_df)
53
+ class_weights = torch.tensor([total_samples / (len(label_counts) * count) for count in label_counts.sort_index().values])
54
+ class_weights = class_weights.to(dtype=torch.float32)
55
+ print("Class weights:", class_weights)
56
+
57
+ class WeightedTrainer(Trainer):
58
+ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
59
+ labels = inputs.pop("labels")
60
+ outputs = model(**inputs)
61
+ logits = outputs.logits
62
+ weights = class_weights.to(logits.device)
63
+ loss = F.cross_entropy(logits, labels, weight=weights)
64
+ return (loss, outputs) if return_outputs else loss
65
+
66
+
67
+ # Metrics
68
+ def compute_metrics(eval_pred):
69
+ logits, labels = eval_pred
70
+ preds = np.argmax(logits, axis=1)
71
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
72
+ acc = accuracy_score(labels, preds)
73
+ return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
74
+
75
+ # Training arguments
76
+ args = TrainingArguments(
77
+ output_dir="./alternative_polarity/scideberta/checkpoints",
78
+ eval_strategy="epoch",
79
+ save_strategy="epoch",
80
+ learning_rate=2e-5,
81
+ per_device_train_batch_size=4,
82
+ per_device_eval_batch_size=8,
83
+ num_train_epochs=4,
84
+ weight_decay=0.01,
85
+ load_best_model_at_end=True,
86
+ metric_for_best_model="f1"
87
+ )
88
+
89
+ # Trainer
90
+ trainer = WeightedTrainer(
91
+ model=model,
92
+ args=args,
93
+ train_dataset=train_ds,
94
+ eval_dataset=dev_ds,
95
+ tokenizer=tokenizer,
96
+ compute_metrics=compute_metrics
97
+ )
98
+
99
+ # Train
100
+ trainer.train()
101
+
102
+ # Evaluate on test
103
+ results = trainer.evaluate(test_ds)
104
+ print("Test results:", results)
105
+
106
+ # Save the model and tokenizer
107
+ model.save_pretrained("./alternative_polarity/scideberta/scideberta_full_polarity_final_model")
108
+ tokenizer.save_pretrained("./alternative_polarity/scideberta/scideberta_full_polarity_final_model")
glimpse-ui/alternative_topic/debetra/deberta_topic.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from pathlib import Path
5
+ import nltk
6
+ from tqdm import tqdm
7
+ import sys, os.path
8
+
9
+ nltk.download('punkt')
10
+
11
+ BASE_DIR = Path(__file__).resolve().parent.parent.parent
12
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
13
+
14
+ from glimpse.glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
15
+
16
+ # === CONFIGURATION ===
17
+
18
+ MODEL_DIR = BASE_DIR / "alternative_topic" / "deberta" / "final_model"
19
+ DATA_DIR = BASE_DIR / "glimpse" / "data" / "processed"
20
+ OUTPUT_DIR = BASE_DIR / "data" / "topic_scored"
21
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
22
+
23
+ # === Load model and tokenizer ===
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
25
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
26
+ model.eval()
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ model.to(device)
29
+
30
+ # === Tokenize like GLIMPSE ===
31
+ # def tokenize_sentences(text: str) -> list:
32
+ # # same tokenization as in the original glimpse code
33
+ # text = text.replace('-----', '\n')
34
+ # sentences = nltk.sent_tokenize(text)
35
+ # sentences = [sentence for sentence in sentences if sentence != ""]
36
+ # return sentences
37
+
38
+
39
+ # === Label map (optional: for human-readable output) ===
40
+ id2label = {
41
+ # 0: "Evaluative",
42
+ # 1: "Structuring",
43
+ # 2: "Request",
44
+ # 3: "Fact",
45
+ # 4: "Social",
46
+ # 5: "Other",
47
+ 0: "Substance",
48
+ 1: "Clarity",
49
+ 2: "Soundness/Correctness",
50
+ 3: "Originality",
51
+ 4: "Motivation/Impact",
52
+ 5: "Meaningful Comparison",
53
+ 6: "Replicability",
54
+ 7: "NONE" # This is used for sentences that do not match any specific topic
55
+ }
56
+
57
+ def predict_topic(sentences):
58
+ inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
59
+ with torch.no_grad():
60
+ outputs = model(**inputs)
61
+ predictions = torch.argmax(outputs.logits, dim=1).cpu().tolist()
62
+ # Convert predictions to human-readable labels
63
+ predictions = [id2label[pred] for pred in predictions]
64
+ return predictions
65
+
66
+
67
+ def find_topic(start_year=2017, end_year=2021):
68
+ for year in range(start_year, end_year + 1):
69
+ print(f"Processing {year}...")
70
+ input_path = DATA_DIR / f"all_reviews_{year}.csv"
71
+ output_path = OUTPUT_DIR / f"topic_scored_reviews_{year}.csv"
72
+
73
+ df = pd.read_csv(input_path)
74
+
75
+ all_rows = []
76
+ for _, row in tqdm(df.iterrows(), total=len(df)):
77
+ review_id = row["id"]
78
+ text = row["text"]
79
+ sentences = glimpse_tokenizer(text)
80
+ if not sentences:
81
+ continue
82
+ labels = predict_topic(sentences)
83
+ for sentence, topic in zip(sentences, labels):
84
+ all_rows.append({"id": review_id, "sentence": sentence, "topic": topic})
85
+
86
+ output_df = pd.DataFrame(all_rows)
87
+ output_df.to_csv(output_path, index=False)
88
+ print(f"Saved topic-scored data to {output_path}")
89
+
90
+
91
+ if __name__ == "__main__":
92
+ find_topic()
glimpse-ui/alternative_topic/debetra/deberta_topic_train.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from datasets import Dataset
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
4
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
5
+ import numpy as np
6
+
7
+ # Load data
8
+ dev_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_topic_dev.csv")
9
+ train_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_topic_train.csv")
10
+ test_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_topic_test.csv")
11
+
12
+ # Convert to HuggingFace Datasets
13
+ train_ds = Dataset.from_pandas(train_df)
14
+ dev_ds = Dataset.from_pandas(dev_df)
15
+ test_ds = Dataset.from_pandas(test_df)
16
+
17
+ # Tokenize
18
+ model_name = "microsoft/deberta-v3-base"
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+
21
+ def tokenize(batch):
22
+ return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=256)
23
+
24
+ train_ds = train_ds.map(tokenize, batched=True)
25
+ dev_ds = dev_ds.map(tokenize, batched=True)
26
+ test_ds = test_ds.map(tokenize, batched=True)
27
+
28
+ # Set format for PyTorch
29
+ train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
30
+ dev_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
31
+ test_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
32
+
33
+ print(train_df['label'].value_counts().sort_index())
34
+
35
+
36
+ # Load model
37
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
38
+
39
+ # Metrics
40
+ def compute_metrics(eval_pred):
41
+ logits, labels = eval_pred
42
+ preds = np.argmax(logits, axis=1)
43
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
44
+ acc = accuracy_score(labels, preds)
45
+ return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
46
+
47
+ # Training arguments
48
+ args = TrainingArguments(
49
+ output_dir="./alternative_topic/deberta/checkpoints",
50
+ eval_strategy="epoch",
51
+ save_strategy="epoch",
52
+ learning_rate=2e-5,
53
+ per_device_train_batch_size=8,
54
+ per_device_eval_batch_size=16,
55
+ num_train_epochs=4,
56
+ weight_decay=0.01,
57
+ load_best_model_at_end=True,
58
+ metric_for_best_model="f1"
59
+ )
60
+
61
+ # Trainer
62
+ trainer = Trainer(
63
+ model=model,
64
+ args=args,
65
+ train_dataset=train_ds,
66
+ eval_dataset=dev_ds,
67
+ tokenizer=tokenizer,
68
+ compute_metrics=compute_metrics
69
+ )
70
+
71
+ # Train
72
+ trainer.train()
73
+
74
+ # Evaluate on test
75
+ results = trainer.evaluate(test_ds)
76
+ print("Test results:", results)
77
+
78
+ # Save the model and tokenizer
79
+ model.save_pretrained("./alternative_topic/deberta/final_model")
80
+ tokenizer.save_pretrained("./alternative_topic/deberta/final_model")
glimpse-ui/alternative_topic/scideberta/scideberta_topic.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from pathlib import Path
5
+ import nltk
6
+ from tqdm import tqdm
7
+ import sys, os.path
8
+
9
+ nltk.download('punkt')
10
+
11
+ BASE_DIR = Path(__file__).resolve().parent.parent.parent
12
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
13
+
14
+ from glimpse.glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
15
+
16
+ # === CONFIGURATION ===
17
+
18
+ MODEL_DIR = BASE_DIR / "alternative_topic" / "scideberta" / "final_model"
19
+ DATA_DIR = BASE_DIR / "glimpse" / "data" / "processed"
20
+ OUTPUT_DIR = BASE_DIR / "data" / "topic_scored"
21
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
22
+
23
+ # === Load model and tokenizer ===
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
25
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
26
+ model.eval()
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ model.to(device)
29
+
30
+ # === Tokenize like GLIMPSE ===
31
+ # def tokenize_sentences(text: str) -> list:
32
+ # # same tokenization as in the original glimpse code
33
+ # text = text.replace('-----', '\n')
34
+ # sentences = nltk.sent_tokenize(text)
35
+ # sentences = [sentence for sentence in sentences if sentence != ""]
36
+ # return sentences
37
+
38
+
39
+ # === Label map (optional: for human-readable output) ===
40
+ id2label = {
41
+ # 0: "Evaluative",
42
+ # 1: "Structuring",
43
+ # 2: "Request",
44
+ # 3: "Fact",
45
+ # 4: "Social",
46
+ # 5: "Other",
47
+ 0: "Substance",
48
+ 1: "Clarity",
49
+ 2: "Soundness/Correctness",
50
+ 3: "Originality",
51
+ 4: "Motivation/Impact",
52
+ 5: "Meaningful Comparison",
53
+ 6: "Replicability",
54
+ 7: "NONE" # This is used for sentences that do not match any specific topic
55
+ }
56
+
57
+ def predict_topic(sentences):
58
+ inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
59
+ with torch.no_grad():
60
+ outputs = model(**inputs)
61
+ predictions = torch.argmax(outputs.logits, dim=1).cpu().tolist()
62
+ # Convert predictions to human-readable labels
63
+ predictions = [id2label[pred] for pred in predictions]
64
+ return predictions
65
+
66
+
67
+ def find_topic(start_year=2017, end_year=2021):
68
+ for year in range(start_year, end_year + 1):
69
+ print(f"Processing {year}...")
70
+ input_path = DATA_DIR / f"all_reviews_{year}.csv"
71
+ output_path = OUTPUT_DIR / f"topic_scored_reviews_{year}.csv"
72
+
73
+ df = pd.read_csv(input_path)
74
+
75
+ all_rows = []
76
+ for _, row in tqdm(df.iterrows(), total=len(df)):
77
+ review_id = row["id"]
78
+ text = row["text"]
79
+ sentences = glimpse_tokenizer(text)
80
+ if not sentences:
81
+ continue
82
+ labels = predict_topic(sentences)
83
+ for sentence, topic in zip(sentences, labels):
84
+ all_rows.append({"id": review_id, "sentence": sentence, "topic": topic})
85
+
86
+ output_df = pd.DataFrame(all_rows)
87
+ output_df.to_csv(output_path, index=False)
88
+ print(f"Saved topic-scored data to {output_path}")
89
+
90
+
91
+ if __name__ == "__main__":
92
+ find_topic()
glimpse-ui/alternative_topic/scideberta/scideberta_topic_train.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from datasets import Dataset
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
4
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
5
+ import numpy as np
6
+
7
+ # Load data
8
+ dev_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_topic_dev.csv")
9
+ train_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_topic_train.csv")
10
+ test_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_topic_test.csv")
11
+
12
+ # Convert to HuggingFace Datasets
13
+ train_ds = Dataset.from_pandas(train_df)
14
+ dev_ds = Dataset.from_pandas(dev_df)
15
+ test_ds = Dataset.from_pandas(test_df)
16
+
17
+ # Tokenize
18
+ model_name = "KISTI-AI/Scideberta-full"
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+
21
+ def tokenize(batch):
22
+ return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=256)
23
+
24
+ train_ds = train_ds.map(tokenize, batched=True)
25
+ dev_ds = dev_ds.map(tokenize, batched=True)
26
+ test_ds = test_ds.map(tokenize, batched=True)
27
+
28
+ # Set format for PyTorch
29
+ train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
30
+ dev_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
31
+ test_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
32
+
33
+ print(train_df['label'].value_counts().sort_index())
34
+
35
+
36
+ # Load model
37
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
38
+
39
+ # Metrics
40
+ def compute_metrics(eval_pred):
41
+ logits, labels = eval_pred
42
+ preds = np.argmax(logits, axis=1)
43
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
44
+ acc = accuracy_score(labels, preds)
45
+ return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
46
+
47
+ # Training arguments
48
+ args = TrainingArguments(
49
+ output_dir="./alternative_topic/scideberta/checkpoints",
50
+ eval_strategy="epoch",
51
+ save_strategy="epoch",
52
+ learning_rate=2e-5,
53
+ per_device_train_batch_size=8,
54
+ per_device_eval_batch_size=16,
55
+ num_train_epochs=4,
56
+ weight_decay=0.01,
57
+ load_best_model_at_end=True,
58
+ metric_for_best_model="f1"
59
+ )
60
+
61
+ # Trainer
62
+ trainer = Trainer(
63
+ model=model,
64
+ args=args,
65
+ train_dataset=train_ds,
66
+ eval_dataset=dev_ds,
67
+ tokenizer=tokenizer,
68
+ compute_metrics=compute_metrics
69
+ )
70
+
71
+ # Train
72
+ trainer.train()
73
+
74
+ # Evaluate on test
75
+ results = trainer.evaluate(test_ds)
76
+ print("Test results:", results)
77
+
78
+ # Save the model and tokenizer
79
+ model.save_pretrained("./alternative_topic/scideberta/final_model")
80
+ tokenizer.save_pretrained("./alternative_topic/scideberta/final_model")
glimpse-ui/data/ExtractDISAPEREData.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pandas as pd
4
+ from pathlib import Path
5
+
6
+ BASE_DIR = Path(__file__).resolve().parent.parent
7
+ base_path = BASE_DIR / "data" / "DISAPERE-main" / "DISAPERE" / "final_dataset"
8
+ output_path = BASE_DIR / "data" / "DISAPERE-main" / "SELFExtractedData"
9
+
10
+ ###################################################################################
11
+ ###################################################################################
12
+
13
+ # EXTRACTING POLARITY SENTENCES FROM DISAPERE DATASET
14
+
15
+ # def extract_polarity_sentences(json_dir):
16
+ # data = []
17
+ # for filename in os.listdir(json_dir):
18
+ # if filename.endswith(".json"):
19
+ # with open(os.path.join(json_dir, filename), "r") as f:
20
+ # thread = json.load(f)
21
+ # for sentence in thread.get("review_sentences", []):
22
+ # text = sentence.get("text", "").strip()
23
+ # polarity = sentence.get("polarity")
24
+ # if text:
25
+ # if polarity == "pol_positive":
26
+ # label = 2
27
+ # elif polarity == "pol_negative":
28
+ # label = 0
29
+ # else:
30
+ # label = 1
31
+ # data.append({"text": text, "label": label})
32
+ # return pd.DataFrame(data)
33
+
34
+ # # Extract and save each split
35
+ # for split in ["train", "dev", "test"]:
36
+ # df = extract_polarity_sentences(os.path.join(base_path, split))
37
+ # out_file = os.path.join(output_path, f"disapere_polarity_{split}.csv")
38
+ # df.to_csv(out_file, index=False)
39
+ # print(f"{split.capitalize()} saved to {out_file}: {len(df)} samples")
40
+
41
+
42
+ ###################################################################################
43
+ ###################################################################################
44
+
45
+ # 2. EXTRACTING TOPIC SENTENCES FROM DISAPERE DATASET
46
+ #
47
+ # === Topic Label Mapping ===
48
+ # 1: "Structuring"
49
+ # 0: "Evaluative"
50
+ # 2: "Request"
51
+ # 3: "Fact"
52
+ # 4: "Social"
53
+ # 5: "Other"
54
+ # 6: "Substance"
55
+ # 7: "Clarity"
56
+ # 8: "Soundness/Correctness"
57
+ # 9: "Originality"
58
+ # 10: "Motivation/Impact"
59
+ # 11: "Meaningful Comparison"
60
+ # 12: "Replicability"
61
+
62
+ # Final topic classes
63
+ topic_classes = [
64
+ "asp_substance",
65
+ "asp_clarity",
66
+ "asp_soundness-correctness",
67
+ "asp_originality",
68
+ "asp_impact",
69
+ "asp_comparison",
70
+ "asp_replicability",
71
+ "None", # This is used for sentences that do not match any specific topic
72
+ # "arg-structuring_summary"
73
+ ]
74
+
75
+ label_map = {label: idx for idx, label in enumerate(topic_classes)}
76
+
77
+ def extract_topic_sentences(json_dir):
78
+ data = []
79
+ for filename in os.listdir(json_dir):
80
+ if filename.endswith(".json"):
81
+ with open(os.path.join(json_dir, filename), "r") as f:
82
+ thread = json.load(f)
83
+ for sentence in thread.get("review_sentences", []):
84
+ text = sentence.get("text", "").strip()
85
+ aspect = sentence.get("aspect", "")
86
+ # fine_action = sentence.get("fine_review_action", "")
87
+
88
+ # Decide label source
89
+ topic = aspect if aspect in label_map else "None"
90
+
91
+ if text and topic in label_map:
92
+ label = label_map[topic]
93
+ data.append({"text": text, "label": label})
94
+ return pd.DataFrame(data)
95
+
96
+ # Extract and save each split
97
+ for split in ["train", "dev", "test"]:
98
+ df = extract_topic_sentences(os.path.join(base_path, split))
99
+ out_file = os.path.join(output_path, f"disapere_topic_{split}.csv")
100
+ df.to_csv(out_file, index=False)
101
+ print(f"{split.capitalize()} saved to {out_file}: {len(df)} samples")
102
+
103
+ ###################################################################################
104
+ ###################################################################################
105
+
106
+
glimpse-ui/glimpse/.gitignore ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by .ignore support plugin (hsz.mobi)
2
+ ### Python template
3
+
4
+ # GLIMPSE
5
+ # Ignore all the data except orignial files
6
+ data/*
7
+ !data/
8
+ summaries/
9
+ output/
10
+ slurm*
11
+ !scripts/
12
+ .gradio/
13
+ .test/
14
+
15
+ # Byte-compiled / optimized / DLL files
16
+ __pycache__/
17
+ *.py[cod]
18
+ *$py.class
19
+
20
+ # C extensions
21
+ *.so
22
+
23
+ # Distribution / packaging
24
+ .Python
25
+ env/
26
+ build/
27
+ develop-eggs/
28
+ dist/
29
+ downloads/
30
+ eggs/
31
+ .eggs/
32
+ lib/
33
+ lib64/
34
+ parts/
35
+ sdist/
36
+ var/
37
+ *.egg-info/
38
+ .installed.cfg
39
+ *.egg
40
+
41
+ # PyInstaller
42
+ # Usually these files are written by a python script from a template
43
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
44
+ *.manifest
45
+ *.spec
46
+
47
+ # Installer logs
48
+ pip-log.txt
49
+ pip-delete-this-directory.txt
50
+
51
+ # Unit test / coverage reports
52
+ htmlcov/
53
+ .tox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *,cover
60
+ .hypothesis/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ target/
82
+
83
+ # IPython Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # pyenv
87
+ .python-version
88
+
89
+ # celery beat schedule file
90
+ celerybeat-schedule
91
+
92
+ # dotenv
93
+ .env
94
+
95
+ # virtualenv
96
+ venv/
97
+ ENV/
98
+
99
+ # Spyder project settings
100
+ .spyderproject
101
+
102
+ # Rope project settings
103
+ .ropeproject
104
+ ### VirtualEnv template
105
+ # Virtualenv
106
+ # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
107
+ [Bb]in
108
+ [Ii]nclude
109
+ [Ll]ib
110
+ [Ll]ib64
111
+ [Ll]ocal
112
+ [Ss]cripts
113
+ pyvenv.cfg
114
+ .venv
115
+ pip-selfcheck.json
116
+
117
+ ### JetBrains template
118
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
119
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
120
+
121
+ # User-specific stuff
122
+ .idea/**/workspace.xml
123
+ .idea/**/tasks.xml
124
+ .idea/**/usage.statistics.xml
125
+ .idea/**/dictionaries
126
+ .idea/**/shelf
127
+
128
+ # AWS User-specific
129
+ .idea/**/aws.xml
130
+
131
+ # Generated files
132
+ .idea/**/contentModel.xml
133
+
134
+ # Sensitive or high-churn files
135
+ .idea/**/dataSources/
136
+ .idea/**/dataSources.ids
137
+ .idea/**/dataSources.local.xml
138
+ .idea/**/sqlDataSources.xml
139
+ .idea/**/dynamic.xml
140
+ .idea/**/uiDesigner.xml
141
+ .idea/**/dbnavigator.xml
142
+
143
+ # Gradle
144
+ .idea/**/gradle.xml
145
+ .idea/**/libraries
146
+
147
+ # Gradle and Maven with auto-import
148
+ # When using Gradle or Maven with auto-import, you should exclude module files,
149
+ # since they will be recreated, and may cause churn. Uncomment if using
150
+ # auto-import.
151
+ # .idea/artifacts
152
+ # .idea/compiler.xml
153
+ # .idea/jarRepositories.xml
154
+ # .idea/modules.xml
155
+ # .idea/*.iml
156
+ # .idea/modules
157
+ # *.iml
158
+ # *.ipr
159
+
160
+ # CMake
161
+ cmake-build-*/
162
+
163
+ # Mongo Explorer plugin
164
+ .idea/**/mongoSettings.xml
165
+
166
+ # File-based project format
167
+ *.iws
168
+
169
+ # IntelliJ
170
+ out/
171
+
172
+ # mpeltonen/sbt-idea plugin
173
+ .idea_modules/
174
+
175
+ # JIRA plugin
176
+ atlassian-ide-plugin.xml
177
+
178
+ # Cursive Clojure plugin
179
+ .idea/replstate.xml
180
+
181
+ # SonarLint plugin
182
+ .idea/sonarlint/
183
+
184
+ # Crashlytics plugin (for Android Studio and IntelliJ)
185
+ com_crashlytics_export_strings.xml
186
+ crashlytics.properties
187
+ crashlytics-build.properties
188
+ fabric.properties
189
+
190
+ # Editor-based Rest Client
191
+ .idea/httpRequests
192
+
193
+ # Android studio 3.1+ serialized cache file
194
+ .idea/caches/build_file_checksums.ser
195
+
196
+ # idea folder, uncomment if you don't need it
197
+ # .idea
198
+ share/man/man1/isympy.1
199
+ share/man/man1/ttx.1
200
+
201
+ # IDEs
202
+ .idea/
203
+ .vscode/
glimpse-ui/glimpse/Readme.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ This is the repositotry of GLIMPSE: Pragmatically Informative Multi-Document Summarization for Scholarly Reviews
3
+ [Paper](https://arxiv.org/abs/2406.07359) | [Code](https://github.com/icannos/glimpse-mds)
4
+
5
+
6
+ ### Installation
7
+
8
+ - We use python 3.10 and CUDA 12.1
9
+ ``` bash
10
+ module load miniconda/3
11
+ module load cuda12
12
+ ```
13
+ - First, create a virtual environment using:
14
+ ``` bash
15
+ conda create -n glimpse python=3.10
16
+ ```
17
+ - Second, activate the environment and install pytorch:
18
+ ``` bash
19
+ conda activate glimpse
20
+ conda install pytorch==2.1.1 pytorch-cuda=12.1 -c pytorch -c nvidia
21
+ ```
22
+
23
+ - Finally, all remaining required packages could be installed with the requirements file:
24
+
25
+ ``` bash
26
+ pip install -r requirements
27
+ ```
28
+ ### Data Loading
29
+
30
+ Step 1: Start by processing the input files from data.
31
+
32
+ ``` bash
33
+ python glimpse/data_loading/data_processing.py
34
+ ```
35
+
36
+ ### Generating Summaries and Computing RSA Scores
37
+ Step 2: Now, we generate candidate summaries and compute RSA scores for each candidate
38
+ - for extractive candidates, use the following command:
39
+ ``` bash
40
+ sbatch scripts/extractive.sh Path_of_Your_Processed_Dataset_Step1.csv
41
+ ```
42
+ - for abstractive candidates, use either of the following commands:
43
+ - In case the last batch is incomplete, you can add padding using `--add-padding` argument to complete it:
44
+ ``` bash
45
+ sbatch scripts/abstractive.sh Path_of_Your_Processed_Dataset_Step1.csv --add-padding
46
+ ```
47
+ - If you want to remove the last incomplete batch, you can run the script without the argument:
48
+ ``` bash
49
+ sbatch scripts/abstractive.sh Path_of_Your_Processed_Dataset_Step1.csv
50
+ ```
51
+
52
+ `rsasumm/` provides a python package with an implementation of RSA incremental decoding and RSA reranking of candidates.
53
+ `mds/` provides the experiment scripts and analysis for the MultiDocument Summarization task.
54
+
55
+
56
+ ## Citation
57
+
58
+ If you use this code, please cite the following papers:
59
+
60
+ ```@misc{darrin2024glimpsepragmaticallyinformativemultidocument,
61
+ title={GLIMPSE: Pragmatically Informative Multi-Document Summarization for Scholarly Reviews},
62
+ author={Maxime Darrin and Ines Arous and Pablo Piantanida and Jackie CK Cheung},
63
+ year={2024},
64
+ eprint={2406.07359},
65
+ archivePrefix={arXiv},
66
+ primaryClass={cs.CL},
67
+ url={https://arxiv.org/abs/2406.07359},
68
+ }
69
+ ```
glimpse-ui/glimpse/examples/RSA Sum tests.ipynb ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "initial_id",
7
+ "metadata": {
8
+ "collapsed": true,
9
+ "ExecuteTime": {
10
+ "end_time": "2024-01-12T16:31:17.690349522Z",
11
+ "start_time": "2024-01-12T16:31:15.472874479Z"
12
+ }
13
+ },
14
+ "outputs": [],
15
+ "source": [
16
+ "import torch"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 2,
22
+ "outputs": [],
23
+ "source": [
24
+ "%reload_ext autoreload\n",
25
+ "%autoreload 2"
26
+ ],
27
+ "metadata": {
28
+ "collapsed": false,
29
+ "ExecuteTime": {
30
+ "end_time": "2024-01-12T16:31:17.717430741Z",
31
+ "start_time": "2024-01-12T16:31:17.695066680Z"
32
+ }
33
+ },
34
+ "id": "ecefdad828c7daa3"
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 3,
39
+ "outputs": [
40
+ {
41
+ "name": "stderr",
42
+ "output_type": "stream",
43
+ "text": [
44
+ "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-cnn and are newly initialized: ['model.shared.weight']\n",
45
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
46
+ ]
47
+ }
48
+ ],
49
+ "source": [
50
+ "\n",
51
+ "from transformers import AutoTokenizer, BartForConditionalGeneration\n",
52
+ "\n",
53
+ "tokenizer = AutoTokenizer.from_pretrained(\"facebook/bart-large-cnn\")\n",
54
+ "model = BartForConditionalGeneration.from_pretrained(\"facebook/bart-large-cnn\")\n"
55
+ ],
56
+ "metadata": {
57
+ "collapsed": false,
58
+ "ExecuteTime": {
59
+ "end_time": "2024-01-12T16:31:26.058437142Z",
60
+ "start_time": "2024-01-12T16:31:17.720106168Z"
61
+ }
62
+ },
63
+ "id": "8c32b182fbcac2b6"
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 4,
68
+ "outputs": [],
69
+ "source": [
70
+ "from rsasumm.beam_search import RSAContextualDecoding"
71
+ ],
72
+ "metadata": {
73
+ "collapsed": false,
74
+ "ExecuteTime": {
75
+ "end_time": "2024-01-12T16:31:26.097766981Z",
76
+ "start_time": "2024-01-12T16:31:26.056626187Z"
77
+ }
78
+ },
79
+ "id": "cb33d902fe736c25"
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 5,
84
+ "outputs": [],
85
+ "source": [
86
+ "\n",
87
+ "\n",
88
+ "texts = ['The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. I believe the authors missed Jane and al 2021. In addition, I think, there is a mistake in the math.',\n",
89
+ " 'The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. However, some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.',\n",
90
+ " 'The paper gives really interesting insights on the topic of transfer learning. It is not well presented and lack experiments. In addition, some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.'\n",
91
+ " ]\n",
92
+ "\n",
93
+ "# texts = [texts[2], texts[1], texts[0]]\n",
94
+ "\n"
95
+ ],
96
+ "metadata": {
97
+ "collapsed": false,
98
+ "ExecuteTime": {
99
+ "end_time": "2024-01-12T16:31:26.127922110Z",
100
+ "start_time": "2024-01-12T16:31:26.098805312Z"
101
+ }
102
+ },
103
+ "id": "436ef1482c361159"
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": 6,
108
+ "outputs": [],
109
+ "source": [
110
+ "source_texts = tokenizer(texts, return_tensors=\"pt\", padding=True)\n",
111
+ "\n",
112
+ "rsa = RSAContextualDecoding(model, tokenizer, 'cpu')\n",
113
+ "\n"
114
+ ],
115
+ "metadata": {
116
+ "collapsed": false,
117
+ "ExecuteTime": {
118
+ "end_time": "2024-01-12T16:31:26.169520864Z",
119
+ "start_time": "2024-01-12T16:31:26.125283164Z"
120
+ }
121
+ },
122
+ "id": "84b9943cac6cd7b2"
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": 7,
127
+ "outputs": [],
128
+ "source": [
129
+ "output = rsa.generate(target_id=1, source_texts_ids=source_texts.input_ids, source_text_attention_mask=source_texts.attention_mask, max_length=50, top_p=0.95, do_sample=True, rationality=8.0, temperature=1.0, process_logits_before_rsa=True)"
130
+ ],
131
+ "metadata": {
132
+ "collapsed": false,
133
+ "ExecuteTime": {
134
+ "end_time": "2024-01-12T16:32:14.857034731Z",
135
+ "start_time": "2024-01-12T16:31:26.164578792Z"
136
+ }
137
+ },
138
+ "id": "620e54a63dd2099c"
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 8,
143
+ "outputs": [
144
+ {
145
+ "data": {
146
+ "text/plain": "['Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.']"
147
+ },
148
+ "execution_count": 8,
149
+ "metadata": {},
150
+ "output_type": "execute_result"
151
+ }
152
+ ],
153
+ "source": [
154
+ "\n",
155
+ "tokenizer.batch_decode(output[0], skip_special_tokens=True)\n",
156
+ "\n"
157
+ ],
158
+ "metadata": {
159
+ "collapsed": false,
160
+ "ExecuteTime": {
161
+ "end_time": "2024-01-12T16:32:14.858531480Z",
162
+ "start_time": "2024-01-12T16:32:14.856763396Z"
163
+ }
164
+ },
165
+ "id": "fb3a5a9a8f9990ee"
166
+ }
167
+ ],
168
+ "metadata": {
169
+ "kernelspec": {
170
+ "display_name": "Python 3",
171
+ "language": "python",
172
+ "name": "python3"
173
+ },
174
+ "language_info": {
175
+ "codemirror_mode": {
176
+ "name": "ipython",
177
+ "version": 2
178
+ },
179
+ "file_extension": ".py",
180
+ "mimetype": "text/x-python",
181
+ "name": "python",
182
+ "nbconvert_exporter": "python",
183
+ "pygments_lexer": "ipython2",
184
+ "version": "2.7.6"
185
+ }
186
+ },
187
+ "nbformat": 4,
188
+ "nbformat_minor": 5
189
+ }
glimpse-ui/glimpse/examples/reviews/reviews_app.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Tuple
3
+
4
+ import nltk
5
+ import numpy as np
6
+ import seaborn as sns
7
+
8
+ import sys, os.path
9
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
10
+
11
+ from rsasumm.rsa_reranker import RSAReranking
12
+ import gradio as gr
13
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
14
+ import seaborn as sns
15
+ import pandas as pd
16
+ import matplotlib.pyplot as plt
17
+
18
+ MODEL = "facebook/bart-large-cnn"
19
+
20
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL)
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
22
+
23
+
24
+ latex_template = r"""
25
+ \begin{subfigure}[b]{0.48\textwidth}
26
+ \resizebox{\textwidth}{!}{
27
+ \begin{coloredbox}{darkgray}{Review 1}
28
+ [REVIEW 1]
29
+
30
+ \end{coloredbox}}
31
+ \end{subfigure}
32
+ \begin{subfigure}[b]{0.48\textwidth}
33
+ \resizebox{\textwidth}{!}{
34
+ \begin{coloredbox}{darkgray}{Review 2}
35
+ [REVIEW 2]
36
+ \end{coloredbox}}
37
+ \end{subfigure}
38
+ \begin{subfigure}[b]{0.48\textwidth}
39
+ \resizebox{\textwidth}{!}{
40
+ \begin{coloredbox}{darkgray}{Review 3}
41
+ [REVIEW 3]
42
+ \end{coloredbox}}
43
+ \end{subfigure}
44
+ """
45
+
46
+ EXAMPLES = [
47
+ "The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. I believe the authors missed Jane and al 2021. In addition, I think, there is a mistake in the math.",
48
+ "The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. Some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.",
49
+ "The paper gives really interesting insights on the topic of transfer learning. It is not well presented and lack experiments. Some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.",
50
+ ]
51
+
52
+
53
+ def make_colored_text_to_latex(scored_texts : List[Tuple[str, float]]):
54
+ """
55
+ Make a latex string from a list of scored texts.
56
+ """
57
+
58
+ # cast scores between 0 and 1
59
+ scores = np.array([score for _, score in scored_texts])
60
+ scores = (scores - scores.min()) / (scores.max() - scores.min())
61
+
62
+ # make color map in hex
63
+ cmap = sns.diverging_palette(250, 30, l=50, center="dark", as_cmap=True)
64
+ hex_colors = [cmap(score)[0:3] for score in scores]
65
+ # make html color string
66
+ hex_colors = [",".join([str(round(x, 2)) for x in color]) for color in hex_colors]
67
+ # make latex string
68
+ latex_string = ""
69
+ for (text, score), hex_color in zip(scored_texts, hex_colors):
70
+ latex_string += "\\textcolor[rgb]{" + str(hex_color) + "}{" + text + "} "
71
+
72
+ return latex_string
73
+
74
+
75
+
76
+
77
+ def summarize(text1, text2, text3, iterations, rationality=1.0):
78
+ # get sentences for each text
79
+
80
+ text1_sentences = nltk.sent_tokenize(text1)
81
+ text2_sentences = nltk.sent_tokenize(text2)
82
+ text3_sentences = nltk.sent_tokenize(text3)
83
+
84
+
85
+ # remove empty sentences
86
+ text1_sentences = [sentence for sentence in text1_sentences if sentence != ""]
87
+ text2_sentences = [sentence for sentence in text2_sentences if sentence != ""]
88
+ text3_sentences = [sentence for sentence in text3_sentences if sentence != ""]
89
+
90
+ sentences = list(set(text1_sentences + text2_sentences + text3_sentences))
91
+
92
+ rsa_reranker = RSAReranking(
93
+ model,
94
+ tokenizer,
95
+ candidates=sentences,
96
+ source_texts=[text1, text2, text3],
97
+ device="cpu",
98
+ rationality=rationality,
99
+ )
100
+ (
101
+ best_rsa,
102
+ best_base,
103
+ speaker_df,
104
+ listener_df,
105
+ initial_listener,
106
+ language_model_proba_df,
107
+ initial_consensuality_scores,
108
+ consensuality_scores,
109
+ ) = rsa_reranker.rerank(t=iterations)
110
+
111
+ # apply exp to the probabilities
112
+ speaker_df = speaker_df.applymap(lambda x: math.exp(x))
113
+
114
+ text_1_summaries = speaker_df.loc[text1][text1_sentences]
115
+ text_1_summaries = text_1_summaries / text_1_summaries.sum()
116
+
117
+ text_2_summaries = speaker_df.loc[text2][text2_sentences]
118
+ text_2_summaries = text_2_summaries / text_2_summaries.sum()
119
+
120
+ text_3_summaries = speaker_df.loc[text3][text3_sentences]
121
+ text_3_summaries = text_3_summaries / text_3_summaries.sum()
122
+
123
+ # make list of tuples
124
+ text_1_summaries = [(sentence, text_1_summaries[sentence]) for sentence in text1_sentences]
125
+ text_2_summaries = [(sentence, text_2_summaries[sentence]) for sentence in text2_sentences]
126
+ text_3_summaries = [(sentence, text_3_summaries[sentence]) for sentence in text3_sentences]
127
+
128
+ # normalize consensuality scores between -1 and 1
129
+
130
+ consensuality_scores = (consensuality_scores - (consensuality_scores.max() - consensuality_scores.min()) / 2) / (consensuality_scores.max() - consensuality_scores.min()) / 2
131
+ consensuality_scores_01 = (consensuality_scores - consensuality_scores.min()) / (consensuality_scores.max() - consensuality_scores.min())
132
+
133
+
134
+ most_consensual = consensuality_scores.sort_values(ascending=True).head(3).index.tolist()
135
+ least_consensual = consensuality_scores.sort_values(ascending=False).head(3).index.tolist()
136
+
137
+ most_consensual = [(sentence, consensuality_scores[sentence]) for sentence in most_consensual]
138
+ least_consensual = [(sentence, consensuality_scores[sentence]) for sentence in least_consensual]
139
+
140
+ text_1_consensuality = consensuality_scores.loc[text1_sentences]
141
+ text_2_consensuality = consensuality_scores.loc[text2_sentences]
142
+ text_3_consensuality = consensuality_scores.loc[text3_sentences]
143
+
144
+ # rescale between -1 and 1
145
+ # text_1_consensuality = (text_1_consensuality - (text_1_consensuality.max() - text_1_consensuality.min()) / 2) / (text_1_consensuality.max() - text_1_consensuality.min()) / 2
146
+ # text_2_consensuality = (text_2_consensuality - (text_2_consensuality.max() - text_2_consensuality.min()) / 2) / (text_2_consensuality.max() - text_2_consensuality.min()) / 2
147
+ # text_3_consensuality = (text_3_consensuality - (text_3_consensuality.max() - text_3_consensuality.min()) / 2) / (text_3_consensuality.max() - text_3_consensuality.min()) / 2
148
+
149
+ text_1_consensuality = [(sentence, text_1_consensuality[sentence]) for sentence in text1_sentences]
150
+ text_2_consensuality = [(sentence, text_2_consensuality[sentence]) for sentence in text2_sentences]
151
+ text_3_consensuality = [(sentence, text_3_consensuality[sentence]) for sentence in text3_sentences]
152
+
153
+ fig1 = plt.figure(figsize=(20, 10))
154
+ ax = fig1.add_subplot(111)
155
+ sns.heatmap(
156
+ listener_df,
157
+ ax=ax,
158
+ cmap="Blues",
159
+ annot=True,
160
+ fmt=".2f",
161
+ cbar=False,
162
+ annot_kws={"size": 10},
163
+ )
164
+ ax.set_title("Listener probabilities")
165
+ ax.set_xlabel("Candidate sentences")
166
+ ax.set_ylabel("Source texts")
167
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
168
+ fig1.tight_layout()
169
+
170
+ fig2 = plt.figure(figsize=(20, 10))
171
+ ax = fig2.add_subplot(111)
172
+ sns.heatmap(
173
+ speaker_df,
174
+ ax=ax,
175
+ cmap="Blues",
176
+ annot=True,
177
+ fmt=".2f",
178
+ cbar=False,
179
+ annot_kws={"size": 10},
180
+ )
181
+ ax.set_title("Speaker probabilities")
182
+ ax.set_xlabel("Candidate sentences")
183
+ ax.set_ylabel("Source texts")
184
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
185
+ fig2.tight_layout()
186
+
187
+ latex_text_1 = make_colored_text_to_latex(text_1_summaries)
188
+ latex_text_2 = make_colored_text_to_latex(text_2_summaries)
189
+ latex_text_3 = make_colored_text_to_latex(text_3_summaries)
190
+
191
+ text_1_consensuality_ = consensuality_scores_01.loc[text1_sentences]
192
+ text_2_consensuality_ = consensuality_scores_01.loc[text2_sentences]
193
+ text_3_consensuality_ = consensuality_scores_01.loc[text3_sentences]
194
+
195
+ text_1_consensuality_ = [(sentence, text_1_consensuality_[sentence]) for sentence in text1_sentences]
196
+ text_2_consensuality_ = [(sentence, text_2_consensuality_[sentence]) for sentence in text2_sentences]
197
+ text_3_consensuality_ = [(sentence, text_3_consensuality_[sentence]) for sentence in text3_sentences]
198
+
199
+ latex_text_1_consensuality = make_colored_text_to_latex(text_1_consensuality_)
200
+ latex_text_2_consensuality = make_colored_text_to_latex(text_2_consensuality_)
201
+ latex_text_3_consensuality = make_colored_text_to_latex(text_3_consensuality_)
202
+
203
+ latex = latex_template.replace("[REVIEW 1]", latex_text_1)
204
+ latex = latex.replace("[REVIEW 2]", latex_text_2)
205
+ latex = latex.replace("[REVIEW 3]", latex_text_3)
206
+
207
+
208
+ return text_1_summaries, text_2_summaries, text_3_summaries, text_1_consensuality, text_2_consensuality, text_3_consensuality, most_consensual, least_consensual, fig1, fig2, latex
209
+
210
+
211
+ # make gradiot highlightedText component
212
+
213
+
214
+ iface = gr.Interface(
215
+ fn=summarize,
216
+ inputs=[
217
+ gr.Textbox(lines=10, value=EXAMPLES[0]),
218
+ gr.Textbox(lines=10, value=EXAMPLES[1]),
219
+ gr.Textbox(lines=10, value=EXAMPLES[2]),
220
+ gr.Number(value=1, label="Iterations"),
221
+ gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=1.0, label="Rationality"),
222
+ ],
223
+ outputs=[
224
+ gr.Highlightedtext(
225
+ show_legend=True,
226
+ label="Uniqueness score for each sentence in text 1",
227
+ ),
228
+ gr.Highlightedtext(
229
+ show_legend=True,
230
+ label="Uniqueness score for each sentence in text 2",
231
+ ),
232
+ gr.Highlightedtext(
233
+ show_legend=True,
234
+ label="Uniqueness score for each sentence in text 3",
235
+ ),
236
+ gr.Highlightedtext(
237
+ show_legend=True,
238
+ label="Consensuality score for each sentence in text 1",
239
+
240
+ ),
241
+ gr.Highlightedtext(
242
+ show_legend=True,
243
+ label="Consensuality score for each sentence in text 2",
244
+ ),
245
+ gr.Highlightedtext(
246
+ show_legend=True,
247
+ label="Consensuality score for each sentence in text 3",
248
+ ),
249
+ gr.Highlightedtext(
250
+ show_legend=True,
251
+ label="Most consensual sentences",
252
+
253
+ ),
254
+ gr.Highlightedtext(
255
+ show_legend=True,
256
+ label="Least consensual sentences",
257
+ ),
258
+ gr.Plot(
259
+ label="Listener probabilities",
260
+ ),
261
+ gr.Plot(
262
+ label="Speaker probabilities",
263
+ ),
264
+
265
+ gr.Textbox(lines=10, label="Latex Consensuality scores"),
266
+
267
+
268
+
269
+ ],
270
+ title="RSA Summarizer",
271
+ description="Summarize 3 texts using RSA",
272
+ )
273
+
274
+ iface.launch(share=True)
glimpse-ui/glimpse/examples/reviews/reviews_latex_generation.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Tuple
3
+
4
+ import nltk
5
+ import numpy as np
6
+ import seaborn as sns
7
+
8
+ from rsasumm.rsa_reranker import RSAReranking
9
+ import gradio as gr
10
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
11
+ import seaborn as sns
12
+ import pandas as pd
13
+ import matplotlib.pyplot as plt
14
+
15
+ MODEL = "facebook/bart-large-cnn"
16
+
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL)
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
19
+
20
+
21
+ latex_template = r"""
22
+ \begin{subfigure}[b]{0.48\textwidth}
23
+ \resizebox{\textwidth}{!}{
24
+ \begin{coloredbox}{darkgray}{Review 1}
25
+ [REVIEW 1]
26
+
27
+ \end{coloredbox}}
28
+ \end{subfigure}
29
+ \begin{subfigure}[b]{0.48\textwidth}
30
+ \resizebox{\textwidth}{!}{
31
+ \begin{coloredbox}{darkgray}{Review 2}
32
+ [REVIEW 2]
33
+ \end{coloredbox}}
34
+ \end{subfigure}
35
+ \begin{subfigure}[b]{0.48\textwidth}
36
+ \resizebox{\textwidth}{!}{
37
+ \begin{coloredbox}{darkgray}{Review 3}
38
+ [REVIEW 3]
39
+ \end{coloredbox}}
40
+ \end{subfigure}
41
+ """
42
+
43
+ EXAMPLES = [
44
+ "The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. I believe the authors missed Jane and al 2021. In addition, I think, there is a mistake in the math.",
45
+ "The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. Some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.",
46
+ "The paper gives really interesting insights on the topic of transfer learning. It is not well presented and lack experiments. Some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.",
47
+ ]
48
+
49
+
50
+ def make_colored_text_to_latex(scored_texts : List[Tuple[str, float]]):
51
+ """
52
+ Make a latex string from a list of scored texts.
53
+ """
54
+
55
+ # cast scores between 0 and 1
56
+ scores = np.array([score for _, score in scored_texts])
57
+ scores = (scores - scores.min()) / (scores.max() - scores.min())
58
+
59
+ # make color map in hex
60
+ cmap = sns.diverging_palette(250, 30, l=50, center="dark", as_cmap=True)
61
+ hex_colors = [cmap(score)[0:3] for score in scores]
62
+ # make html color string
63
+ hex_colors = [",".join([str(round(x, 2)) for x in color]) for color in hex_colors]
64
+ # make latex string
65
+ latex_string = ""
66
+ for (text, score), hex_color in zip(scored_texts, hex_colors):
67
+ #latex_string += "\\textcolor[rgb]{" + str(hex_color) + "}{" + text + "} "
68
+ latex_string += "\\hlc{" + str(hex_color)[1:-1] + "}{" + text + "} "
69
+
70
+ return latex_string
71
+
72
+
73
+
74
+
75
+ def summarize(text1, text2, text3, iterations, rationality=1.0):
76
+ # get sentences for each text
77
+
78
+ text1_sentences = nltk.sent_tokenize(text1)
79
+ text2_sentences = nltk.sent_tokenize(text2)
80
+ text3_sentences = nltk.sent_tokenize(text3)
81
+
82
+
83
+ # remove empty sentences
84
+ text1_sentences = [sentence for sentence in text1_sentences if sentence != ""]
85
+ text2_sentences = [sentence for sentence in text2_sentences if sentence != ""]
86
+ text3_sentences = [sentence for sentence in text3_sentences if sentence != ""]
87
+
88
+ sentences = list(set(text1_sentences + text2_sentences + text3_sentences))
89
+
90
+ rsa_reranker = RSAReranking(
91
+ model,
92
+ tokenizer,
93
+ candidates=sentences,
94
+ source_texts=[text1, text2, text3],
95
+ device="cpu",
96
+ rationality=rationality,
97
+ )
98
+ (
99
+ best_rsa,
100
+ best_base,
101
+ speaker_df,
102
+ listener_df,
103
+ initial_listener,
104
+ language_model_proba_df,
105
+ initial_consensuality_scores,
106
+ consensuality_scores,
107
+ ) = rsa_reranker.rerank(t=iterations)
108
+
109
+ # apply exp to the probabilities
110
+ speaker_df = speaker_df.applymap(lambda x: math.exp(x))
111
+
112
+ text_1_summaries = speaker_df.loc[text1][text1_sentences]
113
+ text_1_summaries = text_1_summaries / text_1_summaries.sum()
114
+
115
+ text_2_summaries = speaker_df.loc[text2][text2_sentences]
116
+ text_2_summaries = text_2_summaries / text_2_summaries.sum()
117
+
118
+ text_3_summaries = speaker_df.loc[text3][text3_sentences]
119
+ text_3_summaries = text_3_summaries / text_3_summaries.sum()
120
+
121
+ # make list of tuples
122
+ text_1_summaries = [(sentence, text_1_summaries[sentence]) for sentence in text1_sentences]
123
+ text_2_summaries = [(sentence, text_2_summaries[sentence]) for sentence in text2_sentences]
124
+ text_3_summaries = [(sentence, text_3_summaries[sentence]) for sentence in text3_sentences]
125
+
126
+ # normalize consensuality scores between -1 and 1
127
+
128
+ consensuality_scores = (consensuality_scores - (consensuality_scores.max() - consensuality_scores.min()) / 2) / (consensuality_scores.max() - consensuality_scores.min()) / 2
129
+ consensuality_scores_01 = (consensuality_scores - consensuality_scores.min()) / (consensuality_scores.max() - consensuality_scores.min())
130
+
131
+
132
+ most_consensual = consensuality_scores.sort_values(ascending=True).head(3).index.tolist()
133
+ least_consensual = consensuality_scores.sort_values(ascending=False).head(3).index.tolist()
134
+
135
+ most_consensual = [(sentence, consensuality_scores[sentence]) for sentence in most_consensual]
136
+ least_consensual = [(sentence, consensuality_scores[sentence]) for sentence in least_consensual]
137
+
138
+ text_1_consensuality = consensuality_scores.loc[text1_sentences]
139
+ text_2_consensuality = consensuality_scores.loc[text2_sentences]
140
+ text_3_consensuality = consensuality_scores.loc[text3_sentences]
141
+
142
+ # rescale between -1 and 1
143
+ # text_1_consensuality = (text_1_consensuality - (text_1_consensuality.max() - text_1_consensuality.min()) / 2) / (text_1_consensuality.max() - text_1_consensuality.min()) / 2
144
+ # text_2_consensuality = (text_2_consensuality - (text_2_consensuality.max() - text_2_consensuality.min()) / 2) / (text_2_consensuality.max() - text_2_consensuality.min()) / 2
145
+ # text_3_consensuality = (text_3_consensuality - (text_3_consensuality.max() - text_3_consensuality.min()) / 2) / (text_3_consensuality.max() - text_3_consensuality.min()) / 2
146
+
147
+ text_1_consensuality = [(sentence, text_1_consensuality[sentence]) for sentence in text1_sentences]
148
+ text_2_consensuality = [(sentence, text_2_consensuality[sentence]) for sentence in text2_sentences]
149
+ text_3_consensuality = [(sentence, text_3_consensuality[sentence]) for sentence in text3_sentences]
150
+
151
+ fig1 = plt.figure(figsize=(20, 10))
152
+ ax = fig1.add_subplot(111)
153
+ sns.heatmap(
154
+ listener_df,
155
+ ax=ax,
156
+ cmap="Blues",
157
+ annot=True,
158
+ fmt=".2f",
159
+ cbar=False,
160
+ annot_kws={"size": 10},
161
+ )
162
+ ax.set_title("Listener probabilities")
163
+ ax.set_xlabel("Candidate sentences")
164
+ ax.set_ylabel("Source texts")
165
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
166
+ fig1.tight_layout()
167
+
168
+ fig2 = plt.figure(figsize=(20, 10))
169
+ ax = fig2.add_subplot(111)
170
+ sns.heatmap(
171
+ speaker_df,
172
+ ax=ax,
173
+ cmap="Blues",
174
+ annot=True,
175
+ fmt=".2f",
176
+ cbar=False,
177
+ annot_kws={"size": 10},
178
+ )
179
+ ax.set_title("Speaker probabilities")
180
+ ax.set_xlabel("Candidate sentences")
181
+ ax.set_ylabel("Source texts")
182
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
183
+ fig2.tight_layout()
184
+
185
+ latex_text_1 = make_colored_text_to_latex(text_1_summaries)
186
+ latex_text_2 = make_colored_text_to_latex(text_2_summaries)
187
+ latex_text_3 = make_colored_text_to_latex(text_3_summaries)
188
+
189
+ text_1_consensuality_ = consensuality_scores_01.loc[text1_sentences]
190
+ text_2_consensuality_ = consensuality_scores_01.loc[text2_sentences]
191
+ text_3_consensuality_ = consensuality_scores_01.loc[text3_sentences]
192
+
193
+ text_1_consensuality_ = [(sentence, text_1_consensuality_[sentence]) for sentence in text1_sentences]
194
+ text_2_consensuality_ = [(sentence, text_2_consensuality_[sentence]) for sentence in text2_sentences]
195
+ text_3_consensuality_ = [(sentence, text_3_consensuality_[sentence]) for sentence in text3_sentences]
196
+
197
+ latex_text_1_consensuality = make_colored_text_to_latex(text_1_consensuality_)
198
+ latex_text_2_consensuality = make_colored_text_to_latex(text_2_consensuality_)
199
+ latex_text_3_consensuality = make_colored_text_to_latex(text_3_consensuality_)
200
+
201
+ latex = latex_template.replace("[REVIEW 1]", latex_text_1)
202
+ latex = latex.replace("[REVIEW 2]", latex_text_2)
203
+ latex = latex.replace("[REVIEW 3]", latex_text_3)
204
+
205
+
206
+ return text_1_summaries, text_2_summaries, text_3_summaries, text_1_consensuality, text_2_consensuality, text_3_consensuality, most_consensual, least_consensual, fig1, fig2, latex
207
+
208
+
209
+ # make gradiot highlightedText component
210
+
211
+
212
+ iface = gr.Interface(
213
+ fn=summarize,
214
+ inputs=[
215
+ gr.Textbox(lines=10, value=EXAMPLES[0]),
216
+ gr.Textbox(lines=10, value=EXAMPLES[1]),
217
+ gr.Textbox(lines=10, value=EXAMPLES[2]),
218
+ gr.Number(value=1, label="Iterations"),
219
+ gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=1.0, label="Rationality"),
220
+ ],
221
+ outputs=[
222
+ gr.Highlightedtext(
223
+ show_legend=True,
224
+ label="Uniqueness score for each sentence in text 1",
225
+ ),
226
+ gr.Highlightedtext(
227
+ show_legend=True,
228
+ label="Uniqueness score for each sentence in text 2",
229
+ ),
230
+ gr.Highlightedtext(
231
+ show_legend=True,
232
+ label="Uniqueness score for each sentence in text 3",
233
+ ),
234
+ gr.Highlightedtext(
235
+ show_legend=True,
236
+ label="Consensuality score for each sentence in text 1",
237
+
238
+ ),
239
+ gr.Highlightedtext(
240
+ show_legend=True,
241
+ label="Consensuality score for each sentence in text 2",
242
+ ),
243
+ gr.Highlightedtext(
244
+ show_legend=True,
245
+ label="Consensuality score for each sentence in text 3",
246
+ ),
247
+ gr.Highlightedtext(
248
+ show_legend=True,
249
+ label="Most consensual sentences",
250
+
251
+ ),
252
+ gr.Highlightedtext(
253
+ show_legend=True,
254
+ label="Least consensual sentences",
255
+ ),
256
+ gr.Plot(
257
+ label="Listener probabilities",
258
+ ),
259
+ gr.Plot(
260
+ label="Speaker probabilities",
261
+ ),
262
+
263
+ gr.Textbox(lines=10, label="Latex Consensuality scores"),
264
+
265
+
266
+
267
+ ],
268
+ title="RSA Summarizer",
269
+ description="Summarize 3 texts using RSA",
270
+ )
271
+
272
+ iface.launch()
glimpse-ui/glimpse/glimpse/baselines/generate_llm_summaries.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import re
7
+ import argparse
8
+ from tqdm import tqdm
9
+
10
+
11
+ def parse_args():
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--dataset", default="")
14
+ parser.add_argument("--batch_size", type=int, default=4)
15
+ parser.add_argument("--device", type=str, default="cuda")
16
+ parser.add_argument("--output", type=Path, default="")
17
+
18
+ args = parser.parse_args()
19
+ return args
20
+
21
+ def prepare_dataset(dataset_name, dataset_path="rsasumm/data/processed/"):
22
+ dataset_path = Path(dataset_path)
23
+ if dataset_name == "amazon":
24
+ dataset = pd.read_csv(dataset_path / "amazon_test.csv")
25
+ elif dataset_name == "space":
26
+ dataset = pd.read_csv(dataset_path / "space.csv")
27
+ elif dataset_name == "yelp":
28
+ dataset = pd.read_csv(dataset_path / "yelp_test.csv")
29
+ elif dataset_name == "reviews":
30
+ dataset = pd.read_csv(dataset_path / "test_metareviews.csv")
31
+ else:
32
+ raise ValueError(f"Unknown dataset {dataset_name}")
33
+
34
+
35
+ return dataset
36
+
37
+
38
+ # group text by sample id and concatenate text
39
+
40
+ def group_text_by_id(df: pd.DataFrame) -> pd.DataFrame:
41
+ """
42
+ Group the text by the sample id and concatenate the text.
43
+ :param df: The dataframe
44
+ :return: The dataframe with the text grouped by the sample id
45
+ """
46
+ texts = df.groupby("id")["text"].apply(lambda x: " ".join(x))
47
+
48
+ # retrieve first gold by id
49
+ gold = df.groupby("id")["gold"].first()
50
+
51
+ # create new dataframe
52
+ df = pd.DataFrame({"text": texts, "gold": gold}, index=texts.index)
53
+
54
+ return df
55
+
56
+
57
+ def generate_summaries(model, tokenizer, df, batch_size, device):
58
+
59
+ # df columns = id, text, gold
60
+ # make instruction:
61
+
62
+ def make_instruction(text):
63
+ return f"[INST]\n{text}\n Summarize the previous text:[/INST]\n\n"
64
+
65
+ df["instruction"] = df["text"].apply(make_instruction)
66
+
67
+ # make data loader
68
+ dataset = df[["instruction"]].values.tolist()
69
+
70
+
71
+ model = model.to(device).eval()
72
+
73
+ summaries = []
74
+ with torch.no_grad():
75
+ for batch in tqdm(dataset):
76
+ print(batch)
77
+ inputs = tokenizer.encode(batch, padding=True, truncation=True, return_tensors="pt")
78
+ inputs = {k: v.to(device) for k, v in inputs.items()}
79
+ outputs = model.generate(**inputs, temperature=0.7, top_p=0.7, top_k=50, max_new_tokens=500)
80
+ summaries.extend(tokenizer.batch_decode(outputs, skip_special_tokens=True))
81
+
82
+ # remove the instruction from the summaries
83
+ df["summary"] = [re.sub(r"\[INST\]\n.*\[/INST\]\n\n", "", summary) for summary in summaries]
84
+
85
+ return df
86
+
87
+ def main():
88
+
89
+ args = parse_args()
90
+ model = "togethercomputer/Llama-2-7B-32K-Instruct"
91
+ tokenizer = AutoTokenizer.from_pretrained(model)
92
+ model = AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True, torch_dtype=torch.float16)
93
+
94
+ df = prepare_dataset(args.dataset)
95
+
96
+ df = group_text_by_id(df)
97
+
98
+ df = generate_summaries(model, tokenizer, df, args.batch_size, args.device)
99
+ df['metadata/Method'] = "LLM"
100
+ df['metadata/Model'] = model
101
+
102
+ name = f"{args.dataset}-_-{model.replace('/', '-')}-_-llm_summaries.csv"
103
+ path = Path(args.output) / name
104
+
105
+ Path(args.output).mkdir(exist_ok=True, parents=True)
106
+ df.to_csv(path, index=True)
107
+
108
+
109
+ main()
110
+
111
+
112
+
glimpse-ui/glimpse/glimpse/baselines/sumy_baselines.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sumy.parsers.plaintext import PlaintextParser
2
+ from sumy.parsers.html import HtmlParser
3
+ from sumy.nlp.tokenizers import Tokenizer
4
+ from sumy.nlp.stemmers import Stemmer
5
+ from sumy.utils import get_stop_words
6
+
7
+ import argparse
8
+
9
+ import pandas as pd
10
+ from pathlib import Path
11
+
12
+ import nltk
13
+
14
+
15
+ def summarize(method, language, sentence_count, input_type, input_):
16
+ if method == 'LSA':
17
+ from sumy.summarizers.lsa import LsaSummarizer as Summarizer
18
+ if method == 'text-rank':
19
+ from sumy.summarizers.text_rank import TextRankSummarizer as Summarizer
20
+ if method == 'lex-rank':
21
+ from sumy.summarizers.lex_rank import LexRankSummarizer as Summarizer
22
+ if method == 'edmundson':
23
+ from sumy.summarizers.edmundson import EdmundsonSummarizer as Summarizer
24
+ if method == 'luhn':
25
+ from sumy.summarizers.luhn import LuhnSummarizer as Summarizer
26
+ if method == 'kl-sum':
27
+ from sumy.summarizers.kl import KLSummarizer as Summarizer
28
+ if method == 'random':
29
+ from sumy.summarizers.random import RandomSummarizer as Summarizer
30
+ if method == 'reduction':
31
+ from sumy.summarizers.reduction import ReductionSummarizer as Summarizer
32
+
33
+ if input_type == "URL":
34
+ parser = HtmlParser.from_url(input_, Tokenizer(language))
35
+ if input_type == "text":
36
+ parser = PlaintextParser.from_string(input_, Tokenizer(language))
37
+
38
+ stemmer = Stemmer(language)
39
+ summarizer = Summarizer(stemmer)
40
+ stop_words = get_stop_words(language)
41
+
42
+ if method == 'edmundson':
43
+ summarizer.null_words = stop_words
44
+ summarizer.bonus_words = parser.significant_words
45
+ summarizer.stigma_words = parser.stigma_words
46
+ else:
47
+ summarizer.stop_words = stop_words
48
+
49
+ summary_sentences = summarizer(parser.document, sentence_count)
50
+ summary = ' '.join([str(sentence) for sentence in summary_sentences])
51
+
52
+ return summary
53
+
54
+
55
+ # methods = ['LSA', 'text-rank', 'lex-rank', 'edmundson', 'luhn', 'kl-sum', 'random', 'reduction']
56
+
57
+
58
+ def parse_args():
59
+ parser = argparse.ArgumentParser()
60
+ parser.add_argument("--dataset", default="")
61
+ # method
62
+ parser.add_argument("--method", type=str, choices=['LSA', 'text-rank', 'lex-rank', 'edmundson', 'luhn', 'kl-sum', 'random', 'reduction'], default="LSA")
63
+ parser.add_argument("--batch_size", type=int, default=4)
64
+ parser.add_argument("--device", type=str, default="cuda")
65
+ parser.add_argument("--output", type=Path, default="")
66
+
67
+ args = parser.parse_args()
68
+ return args
69
+
70
+ def prepare_dataset(dataset_name, dataset_path="rsasumm/data/processed/"):
71
+ dataset_path = Path(dataset_path)
72
+ if dataset_name == "amazon":
73
+ dataset = pd.read_csv(dataset_path / "amazon_test.csv")
74
+ elif dataset_name == "space":
75
+ dataset = pd.read_csv(dataset_path / "space.csv")
76
+ elif dataset_name == "yelp":
77
+ dataset = pd.read_csv(dataset_path / "yelp_test.csv")
78
+ elif dataset_name == "reviews":
79
+ dataset = pd.read_csv(dataset_path / "test_metareviews.csv")
80
+ else:
81
+ raise ValueError(f"Unknown dataset {dataset_name}")
82
+
83
+
84
+ return dataset
85
+
86
+
87
+ # group text by sample id and concatenate text
88
+
89
+ def group_text_by_id(df: pd.DataFrame) -> pd.DataFrame:
90
+ """
91
+ Group the text by the sample id and concatenate the text.
92
+ :param df: The dataframe
93
+ :return: The dataframe with the text grouped by the sample id
94
+ """
95
+ texts = df.groupby("id")["text"].apply(lambda x: " ".join(x))
96
+
97
+ # retrieve first gold by id
98
+ gold = df.groupby("id")["gold"].first()
99
+
100
+ # create new dataframe
101
+ df = pd.DataFrame({"text": texts, "gold": gold}, index=texts.index)
102
+
103
+ return df
104
+
105
+
106
+ def main():
107
+ args = parse_args()
108
+ for N in [1]:
109
+ dataset = prepare_dataset(args.dataset)
110
+ # dataset = group_text_by_id(dataset)
111
+
112
+ summaries = []
113
+ for text in dataset.text:
114
+ summary = summarize(args.method, "english", N, "text", text)
115
+ summaries.append(summary)
116
+
117
+ dataset['summary'] = summaries
118
+ dataset['metadata/dataset'] = args.dataset
119
+ dataset["metadata/method"] = args.method
120
+ dataset["metadata/sentence_count"] = N
121
+
122
+ name = f"{args.dataset}-_-{args.method}-_-sumy_{N}.csv"
123
+ path = Path(args.output) / name
124
+
125
+ Path(args.output).mkdir(exist_ok=True, parents=True)
126
+ dataset.to_csv(path, index=True)
127
+
128
+
129
+ main()
glimpse-ui/glimpse/glimpse/data_loading/Glimpse_tokenizer.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import spacy
3
+ import importlib
4
+ import nltk
5
+
6
+ ############################################
7
+ ### CHANGE THIS LINE TO CHOOSE TOKENIZER ###
8
+ ORIGINAL_TOKENIZER = False
9
+ ############################################
10
+
11
+ try:
12
+ importlib.util.find_spec("en_core_web_sm")
13
+ nlp = spacy.load("en_core_web_sm")
14
+ except:
15
+ import spacy.cli
16
+ spacy.cli.download("en_core_web_sm")
17
+ nlp = spacy.load("en_core_web_sm")
18
+
19
+ def glimpse_tokenizer(text: str) -> list:
20
+
21
+ # If the original tokenizer is set to True, use the original tokenizer
22
+ if ORIGINAL_TOKENIZER:
23
+ return original_tokenizer(text)
24
+
25
+ # else, use the new tokenizer
26
+ else:
27
+
28
+ # More general-purpose tokenizer that handles both natural paragraph text and structured reviews.
29
+
30
+ # Normalize long dashes
31
+ text = re.sub(r"[-]{2,}", "\n", text)
32
+
33
+ # Keep line breaks meaningful (but fallback to sentence splitting)
34
+ chunks = re.split(r"\n+", text)
35
+ sentences = []
36
+
37
+ for chunk in chunks:
38
+ chunk = chunk.strip()
39
+ if not chunk:
40
+ continue
41
+
42
+ # Section headers and bullets become single “sentences”
43
+ if re.match(r"^(Summary|Strengths?|Weaknesses?|Minor)\s*:?", chunk, re.IGNORECASE):
44
+ sentences.append(chunk)
45
+ continue
46
+
47
+ if re.match(r"^(\d+(\.\d+)*\.|-)\s+.+", chunk):
48
+ sentences.append(chunk)
49
+ continue
50
+
51
+ # Otherwise, apply SpaCy sentence splitting
52
+ doc = nlp(chunk)
53
+ sentences.extend([sent.text.strip() for sent in doc.sents if sent.text.strip()])
54
+
55
+ return sentences
56
+
57
+ # reuse the original glimpse tokenizer
58
+ # def glimpse_tokenizer(text: str) -> list:
59
+ # return tokenize_sentences(text)
60
+
61
+ # Default glimpse tokenizer from the original code
62
+ def original_tokenizer(text: str) -> list:
63
+ """
64
+ Tokenizes the input text into sentences.
65
+
66
+ @param text: The input text to be tokenized
67
+ @return: A list of tokenized sentences
68
+ """
69
+ text = text.replace('-----', '\n')
70
+ sentences = nltk.sent_tokenize(text)
71
+ # remove empty sentences
72
+ sentences = [sentence for sentence in sentences if sentence != ""]
73
+
74
+ return sentences
glimpse-ui/glimpse/glimpse/data_loading/data_processing.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+
4
+ data_glimpse = "data/processed/"
5
+ if not os.path.exists(data_glimpse):
6
+ os.makedirs(data_glimpse)
7
+
8
+ for year in range (2017, 2021 + 1):
9
+ dataset = pd.read_csv(f"data/all_reviews_{year}.csv")
10
+ sub_dataset = dataset[['id','review', 'metareview']]
11
+ sub_dataset.rename(columns={"review": "text", "metareview": "gold"}, inplace=True)
12
+
13
+ sub_dataset.to_csv(f"{data_glimpse}all_reviews_{year}.csv", index=False)
14
+
15
+
glimpse-ui/glimpse/glimpse/data_loading/generate_abstractive_candidates.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+ from torch.utils.data import DataLoader
6
+ from datasets import Dataset
7
+ from tqdm import tqdm
8
+ import datetime
9
+ import torch
10
+
11
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
+
13
+ GENERATION_CONFIGS = {
14
+ "top_p_sampling": {
15
+ "max_new_tokens": 200,
16
+ "do_sample": True,
17
+ "top_p": 0.95,
18
+ "temperature": 1.0,
19
+ "num_return_sequences": 8,
20
+ "num_beams" : 1,
21
+
22
+ #"num_beam_groups" : 4,
23
+ },
24
+
25
+ **{
26
+ f"sampling_topp_{str(topp).replace('.', '')}": {
27
+ "max_new_tokens": 200,
28
+ "do_sample": True,
29
+ "num_return_sequences": 8,
30
+ "top_p": 0.95,
31
+ }
32
+ for topp in [0.5, 0.8, 0.95, 0.99]
33
+ },
34
+ }
35
+
36
+ # add base.csv config to all configs
37
+ for key, value in GENERATION_CONFIGS.items():
38
+ GENERATION_CONFIGS[key] = {
39
+ # "max_length": 2048,
40
+ "min_length": 0,
41
+ "early_stopping": True,
42
+ **value,
43
+ }
44
+
45
+
46
+ def parse_args():
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument("--model_name", type=str, default="facebook/bart-large-cnn")
49
+ parser.add_argument("--dataset_path", type=Path, default="data/processed/all_reviews_2017.csv")
50
+ parser.add_argument("--decoding_config", type=str, default="top_p_sampling", choices=GENERATION_CONFIGS.keys())
51
+
52
+ parser.add_argument("--batch_size", type=int, default=16)
53
+ parser.add_argument("--device", type=str, default="cuda")
54
+ parser.add_argument("--trimming", action=argparse.BooleanOptionalAction, default=True)
55
+
56
+ parser.add_argument("--output_dir", type=str, default="data/candidates")
57
+
58
+ # if ran in a scripted way, the output path will be printed
59
+ parser.add_argument("--scripted-run", action=argparse.BooleanOptionalAction, default=False)
60
+
61
+ # limit the number of samples to generate
62
+ parser.add_argument("--limit", type=int, default=None)
63
+
64
+ args = parser.parse_args()
65
+
66
+ return args
67
+
68
+
69
+ def prepare_dataset(dataset_path) -> Dataset:
70
+ try:
71
+ dataset = pd.read_csv(dataset_path)
72
+ except:
73
+ raise ValueError(f"Unknown dataset {dataset_path}")
74
+
75
+ # make a dataset from the dataframe
76
+ dataset = Dataset.from_pandas(dataset)
77
+
78
+ return dataset
79
+
80
+
81
+ def evaluate_summarizer(
82
+ model, tokenizer, dataset: Dataset, decoding_config, batch_size: int,
83
+ device: str, trimming: bool
84
+ ) -> Dataset:
85
+ """
86
+ @param model: The model used to generate the summaries
87
+ @param tokenizer: The tokenizer used to tokenize the text and the summary
88
+ @param dataset: A dataset with the text
89
+ @param decoding_config: Dictionary with the decoding config
90
+ @param batch_size: The batch size used to generate the summaries
91
+ @return: The same dataset with the summaries added
92
+ """
93
+ # create a dataset with the text and the summary
94
+
95
+ # create a dataloader
96
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=trimming)
97
+
98
+ # generate summaries
99
+ summaries = []
100
+ print("Generating summaries...")
101
+
102
+ for batch in tqdm(dataloader):
103
+ text = batch["text"]
104
+
105
+ inputs = tokenizer(
106
+ text,
107
+ max_length=1024,
108
+ padding="max_length",
109
+ truncation=True,
110
+ return_tensors="pt",
111
+ )
112
+
113
+ # move inputs to device
114
+ inputs = {key: value.to(device) for key, value in inputs.items()}
115
+
116
+ # generate summaries
117
+ outputs = model.generate(
118
+ **inputs,
119
+ **decoding_config,
120
+ )
121
+
122
+ total_size = outputs.numel() # Total number of elements in the tensor
123
+ target_size = batch_size * outputs.shape[-1] # Target size of the last dimension
124
+ pad_size = (target_size - (total_size % target_size)) % target_size # Calculate the required padding size to make the total number of elements divisible by the target size
125
+
126
+ # Pad the tensor with zeros to make the total number of elements divisible by the target size
127
+ if not trimming and pad_size != 0: outputs = torch.nn.functional.pad(outputs, (0, 0, 0, pad_size // outputs.shape[-1]))
128
+
129
+ # output : (batch_size * num_return_sequences, max_length)
130
+ try:
131
+ outputs = outputs.reshape(batch_size, -1, outputs.shape[-1])
132
+ except Exception as e:
133
+ print(f"Error reshaping outputs: {e}")
134
+ raise ValueError(f"Cannot reshape tensor of size {outputs.numel()} into shape "
135
+ f"({batch_size}, -1, {outputs.shape[-1]}).")
136
+
137
+ # decode summaries
138
+ for b in range(batch_size):
139
+ summaries.append(
140
+ [
141
+ tokenizer.decode(
142
+ outputs[b, i],
143
+ skip_special_tokens=True,
144
+ )
145
+ for i in range(outputs.shape[1])
146
+ ]
147
+ )
148
+
149
+ # if trimming the last batch, remove them from the dataset
150
+ if trimming: dataset = dataset.select(range(len(summaries)))
151
+
152
+ # add summaries to the huggingface dataset
153
+ dataset = dataset.map(lambda example: {"summary": summaries.pop(0)})
154
+
155
+ return dataset
156
+
157
+
158
+ def sanitize_model_name(model_name: str) -> str:
159
+ """
160
+ Sanitize the model name to be used as a folder name.
161
+ @param model_name: The model name
162
+ @return: The sanitized model name
163
+ """
164
+ return model_name.replace("/", "_")
165
+
166
+
167
+ def main():
168
+ args = parse_args()
169
+
170
+ # load the model
171
+ model = AutoModelForSeq2SeqLM.from_pretrained(
172
+ args.model_name
173
+ )
174
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
175
+
176
+ tokenizer.pad_token = tokenizer.unk_token
177
+ tokenizer.pad_token_id = tokenizer.unk_token_id
178
+
179
+ # move model to device
180
+ model = model.to(args.device)
181
+
182
+ # load the dataset
183
+ print("Loading dataset...")
184
+ dataset = prepare_dataset(args.dataset_path)
185
+
186
+ # limit the number of samples
187
+ if args.limit is not None:
188
+ _lim = min(args.limit, len(dataset))
189
+ dataset = dataset.select(range(_lim))
190
+
191
+ # generate summaries
192
+ dataset = evaluate_summarizer(
193
+ model,
194
+ tokenizer,
195
+ dataset,
196
+ GENERATION_CONFIGS[args.decoding_config],
197
+ args.batch_size,
198
+ args.device,
199
+ args.trimming,
200
+ )
201
+
202
+ df_dataset = dataset.to_pandas()
203
+ df_dataset = df_dataset.explode('summary')
204
+ df_dataset = df_dataset.reset_index()
205
+ # add an idx with the id of the summary for each example
206
+ df_dataset['id_candidate'] = df_dataset.groupby(['index']).cumcount()
207
+
208
+ # save the dataset
209
+ # add unique date in name
210
+ now = datetime.datetime.now()
211
+ date = now.strftime("%Y-%m-%d-%H-%M-%S")
212
+ model_name = sanitize_model_name(args.model_name)
213
+ padding_status = "trimmed" if args.trimming else "padded"
214
+ output_path = (
215
+ Path(args.output_dir)
216
+ / f"{model_name}-_-{args.dataset_path.stem}-_-{args.decoding_config}-_-{padding_status}-_-{date}.csv"
217
+ )
218
+
219
+ # create output dir if it doesn't exist
220
+ if not output_path.parent.exists():
221
+ output_path.parent.mkdir(parents=True, exist_ok=True)
222
+
223
+ df_dataset.to_csv(output_path, index=False, encoding="utf-8")
224
+
225
+ # in case of scripted run, print the output path
226
+ if args.scripted_run: print(output_path)
227
+
228
+
229
+ if __name__ == "__main__":
230
+ main()
glimpse-ui/glimpse/glimpse/data_loading/generate_extractive_candidates.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ from pathlib import Path
4
+
5
+ import pandas as pd
6
+ from datasets import Dataset
7
+ from tqdm import tqdm
8
+
9
+ import nltk
10
+
11
+ import sys, os.path
12
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
13
+
14
+ from glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
15
+
16
+ # def tokenize_sentences(text: str) -> list:
17
+ # """
18
+ # Tokenizes the input text into sentences.
19
+
20
+ # @param text: The input text to be tokenized
21
+ # @return: A list of tokenized sentences
22
+ # """
23
+ # text = text.replace('-----', '\n')
24
+ # sentences = nltk.sent_tokenize(text)
25
+ # # remove empty sentences
26
+ # sentences = [sentence for sentence in sentences if sentence != ""]
27
+
28
+ # return sentences
29
+
30
+ def parse_args():
31
+ parser = argparse.ArgumentParser()
32
+
33
+ parser.add_argument("--dataset_path", type=Path, default="glimpse/data/processed/all_reviews_2017.csv")
34
+ parser.add_argument("--output_dir", type=str, default="glimpse/data/candidates")
35
+
36
+ # if ran in a scripted way, the output path will be printed
37
+ parser.add_argument("--scripted-run", action=argparse.BooleanOptionalAction, default=False)
38
+
39
+ # limit the number of samples to generate
40
+ parser.add_argument("--limit", type=int, default=None)
41
+
42
+ args = parser.parse_args()
43
+
44
+ return args
45
+
46
+
47
+ def prepare_dataset(dataset_path) -> Dataset:
48
+
49
+ try:
50
+ dataset = pd.read_csv(dataset_path)
51
+ except:
52
+ raise ValueError(f"Unknown dataset {dataset_path}")
53
+
54
+ # make a dataset from the dataframe
55
+ dataset = Dataset.from_pandas(dataset)
56
+
57
+ return dataset
58
+
59
+
60
+ def evaluate_summarizer(dataset: Dataset) -> Dataset:
61
+ """
62
+ @param dataset: A dataset with the text
63
+ @return: The same dataset with the summaries added
64
+ """
65
+ # create a dataset with the text and the summary
66
+
67
+ # create a dataloader
68
+
69
+ # generate summaries
70
+ summaries = []
71
+ print("Generating summaries...")
72
+
73
+ # (tqdm library for progress bar)
74
+ for sample in tqdm(dataset):
75
+ text = sample["text"]
76
+
77
+ sentences = glimpse_tokenizer(text)
78
+
79
+ summaries.append(sentences)
80
+
81
+ # add summaries to the huggingface dataset
82
+ dataset = dataset.map(lambda example: {"summary": summaries.pop(0)})
83
+
84
+ return dataset
85
+
86
+
87
+ def main():
88
+ args = parse_args()
89
+ # load the dataset
90
+ print("Loading dataset...")
91
+ dataset = prepare_dataset(args.dataset_path)
92
+
93
+ # limit the number of samples
94
+ if args.limit is not None:
95
+ _lim = min(args.limit, len(dataset))
96
+ dataset = dataset.select(range(_lim))
97
+
98
+ # generate summaries
99
+ dataset = evaluate_summarizer(
100
+ dataset,
101
+ )
102
+
103
+ df_dataset = dataset.to_pandas()
104
+ df_dataset = df_dataset.explode("summary")
105
+ df_dataset = df_dataset.reset_index()
106
+ # add an idx with the id of the summary for each example
107
+ df_dataset["id_candidate"] = df_dataset.groupby(["index"]).cumcount()
108
+
109
+ # save the dataset
110
+ # add unique date in name
111
+ now = datetime.datetime.now()
112
+ date = now.strftime("%Y-%m-%d-%H-%M-%S")
113
+ output_path = (
114
+ Path(args.output_dir)
115
+ / f"extractive_sentences-_-{args.dataset_path.stem}-_-none-_-{date}.csv"
116
+ )
117
+
118
+ # create output dir if it doesn't exist
119
+ if not output_path.parent.exists():
120
+ output_path.parent.mkdir(parents=True, exist_ok=True)
121
+
122
+ df_dataset.to_csv(output_path, index=False, encoding="utf-8")
123
+
124
+ # in case of scripted run, print the output path
125
+ if args.scripted_run: print(output_path)
126
+
127
+
128
+ if __name__ == "__main__":
129
+ main()
glimpse-ui/glimpse/glimpse/evaluate/Evaluate informativeness.ipynb ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "38be7bab-8a42-49dd-8976-2755ee84edbe",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "import pandas as pd\n",
13
+ "import numpy as np\n",
14
+ "from pathlib import Path\n",
15
+ "import pickle as pk\n",
16
+ "import nltk\n",
17
+ "import seaborn as sns\n",
18
+ "import matplotlib.pyplot as plt\n",
19
+ "\n",
20
+ "\n",
21
+ "\n",
22
+ "export_summaries_path = Path('output/summaries/methods_per_text')\n",
23
+ "export_summaries_path.mkdir(parents=True, exist_ok=True)"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "id": "1c24ddc1-978e-4a24-8fdb-e48b2848edd0",
30
+ "metadata": {
31
+ "tags": []
32
+ },
33
+ "outputs": [],
34
+ "source": [
35
+ "\n",
36
+ "dfs = []\n",
37
+ "for file in export_summaries_path.glob('*.csv'):\n",
38
+ " df = pd.read_csv(file)\n",
39
+ " generation_method, dataset = file.stem.split('-_-')[:2]\n",
40
+ " \n",
41
+ " df['metadata/Generation'] = generation_method\n",
42
+ " df['metadata/Dataset'] = dataset\n",
43
+ " \n",
44
+ " dfs.append(df)\n",
45
+ " \n",
46
+ "df = pd.concat(dfs)\n",
47
+ "\n",
48
+ "df = df.drop([c for c in df.columns if \"Unnamed:\" in c], axis=1)\n",
49
+ "\n",
50
+ "del dfs\n",
51
+ "\n",
52
+ "def replace_abstractive(x):\n",
53
+ " if \"abstractive\" in x:\n",
54
+ " return \"extractive_sentences\"\n",
55
+ " else:\n",
56
+ " return x\n",
57
+ "\n",
58
+ "df['metadata/Generation'] = df['metadata/Generation'].apply(replace_abstractive)\n",
59
+ "\n"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "id": "582af03f-ac60-4644-88d7-28caf84bb552",
66
+ "metadata": {
67
+ "tags": []
68
+ },
69
+ "outputs": [],
70
+ "source": [
71
+ "df = df[~(df['Method'].str.contains('Lead')).fillna(False)]\n",
72
+ "\n"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "id": "099fda3b-137c-4e3b-bb70-933ad183d378",
79
+ "metadata": {
80
+ "tags": []
81
+ },
82
+ "outputs": [],
83
+ "source": []
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "id": "7d9babd3-a409-409c-848e-db6a3f846d6f",
89
+ "metadata": {
90
+ "tags": []
91
+ },
92
+ "outputs": [],
93
+ "source": [
94
+ "ddf = df.copy()\n",
95
+ "ddf['proba_of_success'] = ddf['proba_of_success'].apply(np.exp)\n",
96
+ "\n",
97
+ "discriminativity = ddf.groupby(['metadata/Generation', 'metadata/reranking_model', 'Method'])[['proba_of_success', 'LM Perplexity']].agg(['mean']).droplevel(1, axis=1).sort_values('proba_of_success', ascending=False).reset_index()\n",
98
+ "discriminativity"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "id": "d1bffd58-0eed-4007-bcca-cd0a6dfdcd1d",
105
+ "metadata": {
106
+ "tags": []
107
+ },
108
+ "outputs": [],
109
+ "source": [
110
+ "\n",
111
+ "ddf = ddf.sort_values('proba_of_success', ascending=False)\n",
112
+ "sns.catplot(data=ddf, y='proba_of_success', x='Method', hue=\"metadata/Generation\", kind='bar', col='metadata/reranking_model')"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "id": "6684f99b-68d8-4a4f-a1f7-8943f9755bb7",
119
+ "metadata": {
120
+ "tags": []
121
+ },
122
+ "outputs": [],
123
+ "source": [
124
+ "\n",
125
+ "fig, ax = plt.subplots(1, 1)\n",
126
+ "\n",
127
+ "paretto = get_pareto_points(discriminativity[['proba_of_success', 'LM Perplexity']].values)\n",
128
+ "\n",
129
+ "ax.plot(paretto[:, 1], paretto[:, 0], c='purple', linewidth=5, linestyle=\"--\", label=\"Pareto front\")\n",
130
+ "sns.scatterplot(data=discriminativity, y='proba_of_success', x='LM Perplexity', hue=\"Method\", s=200, alpha=0.8, style='metadata/reranking_model')\n",
131
+ "plt.xlim(-70, 0)\n",
132
+ "\n",
133
+ " \n",
134
+ "\n",
135
+ "def get_pareto_points(data):\n",
136
+ " # data : [N, 2]\n",
137
+ " \n",
138
+ " optima = []\n",
139
+ " for p in data:\n",
140
+ " x, y = p\n",
141
+ " if len([p2 for p2 in data if p2[0] > p[0] and p2[1] > p[1]]) == 0:\n",
142
+ " optima.append(p)\n",
143
+ " \n",
144
+ " return np.array(optima)\n",
145
+ "\n",
146
+ "\n",
147
+ "\n",
148
+ " \n",
149
+ "\n",
150
+ " \n",
151
+ " \n",
152
+ " "
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "id": "6f6d6a10-e13e-4d44-ace3-dae5de0f362c",
159
+ "metadata": {
160
+ "tags": []
161
+ },
162
+ "outputs": [],
163
+ "source": []
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": null,
168
+ "id": "6f2b4cc9-8584-463e-b130-6c48c85e2665",
169
+ "metadata": {
170
+ "tags": []
171
+ },
172
+ "outputs": [],
173
+ "source": []
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "id": "d49c5df4-23a4-473c-81a6-b8928ffdf8af",
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "\n",
183
+ "ddf = df.copy().drop('level_0', axis=1)\n",
184
+ "ddf['proba_of_success'] = -ddf['proba_of_success']\n",
185
+ "ddf['LM Perplexity'] = -ddf['LM Perplexity']\n",
186
+ "ddf = ddf.reset_index()\n",
187
+ "\n",
188
+ "\n",
189
+ "\n",
190
+ "\n",
191
+ "sns.displot(data=ddf, x='LM Perplexity', y='proba_of_success', hue=\"Method\", kind='kde')\n",
192
+ "# plt.xscale('log')\n",
193
+ "#plt.yscale('log')\n",
194
+ "\n",
195
+ "plt.xlim(0, 100)\n",
196
+ "plt.ylim(-2, 3)\n",
197
+ "\n"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": null,
203
+ "id": "d1d1235a-7bff-4e51-ad43-20dd5d1a0734",
204
+ "metadata": {
205
+ "tags": []
206
+ },
207
+ "outputs": [],
208
+ "source": [
209
+ "sns.scatterplot(data=df, x='LM Perplexity', y='proba_of_success', hue=\"Method\")\n"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "id": "faf6ec08-feca-4106-bf51-7a3ffb438267",
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": []
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": null,
223
+ "id": "e0805bfe-660f-4d1f-a28a-cf09220b5da3",
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": []
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "id": "5235f917-81b5-40ca-9c62-cb63a8aaaffb",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": []
235
+ }
236
+ ],
237
+ "metadata": {
238
+ "kernelspec": {
239
+ "display_name": "pytorch-gpu-2.0.0_py3.10.9",
240
+ "language": "python",
241
+ "name": "module-conda-env-pytorch-gpu-2.0.0_py3.10.9"
242
+ },
243
+ "language_info": {
244
+ "codemirror_mode": {
245
+ "name": "ipython",
246
+ "version": 3
247
+ },
248
+ "file_extension": ".py",
249
+ "mimetype": "text/x-python",
250
+ "name": "python",
251
+ "nbconvert_exporter": "python",
252
+ "pygments_lexer": "ipython3",
253
+ "version": "3.10.9"
254
+ }
255
+ },
256
+ "nbformat": 4,
257
+ "nbformat_minor": 5
258
+ }
glimpse-ui/glimpse/glimpse/evaluate/evaluate_bartbert_metrics.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+
6
+ from bert_score import BERTScorer
7
+
8
+ def sanitize_model_name(model_name: str) -> str:
9
+ """
10
+ Sanitize the model name to be used as a folder name.
11
+ @param model_name: The model name
12
+ @return: The sanitized model name
13
+ """
14
+ return model_name.replace("/", "_")
15
+
16
+ # logging.basicConfig(stream=stdout, level=logging.)
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("--summaries", type=Path, default="")
20
+
21
+ # device
22
+ parser.add_argument("--device", type=str, default="cuda")
23
+
24
+ args = parser.parse_args()
25
+ return args
26
+
27
+
28
+
29
+ def parse_summaries(path: Path):
30
+ """
31
+ :return: a pandas dataframe with at least the columns 'text' and 'summary'
32
+ """
33
+ # read csv file
34
+
35
+ df = pd.read_csv(path).dropna()
36
+
37
+
38
+ # check if the csv file has the correct columns
39
+ if not all([col in df.columns for col in ["gold", "summary"]]):
40
+ raise ValueError("The csv file must have the columns 'text' and 'summary'.")
41
+
42
+ return df
43
+
44
+
45
+ def evaluate_bartbert(df, device="cuda"):
46
+ # make a list of the tuples (text, summary)
47
+
48
+ # texts = df.text.tolist()
49
+ texts = df.gold.tolist()
50
+ summaries = df.summary.tolist()
51
+
52
+ scorer = BERTScorer(lang="en", rescale_with_baseline=True, device=device)
53
+
54
+ metrics = {'BERTScore': []}
55
+ for i in range(len(texts)):
56
+ texts[i] = texts[i].replace("\n", " ")
57
+ summaries[i] = summaries[i].replace("\n", " ")
58
+
59
+ P, R, F1 = scorer.score([summaries[i]], [texts[i]])
60
+
61
+ metrics['BERTScore'].append(F1.mean().item())
62
+
63
+ # compute the mean of the metrics
64
+ # metrics = {k: sum(v) / len(v) for k, v in metrics.items()}
65
+
66
+ return metrics
67
+
68
+
69
+ def main():
70
+ args = parse_args()
71
+
72
+ path = args.summaries
73
+ path.parent.mkdir(parents=True, exist_ok=True)
74
+
75
+ # load the model
76
+ df = parse_summaries(args.summaries)
77
+
78
+ metrics = evaluate_bartbert(df)
79
+
80
+ # make a dataframe with the metric
81
+ df = pd.DataFrame(metrics)
82
+
83
+ # Add the model name in the metrics names
84
+ df = df.add_prefix(f"common/")
85
+
86
+ # save the dataframe
87
+
88
+ # check if exists already, if it does load it and add the new columns
89
+
90
+ print(df)
91
+
92
+ if path.exists():
93
+ df_old = pd.read_csv(path, index_col=0)
94
+
95
+ # create the colums if they do not exist
96
+ for col in df.columns:
97
+ if col not in df_old.columns:
98
+ df_old[col] = float("nan")
99
+
100
+ # add entry to the dataframe
101
+ for col in df.columns:
102
+ df_old[col] = df[col]
103
+
104
+ df = df_old
105
+
106
+ df.to_csv(path)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ main()
glimpse-ui/glimpse/glimpse/evaluate/evaluate_common_metrics_samples.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+ from rouge_score import rouge_scorer
6
+
7
+
8
+ def sanitize_model_name(model_name: str) -> str:
9
+ """
10
+ Sanitize the model name to be used as a folder name.
11
+ @param model_name: The model name
12
+ @return: The sanitized model name
13
+ """
14
+ return model_name.replace("/", "_")
15
+
16
+
17
+ # logging.basicConfig(stream=stdout, level=logging.)
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--summaries", type=Path, default="")
21
+
22
+ args = parser.parse_args()
23
+ return args
24
+
25
+
26
+
27
+ def parse_summaries(path: Path):
28
+ """
29
+ :return: a pandas dataframe with at least the columns 'text' and 'summary'
30
+ """
31
+ # read csv file
32
+
33
+ df = pd.read_csv(path).dropna()
34
+
35
+ # check if the csv file has the correct columns
36
+ if not all([col in df.columns for col in ["gold", "summary"]]):
37
+ raise ValueError("The csv file must have the columns 'text' and 'summary'.")
38
+
39
+ return df
40
+
41
+
42
+ def evaluate_rouge(
43
+ df,
44
+ ):
45
+ # make a list of the tuples (text, summary)
46
+
47
+ texts = df.gold.tolist()
48
+ summaries = df.summary.tolist()
49
+
50
+ # rouges
51
+ metrics = {"rouge1": [], "rouge2": [], "rougeL": [], "rougeLsum": []}
52
+
53
+ rouges = rouge_scorer.RougeScorer(
54
+ ["rouge1", "rouge2", "rougeL", "rougeLsum"], use_stemmer=True
55
+ )
56
+
57
+ metrics["rouge1"].extend(
58
+ [
59
+ rouges.score(summary, text)["rouge1"].fmeasure
60
+ for summary, text in zip(summaries, texts)
61
+ ]
62
+ )
63
+ metrics["rouge2"].extend(
64
+ [
65
+ rouges.score(summary, text)["rouge2"].fmeasure
66
+ for summary, text in zip(summaries, texts)
67
+ ]
68
+ )
69
+ metrics["rougeL"].extend(
70
+ [
71
+ rouges.score(summary, text)["rougeL"].fmeasure
72
+ for summary, text in zip(summaries, texts)
73
+ ]
74
+ )
75
+ metrics["rougeLsum"].extend(
76
+ [
77
+ rouges.score(summary, text)["rougeLsum"].fmeasure
78
+ for summary, text in zip(summaries, texts)
79
+ ]
80
+ )
81
+
82
+ # compute the mean of the metrics
83
+ # metrics = {k: sum(v) / len(v) for k, v in metrics.items()}
84
+
85
+ return metrics
86
+
87
+
88
+ def main():
89
+ args = parse_args()
90
+
91
+ # load the model
92
+ df = parse_summaries(args.summaries)
93
+
94
+ metrics = evaluate_rouge(df)
95
+
96
+
97
+ # # add index to the metrics
98
+ # metrics["index"] = [i for i in range(len(df))]
99
+
100
+ df = pd.DataFrame.from_dict(metrics)
101
+ df = df.add_prefix(f"common/")
102
+
103
+ # merge the metrics with the summaries
104
+
105
+ if args.summaries.exists():
106
+ df_old = parse_summaries(args.summaries)
107
+
108
+ for col in df.columns:
109
+ if col not in df_old.columns:
110
+ df_old[col] = float("nan")
111
+
112
+ # add entry to the dataframe
113
+ for col in df.columns:
114
+ df_old[col] = df[col]
115
+
116
+ df = df_old
117
+
118
+ df.to_csv(args.summaries, index=False)
119
+
120
+
121
+ if __name__ == "__main__":
122
+ main()
glimpse-ui/glimpse/glimpse/evaluate/evaluate_seahorse_metrics_samples.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.utils.data
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
+
11
+ map_questionnumber_to_question = {
12
+ "question1": "SHMetric/Comprehensible",
13
+ "question2": "SHMetric/Repetition",
14
+ "question3": "SHMetric/Grammar",
15
+ "question4": "SHMetric/Attribution",
16
+ "question5": "SHMetric/Main ideas",
17
+ "question6": "SHMetric/Conciseness",
18
+ }
19
+
20
+ def sanitize_model_name(model_name: str) -> str:
21
+ """
22
+ Sanitize the model name to be used as a folder name.
23
+ @param model_name: The model name
24
+ @return: The sanitized model name
25
+ """
26
+ return model_name.replace("/", "_")
27
+
28
+ # logging.basicConfig(stream=stdout, level=logging.)
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument(
32
+ "--question",
33
+ type=str,
34
+ default="repetition",
35
+ )
36
+ parser.add_argument("--summaries", type=Path, default="")
37
+ parser.add_argument("--select", type=str, default="*")
38
+ parser.add_argument("--batch_size", type=int, default=16)
39
+ parser.add_argument("--device", type=str, default="cuda")
40
+
41
+
42
+ args = parser.parse_args()
43
+ return args
44
+
45
+ def parse_summaries(path: Path):
46
+ """
47
+ :return: a pandas dataframe with at least the columns 'text' and 'summary'
48
+ """
49
+ # read csv file
50
+
51
+ df = pd.read_csv(path).dropna()
52
+
53
+ # check if the csv file has the correct columns
54
+ if not all([col in df.columns for col in ["text", "summary"]]):
55
+ raise ValueError("The csv file must have the columns 'text' and 'summary'.")
56
+
57
+ return df
58
+
59
+
60
+ def evaluate_classification_task(model, tokenizer, question, df, batch_size):
61
+
62
+ texts = df.text.tolist()
63
+ summaries = df.summary.tolist()
64
+
65
+ template = "premise: {premise} hypothesis: {hypothesis}"
66
+ ds = [template.format(premise=text[:20*1024], hypothesis=summary) for text, summary in zip(texts, summaries)]
67
+
68
+
69
+ eval_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size)
70
+
71
+ metrics = {f"{question}/proba_1": [], f"{question}/proba_0": [], f"{question}/guess": []}
72
+
73
+ with torch.no_grad():
74
+ for batch in tqdm(eval_loader):
75
+ # tokenize the batch
76
+ inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt")
77
+ # move the inputs to the device
78
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
79
+
80
+ N_inputs = inputs["input_ids"].shape[0]
81
+ # make decoder inputs to be <pad>
82
+ decoder_input_ids = torch.full((N_inputs, 1), tokenizer.pad_token_id, dtype=torch.long, device=model.device)
83
+
84
+ outputs = model(**inputs, decoder_input_ids=decoder_input_ids)
85
+ logits = outputs.logits
86
+ # retrieve logits for the last token and the scores for 0 and 1
87
+ logits = logits[:, -1, [497, 333]]
88
+
89
+ # compute the probabilities
90
+ probs = F.softmax(logits, dim=-1)
91
+
92
+ # compute the guess
93
+ guess = probs.argmax(dim=-1)
94
+
95
+ # append the metrics
96
+ metrics[f"{question}/proba_1"].extend(probs[:, 1].tolist())
97
+ metrics[f"{question}/proba_0"].extend(probs[:, 0].tolist())
98
+ metrics[f"{question}/guess"].extend(guess.tolist())
99
+
100
+ # average the metrics
101
+
102
+ # metrics = {k: sum(v) / len(v) for k, v in metrics.items()}
103
+
104
+ return metrics
105
+
106
+ def main():
107
+ args = parse_args()
108
+
109
+ model_name = f"google/seahorse-large-q{args.question}"
110
+ question = map_questionnumber_to_question[f"question{args.question}"]
111
+
112
+ # load the model
113
+ # load in float16 to save memory
114
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.float16)
115
+
116
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
117
+
118
+ df = parse_summaries(args.summaries)
119
+
120
+ metrics = evaluate_classification_task(model, tokenizer, question, df, args.batch_size)
121
+
122
+ # make a dataframe with the metric
123
+ df_metrics = pd.DataFrame(metrics)
124
+
125
+ # merge the metrics with the summaries
126
+ df = parse_summaries(args.summaries)
127
+ df = pd.concat([df, df_metrics], axis=1)
128
+
129
+ path = Path(args.summaries)
130
+
131
+ if path.exists():
132
+ df_old = pd.read_csv(path, index_col=0)
133
+
134
+ # create the colums if they do not exist
135
+ for col in df.columns:
136
+ if col not in df_old.columns:
137
+ df_old[col] = float("nan")
138
+
139
+ # add entry to the dataframe
140
+ for col in df.columns:
141
+ df_old[col] = df[col]
142
+
143
+ df = df_old
144
+
145
+ # save the dataframe
146
+ df.to_csv(args.summaries)
147
+
148
+
149
+ if __name__ == "__main__":
150
+ main()
glimpse-ui/glimpse/glimpse/src/beam_rsa_decoding.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ from pathlib import Path
4
+
5
+ import pandas as pd
6
+ from datasets import Dataset
7
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
+
9
+ from rsasumm.beam_search import RSAContextualDecoding
10
+ from tqdm import tqdm
11
+
12
+ GENERATION_CONFIGS = {
13
+ "top_p_sampling": {
14
+ "max_new_tokens": 200,
15
+ "do_sample": True,
16
+ "top_p": 0.95,
17
+ "temperature": 1.0,
18
+ "num_return_sequences": 8,
19
+ "num_beams": 1,
20
+ # "num_beam_groups" : 4,
21
+ },
22
+ **{
23
+ f"sampling_topp_{str(topp).replace('.', '')}": {
24
+ "max_new_tokens": 200,
25
+ "do_sample": True,
26
+ "num_return_sequences": 8,
27
+ "top_p": 0.95,
28
+ }
29
+ for topp in [0.5, 0.8, 0.95, 0.99]
30
+ },
31
+ }
32
+
33
+ # add base.csv config to all configs
34
+ for key, value in GENERATION_CONFIGS.items():
35
+ GENERATION_CONFIGS[key] = {
36
+ # "max_length": 2048,
37
+ "min_length": 0,
38
+ "early_stopping": True,
39
+ **value,
40
+ }
41
+
42
+
43
+ def parse_args():
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument("--model_name", type=str, default="facebook/bart-large-cnn")
46
+ parser.add_argument("--dataset_name", type=str, default="amazon")
47
+ parser.add_argument("--dataset_path", type=str, default=None)
48
+ parser.add_argument(
49
+ "--decoding_config",
50
+ type=str,
51
+ default="top_p_sampling",
52
+ choices=GENERATION_CONFIGS.keys(),
53
+ )
54
+
55
+ parser.add_argument("--batch_size", type=int, default=16)
56
+ parser.add_argument("--device", type=str, default="cuda")
57
+
58
+ parser.add_argument("--output_dir", type=str, default="output")
59
+
60
+ # limit the number of samples to generate
61
+ parser.add_argument("--limit", type=int, default=None)
62
+
63
+ args = parser.parse_args()
64
+
65
+ return args
66
+
67
+
68
+ def prepare_dataset(dataset_name, dataset_path=None) -> Dataset:
69
+ dataset_path = Path(dataset_path)
70
+ if dataset_name == "amazon":
71
+ dataset = pd.read_csv(dataset_path / "amazon_test.csv")
72
+ elif dataset_name == "space":
73
+ dataset = pd.read_csv(dataset_path / "space.csv")
74
+ elif dataset_name == "yelp":
75
+ dataset = pd.read_csv(dataset_path / "yelp_test.csv")
76
+ elif dataset_name == "reviews":
77
+ dataset = pd.read_csv(dataset_path / "test_metareviews.csv")
78
+ elif dataset_name == "multi_news":
79
+ dataset = pd.read_csv(dataset_path / "multi_news.csv")
80
+ else:
81
+ raise ValueError(f"Unknown dataset {dataset_name}")
82
+
83
+ # make a dataset from the dataframe
84
+ dataset = Dataset.from_pandas(dataset)
85
+
86
+ return dataset
87
+
88
+
89
+ def evaluate_summarizer(model, tokenizer, dataset: Dataset, decoding_config) -> Dataset:
90
+ """
91
+ @param model: The model used to generate the summaries
92
+ @param tokenizer: The tokenizer used to tokenize the text and the summary
93
+ @param dataset: A dataset with the text
94
+ @param decoding_config: Dictoionary with the decoding config
95
+ @return: The same dataset with the summaries added
96
+ """
97
+
98
+ rsa = RSAContextualDecoding(model, tokenizer, device=model.device)
99
+
100
+ # generate summaries
101
+ summaries = []
102
+
103
+ print("Generating summaries...")
104
+
105
+ for id, batch in tqdm(dataset.to_pandas().groupby("id")):
106
+ text = batch["text"].tolist()
107
+
108
+ inputs = tokenizer(
109
+ text,
110
+ max_length=1024,
111
+ padding="max_length",
112
+ truncation=True,
113
+ return_tensors="pt",
114
+ )
115
+ batch_size = inputs["input_ids"].shape[0]
116
+
117
+ for k in tqdm(range(len(text))):
118
+ # move inputs to device
119
+ inputs = {key: value.to("cuda") for key, value in inputs.items()}
120
+
121
+ output = rsa.generate(
122
+ target_id=k,
123
+ source_texts_ids=inputs["input_ids"],
124
+ source_text_attention_mask=inputs["attention_mask"],
125
+ max_length=50,
126
+ top_p=0.95,
127
+ do_sample=True,
128
+ rationality=8.0,
129
+ temperature=1.0,
130
+ process_logits_before_rsa=True,
131
+ )
132
+ # output : (batch_size * num_return_sequences, max_length)
133
+ outputs = output[0]
134
+ summaries.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
135
+
136
+ # decode summaries
137
+
138
+ # add summaries to the huggingface dataset
139
+ dataset = dataset.add_column("summary", summaries)
140
+
141
+ return dataset
142
+
143
+
144
+ def sanitize_model_name(model_name: str) -> str:
145
+ """
146
+ Sanitize the model name to be used as a folder name.
147
+ @param model_name: The model name
148
+ @return: The sanitized model name
149
+ """
150
+ return model_name.replace("/", "_")
151
+
152
+
153
+ def main():
154
+ args = parse_args()
155
+
156
+ # load the model
157
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
158
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
159
+
160
+ tokenizer.pad_token = tokenizer.unk_token
161
+ tokenizer.pad_token_id = tokenizer.unk_token_id
162
+
163
+ # move model to device
164
+ model = model.to(args.device)
165
+
166
+ # load the dataset
167
+ print("Loading dataset...")
168
+ dataset = prepare_dataset(args.dataset_name, args.dataset_path)
169
+
170
+ # limit the number of samples
171
+ if args.limit is not None:
172
+ _lim = min(args.limit, len(dataset))
173
+ dataset = dataset.select(range(_lim))
174
+
175
+ # generate summaries
176
+ dataset = evaluate_summarizer(
177
+ model,
178
+ tokenizer,
179
+ dataset,
180
+ GENERATION_CONFIGS[args.decoding_config],
181
+ )
182
+
183
+ df_dataset = dataset.to_pandas()
184
+ df_dataset = df_dataset.explode("summary")
185
+ df_dataset = df_dataset.reset_index()
186
+ # add an idx with the id of the summary for each example
187
+ # df_dataset["id_candidate"] = df_dataset.groupby(["index"]).cumcount()
188
+
189
+ # save the dataset
190
+ # add unique date in name
191
+ now = datetime.datetime.now()
192
+ date = now.strftime("%Y-%m-%d-%H-%M-%S")
193
+ model_name = sanitize_model_name(args.model_name)
194
+ output_path = (
195
+ Path(args.output_dir)
196
+ / f"{model_name}-_-{args.dataset_name}-_-{args.decoding_config}-_-{date}.csv"
197
+ )
198
+
199
+ # create output dir if it doesn't exist
200
+ if not output_path.parent.exists():
201
+ output_path.parent.mkdir(parents=True, exist_ok=True)
202
+
203
+ df_dataset.to_csv(output_path, index=False, encoding="utf-8")
204
+
205
+
206
+ if __name__ == "__main__":
207
+ main()
glimpse-ui/glimpse/glimpse/src/compute_rsa.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pandas as pd
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PegasusTokenizer
5
+ import argparse
6
+ from tqdm import tqdm
7
+
8
+ from pickle import dump
9
+
10
+ import sys, os.path
11
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
12
+
13
+ from rsasumm.rsa_reranker import RSAReranking
14
+
15
+
16
+ DESC = """
17
+ Compute the RSA matrices for all the set of multi-document samples and dump these along with additional information in a pickle file.
18
+ """
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("--model_name", type=str, default="facebook/bart-large-cnn")
23
+ parser.add_argument("--summaries", type=Path, default="glimpse/data/candidates/extractive_sentences-_-all_reviews_2017-_-none-_-2025-05-20-20-22-18.csv")
24
+ parser.add_argument("--output_dir", type=str, default="glimpse/output")
25
+
26
+ parser.add_argument("--filter", type=str, default=None)
27
+
28
+ # if ran in a scripted way, the output path will be printed
29
+ parser.add_argument("--scripted-run", action=argparse.BooleanOptionalAction, default=False)
30
+
31
+ parser.add_argument("--device", type=str, default="cuda")
32
+
33
+ return parser.parse_args()
34
+
35
+
36
+ def parse_summaries(path: Path) -> pd.DataFrame:
37
+
38
+ try:
39
+ summaries = pd.read_csv(path)
40
+ except:
41
+ raise ValueError(f"Unknown dataset {path}")
42
+
43
+ # check if the dataframe has the right columns
44
+ if not all(
45
+ col in summaries.columns for col in ["index", "id", "text", "gold", "summary", "id_candidate"]
46
+ ):
47
+ raise ValueError(
48
+ "The dataframe must have columns ['index', 'id', 'text', 'gold', 'summary', 'id_candidate']"
49
+ )
50
+
51
+ return summaries
52
+
53
+
54
+ def compute_rsa(summaries: pd.DataFrame, model, tokenizer, device):
55
+ results = []
56
+ for name, group in tqdm(summaries.groupby(["id"])):
57
+ rsa_reranker = RSAReranking(
58
+ model,
59
+ tokenizer,
60
+ device=device,
61
+ candidates=group.summary.unique().tolist(),
62
+ source_texts=group.text.unique().tolist(),
63
+ rationality=1,
64
+ )
65
+ (
66
+ best_rsa,
67
+ best_base,
68
+ speaker_df,
69
+ listener_df,
70
+ initial_listener,
71
+ language_model_proba_df,
72
+ initial_consensuality_scores,
73
+ consensuality_scores,
74
+ ) = rsa_reranker.rerank(t=1)
75
+
76
+ gold = group['gold'].tolist()[0]
77
+
78
+ results.append(
79
+ {
80
+ "id": name,
81
+ "best_rsa": best_rsa, # best speaker score
82
+ "best_base": best_base, # naive baseline
83
+ "speaker_df": speaker_df, # all speaker results
84
+ "listener_df": listener_df, # all listener results (chances of guessing correctly)
85
+ "initial_listener": initial_listener,
86
+ "language_model_proba_df": language_model_proba_df,
87
+ "initial_consensuality_scores": initial_consensuality_scores,
88
+ "consensuality_scores": consensuality_scores, # consensuality scores
89
+ "gold": gold,
90
+ "rationality": 1, # hyperparameter
91
+ "text_candidates" : group
92
+ }
93
+ )
94
+
95
+ return results
96
+
97
+
98
+ def main():
99
+ args = parse_args()
100
+
101
+ if args.filter is not None:
102
+ if args.filter not in args.summaries.stem:
103
+ return
104
+
105
+ # load the model and the tokenizer
106
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
107
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
108
+
109
+ model = model.to(args.device)
110
+
111
+ # load the summaries
112
+ summaries = parse_summaries(args.summaries)
113
+
114
+ # rerank the summaries
115
+ results = compute_rsa(summaries, model, tokenizer, args.device)
116
+ results = {"results": results}
117
+
118
+ results["metadata/reranking_model"] = args.model_name
119
+ results["metadata/rsa_iterations"] = 1
120
+
121
+ # save the summaries
122
+ # make the output directory if it does not exist
123
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
124
+ output_path = Path(args.output_dir) / f"{args.summaries.stem}-_-r3-_-rsa_reranked-{args.model_name.replace('/', '-')}.pk"
125
+ output_path_base = (
126
+ Path(args.output_dir) / f"{args.summaries.stem}-_-base_reranked.pk"
127
+ )
128
+
129
+ with open(output_path, "wb") as f:
130
+ dump(results, f)
131
+
132
+ # in case of scripted run, print the output path
133
+ if args.scripted_run: print(output_path)
134
+
135
+
136
+ if __name__ == "__main__":
137
+ main()
glimpse-ui/glimpse/glimpse/src/rsa_merge_into_single.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ from pathlib import Path
4
+
5
+ import pandas as pd
6
+ from datasets import Dataset
7
+ from tqdm import tqdm
8
+
9
+
10
+ def parse_args():
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--summaries", type=Path)
13
+
14
+ parser.add_argument("--output_dir", type=str, default="output")
15
+
16
+ # limit the number of samples to generate
17
+ parser.add_argument("--limit", type=int, default=None)
18
+
19
+ args = parser.parse_args()
20
+
21
+ return args
22
+
23
+
24
+ def main():
25
+ args = parse_args()
26
+
27
+ path = Path(args.summaries)
28
+
29
+ for file in path.glob("*.csv"):
30
+ model_name, dataset, decoding_config, date, reranking_type = file.stem.split('-_-')
31
+ df = pd.read_csv(file)
32
+ df = df.drop(["Unnamed: 0.1", "Unnamed: 0", ], axis=1)
33
+
34
+ # df = df[['id', 'id_text', 'text', 'summary', 'gold']]
35
+
36
+
37
+ merged_summaries = df.groupby("id").agg({"summary": " ".join}).reset_index()
38
+
39
+ # add gold and text
40
+
41
+ merged_summaries = merged_summaries.merge(df[["id", "gold", "text"]], on="id")
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+ if __name__ == "__main__":
52
+ main()
glimpse-ui/glimpse/glimpse/src/rsa_reranking.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pandas as pd
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ import argparse
6
+ from tqdm import tqdm
7
+
8
+
9
+ from rsasumm.rsa_reranker import RSAReranking
10
+
11
+ def parse_args():
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--model_name", type=str, default="facebook/bart-large-cnn")
14
+ parser.add_argument("--summaries", type=Path, default="")
15
+ parser.add_argument("--output_dir", type=str, default="output")
16
+
17
+ parser.add_argument("--filter", type=str, default=None)
18
+
19
+ parser.add_argument("--device", type=str, default="cuda")
20
+
21
+ return parser.parse_args()
22
+
23
+
24
+ def parse_summaries(path : Path) -> pd.DataFrame:
25
+ summaries = pd.read_csv(path)
26
+
27
+ # check if the dataframe has the right columns
28
+ if not all(col in summaries.columns for col in ["id", "text", "id_candidate", "summary"]):
29
+ raise ValueError("The dataframe must have columns ['id', 'text', 'id_candidate', 'summary']")
30
+
31
+ return summaries
32
+
33
+ def reranking_rsa(summaries : pd.DataFrame, model, tokenizer, device):
34
+
35
+ best_summaries = []
36
+ best_bases = []
37
+ for name, group in tqdm(summaries.groupby(["id"])):
38
+ rsa_reranker = RSAReranking(model, tokenizer, device, group.summary.unique().tolist(), group.text.unique().tolist())
39
+ best_rsa, best_base, speaker_df, listener_df, initial_listener, language_model_proba_df = rsa_reranker.rerank(t=3)
40
+
41
+ group = group.set_index("summary")
42
+ group_lines = group.loc[best_rsa]
43
+ group_lines['speaker_proba'] = 0
44
+ group_lines['listener_proba'] = 0
45
+ group_lines['language_model_proba'] = 0
46
+ group_lines['initial_listener_proba'] = 0
47
+
48
+ group_lines = group_lines.reset_index()
49
+
50
+ for i, (idx, line) in enumerate(group_lines.iterrows()):
51
+ summary = line['summary']
52
+ text = line['text']
53
+
54
+ group_lines['speaker_proba'].loc[i] = speaker_df.loc[text, summary]
55
+ group_lines['listener_proba'].loc[i] = listener_df.loc[text, summary]
56
+ group_lines['language_model_proba'].loc[i] = language_model_proba_df.loc[text, summary]
57
+ group_lines['initial_listener_proba'].loc[i] = initial_listener.loc[text, summary]
58
+
59
+
60
+ group_lines["id"] = name
61
+ best_summaries.append(group_lines)
62
+
63
+ best_base_lines = group.loc[best_base]
64
+ best_base_lines = best_base_lines.reset_index()
65
+
66
+ best_base_lines['speaker_proba'] = 0
67
+ best_base_lines['listener_proba'] = 0
68
+ best_base_lines['language_model_proba'] = 0
69
+ best_base_lines['initial_listener_proba'] = 0
70
+
71
+ for i, (idx, line) in enumerate(best_base_lines.iterrows()):
72
+ summary = line['summary']
73
+ text = line['text']
74
+
75
+ best_base_lines['speaker_proba'].loc[i] = speaker_df.loc[text, summary]
76
+ best_base_lines['listener_proba'].loc[i] = listener_df.loc[text, summary]
77
+ best_base_lines['language_model_proba'].loc[i] = language_model_proba_df.loc[text, summary]
78
+ best_base_lines['initial_listener_proba'].loc[i] = initial_listener.loc[text, summary]
79
+
80
+
81
+ best_base_lines["id"] = name
82
+ best_bases.append(best_base_lines)
83
+
84
+ best_summaries = pd.concat(best_summaries)
85
+ best_bases = pd.concat(best_bases)
86
+
87
+ return best_summaries, best_bases
88
+
89
+
90
+ def main():
91
+ args = parse_args()
92
+
93
+ if args.filter is not None:
94
+ if args.filter not in args.summaries.stem:
95
+ return
96
+
97
+ # load the model and the tokenizer
98
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
99
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
100
+
101
+ model = model.to(args.device)
102
+
103
+ # load the summaries
104
+ summaries = parse_summaries(args.summaries)
105
+
106
+ # rerank the summaries
107
+ best_summaries, bast_base = reranking_rsa(summaries, model, tokenizer, device=args.device)
108
+
109
+ best_summaries['metadata/reranking_model'] = args.model_name
110
+ best_summaries['metadata/rsa_iterations'] = 3
111
+
112
+ bast_base['metadata/reranking_model'] = args.model_name
113
+ bast_base['metadata/rsa_iterations'] = 3
114
+
115
+
116
+ # save the summaries
117
+ # make the output directory if it does not exist
118
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
119
+ output_path = Path(args.output_dir) / f"{args.summaries.stem}-_-rsa_reranked.csv"
120
+ output_path_base = Path(args.output_dir) / f"{args.summaries.stem}-_-base_reranked.csv"
121
+
122
+ best_summaries.to_csv(output_path)
123
+ bast_base.to_csv(output_path_base)
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()
glimpse-ui/glimpse/mds/Single summaries expes.ipynb ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "64c5b118-5a32-4220-89f2-4e3ccd7a28d2",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "import matplotlib as mpl\n",
13
+ "# Use the pgf backend (must be set before pyplot imported)\n",
14
+ "mpl.use('pgf')\n",
15
+ "\n",
16
+ "import pandas as pd\n",
17
+ "import numpy as np\n",
18
+ "import matplotlib.pyplot as plt\n",
19
+ "import seaborn as sns\n",
20
+ "import re\n",
21
+ "from pathlib import Path"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "outputs": [],
27
+ "source": [],
28
+ "metadata": {
29
+ "collapsed": false
30
+ },
31
+ "id": "7893116c24574642"
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "outputs": [],
36
+ "source": [
37
+ "# use pgf backend\n",
38
+ "plt.style.use('seaborn-paper')\n"
39
+ ],
40
+ "metadata": {
41
+ "collapsed": false
42
+ },
43
+ "id": "f3dc93e0b2eb9894"
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "3806928d-0624-4d9f-905f-3bf41b9725f1",
49
+ "metadata": {
50
+ "tags": []
51
+ },
52
+ "outputs": [],
53
+ "source": [
54
+ "sumy_individual_path = Path('output/summaries/sumy_individual/')\n",
55
+ "ours_individual_path = Path('output/summaries/methods_reviews_individual/')\n",
56
+ "\n",
57
+ "TABLE_PATH = Path(\"../../../EMIRR/papers/rsa_multi_document/tables/\")\n",
58
+ "FIGURE = Path(\"../../../EMIRR/papers/rsa_multi_document/figures/\")\n",
59
+ "\n",
60
+ "# make sure the folder exists\n",
61
+ "TABLE_PATH.mkdir(parents=True, exist_ok=True)\n",
62
+ "FIGURE.mkdir(parents=True, exist_ok=True)"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "id": "141a8192-773a-40af-a891-620a6ab81efd",
69
+ "metadata": {
70
+ "tags": []
71
+ },
72
+ "outputs": [],
73
+ "source": [
74
+ "\n",
75
+ "dfs = []\n",
76
+ "for file in sumy_individual_path.glob('*.csv'):\n",
77
+ " df = pd.read_csv(file)\n",
78
+ " method = file.stem.split('-_-')[1]\n",
79
+ " \n",
80
+ " sumy = file.stem.split('-_-')[-1].split('_')\n",
81
+ " if len(sumy) > 1:\n",
82
+ " sentence_count = int(sumy[-1])\n",
83
+ " df['metadata/sentence_count'] = sentence_count\n",
84
+ "\n",
85
+ " # df['Method'] = method\n",
86
+ " dfs.append(df)\n",
87
+ " \n",
88
+ " \n",
89
+ "for file in ours_individual_path.glob('*.csv'):\n",
90
+ " generation_method, dataset, generation_params, date, rsa_param, rsa_ranking_model, method = file.stem.split('-_-')\n",
91
+ " \n",
92
+ " method, n = \"_\".join( method.split('_')[:-1]), method.split('_')[-1]\n",
93
+ " \n",
94
+ " if \"metadata/method\" not in df.columns:\n",
95
+ " df['metadata/method'] = method\n",
96
+ " \n",
97
+ "# reranking_model = rsa_ranking_model[len(\"rsa_reranked-\"):]\n",
98
+ " \n",
99
+ "# df['Ranking Model'] = reranking_model\n",
100
+ "# df['Method'] = method\n",
101
+ "# df['N'] = int(n) if n != \"based\" else 3 \n",
102
+ " df['Generation Method'] = generation_method\n",
103
+ " \n",
104
+ " df = pd.read_csv(file)\n",
105
+ " dfs.append(df)\n",
106
+ " \n",
107
+ "df = pd.concat(dfs)\n",
108
+ "del dfs\n",
109
+ "\n",
110
+ "df = df.drop([c for c in df.columns if \"Unnamed\" in c], axis=1)\n",
111
+ "\n"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "outputs": [],
117
+ "source": [
118
+ "\n",
119
+ "df['metadata/method'] = df['metadata/method'].fillna('N/A')\n",
120
+ "df = df[~(df[\"metadata/method\"].str.contains('lead'))]\n",
121
+ "df = df[~(df[\"metadata/method\"].str.contains('Lead'))]\n",
122
+ "\n"
123
+ ],
124
+ "metadata": {
125
+ "collapsed": false
126
+ },
127
+ "id": "804fb4bacf2686d4",
128
+ "execution_count": null
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "id": "419db7d5-b90b-47a3-9c25-f39eff337849",
134
+ "metadata": {
135
+ "tags": []
136
+ },
137
+ "outputs": [],
138
+ "source": [
139
+ "def fix_generation(x):\n",
140
+ " if x == \"abstractive_sentences\":\n",
141
+ " return \"extractive_sentences\"\n",
142
+ " else:\n",
143
+ " return x\n",
144
+ "\n",
145
+ "\n",
146
+ "df['Generation Method'] = df[\"Generation Method\"].apply(fix_generation)"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "id": "4a157db9-e408-46e8-9499-2751e4cbe7e4",
153
+ "metadata": {
154
+ "tags": []
155
+ },
156
+ "outputs": [],
157
+ "source": [
158
+ "df['N'] = (df['metadata/n_sentences'].fillna(0) + df['metadata/sentence_count'].fillna(0)).apply(int)\n",
159
+ "\n",
160
+ "def fix_methods(x):\n",
161
+ "\n",
162
+ " if \"consensus\" in str(x):\n",
163
+ " return \"Agreement\"\n",
164
+ " elif \"rsa\" in str(x):\n",
165
+ " return \"Speaker+Agreement\"\n",
166
+ " else:\n",
167
+ " return x\n",
168
+ " \n",
169
+ "df['metadata/method'] = df['metadata/method'].apply(fix_methods)\n",
170
+ "\n"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "id": "4b926a30-f4be-4de9-b921-1eac3145e87e",
177
+ "metadata": {
178
+ "tags": []
179
+ },
180
+ "outputs": [],
181
+ "source": [
182
+ "df['metadata/sentence_count'].unique()"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "id": "0451f550-0522-40e9-a64a-551337ae47f6",
189
+ "metadata": {
190
+ "tags": []
191
+ },
192
+ "outputs": [],
193
+ "source": [
194
+ "\n",
195
+ " "
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": null,
201
+ "id": "7d213f27-9b7f-4893-b2c1-251715401db5",
202
+ "metadata": {
203
+ "tags": []
204
+ },
205
+ "outputs": [],
206
+ "source": [
207
+ "\n",
208
+ "metric= 'SHMetric/Main ideas/proba_1'\n",
209
+ "\n",
210
+ "SHMetric = df.columns[df.columns.str.contains('SHMetric') & df.columns.str.contains('proba_1')].tolist()\n",
211
+ "\n",
212
+ "toplot = df.copy()\n",
213
+ "toplot['metadata/reranking_model'] = toplot['metadata/reranking_model'].fillna('N/A')\n",
214
+ "toplot['Generation Method'] = toplot['Generation Method'].fillna('N/A')\n",
215
+ "\n",
216
+ "\n",
217
+ "toplot = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\", \"metadata/reranking_model\"]).mean()\n",
218
+ "idx = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\"])[metric].idxmax()\n",
219
+ "\n",
220
+ "toplot = toplot.loc[idx].reset_index()\n",
221
+ "\n",
222
+ "avg = toplot.groupby([\"metadata/method\"]).agg(['mean', 'std'])\n",
223
+ "avg = avg[SHMetric]\n",
224
+ "\n",
225
+ "display(avg)\n",
226
+ "\n",
227
+ "# rename columns Consiness, Main ideas, Repetition\n",
228
+ "avg.columns = pd.MultiIndex.from_tuples([(f'{c[0].split(\"/\")[1]}', c[1]) for c in avg.columns])\n",
229
+ "\n",
230
+ "def map_ours(x):\n",
231
+ " if \"Agreement\" in x:\n",
232
+ " return \"Ours\"\n",
233
+ " else:\n",
234
+ " return \"Bas.\"\n",
235
+ "\n",
236
+ "\n",
237
+ "avg = avg.groupby([\"metadata/method\"]).mean()\n",
238
+ "\n",
239
+ "avg['Ours'] = avg.index.get_level_values(0).map(map_ours)\n",
240
+ "\n",
241
+ "\n",
242
+ "avg = avg.reset_index().rename(columns={'metadata/method': 'Method'})\n",
243
+ "avg = avg.set_index(['Ours', 'Method'])\n",
244
+ "avg = avg.sort_index()\n",
245
+ "\n",
246
+ "# print avg columns level 0\n",
247
+ "print(avg.columns.get_level_values(0))\n",
248
+ "\n",
249
+ "#Index(['Comprehensible', 'Comprehensible', 'Repetition', 'Repetition',\n",
250
+ " # 'Grammar', 'Grammar', 'Attribution', 'Attribution', 'Main ideas',\n",
251
+ " # 'Main ideas', 'Conciseness', 'Conciseness'],\n",
252
+ " # dtype='object')\n",
253
+ " \n",
254
+ "# rename columns with shorter names\n",
255
+ "avg.columns = pd.MultiIndex.from_tuples([\n",
256
+ " ('Compr.', 'mean'), ('Compr.', 'std'),\n",
257
+ " ('Repet.', 'mean'), ('Repet.', 'std'),\n",
258
+ " ('Gram.', 'mean'), ('Gram.', 'std'),\n",
259
+ " ('Attr.', 'mean'), ('Attr.', 'std'),\n",
260
+ " ('M. i.', 'mean'), ('M. i.', 'std'),\n",
261
+ " ('Conc.', 'mean'), ('Conc.', 'std')\n",
262
+ "])\n",
263
+ "\n",
264
+ "\n",
265
+ "style = avg.style\n",
266
+ "style = style.format(\"{:.2f}\")\n",
267
+ "\n",
268
+ "# make std column smaller and lighter in latex\n",
269
+ "idx = pd.IndexSlice\n",
270
+ "# style = style.set_properties(subset=idx[:, ['std']], **{'font-size': '10pt', 'font-weight': 'lighter'})\n",
271
+ "\n",
272
+ "# bold the best value in each mean column\n",
273
+ "style = style.highlight_max(axis=0, subset=idx[:, idx[:, 'mean']], props=\"bfseries: ;\")\n",
274
+ "\n",
275
+ "# make std columns smaller and add +/- sign\n",
276
+ "style = style.set_properties(**{'color':'[HTML]{A0A1A3}'} ,subset=(idx[:], idx[:, 'std']))\n",
277
+ "style = style.format(\"±{:.2f}\", subset=(idx[:], idx[:, 'std']))\n",
278
+ "\n",
279
+ "# drop level 1 of columns\n",
280
+ "style = style.hide_columns(level=1)\n",
281
+ "\n",
282
+ "# to latex\n",
283
+ "latex = style.to_latex(clines=\"skip-last;data\", hrules=True, multirow_align=\"l\", environment=\"table*\", caption=\"Estimated human judgment using the SEAHORSE metrics for all baselines and our templated summaries compared against each document independently. M. i. stands for Main ideas, Attr. for Attribution, Gram. for Grammar, Compr. for Comprehensible, Conc. for Conciseness, and Repet. for Repetition. The best value in each column is in bold.\")\n",
284
+ "display(style)\n",
285
+ "\n",
286
+ "# add resize box\n",
287
+ "latex = latex.replace(\"\\\\begin{tabular}\", \"\\\\resizebox{\\\\textwidth}{!}{\\\\begin{tabular}\")\n",
288
+ "latex = latex.replace(\"\\\\end{tabular}\", \"\\\\end{tabular}}\")\n",
289
+ "\n",
290
+ "\n",
291
+ "# replace \n",
292
+ "\n",
293
+ "# write to file\n",
294
+ "with open(TABLE_PATH / \"seahorse.tex\", \"w\") as f:\n",
295
+ " f.write(latex)\n",
296
+ "\n",
297
+ "\n",
298
+ "\n",
299
+ "\n",
300
+ "\n",
301
+ "\n",
302
+ "# display(avg)\n",
303
+ "# avg.set_index('Method')\""
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": null,
309
+ "id": "cd4c7707-f1cd-4810-972b-1a69e0ec68ae",
310
+ "metadata": {
311
+ "tags": []
312
+ },
313
+ "outputs": [],
314
+ "source": [
315
+ "metric='SHMetric/Main ideas/proba_1'\n",
316
+ "# white grid\n",
317
+ "sns.set(style=\"whitegrid\")\n",
318
+ "avg = df.groupby([\"metadata/method\", \"id\", \"metadata/reranking_model\", \"Generation Method\"]).mean().reset_index()\n",
319
+ "avg = avg.sort_values(metric)\n",
320
+ "\n",
321
+ "# rename columns with human readable names\n",
322
+ "avg = avg.rename(columns={\n",
323
+ " 'metadata/method': 'Method',\n",
324
+ " 'metadata/reranking_model': 'Reranking Model',\n",
325
+ " 'Generation Method': 'Generation Method',\n",
326
+ " metric: 'Main Ideas'\n",
327
+ "})\n",
328
+ "\n",
329
+ "\n",
330
+ "\n",
331
+ "g = sns.catplot(data=avg, y=\"Main Ideas\", x=\"Method\", hue=\"Reranking Model\", col=\"Generation Method\", kind=\"bar\")\n",
332
+ "\n",
333
+ "\n",
334
+ "# get legend label and handle\n",
335
+ "handles, labels = g._legend_data.values(), g._legend_data.keys()\n",
336
+ "\n",
337
+ "# set legend\n",
338
+ "g._legend.remove()\n",
339
+ "g.fig.legend(handles, labels, loc='upper center', ncol=2, fontsize=25, title_fontsize=25, title=\"Reranking Model\", bbox_to_anchor=(0.4, -0.3))\n",
340
+ "\n",
341
+ "\n",
342
+ "# set title template \n",
343
+ "g.set_titles(\"{col_name}\")\n",
344
+ "\n",
345
+ "# add hline at 0.215 for the baseline, on each axis\n",
346
+ "for ax in g.axes.flat:\n",
347
+ " ax.axhline(0.215, ls='--', color='black', linewidth=5)\n",
348
+ " ax.set_xticklabels(ax.get_xticklabels(), rotation=30)\n",
349
+ " \n",
350
+ "# make label bigger\n",
351
+ "for ax in g.axes.flat:\n",
352
+ " ax.set_xlabel(\"\")\n",
353
+ " ax.set_ylabel(ax.get_ylabel(), fontsize=25, fontweight='bold')\n",
354
+ " ax.set_xticklabels(ax.get_xticklabels(), fontsize=25, fontweight='bold')\n",
355
+ " \n",
356
+ "# make title bigger\n",
357
+ "for ax in g.axes.flat:\n",
358
+ " ax.set_title(ax.get_title(), fontsize=25, fontweight='bold')\n",
359
+ " \n",
360
+ "# add annotation for the hline on the first axis\n",
361
+ "\n",
362
+ "\n",
363
+ "\n",
364
+ "\n",
365
+ "\n",
366
+ "plt.xticks(rotation=30)\n",
367
+ "\n",
368
+ "# save figure\n",
369
+ "g.savefig(FIGURE / \"seahorse_main_ideas.pdf\")\n",
370
+ "\n",
371
+ "\n",
372
+ "\n"
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "code",
377
+ "execution_count": null,
378
+ "id": "f84773d4-f2b7-4db6-851f-4683d545345b",
379
+ "metadata": {
380
+ "tags": []
381
+ },
382
+ "outputs": [],
383
+ "source": [
384
+ "metric='SHMetric/Main ideas/proba_1'\n",
385
+ "\n",
386
+ "toplot = df.copy()\n",
387
+ "toplot['metadata/reranking_model'] = toplot['metadata/reranking_model'].fillna('N/A')\n",
388
+ "toplot['Generation Method'] = toplot['Generation Method'].fillna('N/A')\n",
389
+ "\n",
390
+ "toplot = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\", \"metadata/reranking_model\"]).mean()\n",
391
+ "idx = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\"])[metric].idxmax()\n",
392
+ "toplot = toplot.loc[idx].reset_index()\n",
393
+ "toplot = toplot[~toplot['metadata/method'].str.contains('Lead')]\n",
394
+ "\n",
395
+ "toplot = toplot.sort_values(metric, ascending=True)\n",
396
+ "order = toplot.groupby(\"metadata/method\").mean().sort_values(metric)\n",
397
+ "\n",
398
+ "\n",
399
+ "display(toplot.groupby(\"metadata/method\").mean().sort_values(metric)[metric])\n",
400
+ "\n",
401
+ "sns.barplot(data=toplot, y=metric, x=\"metadata/method\", order=order.index)\n",
402
+ "\n",
403
+ "plt.xticks(rotation=45)"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": null,
409
+ "id": "47f15195-e60d-45f8-aac0-dc88eeea9577",
410
+ "metadata": {
411
+ "tags": []
412
+ },
413
+ "outputs": [],
414
+ "source": [
415
+ "metric='SHMetric/Conciseness/proba_1'\n",
416
+ "\n",
417
+ "toplot = df.copy()\n",
418
+ "toplot['metadata/reranking_model'] = toplot['metadata/reranking_model'].fillna('N/A')\n",
419
+ "toplot['Generation Method'] = toplot['Generation Method'].fillna('N/A')\n",
420
+ "\n",
421
+ "toplot = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\", \"metadata/reranking_model\"]).mean()\n",
422
+ "idx = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\"])[metric].idxmax()\n",
423
+ "toplot = toplot.loc[idx].reset_index()\n",
424
+ "toplot = toplot[~toplot['metadata/method'].str.contains('Lead')]\n",
425
+ "\n",
426
+ "toplot = toplot.sort_values(metric, ascending=True)\n",
427
+ "order = toplot.groupby(\"metadata/method\").mean().sort_values(metric)\n",
428
+ "\n",
429
+ "\n",
430
+ "\n",
431
+ "sns.barplot(data=toplot, y=metric, x=\"metadata/method\", order=order.index)\n",
432
+ "\n",
433
+ "plt.xticks(rotation=45)"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": null,
439
+ "id": "96b5d47b-f9d2-45bf-9284-844dedb24ce9",
440
+ "metadata": {
441
+ "tags": []
442
+ },
443
+ "outputs": [],
444
+ "source": [
445
+ "metric='SHMetric/Repetition/proba_1'\n",
446
+ "\n",
447
+ "toplot = df.copy()\n",
448
+ "toplot['metadata/reranking_model'] = toplot['metadata/reranking_model'].fillna('N/A')\n",
449
+ "toplot['Generation Method'] = toplot['Generation Method'].fillna('N/A')\n",
450
+ "\n",
451
+ "toplot = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\", \"metadata/reranking_model\"]).mean()\n",
452
+ "idx = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\"])[metric].idxmax()\n",
453
+ "toplot = toplot.loc[idx].reset_index()\n",
454
+ "toplot = toplot[~toplot['metadata/method'].str.contains('Lead')]\n",
455
+ "\n",
456
+ "toplot = toplot.sort_values(metric, ascending=True)\n",
457
+ "order = toplot.groupby(\"metadata/method\").mean().sort_values(metric)\n",
458
+ "\n",
459
+ "\n",
460
+ "\n",
461
+ "sns.barplot(data=toplot, y=metric, x=\"metadata/method\", order=order.index)\n",
462
+ "\n",
463
+ "plt.xticks(rotation=45)"
464
+ ]
465
+ },
466
+ {
467
+ "cell_type": "code",
468
+ "execution_count": null,
469
+ "id": "be9a5198-40e9-4b60-8ed4-6b2ea3a9683e",
470
+ "metadata": {},
471
+ "outputs": [],
472
+ "source": [
473
+ "metric='SHMetric/Repetition/proba_1'\n",
474
+ "\n",
475
+ "toplot = df.copy()\n",
476
+ "toplot['metadata/reranking_model'] = toplot['metadata/reranking_model'].fillna('N/A')\n",
477
+ "toplot['Generation Method'] = toplot['Generation Method'].fillna('N/A')\n",
478
+ "\n",
479
+ "toplot = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\", \"metadata/reranking_model\"]).mean()\n",
480
+ "idx = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\"])[metric].idxmax()\n",
481
+ "toplot = toplot.loc[idx].reset_index()\n",
482
+ "toplot = toplot[~toplot['metadata/method'].str.contains('Lead')]\n",
483
+ "\n",
484
+ "toplot = toplot.sort_values(metric, ascending=True)\n",
485
+ "order = toplot.groupby(\"metadata/method\").mean().sort_values(metric)\n",
486
+ "\n",
487
+ "\n",
488
+ "\n",
489
+ "sns.barplot(data=toplot, y=metric, x=\"metadata/method\", order=order.index)\n",
490
+ "\n",
491
+ "plt.xticks(rotation=45)"
492
+ ]
493
+ },
494
+ {
495
+ "cell_type": "code",
496
+ "execution_count": null,
497
+ "id": "6e52ec5a-3f68-4fdf-9828-205cd9b55e42",
498
+ "metadata": {
499
+ "tags": []
500
+ },
501
+ "outputs": [],
502
+ "source": [
503
+ "metric='SHMetric/Main ideas/proba_1'\n",
504
+ "\n",
505
+ "avg = df.groupby([\"metadata/method\", \"id\", \"N\"]).mean().reset_index()\n",
506
+ "avg = avg.sort_values(metric)\n",
507
+ "sns.barplot(data=avg[~avg['metadata/method'].str.contains('Lead')], y=metric, x=\"metadata/method\", hue='N')\n",
508
+ "plt.xticks(rotation=45)"
509
+ ]
510
+ },
511
+ {
512
+ "cell_type": "code",
513
+ "execution_count": null,
514
+ "id": "a675a584-c80f-49ca-94e3-55d868c8b594",
515
+ "metadata": {
516
+ "tags": []
517
+ },
518
+ "outputs": [],
519
+ "source": [
520
+ "metric='SHMetric/Main ideas/proba_1'\n",
521
+ "\n",
522
+ "avg = df.groupby([\"metadata/method\", \"id\"]).mean().reset_index()\n",
523
+ "avg = avg[~avg['metadata/method'].str.contains('Lead')].sort_values(metric, )\n",
524
+ "sns.barplot(data=avg, y=metric, x=\"metadata/method\")\n",
525
+ "plt.xticks(rotation=45)"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "code",
530
+ "execution_count": null,
531
+ "id": "79f7e781-0e34-458d-826d-67c08109cef7",
532
+ "metadata": {
533
+ "tags": []
534
+ },
535
+ "outputs": [],
536
+ "source": [
537
+ "metric='rougeL'\n",
538
+ "\n",
539
+ "avg = df.groupby([\"metadata/method\", \"id\"]).mean().reset_index()\n",
540
+ "avg = avg[~avg['metadata/method'].str.contains('Lead')].sort_values(metric, )\n",
541
+ "sns.barplot(data=avg, y=metric, x=\"metadata/method\")\n",
542
+ "plt.xticks(rotation=45)"
543
+ ]
544
+ },
545
+ {
546
+ "cell_type": "code",
547
+ "execution_count": null,
548
+ "id": "8b65306c-9f5e-4d71-910a-48de9a195534",
549
+ "metadata": {
550
+ "tags": []
551
+ },
552
+ "outputs": [],
553
+ "source": [
554
+ "df.columns"
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "code",
559
+ "execution_count": null,
560
+ "id": "cf6abd56-e9d1-4cea-bd01-d03046b56f3d",
561
+ "metadata": {},
562
+ "outputs": [],
563
+ "source": []
564
+ }
565
+ ],
566
+ "metadata": {
567
+ "kernelspec": {
568
+ "name": "python3",
569
+ "language": "python",
570
+ "display_name": "Python 3 (ipykernel)"
571
+ },
572
+ "language_info": {
573
+ "codemirror_mode": {
574
+ "name": "ipython",
575
+ "version": 3
576
+ },
577
+ "file_extension": ".py",
578
+ "mimetype": "text/x-python",
579
+ "name": "python",
580
+ "nbconvert_exporter": "python",
581
+ "pygments_lexer": "ipython3",
582
+ "version": "3.10.9"
583
+ }
584
+ },
585
+ "nbformat": 4,
586
+ "nbformat_minor": 5
587
+ }
glimpse-ui/glimpse/mds/Template summaries.ipynb ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "12404068-3244-43d6-8556-41e11489bb48",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "import pickle as pk\n",
13
+ "import pandas as pd\n",
14
+ "from pathlib import Path\n",
15
+ "import numpy as np\n",
16
+ "import seaborn as sns\n",
17
+ "\n",
18
+ "from rouge_score import rouge_scorer\n",
19
+ "\n",
20
+ "\n",
21
+ "from lexrank import LexRank\n",
22
+ "from lexrank.mappings.stopwords import STOPWORDS\n",
23
+ "import nltk \n",
24
+ "\n"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "id": "044c82d6-23c8-4c3b-a4b5-12acbbc1cc1a",
31
+ "metadata": {
32
+ "tags": []
33
+ },
34
+ "outputs": [],
35
+ "source": []
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "id": "1f5cdf3f-5485-4b9f-a6fc-b2b8bd8aca7f",
41
+ "metadata": {
42
+ "tags": []
43
+ },
44
+ "outputs": [],
45
+ "source": [
46
+ "\n",
47
+ "\n",
48
+ "\n",
49
+ "path = Path(\"output/summaries/rsa_reranking/reviews_rsa_matrices/\")\n",
50
+ "output_path = Path(\"output/summaries/methods_reviews/\")\n",
51
+ "\n"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "id": "ec5e8ff3-6bef-42bc-8430-df93c1a4e79a",
58
+ "metadata": {
59
+ "tags": []
60
+ },
61
+ "outputs": [],
62
+ "source": []
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "id": "d45cd444-0a81-4670-bbae-213e322ea281",
68
+ "metadata": {
69
+ "tags": []
70
+ },
71
+ "outputs": [],
72
+ "source": []
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "id": "9e6c2863-9e5a-4e5a-bdb6-02e03c5f6105",
78
+ "metadata": {
79
+ "tags": []
80
+ },
81
+ "outputs": [],
82
+ "source": []
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "id": "2ae23148-38ae-4385-99fb-db20da54334d",
88
+ "metadata": {
89
+ "tags": []
90
+ },
91
+ "outputs": [],
92
+ "source": []
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "id": "d15ce19b-a3c7-4554-878f-41acd3204878",
98
+ "metadata": {
99
+ "tags": []
100
+ },
101
+ "outputs": [],
102
+ "source": []
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "id": "1652ff75-b6c2-483a-a9af-ff3ca8616756",
108
+ "metadata": {
109
+ "tags": []
110
+ },
111
+ "outputs": [],
112
+ "source": []
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "id": "983cd24c-b996-4224-8f87-ea79842c41a0",
118
+ "metadata": {
119
+ "tags": []
120
+ },
121
+ "outputs": [],
122
+ "source": []
123
+ },
124
+ {
125
+ "cell_type": "markdown",
126
+ "id": "0db0e8f5-4b4a-4d55-8596-fe095aa4135f",
127
+ "metadata": {},
128
+ "source": [
129
+ "# Consensus score based summaries:"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "id": "560c7f8d-6b8e-4b5b-ba36-0dedc509791f",
136
+ "metadata": {
137
+ "tags": []
138
+ },
139
+ "outputs": [],
140
+ "source": []
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "id": "662a5f1b-1e4d-458e-a437-b1d9c8db4552",
146
+ "metadata": {
147
+ "tags": []
148
+ },
149
+ "outputs": [],
150
+ "source": [
151
+ "def consensus_scores_based_summaries(sample, n_consensus=3, n_dissensus=3):\n",
152
+ " consensus_samples = sample['consensuality_scores'].sort_values(ascending=True).head(n_consensus).index.tolist()\n",
153
+ " disensus_samples = sample['consensuality_scores'].sort_values(ascending=False).head(n_dissensus).index.tolist()\n",
154
+ " \n",
155
+ " consensus = \".\".join(consensus_samples)\n",
156
+ " disensus = \".\".join(disensus_samples)\n",
157
+ " \n",
158
+ " return consensus + \"\\n\\n\" + disensus\n",
159
+ " \n",
160
+ " \n",
161
+ "def rsa_scores_based_summaries(sample, n_consensus=3, n_rsa_speaker=3):\n",
162
+ " consensus_samples = sample['consensuality_scores'].sort_values(ascending=True).head(n_consensus).index.tolist()\n",
163
+ " rsa = sample['best_rsa'].tolist()[:n_rsa_speaker]\n",
164
+ " \n",
165
+ " consensus = \".\".join(consensus_samples)\n",
166
+ " rsa = \".\".join(rsa)\n",
167
+ " \n",
168
+ " return consensus + \"\\n\\n\" + rsa\n",
169
+ "\n",
170
+ "scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)\n",
171
+ "\n",
172
+ "def lead(sample, N=10):\n",
173
+ " texts = sample['speaker_df'].index.tolist()\n",
174
+ " \n",
175
+ " summary = \"\\n\".join([\".\".join(t.split('.')[:N]) for t in texts])\n",
176
+ " \n",
177
+ " return summary\n",
178
+ "\n",
179
+ " \n",
180
+ " \n",
181
+ "\n",
182
+ "scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)\n",
183
+ "\n",
184
+ "\n",
185
+ "def construct_templated_summaries(data, fn, dataset=None): \n",
186
+ " records = []\n",
187
+ " for sample in data['results']:\n",
188
+ " summary = fn(sample)\n",
189
+ " text = \"\\n\\n\".join(sample['speaker_df'].index.tolist())\n",
190
+ " record = {'id' : sample['id'], 'summary': summary, 'metadata/reranking_model' : data['metadata/reranking_model'], 'metadata/rsa_iterations' : data['metadata/reranking_model'], \"text\": text}\n",
191
+ " if dataset is not None:\n",
192
+ " record['gold'] = dataset.loc[sample[\"id\"]]['gold'].tolist()[0]\n",
193
+ " if record['gold'] is not None:\n",
194
+ " rouges = scorer.score(summary, record['gold'])\n",
195
+ " record |= {r : v.fmeasure for r, v in rouges.items()}\n",
196
+ " \n",
197
+ " \n",
198
+ " \n",
199
+ " records.append(record)\n",
200
+ " \n",
201
+ " return pd.DataFrame.from_records(records)\n",
202
+ " \n",
203
+ "\n",
204
+ " \n",
205
+ " \n"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": null,
211
+ "id": "39c45180-a354-429c-8cde-3a7c78013cc6",
212
+ "metadata": {
213
+ "tags": []
214
+ },
215
+ "outputs": [],
216
+ "source": [
217
+ "def prepare_dataset(dataset_name, dataset_path=\"data/processed/\"):\n",
218
+ " dataset_path = Path(dataset_path)\n",
219
+ " if dataset_name == \"amazon\":\n",
220
+ " dataset = pd.read_csv(dataset_path / \"amazon_test.csv\")\n",
221
+ " elif dataset_name == \"space\":\n",
222
+ " dataset = pd.read_csv(dataset_path / \"space.csv\")\n",
223
+ " elif dataset_name == \"yelp\":\n",
224
+ " dataset = pd.read_csv(dataset_path / \"yelp_test.csv\")\n",
225
+ " elif dataset_name == \"reviews\":\n",
226
+ " dataset = pd.read_csv(dataset_path / \"test_metareviews.csv\")\n",
227
+ " else:\n",
228
+ " raise ValueError(f\"Unknown dataset {dataset_name}\")\n",
229
+ "\n",
230
+ "\n",
231
+ " return dataset\n"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": null,
237
+ "id": "addeda2b-71fc-4c9a-8e91-12cf70e52b1e",
238
+ "metadata": {},
239
+ "outputs": [],
240
+ "source": [
241
+ "# df = prepare_dataset('reviews')\n",
242
+ "\n",
243
+ "# for n, group in df.groupby('id'):\n",
244
+ "# for idx, row in group.iterrows():\n",
245
+ "# print(row['text'].replace('-----', \"\\n\"))\n",
246
+ "# print(\"===========\")\n",
247
+ "# break\n",
248
+ "rsa_scores_based_summaries"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": null,
254
+ "id": "7fe212dd-63d2-44b1-b06e-7792b9d504ac",
255
+ "metadata": {
256
+ "tags": []
257
+ },
258
+ "outputs": [],
259
+ "source": [
260
+ "\n",
261
+ "\n",
262
+ "for n in [1, 2, 3, 4, 5, 6]:\n",
263
+ " for file in path.glob(\"*.pk\"):\n",
264
+ " print(file)\n",
265
+ " with file.open('rb') as fd:\n",
266
+ " data = pk.load(fd)\n",
267
+ "\n",
268
+ " Path(output_path).mkdir(parents=True, exist_ok=True)\n",
269
+ " model_name, dataset_name, decoding_config, date = str(file.stem).split('-_-')[:4]\n",
270
+ "\n",
271
+ " dataset = prepare_dataset(dataset_name, dataset_path=\"data/processed/\")\n",
272
+ " dataset = dataset.set_index('id')\n",
273
+ " \n",
274
+ " fn = lambda sample: consensus_scores_based_summaries(sample, n_consensus=n, n_dissensus=n)\n",
275
+ "\n",
276
+ " df = construct_templated_summaries(data, fn, dataset=dataset)\n",
277
+ " \n",
278
+ " df['metadata/method'] = \"Agreement\"\n",
279
+ " df['metadata/n_sentences'] = 2*n\n",
280
+ " df['metadata/n_consensus'] = n\n",
281
+ " df['metadata/n_dissensus'] = n\n",
282
+ "\n",
283
+ " name = file.stem + \"-_-\" + f\"consensus_score_based_{n}.csv\"\n",
284
+ "\n",
285
+ " if (output_path / name).exists():\n",
286
+ " df_old = pd.read_csv(output_path / name)\n",
287
+ "\n",
288
+ " for col in df.columns:\n",
289
+ " if col not in df_old.columns:\n",
290
+ " df_old[col] = float(\"nan\")\n",
291
+ "\n",
292
+ " # add entry to the dataframe\n",
293
+ " for col in df.columns:\n",
294
+ " df_old[col] = df[col]\n",
295
+ "\n",
296
+ " df = df_old\n",
297
+ "\n",
298
+ " df.to_csv(output_path / name)\n",
299
+ " \n",
300
+ " \n",
301
+ " \n",
302
+ " "
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": null,
308
+ "id": "ab45111b-9c9f-44ee-8cc1-613bfa32a007",
309
+ "metadata": {
310
+ "tags": []
311
+ },
312
+ "outputs": [],
313
+ "source": [
314
+ "\n",
315
+ "for n in [1, 2, 3, 4, 5, 6]:\n",
316
+ " for file in path.glob(\"*.pk\"):\n",
317
+ " with file.open('rb') as fd:\n",
318
+ " data = pk.load(fd)\n",
319
+ "\n",
320
+ " Path(output_path).mkdir(parents=True, exist_ok=True)\n",
321
+ " model_name, dataset_name, decoding_config, date = str(file.stem).split('-_-')[:4]\n",
322
+ "\n",
323
+ " dataset = prepare_dataset(dataset_name, dataset_path=\"data/processed/\")\n",
324
+ " dataset = dataset.set_index('id')\n",
325
+ "\n",
326
+ " fn = lambda sample: rsa_scores_based_summaries(sample, n_consensus=n, n_rsa_speaker=n)\n",
327
+ " df = construct_templated_summaries(data, fn, dataset=dataset)\n",
328
+ "\n",
329
+ " df['metadata/method'] = \"Speaker+Agreement\"\n",
330
+ " df['metadata/n_sentences'] = 2*n\n",
331
+ " df['metadata/n_consensus'] = n\n",
332
+ " df['metadata/n_dissensus'] = n\n",
333
+ "\n",
334
+ " name = file.stem + \"-_-\" + f\"rsa_score_based_{n}.csv\"\n",
335
+ "\n",
336
+ " if (output_path / name).exists():\n",
337
+ " df_old = pd.read_csv(output_path / name)\n",
338
+ "\n",
339
+ " for col in df.columns:\n",
340
+ " if col not in df_old.columns:\n",
341
+ " df_old[col] = float(\"nan\")\n",
342
+ "\n",
343
+ " # add entry to the dataframe\n",
344
+ " for col in df.columns:\n",
345
+ " df_old[col] = df[col]\n",
346
+ "\n",
347
+ " df = df_old\n",
348
+ "\n",
349
+ " df.to_csv(output_path / name)"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": null,
355
+ "id": "8b57c318-5fc8-49fc-8746-128e1112e46a",
356
+ "metadata": {
357
+ "tags": []
358
+ },
359
+ "outputs": [],
360
+ "source": [
361
+ "\n",
362
+ "for n in [1, 2, 3, 4, 5, 6, 7, 8]:\n",
363
+ " for file in path.glob(\"*.pk\"):\n",
364
+ " with file.open('rb') as fd:\n",
365
+ " data = pk.load(fd)\n",
366
+ "\n",
367
+ " Path(output_path).mkdir(parents=True, exist_ok=True)\n",
368
+ " model_name, dataset_name, decoding_config, date = str(file.stem).split('-_-')[:4]\n",
369
+ "\n",
370
+ " dataset = prepare_dataset(dataset_name, dataset_path=\"data/processed/\")\n",
371
+ " dataset = dataset.set_index('id')\n",
372
+ "\n",
373
+ " fn = lambda sample: lead(sample, N=2*n)\n",
374
+ "\n",
375
+ "\n",
376
+ " df = construct_templated_summaries(data, fn, dataset=dataset)\n",
377
+ "\n",
378
+ " df['metadata/method'] = \"Lead\"\n",
379
+ " df['metadata/n_sentences'] = 2*n\n",
380
+ "\n",
381
+ " name = file.stem + \"-_-\" + f\"lead_{2*n}.csv\"\n",
382
+ "\n",
383
+ " if (output_path / name).exists():\n",
384
+ " df_old = pd.read_csv(output_path / name)\n",
385
+ "\n",
386
+ " for col in df.columns:\n",
387
+ " if col not in df_old.columns:\n",
388
+ " df_old[col] = float(\"nan\")\n",
389
+ "\n",
390
+ " # add entry to the dataframe\n",
391
+ " for col in df.columns:\n",
392
+ " df_old[col] = df[col]\n",
393
+ "\n",
394
+ " df = df_old\n",
395
+ "\n",
396
+ " df.to_csv(output_path / name)"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "code",
401
+ "execution_count": null,
402
+ "id": "af574823-667d-4722-90bc-2bb095ad3a01",
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": []
406
+ },
407
+ {
408
+ "cell_type": "code",
409
+ "execution_count": null,
410
+ "id": "c32d1d88-c6f5-4ada-aec6-219d90cade16",
411
+ "metadata": {
412
+ "tags": []
413
+ },
414
+ "outputs": [],
415
+ "source": [
416
+ "import seaborn as sns\n",
417
+ "import matplotlib.pyplot as plt"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": null,
423
+ "id": "868ac37e-6187-46f5-935c-111ca532b1b0",
424
+ "metadata": {
425
+ "tags": []
426
+ },
427
+ "outputs": [],
428
+ "source": [
429
+ "output_path = Path(\"output/summaries/methods_reviews/\")"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": null,
435
+ "id": "7263e9d0-6a43-4698-bf6f-82fff0839316",
436
+ "metadata": {
437
+ "tags": []
438
+ },
439
+ "outputs": [],
440
+ "source": [
441
+ "import subprocess\n",
442
+ "\n",
443
+ "\n",
444
+ "for file in output_path.glob(\"*.csv\"):\n",
445
+ " print(file)\n",
446
+ " cmd = [\"python\", \"mds/evaluate_bartbert_metrics.py\", \"--summaries\", file]\n",
447
+ " subprocess.run(cmd)"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": null,
453
+ "id": "18a63537-c44e-421c-beff-50c6518115bf",
454
+ "metadata": {
455
+ "tags": []
456
+ },
457
+ "outputs": [],
458
+ "source": [
459
+ "dfs = []\n",
460
+ "for file in output_path.glob(\"*.csv\"):\n",
461
+ " model_name, dataset_name, decoding_config, date = str(file.stem).split('-_-')[:4]\n",
462
+ " method = str(file.stem).split('-_-')[-1]\n",
463
+ " \n",
464
+ " df = pd.read_csv(file)\n",
465
+ " df['metadata/Model'] = model_name\n",
466
+ " df['metadata/Dataset'] = dataset_name\n",
467
+ " df['metadata/method'] = method\n",
468
+ " \n",
469
+ " df[\"Method\"] = f\"{model_name}/{method}\"\n",
470
+ " \n",
471
+ " dfs.append(df)\n",
472
+ " \n",
473
+ "df = pd.concat(dfs)\n",
474
+ " \n",
475
+ " \n",
476
+ "df"
477
+ ]
478
+ },
479
+ {
480
+ "cell_type": "code",
481
+ "execution_count": null,
482
+ "id": "eef1ec67-62f9-458f-9380-debe40bac46a",
483
+ "metadata": {
484
+ "tags": []
485
+ },
486
+ "outputs": [],
487
+ "source": [
488
+ "sns.catplot(data=df, hue='Method', y='rougeL', x='metadata/Dataset', kind='bar')"
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": null,
494
+ "id": "aed3e026-fd89-416f-a333-c841eaf566e4",
495
+ "metadata": {},
496
+ "outputs": [],
497
+ "source": [
498
+ "sns.catplot(data=df, hue='metadata/method', y='rouge1', x='metadata/reranking_model', kind='bar', row=\"metadata/model\")"
499
+ ]
500
+ },
501
+ {
502
+ "cell_type": "code",
503
+ "execution_count": null,
504
+ "id": "2e433925-6be6-4322-af07-45b4c07ff5ff",
505
+ "metadata": {},
506
+ "outputs": [],
507
+ "source": []
508
+ }
509
+ ],
510
+ "metadata": {
511
+ "kernelspec": {
512
+ "display_name": "pytorch-gpu-2.0.0_py3.10.9",
513
+ "language": "python",
514
+ "name": "module-conda-env-pytorch-gpu-2.0.0_py3.10.9"
515
+ },
516
+ "language_info": {
517
+ "codemirror_mode": {
518
+ "name": "ipython",
519
+ "version": 3
520
+ },
521
+ "file_extension": ".py",
522
+ "mimetype": "text/x-python",
523
+ "name": "python",
524
+ "nbconvert_exporter": "python",
525
+ "pygments_lexer": "ipython3",
526
+ "version": "3.10.9"
527
+ }
528
+ },
529
+ "nbformat": 4,
530
+ "nbformat_minor": 5
531
+ }
glimpse-ui/glimpse/mds/discriminative_classification.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import argparse
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ from sentence_transformers import SentenceTransformer
10
+
11
+ def xlogx(x):
12
+ if x == 0:
13
+ return 0
14
+ else:
15
+ return x * torch.log(x)
16
+
17
+ def parse_summaries(path : Path):
18
+
19
+ # Load the data
20
+ df = pd.read_csv(path)
21
+
22
+ if 'id' not in df.columns:
23
+ raise ValueError('id column not found in the summaries file')
24
+ if 'text' not in df.columns:
25
+ raise ValueError('text column not found in the summaries file')
26
+ if 'summary' not in df.columns:
27
+ raise ValueError('summary column not found in the summaries file')
28
+
29
+ return df
30
+
31
+
32
+ def embed_text_and_summaries(df : pd.DataFrame, model : SentenceTransformer) -> Tuple[torch.Tensor, torch.Tensor]:
33
+
34
+ text_embeddings = model.encode(df.text.tolist(), convert_to_tensor=True)
35
+ summary_embeddings = model.encode(df.summary.tolist(), convert_to_tensor=True)
36
+
37
+ return text_embeddings, summary_embeddings
38
+
39
+
40
+ def compute_dot_products(df : pd.DataFrame, text_embeddings : torch.Tensor, summary_embeddings : torch.Tensor):
41
+
42
+ df = df.reset_index()
43
+ df['index'] = df.index
44
+
45
+ # group by id
46
+ grouped = df.groupby('id')
47
+
48
+ # for each id gather the id of the text and the summary
49
+ ids_per_sample = grouped.index.apply(list).tolist()
50
+
51
+ # compute the dot product between the text and the summary
52
+
53
+ metrics = {'proba_of_success' : []}
54
+ for text_ids in ids_per_sample:
55
+ # shape (num_text, embedding_dim)
56
+ text_embedding = text_embeddings[text_ids]
57
+ summary_embedding = summary_embeddings[text_ids]
58
+
59
+ # shape (num_text, num_text=num_summary)
60
+ dot_product = torch.matmul(text_embedding, summary_embedding.T)
61
+
62
+ # apply log softmax
63
+ log_softmax = torch.nn.functional.log_softmax(dot_product, dim=0)
64
+
65
+ # num_text
66
+ log_proba_of_success = torch.diag(log_softmax).squeeze()
67
+ entropy = torch.xlogy(log_proba_of_success, log_proba_of_success).sum(0).squeeze()
68
+
69
+ metrics['proba_of_success'].extend(log_proba_of_success.tolist())
70
+ # metrics['entropy'].extend(entropy.tolist())
71
+
72
+ df['proba_of_success'] = metrics['proba_of_success']
73
+ # df['entropy'] = metrics['entropy']
74
+
75
+ return df
76
+
77
+ def parse_args():
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument('--summaries', type=Path, required=True)
80
+ parser.add_argument('--model', type=str, default='paraphrase-MiniLM-L6-v2')
81
+ parser.add_argument('--output', type=Path, required=True)
82
+ parser.add_argument('--device', type=str, default='cuda')
83
+
84
+ args = parser.parse_args()
85
+ return args
86
+
87
+ def main():
88
+
89
+ args = parse_args()
90
+
91
+ # load the model
92
+ model = SentenceTransformer(args.model, device=args.device)
93
+
94
+ # load the summaries
95
+ df = parse_summaries(args.summaries)
96
+
97
+ # embedd the text and the summary
98
+ text_embeddings, summary_embeddings = embed_text_and_summaries(df, model)
99
+
100
+ # compute the dot product between the text and the summary
101
+ df = compute_dot_products(df, text_embeddings, summary_embeddings)
102
+
103
+ # create the output directory
104
+ args.output.mkdir(parents=True, exist_ok=True)
105
+
106
+ path = args.output / f"{args.summaries.stem}.csv"
107
+
108
+ # save the results
109
+ df.to_csv(path, index=False)
110
+
111
+
112
+ if __name__ == '__main__':
113
+ main()
glimpse-ui/glimpse/pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "rsasumm"
7
+ version = "0.0.1"
8
+ authors = [
9
+ ]
10
+ description = ""
11
+ readme = "Readme.md"
12
+ requires-python = ">=3.10"
13
+ classifiers = [
14
+ "Programming Language :: Python :: 3",
15
+ "License :: OSI Approved :: MIT License",
16
+ "Operating System :: OS Independent",
17
+ ]
18
+
19
+ [project.urls]
20
+ "Homepage" = ""
21
+ "Bug Tracker" = ""
glimpse-ui/glimpse/requirements ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ numpy==1.25.2
3
+ seaborn
4
+ matplotlib
5
+ gradio
6
+ pandas
7
+ datasets
8
+ nltk
9
+ SentencePiece
10
+ spacy
glimpse-ui/glimpse/rsasumm/__init__.py ADDED
File without changes
glimpse-ui/glimpse/rsasumm/beam_search.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional
2
+
3
+ import torch
4
+ from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper
5
+
6
+
7
+ def compute_rsa_probas(
8
+ logits: torch.Tensor, prior: torch.Tensor, rationality: float = 1.0
9
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
10
+ """
11
+ :param logits: (world_size, num_beam, vocab_size)
12
+ :param prior: (world_size, num_beam) for each beam the prior over the objects
13
+ :param rationality: rationality parameter, the higher the more rational ie the more the speaker will try to adapt
14
+ to the listener
15
+ :return: S1, L1: (world_size, num_beam, vocab_size).
16
+ S1[o, b, w] is the (log)probability of the word w given the object o and the current partial summary for the beam b
17
+ L1[o, b, w] is the (log)probability of the object o given the word w and the current partial summary for the beam b
18
+ """
19
+
20
+ prod = logits + prior[..., None]
21
+
22
+ L0 = torch.nan_to_num(torch.log_softmax(prod, dim=0), nan=-float("inf"))
23
+
24
+ prod_s = logits + L0 * rationality
25
+
26
+ S1 = torch.log_softmax(prod_s, dim=-1)
27
+ S1 = torch.nan_to_num(S1, nan=-float("inf"))
28
+
29
+ prod_l = logits + L0
30
+ L1 = torch.log_softmax(prod_l, dim=0)
31
+ L1 = torch.nan_to_num(L1, nan=-float("inf"))
32
+
33
+ return S1, L1
34
+
35
+
36
+ def sample_from_probs(
37
+ logits: torch.Tensor, num_beams: torch.Tensor, do_sample: bool, K: int = 10
38
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
39
+ """
40
+
41
+ :param logits: (num_beams, vocab_size) log proba for the next token only for the wanted object
42
+ :param num_beams: number of beam to sample. (Can be different from the shape of logits since some beams might have
43
+ finished earlier)
44
+ :param do_sample: sample or use argmax
45
+ :param K: number of samples to draw per beam to create the new population
46
+ :return: idx_beam, idx_token, tokens_scores, the indices of the sampled tokens and their scores
47
+ """
48
+
49
+ vocab_size = logits.shape[-1]
50
+ if do_sample:
51
+ # sample from the probability distribution
52
+ logits = logits.view(num_beams * logits.shape[-1])
53
+ probs = torch.softmax(logits, dim=-1)
54
+ samples = torch.multinomial(probs, num_samples=K * num_beams)
55
+
56
+ # get the indices of the sampled tokens
57
+ idx_beam, idx_token = samples // vocab_size, samples % vocab_size
58
+
59
+ logits = logits.view(num_beams * vocab_size)
60
+
61
+ tokens_scores = logits.gather(dim=-1, index=samples).squeeze(-1)
62
+
63
+ return idx_beam, idx_token, tokens_scores
64
+
65
+ else:
66
+ # get the indices of the most probable tokens
67
+ num_beams = logits.shape[0]
68
+ vocab_size = logits.shape[-1]
69
+
70
+ logits = logits.view(num_beams * vocab_size)
71
+ scores, samples = logits.topk(2 * num_beams, dim=-1)
72
+
73
+ idx_beam, idx_token = samples // vocab_size, samples % vocab_size
74
+
75
+ tokens_scores = scores.squeeze(-1)
76
+
77
+ return idx_beam, idx_token, tokens_scores
78
+
79
+
80
+ # Beam search RSA decoding
81
+ class RSAContextualDecoding:
82
+ def __init__(self, model, tokenizer, device):
83
+ """
84
+
85
+ :param model:
86
+ :param tokenizer:
87
+ :param device:
88
+ """
89
+
90
+ self.model = model.to(device)
91
+ self.tokenizer = tokenizer
92
+ self.device = device
93
+
94
+ def fwd_pass(
95
+ self,
96
+ input_ids: torch.Tensor,
97
+ decoder_input_ids: torch.Tensor,
98
+ attention_mask: torch.Tensor,
99
+ decoder_attention_mask: torch.Tensor,
100
+ ) -> torch.Tensor:
101
+ """
102
+ Make a forward pass through the model to get the logits for the next tokens
103
+ :param input_ids: (world_size, num_beams, input_length)
104
+ :param decoder_input_ids: (world_size, num_beams, partial_target_length)
105
+ :param attention_mask: (world_size, num_beams, input_length)
106
+ :param decoder_attention_mask: (world_size, num_beams, partial_target_length)
107
+ :return: logits: (world_size, num_beams, vocab_size)
108
+ """
109
+ with torch.no_grad():
110
+ world_size, num_beams = input_ids.shape[0], decoder_input_ids.shape[1]
111
+
112
+ input_ids = input_ids.view(world_size * num_beams, input_ids.shape[2]).to(self.device)
113
+ attention_mask = attention_mask.view(
114
+ world_size * num_beams, attention_mask.shape[2]
115
+ ).to(self.device)
116
+
117
+ decoder_input_ids = decoder_input_ids.view(
118
+ world_size * num_beams, decoder_input_ids.shape[2]
119
+ ).to(self.device)
120
+
121
+ decoder_attention_mask = decoder_attention_mask.view(
122
+ world_size * num_beams, decoder_attention_mask.shape[2]
123
+ ).to(self.device)
124
+
125
+ outputs = self.model(
126
+ input_ids=input_ids,
127
+ attention_mask=attention_mask,
128
+ decoder_input_ids=decoder_input_ids,
129
+ decoder_attention_mask=decoder_attention_mask,
130
+ )
131
+ logits = outputs.logits[..., -1, :]
132
+
133
+ logits = logits.view(self.world_size, num_beams, logits.shape[-1])
134
+
135
+ # return the probability of the next token when conditioned on the source text (world_size)
136
+ # and the partial target text (num_beam)
137
+ return logits
138
+
139
+ def duplicate_and_align_input_ids(
140
+ self,
141
+ input_ids: torch.Tensor,
142
+ input_ids_mask: torch.Tensor,
143
+ decoder_input_ids: torch.Tensor,
144
+ decoder_input_ids_mask: torch.Tensor,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
146
+ """
147
+ Duplicate the input_ids and decoder_input_ids to have all pairs of input_ids[i] and decoder_input_ids[j]
148
+ It uses torch.repeat and torch.repeat_interleave to do get something like:
149
+ a 1
150
+ a 2
151
+ a 3
152
+ b 1
153
+ b 2
154
+ b 3
155
+ ...
156
+ :param input_ids: (world_size, input_length)
157
+ :param decoder_input_ids: (num_beam, partial_target_length)
158
+ :return: input_ids: (world_size, num_beam, input_length)
159
+ decoder_input_ids: (world_size, num_beam, partial_target_length)
160
+ aligned such that all pairs of input_ids[i] and decoder_input_ids[j] are present
161
+ """
162
+
163
+ num_beams = decoder_input_ids.shape[0]
164
+
165
+ input_ids = input_ids.unsqueeze(1).repeat(1, num_beams, 1)
166
+ input_ids_mask = input_ids_mask.unsqueeze(1).repeat(1, num_beams, 1)
167
+
168
+ # repeat interleave
169
+ decoder_input_ids = decoder_input_ids.repeat_interleave(self.world_size, dim=0)
170
+ decoder_input_ids_mask = decoder_input_ids_mask.repeat_interleave(
171
+ self.world_size, dim=0
172
+ )
173
+
174
+ decoder_input_ids = decoder_input_ids.view(self.world_size, num_beams, -1)
175
+ decoder_input_ids_mask = decoder_input_ids_mask.view(
176
+ self.world_size, num_beams, -1
177
+ )
178
+
179
+ # print(self.tokenizer.batch_decode(input_ids[0]))
180
+ # print(self.tokenizer.batch_decode(decoder_input_ids[0]))
181
+
182
+ return input_ids, input_ids_mask, decoder_input_ids, decoder_input_ids_mask
183
+
184
+ def compute_rsa_probas(
185
+ self,
186
+ input_ids: torch.Tensor,
187
+ attention_mask: torch.Tensor,
188
+ decoder_input_ids: torch.Tensor,
189
+ decoder_attention_mask: torch.Tensor,
190
+ do_sample: bool = True,
191
+ top_p: Optional[float] = None,
192
+ top_k: Optional[int] = None,
193
+ temperature: float = 1.0,
194
+ rationality: float = 8.0, # seems to be a good value
195
+ process_logits_before_rsa: bool = True,
196
+ beam_scores: torch.Tensor = None,
197
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
198
+ """
199
+
200
+ :param input_ids: input_ids to the encoder/decoder model = source texts
201
+ :param attention_mask: attention_mask to the encoder/decoder model
202
+ :param decoder_input_ids: decoder ids / partial summaries
203
+ :param decoder_attention_mask: attention mask for the decoder
204
+ :param do_sample: are we planning on sampling the tokens or using argmax (to apply or not the logits processor)
205
+ :param top_p: parameters for the logits processor top p
206
+ :param top_k: parameters for the logits processor top k
207
+ :param temperature: sampling temperature
208
+ :param rationality: how rational is the speaker (higher means more rational)
209
+ :param process_logits_before_rsa: should we apply the logits processor before or after the RSA computation
210
+ :param beam_scores: (world_size, num_beams) the scores of the beams to be added to the logits
211
+ :return: S1, L1: (world_size, num_beam, vocab_size).
212
+ """
213
+
214
+ # some sanity checks
215
+ assert (top_p is None) or (
216
+ top_k is None
217
+ ), "top_p and top_k cannot be used together"
218
+ assert ((top_p is not None) and (do_sample)) or (
219
+ top_p is None
220
+ ), "top_p can only be used with sampling"
221
+ assert ((top_k is not None) and (do_sample)) or (
222
+ top_k is None
223
+ ), "top_k can only be used with sampling"
224
+
225
+ # duplicate the input_ids and decoder_input_ids to have all pairs of input_ids[i] and decoder_input_ids[j]
226
+ (
227
+ input_ids,
228
+ attention_mask,
229
+ decoder_input_ids,
230
+ decoder_attention_mask,
231
+ ) = self.duplicate_and_align_input_ids(
232
+ input_ids,
233
+ attention_mask,
234
+ decoder_input_ids,
235
+ decoder_attention_mask,
236
+ )
237
+
238
+ logits = (
239
+ self.fwd_pass(
240
+ input_ids, decoder_input_ids, attention_mask, decoder_attention_mask
241
+ )
242
+ / temperature # apply the temperature
243
+ )
244
+
245
+ logits = torch.nn.functional.log_softmax(logits, dim=-1)
246
+
247
+ world_size = input_ids.shape[0]
248
+ num_beams = decoder_input_ids.shape[1]
249
+
250
+ logits = logits.view(world_size * num_beams, -1)
251
+
252
+ if do_sample and process_logits_before_rsa:
253
+ if top_p is not None:
254
+ logits = TopPLogitsWarper(top_p=top_p)(input_ids=None, scores=logits)
255
+ if top_k is not None:
256
+ logits = TopKLogitsWarper(top_k=top_k)(input_ids=None, scores=logits)
257
+
258
+ logits = logits.view(world_size, num_beams, -1)
259
+
260
+ if beam_scores is not None:
261
+ logits = logits + beam_scores[None, ..., None]
262
+
263
+ # compute the RSA probabilities
264
+ S1, L1 = compute_rsa_probas(logits, self.prior, rationality=rationality)
265
+ logits = S1
266
+
267
+ if do_sample and not process_logits_before_rsa:
268
+ logits = logits.view(world_size * num_beams, -1)
269
+ if top_p is not None:
270
+ logits = TopPLogitsWarper(top_p=top_p)(input_ids=None, scores=logits)
271
+ if top_k is not None:
272
+ logits = TopKLogitsWarper(top_k=top_k)(input_ids=None, scores=logits)
273
+
274
+ logits = logits.view(world_size, num_beams, -1)
275
+
276
+ return logits, L1
277
+
278
+ def generate(
279
+ self,
280
+ target_id: int,
281
+ source_texts_ids: torch.Tensor,
282
+ source_text_attention_mask: torch.Tensor,
283
+ max_length: int = 100,
284
+ num_beams: int = 8,
285
+ do_sample=True,
286
+ top_p: Optional[float] = None,
287
+ top_k: Optional[int] = None,
288
+ temperature: float = 1.0,
289
+ rationality: float = 1.0,
290
+ process_logits_before_rsa=True,
291
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
292
+ """
293
+
294
+ :param target_id: the id of the target object
295
+ :param source_texts_ids: (world_size, input_length) the tokenized source texts
296
+ :param source_text_attention_mask: (world_size, input_length) the attention mask for the source texts
297
+ :param max_length: the maximum length to generate
298
+ :param do_sample: are we sampling or using argmax
299
+ :param top_p: parameters for the logits processor top p
300
+ :param top_k: parameters for the logits processor top k
301
+ :param temperature: sampling temperature
302
+ :param rationality: how rational is the speaker (higher means more rational)
303
+ :param process_logits_before_rsa: should we apply the logits processor before or after the RSA computation
304
+ :return: decoder_input_ids : (num_beams, max_length) decoded sequences, beam_scores: (num_beams) the scores
305
+ of the beams
306
+ """
307
+
308
+ self.num_beam = num_beams
309
+ self.world_size = source_texts_ids.shape[0]
310
+
311
+ self.prior = torch.ones((self.world_size, self.num_beam)).to(self.device) / self.world_size
312
+ beam_scores = torch.zeros(self.num_beam).to(self.device)
313
+
314
+ # initialize the decoder input ids
315
+ decoder_input_ids = torch.full(
316
+ (self.num_beam, 2),
317
+ 0,
318
+ dtype=torch.long,
319
+ device=self.device,
320
+ )
321
+
322
+ # initialize the decoder attention mask
323
+ decoder_attention_mask = torch.ones_like(decoder_input_ids).to(self.device)
324
+
325
+ new_beams = []
326
+ finished_beams = []
327
+
328
+ # run the beam search
329
+ for t in range(max_length):
330
+ # compute the RSA probabilities
331
+ num_beams = decoder_input_ids.shape[0]
332
+
333
+ S1, L1 = self.compute_rsa_probas(
334
+ source_texts_ids,
335
+ source_text_attention_mask,
336
+ decoder_input_ids,
337
+ decoder_attention_mask,
338
+ do_sample=do_sample,
339
+ top_p=top_p,
340
+ top_k=top_k,
341
+ temperature=temperature,
342
+ rationality=rationality,
343
+ beam_scores=beam_scores,
344
+ process_logits_before_rsa=process_logits_before_rsa,
345
+ )
346
+
347
+ # sample from the probabilities
348
+ idx_beam, idx_token, tokens_scores = sample_from_probs(
349
+ S1[target_id].squeeze(), num_beams, do_sample
350
+ )
351
+
352
+ # create all the new beams
353
+
354
+ new_beams = []
355
+
356
+ for idx_t, idx_b, token_score in zip(idx_token, idx_beam, tokens_scores):
357
+ new_beams.append(
358
+ (
359
+ decoder_input_ids[idx_b].tolist() + [idx_t.item()],
360
+ beam_scores[idx_b] + token_score.item(),
361
+ L1[:, idx_b, idx_t.item()],
362
+ )
363
+ )
364
+
365
+ # sort the beams
366
+ new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
367
+
368
+ # keep only the best beams
369
+ new_beams = new_beams[: self.num_beam]
370
+
371
+ # check if the beams are finished
372
+ _new_beams = []
373
+ for beam in new_beams:
374
+ if beam[0][-1] == self.tokenizer.eos_token_id:
375
+ finished_beams.append(beam)
376
+
377
+ else:
378
+ _new_beams.append(beam)
379
+
380
+ new_beams = _new_beams
381
+
382
+ if len(new_beams) == 0:
383
+ break
384
+
385
+ # pad the beams
386
+ max_beam_len = max(len(x[0]) for x in new_beams)
387
+ new_beams = [
388
+ (
389
+ x[0] + [self.tokenizer.pad_token_id] * (max_beam_len - len(x[0])),
390
+ x[1],
391
+ x[2],
392
+ )
393
+ for x in new_beams
394
+ ]
395
+
396
+ # update the beam scores
397
+ beam_scores = torch.tensor([x[1] for x in new_beams]).to(self.device)
398
+
399
+ # update the decoder input ids
400
+ decoder_input_ids: torch.Tensor = torch.tensor(
401
+ [x[0] for x in new_beams], device=self.device
402
+ )
403
+
404
+ # update the decoder attention mask based on pad tokens
405
+ decoder_attention_mask = (
406
+ decoder_input_ids != self.tokenizer.pad_token_id
407
+ ).long()
408
+
409
+ self.prior = torch.stack([x[2] for x in new_beams], dim=1).to(self.device)
410
+
411
+ # self.prior = torch.ones((self.world_size, len(new_beams))) / self.world_size
412
+
413
+ results = []
414
+
415
+ # pad the beams
416
+ max_beam_len = max(len(x[0]) for x in finished_beams + new_beams)
417
+ for x in finished_beams + new_beams:
418
+ results.append(
419
+ (
420
+ x[0] + [self.tokenizer.pad_token_id] * (max_beam_len - len(x[0])),
421
+ x[1],
422
+ x[2],
423
+ )
424
+ )
425
+
426
+ decoder_input_ids = torch.tensor([x[0] for x in results], device=self.device)
427
+
428
+ beam_scores = torch.tensor([x[1] for x in results]).to(self.device)
429
+
430
+ return decoder_input_ids, beam_scores
glimpse-ui/glimpse/rsasumm/rsa_reranker.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cache
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import torch
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+
9
+
10
+ def kl_divergence(p, q):
11
+ """
12
+ Compute the KL divergence between two distributions
13
+ """
14
+ return torch.nan_to_num(p * (p / q).log(), nan=0.0).sum(-1)
15
+
16
+
17
+ def jensen_shannon_divergence(p, q):
18
+ """
19
+ Compute the Jensen-Shannon divergence between two distributions
20
+ """
21
+ m = 0.5 * (p + q)
22
+ return 0.5 * (kl_divergence(p, m) + kl_divergence(q, m))
23
+
24
+
25
+ class RSAReranking:
26
+ """
27
+ Rerank a list of candidates according to the RSA model.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ model,
33
+ tokenizer,
34
+ candidates: List[str],
35
+ source_texts: List[str],
36
+ batch_size: int = 32,
37
+ rationality: int = 1,
38
+ device="cuda",
39
+ ):
40
+ """
41
+ :param model: hf model used to compute the likelihoods (supposed to be a seq2seq model), is S0 in the RSA model
42
+ :param tokenizer:
43
+ :param candidates: list of candidates summaries
44
+ :param source_texts: list of source texts
45
+ :param batch_size: batch size used to compute the likelihoods (can be high since we don't need gradients and
46
+ it's a single forward pass)
47
+ :param rationality: rationality parameter of the RSA model
48
+ :param device: device used to compute the likelihoods
49
+ """
50
+ self.model = model
51
+ self.device = device
52
+ self.model = model.to(self.device)
53
+ self.tokenizer = tokenizer
54
+
55
+
56
+ self.candidates = candidates
57
+ self.source_texts = source_texts
58
+
59
+ self.batch_size = batch_size
60
+ self.rationality = rationality
61
+
62
+ def compute_conditionned_likelihood(
63
+ self, x: List[str], y: List[str], mean: bool = True
64
+ ) -> torch.Tensor:
65
+ """
66
+ Compute the likelihood of y given x
67
+
68
+ :param x: list of source texts len(x) = batch_size
69
+ :param y: list of candidates summaries len(y) = batch_size
70
+ :param mean: average the likelihoods over the tokens of y or take the sum
71
+ :return: tensor of shape (batch_size) containing the likelihoods of y given x
72
+ """
73
+
74
+ # Ensure x,y are pure Python lists of strings (not pandas.Series, np.ndarray, etc.)
75
+ x = [str(item) for item in list(x)]
76
+ y = [str(item) for item in list(y)]
77
+ assert len(x) == len(y), "x and y must have the same length"
78
+
79
+ loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
80
+ batch_size = len(x)
81
+
82
+ x = self.tokenizer(
83
+ x,
84
+ return_tensors="pt",
85
+ padding=True,
86
+ truncation=True,
87
+ max_length=1024,
88
+ )
89
+ y = self.tokenizer(
90
+ y,
91
+ return_tensors="pt",
92
+ padding=True,
93
+ truncation=True,
94
+ max_length=1024,
95
+ )
96
+
97
+ # Move all tensors to the correct device
98
+ x = {k: v.to(self.device) for k, v in x.items()}
99
+ y = {k: v.to(self.device) for k, v in y.items()}
100
+
101
+ # Concatenate the two inputs
102
+ # Compute the likelihood of y given x
103
+
104
+ x_ids = x["input_ids"]
105
+ y_ids = y["input_ids"]
106
+
107
+ logits = self.model(
108
+ input_ids=x_ids,
109
+ decoder_input_ids=y_ids,
110
+ attention_mask=x["attention_mask"],
111
+ decoder_attention_mask=y["attention_mask"],
112
+ ).logits
113
+
114
+ # Compute the likelihood of y given x
115
+
116
+ shifted_logits = logits[..., :-1, :].contiguous()
117
+ shifted_ids = y_ids[..., 1:].contiguous()
118
+
119
+ likelihood = -loss_fn(
120
+ shifted_logits.view(-1, shifted_logits.size(-1)), shifted_ids.view(-1)
121
+ )
122
+
123
+ likelihood = likelihood.view(batch_size, -1).sum(-1)
124
+ if mean:
125
+ likelihood /= (y_ids != self.tokenizer.pad_token_id).float().sum(-1)
126
+
127
+ return likelihood
128
+
129
+ def score(self, x: List[str], y: List[str], **kwargs):
130
+ return self.compute_conditionned_likelihood(x, y, **kwargs)
131
+
132
+ def likelihood_matrix(self) -> torch.Tensor:
133
+ """
134
+ :return: likelihood matrix : (world_size, num_candidates), likelihood[i, j] is the likelihood of
135
+ candidate j being a summary for source text i.
136
+ """
137
+ likelihood_matrix = torch.zeros(
138
+ (len(self.source_texts), len(self.candidates))
139
+ ).to(self.device)
140
+
141
+ pairs = []
142
+ for i, source_text in enumerate(self.source_texts):
143
+ for j, candidate in enumerate(self.candidates):
144
+ pairs.append((i, j, source_text, candidate))
145
+
146
+ # split the pairs into batches
147
+ batches = [
148
+ pairs[i: i + self.batch_size]
149
+ for i in range(0, len(pairs), self.batch_size)
150
+ ]
151
+
152
+ for batch in tqdm(batches):
153
+ # get the source texts and candidates
154
+ source_texts = [pair[2] for pair in batch]
155
+ candidates = [pair[3] for pair in batch]
156
+
157
+ # compute the likelihoods
158
+ with torch.no_grad():
159
+ likelihoods = self.score(
160
+ source_texts, candidates, mean=True
161
+ )
162
+
163
+ # fill the matrix
164
+ for k, (i, j, _, _) in enumerate(batch):
165
+ likelihood_matrix[i, j] = likelihoods[k].detach()
166
+
167
+ return likelihood_matrix
168
+
169
+ @cache
170
+ def S(self, t):
171
+ if t == 0:
172
+ return self.initial_speaker_probas
173
+ else:
174
+ listener = self.L(t - 1)
175
+ prod = listener * self.rationality # + self.initial_speaker_probas.sum(0, keepdim=True)
176
+ return torch.log_softmax(prod, dim=-1)
177
+
178
+ @cache
179
+ def L(self, t):
180
+ speaker = self.S(t)
181
+ return torch.log_softmax(speaker, dim=-2)
182
+
183
+ def mk_listener_dataframe(self, t):
184
+ self.initial_speaker_probas = self.likelihood_matrix()
185
+
186
+ initial_listener_probas = self.L(0)
187
+
188
+ # compute consensus
189
+ uniform_distribution_over_source_texts = torch.ones_like(
190
+ initial_listener_probas
191
+ ) / len(self.source_texts)
192
+
193
+ initital_consensuality_score = (
194
+ torch.exp(initial_listener_probas)
195
+ * (
196
+ initial_listener_probas - torch.log(uniform_distribution_over_source_texts)
197
+ )
198
+ ).sum(0).cpu().numpy()
199
+
200
+ initital_consensuality_score = pd.Series(initital_consensuality_score, index=self.candidates)
201
+
202
+ initial_listener_probas = initial_listener_probas.cpu().numpy()
203
+
204
+ initial_listener_probas = pd.DataFrame(initial_listener_probas)
205
+ initial_listener_probas.index = self.source_texts
206
+ initial_listener_probas.columns = self.candidates
207
+
208
+ initial_speaker_probas = self.S(0).cpu().numpy()
209
+ initial_speaker_probas = pd.DataFrame(initial_speaker_probas)
210
+ initial_speaker_probas.index = self.source_texts
211
+ initial_speaker_probas.columns = self.candidates
212
+
213
+ listener_df = pd.DataFrame(self.L(t).cpu().numpy())
214
+
215
+ consensuality_scores = (
216
+ torch.exp(self.L(t))
217
+ * (self.L(t) - torch.log(uniform_distribution_over_source_texts))
218
+ ).sum(0).cpu().numpy()
219
+
220
+ consensuality_scores = pd.Series(consensuality_scores, index=self.candidates)
221
+
222
+ S = self.S(t).cpu().numpy()
223
+ speaker_df = pd.DataFrame(S)
224
+
225
+ # add the source texts and candidates as index
226
+
227
+ listener_df.index = self.source_texts
228
+ speaker_df.index = self.source_texts
229
+
230
+ listener_df.columns = self.candidates
231
+ speaker_df.columns = self.candidates
232
+
233
+ return listener_df, speaker_df, initial_listener_probas, initial_speaker_probas, initital_consensuality_score, consensuality_scores
234
+
235
+ def rerank(self, t=1):
236
+ """
237
+ return the best summary (according to rsa) for each text
238
+ """
239
+ (
240
+ listener_df,
241
+ speaker_df,
242
+ initial_listener_proba,
243
+ initial_speaker_proba,
244
+ initital_consensuality_score,
245
+ consensuality_scores,
246
+ ) = self.mk_listener_dataframe(t=t)
247
+ best_rsa = speaker_df.idxmax(axis=1).values
248
+ best_base = initial_listener_proba.idxmax(axis=1).values
249
+
250
+ return (
251
+ best_rsa,
252
+ best_base,
253
+ speaker_df,
254
+ listener_df,
255
+ initial_listener_proba,
256
+ initial_speaker_proba,
257
+ initital_consensuality_score,
258
+ consensuality_scores,
259
+ )
260
+
261
+
262
+ class RSARerankingEmbedder(RSAReranking):
263
+ def __init__(self, *args, **kwargs):
264
+ super().__init__(*args, **kwargs)
265
+
266
+ def compute_embeddings(self, x: List[str], y: List[str], **kwargs):
267
+ model_kwargs = kwargs.get("model_kwargs")
268
+
269
+ # shape: (batch_size, embedding_dim)
270
+ x_embeddings = self.model.encode(x, **model_kwargs)
271
+ y_embeddings = self.model.encode(y, **model_kwargs)
272
+
273
+ # dot product between the embeddings : shape (batch_size)
274
+ dot_products = (x_embeddings * y_embeddings).sum(-1)
275
+
276
+ return dot_products
277
+
278
+ def score(self, x: List[str], y: List[str], **kwargs):
279
+ return self.compute_embeddings(x, y, **kwargs)
280
+
glimpse-ui/glimpse/scripts/abstractive.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=main # Ask for unkillable job
3
+ #SBATCH --gres=gpu:1
4
+ #SBATCH --mem=10G # Ask for 10 GB of RAM
5
+ #SBATCH --time=2:00:00 # The job will run for 3 hours
6
+ #SBATCH --output=./logs/abstractive_out.txt
7
+ #SBATCH --error=./logs/abstractive_error.txt
8
+ #SBATCH -c 2
9
+
10
+
11
+ # Load the required modules
12
+ module --quiet load miniconda/3
13
+ module --quiet load cuda/12.1.1
14
+ conda activate "glimpse"
15
+
16
+ # Check if input file path is provided and valid
17
+ if [ -z "$1" ] || [ ! -f "$1" ]; then
18
+ # if no path is provided, or the path is invalid, use the default test dataset
19
+ echo "Couldn't find a valid path. Using default path: data/processed/all_reviews_2017.csv"
20
+ dataset_path="data/processed/all_reviews_2017.csv"
21
+ else
22
+ dataset_path="$1"
23
+ fi
24
+
25
+
26
+ # Generate abstractive summaries
27
+ if [[ "$@" =~ "--add-padding" ]]; then # check if padding argument is present
28
+ # add '--no-trimming' flag to the script
29
+ candidates=$(python glimpse/data_loading/generate_abstractive_candidates.py --dataset_path "$dataset_path" --scripted-run --no-trimming | tail -n 1)
30
+ else
31
+ # no additional flags
32
+ candidates=$(python glimpse/data_loading/generate_abstractive_candidates.py --dataset_path "$dataset_path" --scripted-run | tail -n 1)
33
+ fi
34
+
35
+
36
+ # Compute the RSA scores based on the generated summaries
37
+ rsa_scores=$(python glimpse/src/compute_rsa.py --summaries $candidates | tail -n 1)
glimpse-ui/glimpse/scripts/extractive.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=main # Ask for unkillable job
3
+ #SBATCH --gres=gpu:1
4
+ #SBATCH --mem=10G # Ask for 10 GB of RAM
5
+ #SBATCH --time=2:00:00 # The job will run for 3 hours
6
+ #SBATCH --output=./logs/abstractive_out.txt
7
+ #SBATCH --error=./logs/abstractive_error.txt
8
+ #SBATCH -c 2
9
+
10
+
11
+ # Load the required modules
12
+ module --quiet load miniconda/3
13
+ module --quiet load cuda/12.1.1
14
+ conda activate "glimpse"
15
+
16
+
17
+ # Check if input file path is provided and valid
18
+ if [ -z "$1" ] || [ ! -f "$1" ]; then
19
+ # if no path is provided, or the path is invalid, use the default test dataset
20
+ echo "Couldn't find a valid path. Using default path: data/processed/all_reviews_2017.csv"
21
+ dataset_path="data/processed/all_reviews_2017.csv"
22
+ else
23
+ dataset_path="$1"
24
+ fi
25
+
26
+ # Generate extractive summaries
27
+ candidates=$(python glimpse/data_loading/generate_extractive_candidates.py --dataset_path "$dataset_path" --scripted-run | tail -n 1)
28
+
29
+ # Compute the RSA scores based on the generated summaries
30
+ rsa_scores=$(python glimpse/src/compute_rsa.py --summaries $candidates | tail -n 1)
31
+
glimpse-ui/glimpse_pk_csv_converter.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import pandas as pd
3
+ from pathlib import Path
4
+ import os
5
+ import glob
6
+ import re
7
+ import json
8
+
9
+ def process_pickle_results(pickle_path: Path, output_path: Path):
10
+ # === Load Pickle File ===
11
+ with open(pickle_path, 'rb') as f:
12
+ data = pickle.load(f)
13
+
14
+ # === Extract Metadata ===
15
+ reranking_model = data.get('metadata/reranking_model')
16
+ rsa_iterations = data.get('metadata/rsa_iterations')
17
+ results = data.get('results')
18
+
19
+ # print(f"Reranking model: {reranking_model}, RSA iterations: {rsa_iterations}")
20
+
21
+ # === Validate Results ===
22
+ if not isinstance(results, list):
23
+ raise ValueError("The 'results' key is not a list. Please check the pickle file structure.")
24
+
25
+ # === Process and Flatten Results ===
26
+ csv_data = []
27
+ for index, result in enumerate(results):
28
+ # row = {
29
+ # 'index': index,
30
+ # 'id': str(result.get('id')[0]),
31
+ # 'consensuality_scores': result.get('consensuality_scores').to_dict()
32
+ # if isinstance(result.get('consensuality_scores'), pd.Series) else None,
33
+
34
+ # # Optional fields — uncomment as needed
35
+ # # 'best_base': result.get('best_base').tolist() if isinstance(result.get('best_base'), np.ndarray) else None,
36
+ # # 'best_rsa': result.get('best_rsa').tolist() if isinstance(result.get('best_rsa'), np.ndarray) else None,
37
+ # # 'speaker_df': result.get('speaker_df').to_json() if isinstance(result.get('speaker_df'), pd.DataFrame) else None,
38
+ # # 'listener_df': result.get('listener_df').to_json() if isinstance(result.get('listener_df'), pd.DataFrame) else None,
39
+ # # 'initial_listener': result.get('initial_listener').to_json() if isinstance(result.get('initial_listener'), pd.DataFrame) else None,
40
+ # # 'language_model_proba_df': result.get('language_model_proba_df').to_json() if isinstance(result.get('language_model_proba_df'), pd.DataFrame) else None,
41
+ # # 'initial_consensuality_scores': result.get('initial_consensuality_scores').to_dict() if isinstance(result.get('initial_consensuality_scores'), pd.Series) else None,
42
+ # # 'gold': result.get('gold'),
43
+ # # 'rationality': result.get('rationality'),
44
+ # # 'text_candidates': result.get('text_candidates').to_json() if isinstance(result.get('text_candidates'), pd.DataFrame) else None,
45
+ # }
46
+
47
+
48
+ row = {
49
+ 'index': index,
50
+ 'id': str(result.get('id')[0]),
51
+ 'consensuality_scores': json.dumps(result.get('consensuality_scores').to_dict())
52
+ if isinstance(result.get('consensuality_scores'), pd.Series) else None,
53
+ }
54
+
55
+ csv_data.append(row)
56
+
57
+ # === Save to CSV ===
58
+ df = pd.DataFrame(csv_data)
59
+ df.to_csv(output_path, index=False)
60
+ print(f"Results saved to '{output_path}'.")
61
+
62
+
63
+ if __name__ == "__main__":
64
+
65
+ BASE_DIR = Path(__file__).resolve().parent
66
+
67
+ # Set the path to the pickle file and the output CSV file
68
+ # ==== Uncomment the appropriate line below to set the pickle file path ====
69
+ # pickle_file = BASE_DIR / "glimpse" / "output" / "extractive_sentences-_-all_reviews_2017-_-none-_-2025-05-20-20-22-18-_-r3-_-rsa_reranked-google-pegasus-arxiv.pk"
70
+
71
+ # ==== Find the latest file in the directory and use it instead ====
72
+ # This assumes the pickle files are stored in the 'glimpse/output' directory
73
+ # list_of_files = glob.glob('./glimpse/output/*.pk')
74
+ # pickle_file = max(list_of_files, key=os.path.getctime)
75
+ # print (f"Using pickle file: {pickle_file}")
76
+
77
+ # output_file = BASE_DIR / "data" / "GLIMPSE_results_from_pk.csv"
78
+
79
+ # process_pickle_results(pickle_file, output_file)
80
+
81
+ output_dir = BASE_DIR / "data"
82
+ output_dir.mkdir(parents=True, exist_ok=True)
83
+
84
+ pickle_files = sorted(glob.glob('./glimpse/output/*.pk'), key=os.path.getctime)
85
+
86
+ for pickle_file in pickle_files:
87
+ year_match = re.search(r'(\d{4})', os.path.basename(pickle_file))
88
+ year_tag = year_match.group(1) if year_match else 'unknown_year'
89
+ output_file = output_dir / f"GLIMPSE_results_{year_tag}.csv"
90
+
91
+ print(f"Using pickle file: {pickle_file}, saving as {output_file}")
92
+ process_pickle_results(Path(pickle_file), output_file)
glimpse-ui/interface/Demo.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import sys, os.path
4
+
5
+ import torch
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
7
+
8
+ from glimpse.rsasumm.rsa_reranker import RSAReranking
9
+ import gradio as gr
10
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
11
+ import pandas as pd
12
+
13
+ from scored_reviews_builder import load_scored_reviews
14
+ from glimpse.glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
15
+ # from scibert.scibert_polarity.scibert_polarity import predict_polarity
16
+
17
+ # Load scored reviews
18
+ years, all_scored_reviews_df = load_scored_reviews()
19
+
20
+ # -----------------------------------
21
+ # Pre-processed Tab
22
+ # -----------------------------------
23
+
24
+ def get_preprocessed_scores(year):
25
+ scored_reviews = all_scored_reviews_df[all_scored_reviews_df["year"] == year]["scored_dict"].iloc[0]
26
+ return scored_reviews
27
+
28
+
29
+ # -----------------------------------
30
+ # Interactive Tab
31
+ # -----------------------------------
32
+
33
+ # RSA_model = "facebook/bart-large-cnn"
34
+ RSA_model = "sshleifer/distilbart-cnn-12-3"
35
+
36
+ model = AutoModelForSeq2SeqLM.from_pretrained(RSA_model)
37
+ tokenizer = AutoTokenizer.from_pretrained(RSA_model)
38
+
39
+ # Define the manual color map for topics
40
+ topic_color_map = {
41
+ "Substance": "#cce0ff", # lighter blue
42
+ "Clarity": "#e6ee9c", # lighter yellow-green
43
+ "Soundness/Correctness": "#ffcccc", # lighter red
44
+ "Originality": "#d1c4e9", # lighter purple
45
+ "Motivation/Impact": "#b2ebf2", # lighter teal
46
+ "Meaningful Comparison": "#fff9c4", # lighter yellow
47
+ "Replicability": "#c8e6c9", # lighter green
48
+ }
49
+
50
+
51
+ # GLIMPSE Home/Description Page
52
+ glimpse_description = """
53
+ # ReView: A Tool for Visualizing and Analyzing Scientific Reviews
54
+
55
+ ## Overview
56
+ ReView is a visualization tool designed to assist **area chairs** and **researchers** in efficiently analyzing scholarly reviews. The interface offers two main ways to explore scholarly reviews:
57
+ - Pre-Processed Reviews: Explore real peer reviews from ICLR (2017–2021) with structured visualizations of sentiment, topics, and reviewer agreement.
58
+ - Interactive Tab: Enter your own reviews and view them analyzed in real time using the same NLP-powered highlighting options.
59
+
60
+ All reviews are shown in their original, unaltered form, with visual overlays to help identify key insights such as disagreements, sentiment and common themes—reducing cognitive load and scrolling effort.
61
+
62
+ ---
63
+ ## **Key Features**
64
+ - *Traceability and Transparency:* The tool preserves the original text of each review and overlays highlights for key aspects (e.g., sentiment, topic, agreement), allowing area chairs to trace back every insight to its source without modifying or summarizing the content.
65
+ - *Structured Overview*: All reviews are displayed in one interface and with radio buttons, one can navigate from one highlighting option to the other.
66
+ - *Interactive*: The tool allows users to input their own reviews and, within seconds, view them annotated with highlighted aspects
67
+ ---
68
+ ## **Highlighting Options**
69
+ - *Agreement:* Identifies both shared and conflicting points across reviews, helping to surface consensus and disagreement.
70
+ - *Polarity:* Highlights positive and negative sentiments within the reviews to reveal tone and stance.
71
+ - *Topic:* Organizes the review sentences by their discussed topics, ensuring coverage of diverse reviewer perspectives and improving clarity.
72
+
73
+ ---
74
+
75
+ ### How to Use ReView
76
+
77
+ ReView offers two main ways to explore peer reviews: using pre-processed reviews or by entering your own.
78
+
79
+ #### 🗂️ Pre-Processed Reviews Tab
80
+
81
+ Use this tab to explore reviews from ICLR (2017–2021):
82
+
83
+ 1. **Select a conference year** from the dropdown menu on the right.
84
+ 2. **Navigate between submissions** using the *Next* and *Previous* buttons on the left.
85
+ 3. **Choose a highlighting view** using the radio buttons:
86
+ - **Original**: Displays unmodified review text.
87
+ - **Agreement**: Highlights consensus points in **red** and disagreements in **purple**.
88
+ - **Polarity**: Highlights **positive** sentiment in **green** and **negative** sentiment in **red**.
89
+ - **Topic**: Highlights comments by discussion topic using color-coded labels.
90
+
91
+ #### ✍️ Interactive Tab
92
+
93
+ Use this tab to analyze your own review text:
94
+
95
+ 1. **Enter up to three reviews** in the input fields labeled *Review 1*, *Review 2*, and *Review 3*.
96
+ 2. **Click "Process"** to analyze the input (average processing time: ~42 seconds).
97
+ 3. **Explore the results** using the same highlighting options as above (Agreement, Polarity, Topic).
98
+ """
99
+
100
+
101
+ EXAMPLES = [
102
+ "The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. I believe the authors missed Jane and al 2021. In addition, I think, there is a mistake in the math.",
103
+ "The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. Some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.",
104
+ "The paper gives really interesting insights on the topic of transfer learning. It is not well presented and lack experiments. Some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.",
105
+ ]
106
+
107
+ # Function to summarize the input texts using the RSAReranking model in interactive mode
108
+ def summarize(text1, text2, text3, focus, mode, rationality=1.0, iterations=1):
109
+
110
+ # print(focus, mode, rationality, iterations)
111
+
112
+ # get sentences for each text
113
+ text2_sentences = glimpse_tokenizer(text2)
114
+ text1_sentences = glimpse_tokenizer(text1)
115
+ text3_sentences = glimpse_tokenizer(text3)
116
+
117
+
118
+ # remove empty sentences
119
+ text1_sentences = [sentence for sentence in text1_sentences if sentence != ""]
120
+ text2_sentences = [sentence for sentence in text2_sentences if sentence != ""]
121
+ text3_sentences = [sentence for sentence in text3_sentences if sentence != ""]
122
+
123
+ sentences = list(set(text1_sentences + text2_sentences + text3_sentences))
124
+
125
+ # Load polarity model and tokenizer (SciBERT)
126
+ polarity_model_path = "scibert/scibert_polarity/final_model"
127
+ polarity_tokenizer = AutoTokenizer.from_pretrained(polarity_model_path)
128
+ polarity_model = AutoModelForSequenceClassification.from_pretrained(polarity_model_path)
129
+ polarity_model.eval()
130
+ polarity_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
131
+ polarity_model.to(polarity_device)
132
+
133
+ def predict_polarity(sent_list):
134
+ inputs = polarity_tokenizer(
135
+ sent_list, return_tensors="pt", padding=True, truncation=True, max_length=512
136
+ ).to(polarity_device)
137
+ with torch.no_grad():
138
+ logits = polarity_model(**inputs).logits
139
+ preds = torch.argmax(logits, dim=1).cpu().tolist()
140
+ emoji_map = {0: "➖", 1: None, 2: "➕"}
141
+ return dict(zip(sent_list, [emoji_map[p] for p in preds]))
142
+
143
+
144
+ # Run polarity prediction
145
+ polarity_map = predict_polarity(sentences)
146
+
147
+
148
+ # Load topic model and tokenizer (SciBERT)
149
+ topic_model_path = "scibert/scibert_topic/final_model"
150
+ topic_tokenizer = AutoTokenizer.from_pretrained(topic_model_path)
151
+ topic_model = AutoModelForSequenceClassification.from_pretrained(topic_model_path)
152
+ topic_model.eval()
153
+ topic_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
+ topic_model.to(topic_device)
155
+
156
+ def predict_topic(sent_list):
157
+ inputs = topic_tokenizer(
158
+ sent_list, return_tensors="pt", padding=True, truncation=True, max_length=512
159
+ ).to(topic_device)
160
+ with torch.no_grad():
161
+ logits = topic_model(**inputs).logits
162
+ preds = torch.argmax(logits, dim=1).cpu().tolist()
163
+
164
+ # Topic ID to label and emoji
165
+ id2label = {
166
+ 0: "Substance",
167
+ 1: "Clarity",
168
+ 2: "Correctness",
169
+ 3: "Originality",
170
+ 4: "Impact",
171
+ 5: "Comparison",
172
+ 6: "Replicability",
173
+ 7: None # This is used for sentences that do not match any specific topic,
174
+ }
175
+ return dict(zip(sent_list, [id2label[p] for p in preds]))
176
+
177
+ # Run topic prediction
178
+ topic_map = predict_topic(sentences)
179
+
180
+
181
+
182
+ rsa_reranker = RSAReranking(
183
+ model,
184
+ tokenizer,
185
+ candidates=sentences,
186
+ source_texts=[text1, text2, text3],
187
+ device="cpu",
188
+ rationality=rationality,
189
+ )
190
+ (
191
+ best_rsa,
192
+ best_base,
193
+ speaker_df,
194
+ listener_df,
195
+ initial_listener,
196
+ language_model_proba_df,
197
+ initial_consensuality_scores,
198
+ consensuality_scores,
199
+ ) = rsa_reranker.rerank(t=iterations)
200
+
201
+ # apply exp to the probabilities
202
+ speaker_df = speaker_df.applymap(lambda x: math.exp(x))
203
+
204
+ text_1_summaries = speaker_df.loc[text1][text1_sentences]
205
+ text_1_summaries = text_1_summaries / text_1_summaries.sum()
206
+
207
+ text_2_summaries = speaker_df.loc[text2][text2_sentences]
208
+ text_2_summaries = text_2_summaries / text_2_summaries.sum()
209
+
210
+ text_3_summaries = speaker_df.loc[text3][text3_sentences]
211
+ text_3_summaries = text_3_summaries / text_3_summaries.sum()
212
+
213
+ # make list of tuples
214
+ text_1_summaries = [(sentence, text_1_summaries[sentence]) for sentence in text1_sentences]
215
+ text_2_summaries = [(sentence, text_2_summaries[sentence]) for sentence in text2_sentences]
216
+ text_3_summaries = [(sentence, text_3_summaries[sentence]) for sentence in text3_sentences]
217
+
218
+ # normalize consensuality scores between -1 and 1
219
+ consensuality_scores = (consensuality_scores - (consensuality_scores.max() - consensuality_scores.min()) / 2) / (consensuality_scores.max() - consensuality_scores.min()) / 2
220
+
221
+ # get most and least consensual sentences
222
+ # most consensual --> most common; least consensual --> most unique
223
+ most_consensual = consensuality_scores.sort_values(ascending=True).head(3).index.tolist()
224
+ least_consensual = consensuality_scores.sort_values(ascending=False).head(3).index.tolist()
225
+
226
+ # Convert lists to strings
227
+ most_consensual = " ".join(most_consensual)
228
+ least_consensual = " ".join(least_consensual)
229
+
230
+ text_1_consensuality = consensuality_scores.loc[text1_sentences]
231
+ text_2_consensuality = consensuality_scores.loc[text2_sentences]
232
+ text_3_consensuality = consensuality_scores.loc[text3_sentences]
233
+
234
+ text_1_consensuality = [(sentence, text_1_consensuality[sentence]) for sentence in text1_sentences]
235
+ text_2_consensuality = [(sentence, text_2_consensuality[sentence]) for sentence in text2_sentences]
236
+ text_3_consensuality = [(sentence, text_3_consensuality[sentence]) for sentence in text3_sentences]
237
+
238
+
239
+ def highlight_reviews(text_sentences, consensuality_scores, threshold_common=0.0, threshold_unique=0.0):
240
+ highlighted = []
241
+ for sentence in text_sentences:
242
+ # print(f"Processing sentence: {sentence}", "score:", consensuality_scores.loc[sentence])
243
+ score = consensuality_scores.loc[sentence]
244
+ score = score*2 if score > 0 else score # amplify unique scores for better visibility
245
+
246
+ # common sentences --> positive consensuality scores
247
+ # unique sentences --> negative consensuality scores
248
+
249
+ score *= -1 # invert the score for highlighting
250
+
251
+ highlighted.append((sentence, score))
252
+ return highlighted
253
+
254
+ # Apply highlighting to each review
255
+ text_1_agreement = highlight_reviews(text1_sentences, consensuality_scores)
256
+ text_2_agreement = highlight_reviews(text2_sentences, consensuality_scores)
257
+ text_3_agreement = highlight_reviews(text3_sentences, consensuality_scores)
258
+
259
+ # Add polarity outputs
260
+ text_1_polarity = [(s, polarity_map[s]) for s in text1_sentences]
261
+ text_2_polarity = [(s, polarity_map[s]) for s in text2_sentences]
262
+ text_3_polarity = [(s, polarity_map[s]) for s in text3_sentences]
263
+
264
+ # Add topic outputs
265
+ text_1_topic = [(s, topic_map[s]) for s in text1_sentences]
266
+ text_2_topic = [(s, topic_map[s]) for s in text2_sentences]
267
+ text_3_topic = [(s, topic_map[s]) for s in text3_sentences]
268
+
269
+ # print(type(text_1_consensuality))
270
+ return (
271
+ # text_1_summaries, text_2_summaries, text_3_summaries,
272
+ # text_1_consensuality, text_2_consensuality, text_3_consensuality,
273
+ text_1_agreement, text_2_agreement, text_3_agreement,
274
+ most_consensual, least_consensual,
275
+ text_1_polarity, text_2_polarity, text_3_polarity,
276
+ text_1_topic, text_2_topic, text_3_topic,
277
+ )
278
+
279
+
280
+
281
+
282
+ with gr.Blocks(title="ReView") as demo:
283
+ # gr.Markdown("# ReView Interface")
284
+
285
+ with gr.Tab("Introduction"):
286
+ gr.Markdown(glimpse_description)
287
+
288
+ # -----------------------------------
289
+ # Pre-processed Tab
290
+ # -----------------------------------
291
+ with gr.Tab("Pre-processed Reviews"):
292
+ # Initialize state for this session.
293
+ initial_year = 2017
294
+ initial_scored_reviews = get_preprocessed_scores(initial_year)
295
+ initial_review_ids = list(initial_scored_reviews.keys())
296
+ initial_review = initial_scored_reviews[initial_review_ids[0]]
297
+ number_of_displayed_reviews = len(initial_scored_reviews[initial_review_ids[0]])
298
+ initial_state = {
299
+ "year_choice": initial_year,
300
+ "scored_reviews_for_year": initial_scored_reviews,
301
+ "review_ids": initial_review_ids,
302
+ "current_review_index": 0,
303
+ "current_review": initial_review,
304
+ "number_of_displayed_reviews": number_of_displayed_reviews,
305
+ }
306
+ state = gr.State(initial_state)
307
+
308
+ def update_review_display(state, score_type):
309
+
310
+ review_ids = state["review_ids"]
311
+ current_index = state["current_review_index"]
312
+ current_review = state["scored_reviews_for_year"][review_ids[current_index]]
313
+
314
+ show_polarity = score_type == "Polarity"
315
+ show_consensuality = score_type == "Agreement"
316
+ show_topic = score_type == "Topic"
317
+
318
+
319
+ if show_polarity:
320
+ color_map = {"➕": "#d4fcd6", "➖": "#fcd6d6"}
321
+ legend = False
322
+ elif show_topic:
323
+ color_map = topic_color_map # No color map for topics
324
+ legend = False
325
+ elif show_consensuality:
326
+ color_map = None # Continuous scale, no predefined colors
327
+ legend = True
328
+ else:
329
+ color_map = {} # Default to empty map
330
+ legend = False
331
+
332
+ new_review_id = (
333
+ f"### Submission Link:\n\n{review_ids[current_index]}<br>"
334
+ f"(Showing {current_index + 1} of {len(state['review_ids'])} reviews)"
335
+ )
336
+
337
+ number_of_displayed_reviews = len(current_review)
338
+ review_updates = []
339
+ consensuality_dict = {}
340
+
341
+ for i in range(8):
342
+ if i < number_of_displayed_reviews:
343
+ review_item = list(current_review[i].items())
344
+
345
+ if show_polarity:
346
+ highlighted = []
347
+ for sentence, metadata in review_item:
348
+ polarity = metadata.get("polarity", None)
349
+ if polarity >= 0.995:
350
+ label = "➕" # positive
351
+ elif polarity <= -0.99:
352
+ label = "➖" # negative
353
+ else:
354
+ label = None # ignore neutral (1)
355
+ highlighted.append((sentence, label))
356
+ elif show_consensuality:
357
+ highlighted = []
358
+ for sentence, metadata in review_item:
359
+ score = metadata.get("consensuality", 0.0)
360
+ score = score * 2 - 1 # Normalize to [-1, 1]
361
+ score = score/2.5 if score > 0 else score # Amplify unique scores for better visibility
362
+ score *= -1 # Invert the score for highlighting
363
+
364
+ consensuality_dict[sentence] = score
365
+ highlighted.append((sentence, score))
366
+
367
+ elif show_topic:
368
+ highlighted = []
369
+ for sentence, metadata in review_item:
370
+ topic = metadata.get("topic", None)
371
+ if topic != "NONE":
372
+ highlighted.append((sentence, topic))
373
+ else:
374
+ highlighted.append((sentence, None))
375
+ else:
376
+ highlighted = [
377
+ (sentence, None)
378
+ for sentence, metadata in review_item
379
+ ]
380
+
381
+ review_updates.append(
382
+ gr.update(
383
+ visible=True,
384
+ value=highlighted,
385
+ color_map=color_map,
386
+ show_legend=legend,
387
+ key=f"updated_{score_type}_{i}"
388
+ )
389
+ )
390
+ else:
391
+ review_updates.append(
392
+ gr.update(
393
+ visible=False,
394
+ value=[],
395
+ show_legend=False,
396
+ color_map=color_map,
397
+ key=f"updated_{score_type}_{i}"
398
+ )
399
+ )
400
+
401
+ # Set most consensual / unique sentences
402
+ if show_consensuality and consensuality_dict:
403
+ scores = pd.Series(consensuality_dict)
404
+ most_unique = scores.sort_values(ascending=True).head(3).index.tolist()
405
+ most_common = scores.sort_values(ascending=False).head(3).index.tolist()
406
+ most_common_text = "\n".join(most_common)
407
+ most_unique_text = "\n".join(most_unique)
408
+
409
+ most_common_visibility = gr.update(visible=True, value=most_common_text)
410
+ most_unique_visibility = gr.update(visible=True, value=most_unique_text)
411
+ else:
412
+ # Debugging statements to check visibility settings
413
+ # print("Hiding most common and unique sentences")
414
+
415
+ most_common_visibility = gr.update(visible=False, value=[])
416
+ most_unique_visibility = gr.update(visible=False, value=[])
417
+
418
+ # update topic color map
419
+ if show_topic:
420
+ topic_color_map_visibility = gr.update(
421
+ visible=True,
422
+ color_map=topic_color_map,
423
+ value=[
424
+ ("", "Substance"),
425
+ ("", "Clarity"),
426
+ ("", "Soundness/Correctness"),
427
+ ("", "Originality"),
428
+ ("", "Motivation/Impact"),
429
+ ("", "Meaningful Comparison"),
430
+ ("", "Replicability"),
431
+ ]
432
+ )
433
+ else:
434
+ topic_color_map_visibility = gr.update(visible=False, value=[])
435
+
436
+ return (
437
+ new_review_id,
438
+ *review_updates,
439
+ most_common_visibility,
440
+ most_unique_visibility,
441
+ topic_color_map_visibility,
442
+ state
443
+ )
444
+
445
+
446
+
447
+ # Precompute the initial outputs so something is shown on load.
448
+ init_display = update_review_display(initial_state, score_type="Original")
449
+ # init_display returns: (review_id, review1, review2, review3, review4, review5, review6, review7, review8, state)
450
+
451
+ with gr.Row():
452
+
453
+ with gr.Column(scale=1):
454
+ review_id = gr.Markdown(value=init_display[0], container=True)
455
+ with gr.Row():
456
+ previous_button = gr.Button("Previous", variant="secondary", interactive=True)
457
+ next_button = gr.Button("Next", variant="primary", interactive=True)
458
+
459
+
460
+ with gr.Column(scale=1):
461
+ # Input controls.
462
+ year = gr.Dropdown(choices=years, label="Select Year", interactive=True, value=initial_year)
463
+ score_type = gr.Radio(
464
+ choices=["Original", "Agreement", "Polarity", "Topic"],
465
+ label="Score Type to Display",
466
+ value="Original",
467
+ interactive=True
468
+ )
469
+
470
+ # Output display.
471
+ with gr.Row():
472
+ most_common_sentences = gr.Textbox(
473
+ lines=8,
474
+ label="Most Common Opinions",
475
+ visible=False,
476
+ value=[]
477
+ )
478
+ most_unique_sentences = gr.Textbox(
479
+ lines=8,
480
+ label="Most Divergent Opinions",
481
+ visible=False,
482
+ value=[]
483
+ )
484
+
485
+ # Add a new textbox for topic labels and colors
486
+ topic_text_box = gr.HighlightedText(
487
+ label="Topic Labels (Color-Coded)",
488
+ visible=False,
489
+ value=[],
490
+ show_legend=True,
491
+ )
492
+
493
+ review1 = gr.HighlightedText(
494
+ show_legend=False,
495
+ label="Review 1",
496
+ visible= number_of_displayed_reviews >= 1,
497
+ key="initial_review1",
498
+ # color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
499
+ )
500
+ review2 = gr.HighlightedText(
501
+ show_legend=False,
502
+ label="Review 2",
503
+ visible= number_of_displayed_reviews >= 2,
504
+ key="initial_review2"
505
+ # color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
506
+ )
507
+ review3 = gr.HighlightedText(
508
+ show_legend=False,
509
+ label="Review 3",
510
+ visible= number_of_displayed_reviews >= 3,
511
+ key="initial_review3"
512
+ # color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
513
+ )
514
+ review4 = gr.HighlightedText(
515
+ show_legend=False,
516
+ label="Review 4",
517
+ visible= number_of_displayed_reviews >= 4,
518
+ key="initial_review4"
519
+ # color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
520
+ )
521
+ review5 = gr.HighlightedText(
522
+ show_legend=False,
523
+ label="Review 5",
524
+ visible= number_of_displayed_reviews >= 5,
525
+ key="initial_review5"
526
+ # color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
527
+ )
528
+ review6 = gr.HighlightedText(
529
+ show_legend=False,
530
+ label="Review 6",
531
+ visible= number_of_displayed_reviews >= 6,
532
+ key="initial_review6"
533
+ # color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
534
+ )
535
+ review7 = gr.HighlightedText(
536
+ show_legend=False,
537
+ label="Review 7",
538
+ visible= number_of_displayed_reviews >= 7,
539
+ key="initial_review7"
540
+ # color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
541
+ )
542
+ review8 = gr.HighlightedText(
543
+ show_legend=False,
544
+ label="Review 8",
545
+ visible= number_of_displayed_reviews >= 8,
546
+ key="initial_review8"
547
+ # color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
548
+ )
549
+
550
+ # Callback functions that update state.
551
+ def year_change(year, state, score_type):
552
+ state["year_choice"] = year
553
+ state["scored_reviews_for_year"] = get_preprocessed_scores(year)
554
+ state["review_ids"] = list(state["scored_reviews_for_year"].keys())
555
+ state["current_review_index"] = 0
556
+ state["current_review"] = state["scored_reviews_for_year"][state["review_ids"][0]]
557
+ return update_review_display(state, score_type)
558
+
559
+ def next_review(state, score_type):
560
+ state["current_review_index"] = (state["current_review_index"] + 1) % len(state["review_ids"])
561
+ state["current_review"] = state["scored_reviews_for_year"][state["review_ids"][state["current_review_index"]]]
562
+ return update_review_display(state, score_type)
563
+
564
+ def previous_review(state, score_type):
565
+ state["current_review_index"] = (state["current_review_index"] - 1) % len(state["review_ids"])
566
+ state["current_review"] = state["scored_reviews_for_year"][state["review_ids"][state["current_review_index"]]]
567
+ return update_review_display(state, score_type)
568
+
569
+ # Hook up the callbacks with the session state.
570
+ year.change(
571
+ fn=year_change,
572
+ inputs=[year, state, score_type],
573
+ outputs=[review_id, review1, review2, review3, review4, review5, review6, review7, review8, most_common_sentences, most_unique_sentences, topic_text_box, state]
574
+ )
575
+ score_type.change(
576
+ fn=update_review_display,
577
+ inputs=[state, score_type],
578
+ outputs=[review_id, review1, review2, review3, review4, review5, review6, review7, review8, most_common_sentences, most_unique_sentences, topic_text_box, state]
579
+ )
580
+ next_button.click(
581
+ fn=next_review,
582
+ inputs=[state, score_type],
583
+ outputs=[review_id, review1, review2, review3, review4, review5, review6, review7, review8, most_common_sentences, most_unique_sentences, topic_text_box, state]
584
+ )
585
+ previous_button.click(
586
+ fn=previous_review,
587
+ inputs=[state, score_type],
588
+ outputs=[review_id, review1, review2, review3, review4, review5, review6, review7, review8, most_common_sentences, most_unique_sentences, topic_text_box, state]
589
+ )
590
+
591
+
592
+
593
+
594
+ # -----------------------------------
595
+ # Interactive Tab
596
+ # -----------------------------------
597
+ with gr.Tab("Interactive", interactive=True):
598
+ with gr.Row():
599
+ with gr.Column():
600
+
601
+ gr.Markdown("## Input Reviews")
602
+
603
+ # review_count = gr.Slider(minimum=1, maximum=3, step=1, value=3, label="Number of Reviews", interactive=True)
604
+
605
+ review1_textbox = gr.Textbox(lines=5, value=EXAMPLES[0], label="Review 1", interactive=True)
606
+ review2_textbox = gr.Textbox(lines=5, value=EXAMPLES[1], label="Review 2", interactive=True)
607
+ review3_textbox = gr.Textbox(lines=5, value=EXAMPLES[2], label="Review 3", interactive=True)
608
+
609
+ with gr.Row():
610
+ submit_button = gr.Button("Process", variant="primary", interactive=True)
611
+ clear_button = gr.Button("Clear", variant="secondary", interactive=True)
612
+ gr.Markdown("**Note**: *Once your inputs are processed, you can see the different result by <ins>**only changing the parameters**</ins>, and without the need to re-process.*", container=True)
613
+
614
+
615
+
616
+ with gr.Column():
617
+
618
+ gr.Markdown("## Results")
619
+
620
+ mode_radio = gr.Radio(
621
+ choices=[("In-line Highlighting", "highlight"), ("Generate Summaries", "summary")],
622
+ value="highlight",
623
+ label="Output Mode:",
624
+ interactive=False,
625
+ visible=False # Initially hidden, will be shown based on mode selection
626
+ )
627
+ focus_radio = gr.Radio(
628
+ choices=[("Agreement", "unique"), "Polarity", "Topic",],
629
+ value="unique",
630
+ label="Focus on:",
631
+ interactive=True
632
+ )
633
+ generation_method_radio = gr.Radio(
634
+ choices=[("Extractive", "extractive")], #TODO: add ("Abstractive", "abstractive") and abstractive generation
635
+ value="extractive",
636
+ label="Generation Method:",
637
+ interactive=True,
638
+ visible=False
639
+ )
640
+
641
+ # Fixed rationality (3.0) and iterations (2) to be consistent with the compute_rsa.py script
642
+ #iterations_slider = gr.Slider(minimum=1, maximum=10, step=1, value=2, label="Iterations", interactive=False, visible=False)
643
+ # rationality_slider = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=2.0, label="Rationality", interactive=False, visible=False)
644
+
645
+ with gr.Row():
646
+ unique_sentences = gr.Textbox(
647
+ lines=6, label="Most Divergent Opinions", visible=True, value=None, container=True
648
+ )
649
+ common_sentences = gr.Textbox(
650
+ lines=6, label="Most Common Opinions", visible=True, value=None, container=True
651
+ )
652
+
653
+ uniqueness_score_text1 = gr.HighlightedText(
654
+ show_legend=True, label="Agreement in Review 1", visible=True, value=None,
655
+ )
656
+ uniqueness_score_text2 = gr.HighlightedText(
657
+ show_legend=True, label="Agreement in Review 2", visible=True, value=None,
658
+ )
659
+ uniqueness_score_text3 = gr.HighlightedText(
660
+ show_legend=True, label="Agreement in Review 3", visible=True, value=None,
661
+ )
662
+
663
+
664
+ polarity_score_text1 = gr.HighlightedText(
665
+ show_legend=True, label="Polarity in Review 1", visible=False, value=None,
666
+ color_map={"➕": "#d4fcd6", "➖": "#fcd6d6" }
667
+ )
668
+ polarity_score_text2 = gr.HighlightedText(
669
+ show_legend=True, label="Polarity in Review 2", visible=False, value=None,
670
+ color_map={"➕": "#d4fcd6", "➖": "#fcd6d6" }
671
+ )
672
+ polarity_score_text3 = gr.HighlightedText(
673
+ show_legend=True, label="Polarity in Review 3", visible=False, value=None,
674
+ color_map={"➕": "#d4fcd6", "➖": "#fcd6d6" }
675
+ )
676
+
677
+ aspect_score_text1 = gr.HighlightedText(
678
+ show_legend=False, label="Topic in Review 1", visible=False, value=None,
679
+ color_map = topic_color_map
680
+ )
681
+ aspect_score_text2 = gr.HighlightedText(
682
+ show_legend=False, label="Topic in Review 2", visible=False, value=None,
683
+ color_map = topic_color_map
684
+ )
685
+ aspect_score_text3 = gr.HighlightedText(
686
+ show_legend=False, label="Topic in Review 3", visible=False, value=None,
687
+ color_map = topic_color_map
688
+ )
689
+
690
+
691
+
692
+
693
+ # Connect summarize function to submit button
694
+ submit_button.click(
695
+ fn=summarize,
696
+ inputs=[
697
+ review1_textbox, review2_textbox, review3_textbox,
698
+ focus_radio, mode_radio
699
+ ],
700
+ outputs=[
701
+ uniqueness_score_text1, uniqueness_score_text2, uniqueness_score_text3,
702
+ common_sentences, unique_sentences,
703
+ polarity_score_text1, polarity_score_text2, polarity_score_text3,
704
+ aspect_score_text1, aspect_score_text2, aspect_score_text3
705
+
706
+ ]
707
+ )
708
+
709
+ # Define clear button behavior
710
+ clear_button.click(
711
+ fn=lambda: (None, None, None, None, None, None, None, None, None, None, None), # clear all fields
712
+ inputs=[],
713
+ outputs=[
714
+ review1_textbox, review2_textbox, review3_textbox,
715
+ uniqueness_score_text1, uniqueness_score_text2, uniqueness_score_text3,
716
+ common_sentences, unique_sentences
717
+ ]
718
+ )
719
+
720
+ # Update visibility of generation_method_radio based on mode_radio value
721
+ # def toggle_generation_method(mode):
722
+ # if mode == "summary":
723
+ # return gr.update(visible=True), gr.update(visible=False) # show generation method radio, hide focus radio
724
+ # else:
725
+ # return gr.update(visible=False), gr.update(visible=True) # show focus radio, hide generation method radio
726
+
727
+ # mode_radio.change(
728
+ # fn=toggle_generation_method,
729
+ # inputs=mode_radio,
730
+ # outputs=[generation_method_radio, focus_radio]
731
+ # )
732
+
733
+ # Update visibility of output textboxes based on mode_radio and focus_radio values
734
+ def toggle_output_textboxes(mode, focus):
735
+ if mode == "highlight" and focus == "unique":
736
+ return (
737
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), # in-line uniqueness highlights
738
+ gr.update(visible=True), gr.update(visible=True), # summary highlights
739
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), # polarity highlights
740
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) # aspect highlights
741
+ )
742
+
743
+ elif focus == "Polarity":
744
+ return (
745
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), # in-line uniqueness highlights
746
+ gr.update(visible=False), gr.update(visible=False), # summary highlights
747
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), # polarity highlights
748
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) # aspect highlights
749
+ )
750
+
751
+ elif focus == "Topic":
752
+ return (
753
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), # in-line uniqueness highlights
754
+ gr.update(visible=False), gr.update(visible=False), # summary highlights
755
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), # polarity highlights
756
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) # aspect highlights
757
+ )
758
+
759
+ focus_radio.change(
760
+ fn=toggle_output_textboxes,
761
+ inputs=[mode_radio, focus_radio],
762
+ outputs=[
763
+ uniqueness_score_text1, uniqueness_score_text2, uniqueness_score_text3,
764
+ common_sentences, unique_sentences,
765
+ polarity_score_text1, polarity_score_text2, polarity_score_text3,
766
+ aspect_score_text1, aspect_score_text2, aspect_score_text3
767
+ ]
768
+ )
769
+ # mode_radio.change(
770
+ # fn=toggle_output_textboxes,
771
+ # inputs=[mode_radio, focus_radio],
772
+ # outputs=[
773
+ # uniqueness_score_text1, uniqueness_score_text2, uniqueness_score_text3,
774
+ # consensuality_score_text1, consensuality_score_text2, consensuality_score_text3,
775
+ # most_consensual_sentences, most_unique_sentences
776
+ # ]
777
+ # )
778
+
779
+ # TODO: Configure the slider for the number of review boxes
780
+
781
+ # def toggle_reviews(number_of_displayed_reviews):
782
+ # number_of_displayed_reviews = int(number_of_displayed_reviews)
783
+ # updates = []
784
+ # # for review(i), set visible True if its index is <= n, otherwise False.
785
+ # for i in range(1, 4): updates.append(gr.update(visible=(i <= number_of_displayed_reviews)))
786
+ # return tuple(updates)
787
+
788
+ # review_count.change(
789
+ # fn=toggle_reviews,
790
+ # inputs=[review_count],
791
+ # outputs=[review1_textbox, review2_textbox, review3_textbox]
792
+ # )
793
+
794
+ demo.load(
795
+ fn=update_review_display,
796
+ inputs=[state, score_type],
797
+ outputs=[review_id, review1, review2, review3, review4, review5, review6, review7, review8, most_common_sentences, most_unique_sentences, topic_text_box, state]
798
+ )
799
+
800
+ demo.launch(share=False)
glimpse-ui/scibert/scibert_polarity/final_model/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "allenai/scibert_scivocab_uncased",
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 768,
11
+ "id2label": {
12
+ "0": "LABEL_0",
13
+ "1": "LABEL_1",
14
+ "2": "LABEL_2"
15
+ },
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 3072,
18
+ "label2id": {
19
+ "LABEL_0": 0,
20
+ "LABEL_1": 1,
21
+ "LABEL_2": 2
22
+ },
23
+ "layer_norm_eps": 1e-12,
24
+ "max_position_embeddings": 512,
25
+ "model_type": "bert",
26
+ "num_attention_heads": 12,
27
+ "num_hidden_layers": 12,
28
+ "pad_token_id": 0,
29
+ "position_embedding_type": "absolute",
30
+ "torch_dtype": "float32",
31
+ "transformers_version": "4.46.1",
32
+ "type_vocab_size": 2,
33
+ "use_cache": true,
34
+ "vocab_size": 31090
35
+ }
glimpse-ui/scibert/scibert_polarity/final_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e259f6ae81187152e0aa80b9a478aec607c802eb270114eef2b64c8bac806d43
3
+ size 439706620
glimpse-ui/scibert/scibert_polarity/final_model/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }