Commit
·
6fe7180
0
Parent(s):
Super-squash branch 'main' using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +36 -0
- .gitignore +2 -0
- .gitmodules +3 -0
- README.md +13 -0
- glimpse-ui/.gitignore +362 -0
- glimpse-ui/LICENSE +21 -0
- glimpse-ui/alternative_polarity/deberta/deberta_v3_base_polarity.py +95 -0
- glimpse-ui/alternative_polarity/deberta/deberta_v3_base_polarity_train.py +98 -0
- glimpse-ui/alternative_polarity/manual_polarity_tester.py +65 -0
- glimpse-ui/alternative_polarity/scideberta/scideberta_full_polarity.py +79 -0
- glimpse-ui/alternative_polarity/scideberta/scideberta_full_polarity_train.py +108 -0
- glimpse-ui/alternative_topic/debetra/deberta_topic.py +92 -0
- glimpse-ui/alternative_topic/debetra/deberta_topic_train.py +80 -0
- glimpse-ui/alternative_topic/scideberta/scideberta_topic.py +92 -0
- glimpse-ui/alternative_topic/scideberta/scideberta_topic_train.py +80 -0
- glimpse-ui/data/ExtractDISAPEREData.py +106 -0
- glimpse-ui/glimpse/.gitignore +203 -0
- glimpse-ui/glimpse/Readme.md +69 -0
- glimpse-ui/glimpse/examples/RSA Sum tests.ipynb +189 -0
- glimpse-ui/glimpse/examples/reviews/reviews_app.py +274 -0
- glimpse-ui/glimpse/examples/reviews/reviews_latex_generation.py +272 -0
- glimpse-ui/glimpse/glimpse/baselines/generate_llm_summaries.py +112 -0
- glimpse-ui/glimpse/glimpse/baselines/sumy_baselines.py +129 -0
- glimpse-ui/glimpse/glimpse/data_loading/Glimpse_tokenizer.py +74 -0
- glimpse-ui/glimpse/glimpse/data_loading/data_processing.py +15 -0
- glimpse-ui/glimpse/glimpse/data_loading/generate_abstractive_candidates.py +230 -0
- glimpse-ui/glimpse/glimpse/data_loading/generate_extractive_candidates.py +129 -0
- glimpse-ui/glimpse/glimpse/evaluate/Evaluate informativeness.ipynb +258 -0
- glimpse-ui/glimpse/glimpse/evaluate/evaluate_bartbert_metrics.py +110 -0
- glimpse-ui/glimpse/glimpse/evaluate/evaluate_common_metrics_samples.py +122 -0
- glimpse-ui/glimpse/glimpse/evaluate/evaluate_seahorse_metrics_samples.py +150 -0
- glimpse-ui/glimpse/glimpse/src/beam_rsa_decoding.py +207 -0
- glimpse-ui/glimpse/glimpse/src/compute_rsa.py +137 -0
- glimpse-ui/glimpse/glimpse/src/rsa_merge_into_single.py +52 -0
- glimpse-ui/glimpse/glimpse/src/rsa_reranking.py +127 -0
- glimpse-ui/glimpse/mds/Single summaries expes.ipynb +587 -0
- glimpse-ui/glimpse/mds/Template summaries.ipynb +531 -0
- glimpse-ui/glimpse/mds/discriminative_classification.py +113 -0
- glimpse-ui/glimpse/pyproject.toml +21 -0
- glimpse-ui/glimpse/requirements +10 -0
- glimpse-ui/glimpse/rsasumm/__init__.py +0 -0
- glimpse-ui/glimpse/rsasumm/beam_search.py +430 -0
- glimpse-ui/glimpse/rsasumm/rsa_reranker.py +280 -0
- glimpse-ui/glimpse/scripts/abstractive.sh +37 -0
- glimpse-ui/glimpse/scripts/extractive.sh +31 -0
- glimpse-ui/glimpse_pk_csv_converter.py +92 -0
- glimpse-ui/interface/Demo.py +800 -0
- glimpse-ui/scibert/scibert_polarity/final_model/config.json +35 -0
- glimpse-ui/scibert/scibert_polarity/final_model/model.safetensors +3 -0
- glimpse-ui/scibert/scibert_polarity/final_model/special_tokens_map.json +7 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
glimpse-ui/data/preprocessed_scored_reviews.csv filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
./scibert/*
|
2 |
+
./alternative_*/
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "glimpse-ui"]
|
2 |
+
path = glimpse-ui
|
3 |
+
url = https://github.com/Sina1138/glimpse-ui.git
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: GlimpSys
|
3 |
+
emoji: 📊
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.35.0
|
8 |
+
app_file: glimpse-ui/interface/Demo.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
glimpse-ui/.gitignore
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# project ignores
|
2 |
+
glimpse/
|
3 |
+
data/DISAPERE-main/
|
4 |
+
*checkpoints/
|
5 |
+
.gradio/
|
6 |
+
test.py
|
7 |
+
data/*
|
8 |
+
final_model/
|
9 |
+
alternative_polarity/llama/
|
10 |
+
!data/ExtractDISAPEREData.py
|
11 |
+
!data/preprocessed_scored_reviews.csv
|
12 |
+
|
13 |
+
# Byte-compiled / optimized / DLL files
|
14 |
+
__pycache__/
|
15 |
+
*.py[cod]
|
16 |
+
*$py.class
|
17 |
+
|
18 |
+
# C extensions
|
19 |
+
*.so
|
20 |
+
|
21 |
+
# Distribution / packaging
|
22 |
+
.Python
|
23 |
+
build/
|
24 |
+
develop-eggs/
|
25 |
+
dist/
|
26 |
+
downloads/
|
27 |
+
eggs/
|
28 |
+
.eggs/
|
29 |
+
lib/
|
30 |
+
lib64/
|
31 |
+
parts/
|
32 |
+
sdist/
|
33 |
+
var/
|
34 |
+
wheels/
|
35 |
+
share/python-wheels/
|
36 |
+
*.egg-info/
|
37 |
+
.installed.cfg
|
38 |
+
*.egg
|
39 |
+
MANIFEST
|
40 |
+
|
41 |
+
# PyInstaller
|
42 |
+
# Usually these files are written by a python script from a template
|
43 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
44 |
+
*.manifest
|
45 |
+
*.spec
|
46 |
+
|
47 |
+
# Installer logs
|
48 |
+
pip-log.txt
|
49 |
+
pip-delete-this-directory.txt
|
50 |
+
|
51 |
+
# Unit test / coverage reports
|
52 |
+
htmlcov/
|
53 |
+
.tox/
|
54 |
+
.nox/
|
55 |
+
.coverage
|
56 |
+
.coverage.*
|
57 |
+
.cache
|
58 |
+
nosetests.xml
|
59 |
+
coverage.xml
|
60 |
+
*.cover
|
61 |
+
*.py,cover
|
62 |
+
.hypothesis/
|
63 |
+
.pytest_cache/
|
64 |
+
cover/
|
65 |
+
|
66 |
+
# Translations
|
67 |
+
*.mo
|
68 |
+
*.pot
|
69 |
+
|
70 |
+
# Django stuff:
|
71 |
+
*.log
|
72 |
+
local_settings.py
|
73 |
+
db.sqlite3
|
74 |
+
db.sqlite3-journal
|
75 |
+
|
76 |
+
# Flask stuff:
|
77 |
+
instance/
|
78 |
+
.webassets-cache
|
79 |
+
|
80 |
+
# Scrapy stuff:
|
81 |
+
.scrapy
|
82 |
+
|
83 |
+
# Sphinx documentation
|
84 |
+
docs/_build/
|
85 |
+
|
86 |
+
# PyBuilder
|
87 |
+
.pybuilder/
|
88 |
+
target/
|
89 |
+
|
90 |
+
# Jupyter Notebook
|
91 |
+
.ipynb_checkpoints
|
92 |
+
|
93 |
+
# IPython
|
94 |
+
profile_default/
|
95 |
+
ipython_config.py
|
96 |
+
|
97 |
+
# pyenv
|
98 |
+
# For a library or package, you might want to ignore these files since the code is
|
99 |
+
# intended to run in multiple environments; otherwise, check them in:
|
100 |
+
# .python-version
|
101 |
+
|
102 |
+
# pipenv
|
103 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
104 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
105 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
106 |
+
# install all needed dependencies.
|
107 |
+
#Pipfile.lock
|
108 |
+
|
109 |
+
# UV
|
110 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
111 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
112 |
+
# commonly ignored for libraries.
|
113 |
+
#uv.lock
|
114 |
+
|
115 |
+
# poetry
|
116 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
117 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
118 |
+
# commonly ignored for libraries.
|
119 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
120 |
+
#poetry.lock
|
121 |
+
|
122 |
+
# pdm
|
123 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
124 |
+
#pdm.lock
|
125 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
126 |
+
# in version control.
|
127 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
128 |
+
.pdm.toml
|
129 |
+
.pdm-python
|
130 |
+
.pdm-build/
|
131 |
+
|
132 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
133 |
+
__pypackages__/
|
134 |
+
|
135 |
+
# Celery stuff
|
136 |
+
celerybeat-schedule
|
137 |
+
celerybeat.pid
|
138 |
+
|
139 |
+
# SageMath parsed files
|
140 |
+
*.sage.py
|
141 |
+
|
142 |
+
# Environments
|
143 |
+
.env
|
144 |
+
.venv
|
145 |
+
env/
|
146 |
+
venv/
|
147 |
+
ENV/
|
148 |
+
env.bak/
|
149 |
+
venv.bak/
|
150 |
+
|
151 |
+
# Spyder project settings
|
152 |
+
.spyderproject
|
153 |
+
.spyproject
|
154 |
+
|
155 |
+
# Rope project settings
|
156 |
+
.ropeproject
|
157 |
+
|
158 |
+
# mkdocs documentation
|
159 |
+
/site
|
160 |
+
|
161 |
+
# mypy
|
162 |
+
.mypy_cache/
|
163 |
+
.dmypy.json
|
164 |
+
dmypy.json
|
165 |
+
|
166 |
+
# Pyre type checker
|
167 |
+
.pyre/
|
168 |
+
|
169 |
+
# pytype static type analyzer
|
170 |
+
.pytype/
|
171 |
+
|
172 |
+
# Cython debug symbols
|
173 |
+
cython_debug/
|
174 |
+
|
175 |
+
# PyCharm
|
176 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
177 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
178 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
179 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
180 |
+
#.idea/
|
181 |
+
|
182 |
+
# Ruff stuff:
|
183 |
+
.ruff_cache/
|
184 |
+
|
185 |
+
# PyPI configuration file
|
186 |
+
.pypirc
|
187 |
+
|
188 |
+
# Byte-compiled / optimized / DLL files
|
189 |
+
__pycache__/
|
190 |
+
*.py[cod]
|
191 |
+
*$py.class
|
192 |
+
|
193 |
+
# C extensions
|
194 |
+
*.so
|
195 |
+
|
196 |
+
# Distribution / packaging
|
197 |
+
.Python
|
198 |
+
build/
|
199 |
+
develop-eggs/
|
200 |
+
dist/
|
201 |
+
downloads/
|
202 |
+
eggs/
|
203 |
+
.eggs/
|
204 |
+
lib/
|
205 |
+
lib64/
|
206 |
+
parts/
|
207 |
+
sdist/
|
208 |
+
var/
|
209 |
+
wheels/
|
210 |
+
share/python-wheels/
|
211 |
+
*.egg-info/
|
212 |
+
.installed.cfg
|
213 |
+
*.egg
|
214 |
+
MANIFEST
|
215 |
+
|
216 |
+
# PyInstaller
|
217 |
+
# Usually these files are written by a python script from a template
|
218 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
219 |
+
*.manifest
|
220 |
+
*.spec
|
221 |
+
|
222 |
+
# Installer logs
|
223 |
+
pip-log.txt
|
224 |
+
pip-delete-this-directory.txt
|
225 |
+
|
226 |
+
# Unit test / coverage reports
|
227 |
+
htmlcov/
|
228 |
+
.tox/
|
229 |
+
.nox/
|
230 |
+
.coverage
|
231 |
+
.coverage.*
|
232 |
+
.cache
|
233 |
+
nosetests.xml
|
234 |
+
coverage.xml
|
235 |
+
*.cover
|
236 |
+
*.py,cover
|
237 |
+
.hypothesis/
|
238 |
+
.pytest_cache/
|
239 |
+
cover/
|
240 |
+
|
241 |
+
# Translations
|
242 |
+
*.mo
|
243 |
+
*.pot
|
244 |
+
|
245 |
+
# Django stuff:
|
246 |
+
*.log
|
247 |
+
local_settings.py
|
248 |
+
db.sqlite3
|
249 |
+
db.sqlite3-journal
|
250 |
+
|
251 |
+
# Flask stuff:
|
252 |
+
instance/
|
253 |
+
.webassets-cache
|
254 |
+
|
255 |
+
# Scrapy stuff:
|
256 |
+
.scrapy
|
257 |
+
|
258 |
+
# Sphinx documentation
|
259 |
+
docs/_build/
|
260 |
+
|
261 |
+
# PyBuilder
|
262 |
+
.pybuilder/
|
263 |
+
target/
|
264 |
+
|
265 |
+
# Jupyter Notebook
|
266 |
+
.ipynb_checkpoints
|
267 |
+
|
268 |
+
# IPython
|
269 |
+
profile_default/
|
270 |
+
ipython_config.py
|
271 |
+
|
272 |
+
# pyenv
|
273 |
+
# For a library or package, you might want to ignore these files since the code is
|
274 |
+
# intended to run in multiple environments; otherwise, check them in:
|
275 |
+
# .python-version
|
276 |
+
|
277 |
+
# pipenv
|
278 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
279 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
280 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
281 |
+
# install all needed dependencies.
|
282 |
+
#Pipfile.lock
|
283 |
+
|
284 |
+
# UV
|
285 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
286 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
287 |
+
# commonly ignored for libraries.
|
288 |
+
#uv.lock
|
289 |
+
|
290 |
+
# poetry
|
291 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
292 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
293 |
+
# commonly ignored for libraries.
|
294 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
295 |
+
#poetry.lock
|
296 |
+
|
297 |
+
# pdm
|
298 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
299 |
+
#pdm.lock
|
300 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
301 |
+
# in version control.
|
302 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
303 |
+
.pdm.toml
|
304 |
+
.pdm-python
|
305 |
+
.pdm-build/
|
306 |
+
|
307 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
308 |
+
__pypackages__/
|
309 |
+
|
310 |
+
# Celery stuff
|
311 |
+
celerybeat-schedule
|
312 |
+
celerybeat.pid
|
313 |
+
|
314 |
+
# SageMath parsed files
|
315 |
+
*.sage.py
|
316 |
+
|
317 |
+
# Environments
|
318 |
+
.env
|
319 |
+
.venv
|
320 |
+
env/
|
321 |
+
venv/
|
322 |
+
ENV/
|
323 |
+
env.bak/
|
324 |
+
venv.bak/
|
325 |
+
|
326 |
+
# Spyder project settings
|
327 |
+
.spyderproject
|
328 |
+
.spyproject
|
329 |
+
|
330 |
+
# Rope project settings
|
331 |
+
.ropeproject
|
332 |
+
|
333 |
+
# mkdocs documentation
|
334 |
+
/site
|
335 |
+
|
336 |
+
# mypy
|
337 |
+
.mypy_cache/
|
338 |
+
.dmypy.json
|
339 |
+
dmypy.json
|
340 |
+
|
341 |
+
# Pyre type checker
|
342 |
+
.pyre/
|
343 |
+
|
344 |
+
# pytype static type analyzer
|
345 |
+
.pytype/
|
346 |
+
|
347 |
+
# Cython debug symbols
|
348 |
+
cython_debug/
|
349 |
+
|
350 |
+
# PyCharm
|
351 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
352 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
353 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
354 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
355 |
+
#.idea/
|
356 |
+
|
357 |
+
# Ruff stuff:
|
358 |
+
.ruff_cache/
|
359 |
+
|
360 |
+
# PyPI configuration file
|
361 |
+
.pypirc
|
362 |
+
data/DISAPERE_test.py
|
glimpse-ui/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Sina Salmannia
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
glimpse-ui/alternative_polarity/deberta/deberta_v3_base_polarity.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
4 |
+
from pathlib import Path
|
5 |
+
import nltk
|
6 |
+
from tqdm import tqdm
|
7 |
+
import sys, os.path
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
nltk.download('punkt')
|
11 |
+
|
12 |
+
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
13 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
14 |
+
|
15 |
+
from glimpse.glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
|
16 |
+
|
17 |
+
# === CONFIGURATION ===
|
18 |
+
|
19 |
+
MODEL_DIR = BASE_DIR / "alternative_polarity" / "deberta" / "deberta_v3_base_polarity_final_model"
|
20 |
+
DATA_DIR = BASE_DIR / "glimpse" / "data" / "processed"
|
21 |
+
OUTPUT_DIR = BASE_DIR / "data" / "polarity_scored"
|
22 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
23 |
+
|
24 |
+
# === Load model and tokenizer ===
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
26 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
|
27 |
+
model.eval()
|
28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
model.to(device)
|
30 |
+
|
31 |
+
# === Tokenize like GLIMPSE ===
|
32 |
+
# def tokenize_sentences(text: str) -> list:
|
33 |
+
# # same tokenization as in the original glimpse code
|
34 |
+
# text = text.replace('-----', '\n')
|
35 |
+
# sentences = nltk.sent_tokenize(text)
|
36 |
+
# sentences = [sentence for sentence in sentences if sentence != ""]
|
37 |
+
# return sentences
|
38 |
+
|
39 |
+
|
40 |
+
# def predict_polarity(sentences):
|
41 |
+
# inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
|
42 |
+
# with torch.no_grad():
|
43 |
+
# outputs = model(**inputs)
|
44 |
+
# logits = outputs.logits
|
45 |
+
# temperature = 2.7 # Adjust temperature for scaling logits
|
46 |
+
# probs = F.softmax(logits / temperature, dim=-1)
|
47 |
+
# # Get probability of positive class
|
48 |
+
# polarity_scores = probs[:, 1]
|
49 |
+
# # Rescale: 0 → -1 (very negative), 1 → +1 (very positive)
|
50 |
+
# polarity_scores = (polarity_scores * 2) - 1
|
51 |
+
# return polarity_scores.cpu().tolist()
|
52 |
+
|
53 |
+
def predict_polarity(sentences):
|
54 |
+
inputs = tokenizer(
|
55 |
+
sentences,
|
56 |
+
return_tensors="pt",
|
57 |
+
padding=True,
|
58 |
+
truncation=True,
|
59 |
+
max_length=512
|
60 |
+
).to(device)
|
61 |
+
|
62 |
+
with torch.no_grad():
|
63 |
+
logits = model(**inputs).logits # (batch, 2)
|
64 |
+
logit_diff = logits[:,1] - logits[:,0]
|
65 |
+
alpha = 2.1 # tweak
|
66 |
+
scores = torch.tanh(alpha * logit_diff) # in [-1,1]
|
67 |
+
return scores.cpu().tolist()
|
68 |
+
|
69 |
+
|
70 |
+
def find_polarity(start_year=2017, end_year=2021):
|
71 |
+
for year in range(start_year, end_year + 1):
|
72 |
+
print(f"Processing {year}...")
|
73 |
+
input_path = DATA_DIR / f"all_reviews_{year}.csv"
|
74 |
+
output_path = OUTPUT_DIR / f"polarity_scored_reviews_{year}.csv"
|
75 |
+
|
76 |
+
df = pd.read_csv(input_path)
|
77 |
+
|
78 |
+
all_rows = []
|
79 |
+
for _, row in tqdm(df.iterrows(), total=len(df)):
|
80 |
+
review_id = row["id"]
|
81 |
+
text = row["text"]
|
82 |
+
sentences = glimpse_tokenizer(text)
|
83 |
+
if not sentences:
|
84 |
+
continue
|
85 |
+
labels = predict_polarity(sentences)
|
86 |
+
for sentence, polarity in zip(sentences, labels):
|
87 |
+
all_rows.append({"id": review_id, "sentence": sentence, "polarity": polarity})
|
88 |
+
|
89 |
+
output_df = pd.DataFrame(all_rows)
|
90 |
+
output_df.to_csv(output_path, index=False)
|
91 |
+
print(f"Saved polarity-scored data to {output_path}")
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
find_polarity()
|
glimpse-ui/alternative_polarity/deberta/deberta_v3_base_polarity_train.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from datasets import Dataset
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
|
4 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from transformers import Trainer
|
10 |
+
|
11 |
+
|
12 |
+
# Load data
|
13 |
+
train_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_polarity_train.csv")
|
14 |
+
dev_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_polarity_dev.csv")
|
15 |
+
test_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_polarity_test.csv")
|
16 |
+
|
17 |
+
# Convert to HuggingFace Datasets
|
18 |
+
train_ds = Dataset.from_pandas(train_df)
|
19 |
+
dev_ds = Dataset.from_pandas(dev_df)
|
20 |
+
test_ds = Dataset.from_pandas(test_df)
|
21 |
+
|
22 |
+
# Tokenize
|
23 |
+
model_name = "microsoft/deberta-v3-base"
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
25 |
+
def tokenize(batch):
|
26 |
+
return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=512)
|
27 |
+
|
28 |
+
train_ds = train_ds.map(tokenize, batched=True)
|
29 |
+
dev_ds = dev_ds.map(tokenize, batched=True)
|
30 |
+
test_ds = test_ds.map(tokenize, batched=True)
|
31 |
+
|
32 |
+
# Set format for PyTorch
|
33 |
+
train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
34 |
+
dev_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
35 |
+
test_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
36 |
+
|
37 |
+
# Load model
|
38 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
|
39 |
+
|
40 |
+
# Compute class weights
|
41 |
+
label_counts = train_df['label'].value_counts()
|
42 |
+
total_samples = len(train_df)
|
43 |
+
class_weights = torch.tensor([total_samples / (len(label_counts) * count) for count in label_counts.sort_index().values])
|
44 |
+
class_weights = class_weights.to(dtype=torch.float32)
|
45 |
+
print("Class weights:", class_weights)
|
46 |
+
|
47 |
+
class WeightedTrainer(Trainer):
|
48 |
+
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
49 |
+
labels = inputs.pop("labels")
|
50 |
+
outputs = model(**inputs)
|
51 |
+
logits = outputs.logits
|
52 |
+
weights = class_weights.to(logits.device)
|
53 |
+
loss = F.cross_entropy(logits, labels, weight=weights)
|
54 |
+
return (loss, outputs) if return_outputs else loss
|
55 |
+
|
56 |
+
|
57 |
+
# Metrics
|
58 |
+
def compute_metrics(eval_pred):
|
59 |
+
logits, labels = eval_pred
|
60 |
+
preds = np.argmax(logits, axis=1)
|
61 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
|
62 |
+
acc = accuracy_score(labels, preds)
|
63 |
+
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
|
64 |
+
|
65 |
+
# Training arguments
|
66 |
+
args = TrainingArguments(
|
67 |
+
output_dir="./alternative_polarity/deberta/checkpoints",
|
68 |
+
eval_strategy="epoch",
|
69 |
+
save_strategy="epoch",
|
70 |
+
learning_rate=2e-5,
|
71 |
+
per_device_train_batch_size=4,
|
72 |
+
per_device_eval_batch_size=8,
|
73 |
+
num_train_epochs=4,
|
74 |
+
weight_decay=0.01,
|
75 |
+
load_best_model_at_end=True,
|
76 |
+
metric_for_best_model="f1"
|
77 |
+
)
|
78 |
+
|
79 |
+
# Trainer
|
80 |
+
trainer = WeightedTrainer(
|
81 |
+
model=model,
|
82 |
+
args=args,
|
83 |
+
train_dataset=train_ds,
|
84 |
+
eval_dataset=dev_ds,
|
85 |
+
tokenizer=tokenizer,
|
86 |
+
compute_metrics=compute_metrics
|
87 |
+
)
|
88 |
+
|
89 |
+
# Train
|
90 |
+
trainer.train()
|
91 |
+
|
92 |
+
# Evaluate on test
|
93 |
+
results = trainer.evaluate(test_ds)
|
94 |
+
print("Test results:", results)
|
95 |
+
|
96 |
+
# Save the model and tokenizer
|
97 |
+
model.save_pretrained("./alternative_polarity/deberta/deberta_v3_base_polarity_final_model")
|
98 |
+
tokenizer.save_pretrained("./alternative_polarity/deberta/deberta_v3_base_polarity_final_model")
|
glimpse-ui/alternative_polarity/manual_polarity_tester.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
4 |
+
from pathlib import Path
|
5 |
+
import sys, os
|
6 |
+
|
7 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
8 |
+
from glimpse.glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
|
9 |
+
|
10 |
+
# === CONFIGURATION ===
|
11 |
+
BASE_DIR = Path(__file__).resolve().parent.parent
|
12 |
+
MODEL_DIR = BASE_DIR / "alternative_polarity" / "deberta" / "deberta_v3_large_polarity_final_model"
|
13 |
+
# MODEL_DIR = BASE_DIR / "alternative_polarity" / "llama" / "final_model"
|
14 |
+
# MODEL_DIR = BASE_DIR / "alternative_polarity" / "scideberta" / "scideberta_full_polarity_final_model"
|
15 |
+
|
16 |
+
# --> Best so far: deberta_v3 (passes "pros" test)
|
17 |
+
|
18 |
+
|
19 |
+
# === Load model and tokenizer ===
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
21 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
|
22 |
+
model.eval()
|
23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
+
model.to(device)
|
25 |
+
|
26 |
+
# === Prediction function with confidence ===
|
27 |
+
def predict_polarity(sentences):
|
28 |
+
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
29 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
30 |
+
|
31 |
+
with torch.no_grad():
|
32 |
+
outputs = model(**inputs)
|
33 |
+
probs = F.softmax(outputs.logits, dim=1)
|
34 |
+
confidences, preds = torch.max(probs, dim=1)
|
35 |
+
|
36 |
+
results = []
|
37 |
+
for sentence, pred, conf, prob in zip(sentences, preds, confidences, probs):
|
38 |
+
results.append({
|
39 |
+
"sentence": sentence,
|
40 |
+
"label": "Positive" if pred.item() == 1 else "Negative",
|
41 |
+
"confidence": conf.item(),
|
42 |
+
"probs": prob.cpu().numpy().tolist()
|
43 |
+
})
|
44 |
+
return results
|
45 |
+
|
46 |
+
# === Example: test a multi-sentence peer review ===
|
47 |
+
if __name__ == "__main__":
|
48 |
+
# Replace this with your review
|
49 |
+
full_review = """
|
50 |
+
Pros:
|
51 |
+
Con: The experiments lack comparison with prior work.
|
52 |
+
The authors clearly explain their methodology, which is a strong point.
|
53 |
+
"""
|
54 |
+
|
55 |
+
# Use glimpse tokenizer to split into sentences
|
56 |
+
sentences = glimpse_tokenizer(full_review)
|
57 |
+
|
58 |
+
# Run polarity prediction
|
59 |
+
results = predict_polarity(sentences)
|
60 |
+
|
61 |
+
# Display results
|
62 |
+
for res in results:
|
63 |
+
print(f"\nSentence: {res['sentence']}")
|
64 |
+
print(f" → Prediction: {res['label']} (Confidence: {res['confidence']:.3f})")
|
65 |
+
print(f" Probabilities: [Negative: {res['probs'][0]:.3f}, Positive: {res['probs'][1]:.3f}]")
|
glimpse-ui/alternative_polarity/scideberta/scideberta_full_polarity.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
4 |
+
from pathlib import Path
|
5 |
+
import nltk
|
6 |
+
from tqdm import tqdm
|
7 |
+
import sys, os.path
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
nltk.download('punkt')
|
11 |
+
|
12 |
+
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
13 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
14 |
+
|
15 |
+
from glimpse.glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
|
16 |
+
|
17 |
+
# === CONFIGURATION ===
|
18 |
+
|
19 |
+
MODEL_DIR = BASE_DIR / "alternative_polarity" / "scideberta" / "scideberta_full_polarity_final_model"
|
20 |
+
DATA_DIR = BASE_DIR / "glimpse" / "data" / "processed"
|
21 |
+
OUTPUT_DIR = BASE_DIR / "data" / "polarity_scored"
|
22 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
23 |
+
|
24 |
+
# === Load model and tokenizer ===
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
26 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
|
27 |
+
model.eval()
|
28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
model.to(device)
|
30 |
+
|
31 |
+
# === Tokenize like GLIMPSE ===
|
32 |
+
# def tokenize_sentences(text: str) -> list:
|
33 |
+
# # same tokenization as in the original glimpse code
|
34 |
+
# text = text.replace('-----', '\n')
|
35 |
+
# sentences = nltk.sent_tokenize(text)
|
36 |
+
# sentences = [sentence for sentence in sentences if sentence != ""]
|
37 |
+
# return sentences
|
38 |
+
|
39 |
+
|
40 |
+
def predict_polarity(sentences):
|
41 |
+
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
|
42 |
+
with torch.no_grad():
|
43 |
+
outputs = model(**inputs)
|
44 |
+
logits = outputs.logits
|
45 |
+
temperature = 2.7
|
46 |
+
probs = F.softmax(logits / temperature, dim=-1)
|
47 |
+
# Get probability of positive class
|
48 |
+
polarity_scores = probs[:, 1]
|
49 |
+
# Rescale: 0 → -1 (very negative), 1 → +1 (very positive)
|
50 |
+
polarity_scores = (polarity_scores * 2) - 1
|
51 |
+
return polarity_scores.cpu().tolist()
|
52 |
+
|
53 |
+
|
54 |
+
def find_polarity(start_year=2017, end_year=2021):
|
55 |
+
for year in range(start_year, end_year + 1):
|
56 |
+
print(f"Processing {year}...")
|
57 |
+
input_path = DATA_DIR / f"all_reviews_{year}.csv"
|
58 |
+
output_path = OUTPUT_DIR / f"polarity_scored_reviews_{year}.csv"
|
59 |
+
|
60 |
+
df = pd.read_csv(input_path)
|
61 |
+
|
62 |
+
all_rows = []
|
63 |
+
for _, row in tqdm(df.iterrows(), total=len(df)):
|
64 |
+
review_id = row["id"]
|
65 |
+
text = row["text"]
|
66 |
+
sentences = glimpse_tokenizer(text)
|
67 |
+
if not sentences:
|
68 |
+
continue
|
69 |
+
labels = predict_polarity(sentences)
|
70 |
+
for sentence, polarity in zip(sentences, labels):
|
71 |
+
all_rows.append({"id": review_id, "sentence": sentence, "polarity": polarity})
|
72 |
+
|
73 |
+
output_df = pd.DataFrame(all_rows)
|
74 |
+
output_df.to_csv(output_path, index=False)
|
75 |
+
print(f"Saved polarity-scored data to {output_path}")
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
find_polarity()
|
glimpse-ui/alternative_polarity/scideberta/scideberta_full_polarity_train.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from datasets import Dataset
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
|
4 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from transformers import Trainer
|
10 |
+
|
11 |
+
class WeightedTrainer(Trainer):
|
12 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
13 |
+
labels = inputs.pop("labels")
|
14 |
+
outputs = model(**inputs)
|
15 |
+
logits = outputs.logits
|
16 |
+
weights = class_weights.to(logits.device)
|
17 |
+
loss = F.cross_entropy(logits, labels, weight=weights)
|
18 |
+
return (loss, outputs) if return_outputs else loss
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
# Load data
|
23 |
+
train_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_polarity_train.csv")
|
24 |
+
dev_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_polarity_dev.csv")
|
25 |
+
test_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_polarity_test.csv")
|
26 |
+
|
27 |
+
# Convert to HuggingFace Datasets
|
28 |
+
train_ds = Dataset.from_pandas(train_df)
|
29 |
+
dev_ds = Dataset.from_pandas(dev_df)
|
30 |
+
test_ds = Dataset.from_pandas(test_df)
|
31 |
+
|
32 |
+
model_name = "KISTI-AI/Scideberta-full"
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
34 |
+
|
35 |
+
def tokenize(batch):
|
36 |
+
return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=512)
|
37 |
+
|
38 |
+
train_ds = train_ds.map(tokenize, batched=True)
|
39 |
+
dev_ds = dev_ds.map(tokenize, batched=True)
|
40 |
+
test_ds = test_ds.map(tokenize, batched=True)
|
41 |
+
|
42 |
+
# Set format for PyTorch
|
43 |
+
train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
44 |
+
dev_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
45 |
+
test_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
46 |
+
|
47 |
+
# Load model
|
48 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
|
49 |
+
|
50 |
+
# Compute class weights
|
51 |
+
label_counts = train_df['label'].value_counts()
|
52 |
+
total_samples = len(train_df)
|
53 |
+
class_weights = torch.tensor([total_samples / (len(label_counts) * count) for count in label_counts.sort_index().values])
|
54 |
+
class_weights = class_weights.to(dtype=torch.float32)
|
55 |
+
print("Class weights:", class_weights)
|
56 |
+
|
57 |
+
class WeightedTrainer(Trainer):
|
58 |
+
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
59 |
+
labels = inputs.pop("labels")
|
60 |
+
outputs = model(**inputs)
|
61 |
+
logits = outputs.logits
|
62 |
+
weights = class_weights.to(logits.device)
|
63 |
+
loss = F.cross_entropy(logits, labels, weight=weights)
|
64 |
+
return (loss, outputs) if return_outputs else loss
|
65 |
+
|
66 |
+
|
67 |
+
# Metrics
|
68 |
+
def compute_metrics(eval_pred):
|
69 |
+
logits, labels = eval_pred
|
70 |
+
preds = np.argmax(logits, axis=1)
|
71 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
|
72 |
+
acc = accuracy_score(labels, preds)
|
73 |
+
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
|
74 |
+
|
75 |
+
# Training arguments
|
76 |
+
args = TrainingArguments(
|
77 |
+
output_dir="./alternative_polarity/scideberta/checkpoints",
|
78 |
+
eval_strategy="epoch",
|
79 |
+
save_strategy="epoch",
|
80 |
+
learning_rate=2e-5,
|
81 |
+
per_device_train_batch_size=4,
|
82 |
+
per_device_eval_batch_size=8,
|
83 |
+
num_train_epochs=4,
|
84 |
+
weight_decay=0.01,
|
85 |
+
load_best_model_at_end=True,
|
86 |
+
metric_for_best_model="f1"
|
87 |
+
)
|
88 |
+
|
89 |
+
# Trainer
|
90 |
+
trainer = WeightedTrainer(
|
91 |
+
model=model,
|
92 |
+
args=args,
|
93 |
+
train_dataset=train_ds,
|
94 |
+
eval_dataset=dev_ds,
|
95 |
+
tokenizer=tokenizer,
|
96 |
+
compute_metrics=compute_metrics
|
97 |
+
)
|
98 |
+
|
99 |
+
# Train
|
100 |
+
trainer.train()
|
101 |
+
|
102 |
+
# Evaluate on test
|
103 |
+
results = trainer.evaluate(test_ds)
|
104 |
+
print("Test results:", results)
|
105 |
+
|
106 |
+
# Save the model and tokenizer
|
107 |
+
model.save_pretrained("./alternative_polarity/scideberta/scideberta_full_polarity_final_model")
|
108 |
+
tokenizer.save_pretrained("./alternative_polarity/scideberta/scideberta_full_polarity_final_model")
|
glimpse-ui/alternative_topic/debetra/deberta_topic.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
4 |
+
from pathlib import Path
|
5 |
+
import nltk
|
6 |
+
from tqdm import tqdm
|
7 |
+
import sys, os.path
|
8 |
+
|
9 |
+
nltk.download('punkt')
|
10 |
+
|
11 |
+
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
12 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
13 |
+
|
14 |
+
from glimpse.glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
|
15 |
+
|
16 |
+
# === CONFIGURATION ===
|
17 |
+
|
18 |
+
MODEL_DIR = BASE_DIR / "alternative_topic" / "deberta" / "final_model"
|
19 |
+
DATA_DIR = BASE_DIR / "glimpse" / "data" / "processed"
|
20 |
+
OUTPUT_DIR = BASE_DIR / "data" / "topic_scored"
|
21 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
22 |
+
|
23 |
+
# === Load model and tokenizer ===
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
25 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
|
26 |
+
model.eval()
|
27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
model.to(device)
|
29 |
+
|
30 |
+
# === Tokenize like GLIMPSE ===
|
31 |
+
# def tokenize_sentences(text: str) -> list:
|
32 |
+
# # same tokenization as in the original glimpse code
|
33 |
+
# text = text.replace('-----', '\n')
|
34 |
+
# sentences = nltk.sent_tokenize(text)
|
35 |
+
# sentences = [sentence for sentence in sentences if sentence != ""]
|
36 |
+
# return sentences
|
37 |
+
|
38 |
+
|
39 |
+
# === Label map (optional: for human-readable output) ===
|
40 |
+
id2label = {
|
41 |
+
# 0: "Evaluative",
|
42 |
+
# 1: "Structuring",
|
43 |
+
# 2: "Request",
|
44 |
+
# 3: "Fact",
|
45 |
+
# 4: "Social",
|
46 |
+
# 5: "Other",
|
47 |
+
0: "Substance",
|
48 |
+
1: "Clarity",
|
49 |
+
2: "Soundness/Correctness",
|
50 |
+
3: "Originality",
|
51 |
+
4: "Motivation/Impact",
|
52 |
+
5: "Meaningful Comparison",
|
53 |
+
6: "Replicability",
|
54 |
+
7: "NONE" # This is used for sentences that do not match any specific topic
|
55 |
+
}
|
56 |
+
|
57 |
+
def predict_topic(sentences):
|
58 |
+
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
|
59 |
+
with torch.no_grad():
|
60 |
+
outputs = model(**inputs)
|
61 |
+
predictions = torch.argmax(outputs.logits, dim=1).cpu().tolist()
|
62 |
+
# Convert predictions to human-readable labels
|
63 |
+
predictions = [id2label[pred] for pred in predictions]
|
64 |
+
return predictions
|
65 |
+
|
66 |
+
|
67 |
+
def find_topic(start_year=2017, end_year=2021):
|
68 |
+
for year in range(start_year, end_year + 1):
|
69 |
+
print(f"Processing {year}...")
|
70 |
+
input_path = DATA_DIR / f"all_reviews_{year}.csv"
|
71 |
+
output_path = OUTPUT_DIR / f"topic_scored_reviews_{year}.csv"
|
72 |
+
|
73 |
+
df = pd.read_csv(input_path)
|
74 |
+
|
75 |
+
all_rows = []
|
76 |
+
for _, row in tqdm(df.iterrows(), total=len(df)):
|
77 |
+
review_id = row["id"]
|
78 |
+
text = row["text"]
|
79 |
+
sentences = glimpse_tokenizer(text)
|
80 |
+
if not sentences:
|
81 |
+
continue
|
82 |
+
labels = predict_topic(sentences)
|
83 |
+
for sentence, topic in zip(sentences, labels):
|
84 |
+
all_rows.append({"id": review_id, "sentence": sentence, "topic": topic})
|
85 |
+
|
86 |
+
output_df = pd.DataFrame(all_rows)
|
87 |
+
output_df.to_csv(output_path, index=False)
|
88 |
+
print(f"Saved topic-scored data to {output_path}")
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
find_topic()
|
glimpse-ui/alternative_topic/debetra/deberta_topic_train.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from datasets import Dataset
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
|
4 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
# Load data
|
8 |
+
dev_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_topic_dev.csv")
|
9 |
+
train_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_topic_train.csv")
|
10 |
+
test_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_topic_test.csv")
|
11 |
+
|
12 |
+
# Convert to HuggingFace Datasets
|
13 |
+
train_ds = Dataset.from_pandas(train_df)
|
14 |
+
dev_ds = Dataset.from_pandas(dev_df)
|
15 |
+
test_ds = Dataset.from_pandas(test_df)
|
16 |
+
|
17 |
+
# Tokenize
|
18 |
+
model_name = "microsoft/deberta-v3-base"
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
20 |
+
|
21 |
+
def tokenize(batch):
|
22 |
+
return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=256)
|
23 |
+
|
24 |
+
train_ds = train_ds.map(tokenize, batched=True)
|
25 |
+
dev_ds = dev_ds.map(tokenize, batched=True)
|
26 |
+
test_ds = test_ds.map(tokenize, batched=True)
|
27 |
+
|
28 |
+
# Set format for PyTorch
|
29 |
+
train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
30 |
+
dev_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
31 |
+
test_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
32 |
+
|
33 |
+
print(train_df['label'].value_counts().sort_index())
|
34 |
+
|
35 |
+
|
36 |
+
# Load model
|
37 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
|
38 |
+
|
39 |
+
# Metrics
|
40 |
+
def compute_metrics(eval_pred):
|
41 |
+
logits, labels = eval_pred
|
42 |
+
preds = np.argmax(logits, axis=1)
|
43 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
|
44 |
+
acc = accuracy_score(labels, preds)
|
45 |
+
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
|
46 |
+
|
47 |
+
# Training arguments
|
48 |
+
args = TrainingArguments(
|
49 |
+
output_dir="./alternative_topic/deberta/checkpoints",
|
50 |
+
eval_strategy="epoch",
|
51 |
+
save_strategy="epoch",
|
52 |
+
learning_rate=2e-5,
|
53 |
+
per_device_train_batch_size=8,
|
54 |
+
per_device_eval_batch_size=16,
|
55 |
+
num_train_epochs=4,
|
56 |
+
weight_decay=0.01,
|
57 |
+
load_best_model_at_end=True,
|
58 |
+
metric_for_best_model="f1"
|
59 |
+
)
|
60 |
+
|
61 |
+
# Trainer
|
62 |
+
trainer = Trainer(
|
63 |
+
model=model,
|
64 |
+
args=args,
|
65 |
+
train_dataset=train_ds,
|
66 |
+
eval_dataset=dev_ds,
|
67 |
+
tokenizer=tokenizer,
|
68 |
+
compute_metrics=compute_metrics
|
69 |
+
)
|
70 |
+
|
71 |
+
# Train
|
72 |
+
trainer.train()
|
73 |
+
|
74 |
+
# Evaluate on test
|
75 |
+
results = trainer.evaluate(test_ds)
|
76 |
+
print("Test results:", results)
|
77 |
+
|
78 |
+
# Save the model and tokenizer
|
79 |
+
model.save_pretrained("./alternative_topic/deberta/final_model")
|
80 |
+
tokenizer.save_pretrained("./alternative_topic/deberta/final_model")
|
glimpse-ui/alternative_topic/scideberta/scideberta_topic.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
4 |
+
from pathlib import Path
|
5 |
+
import nltk
|
6 |
+
from tqdm import tqdm
|
7 |
+
import sys, os.path
|
8 |
+
|
9 |
+
nltk.download('punkt')
|
10 |
+
|
11 |
+
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
12 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
13 |
+
|
14 |
+
from glimpse.glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
|
15 |
+
|
16 |
+
# === CONFIGURATION ===
|
17 |
+
|
18 |
+
MODEL_DIR = BASE_DIR / "alternative_topic" / "scideberta" / "final_model"
|
19 |
+
DATA_DIR = BASE_DIR / "glimpse" / "data" / "processed"
|
20 |
+
OUTPUT_DIR = BASE_DIR / "data" / "topic_scored"
|
21 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
22 |
+
|
23 |
+
# === Load model and tokenizer ===
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
25 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
|
26 |
+
model.eval()
|
27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
model.to(device)
|
29 |
+
|
30 |
+
# === Tokenize like GLIMPSE ===
|
31 |
+
# def tokenize_sentences(text: str) -> list:
|
32 |
+
# # same tokenization as in the original glimpse code
|
33 |
+
# text = text.replace('-----', '\n')
|
34 |
+
# sentences = nltk.sent_tokenize(text)
|
35 |
+
# sentences = [sentence for sentence in sentences if sentence != ""]
|
36 |
+
# return sentences
|
37 |
+
|
38 |
+
|
39 |
+
# === Label map (optional: for human-readable output) ===
|
40 |
+
id2label = {
|
41 |
+
# 0: "Evaluative",
|
42 |
+
# 1: "Structuring",
|
43 |
+
# 2: "Request",
|
44 |
+
# 3: "Fact",
|
45 |
+
# 4: "Social",
|
46 |
+
# 5: "Other",
|
47 |
+
0: "Substance",
|
48 |
+
1: "Clarity",
|
49 |
+
2: "Soundness/Correctness",
|
50 |
+
3: "Originality",
|
51 |
+
4: "Motivation/Impact",
|
52 |
+
5: "Meaningful Comparison",
|
53 |
+
6: "Replicability",
|
54 |
+
7: "NONE" # This is used for sentences that do not match any specific topic
|
55 |
+
}
|
56 |
+
|
57 |
+
def predict_topic(sentences):
|
58 |
+
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
|
59 |
+
with torch.no_grad():
|
60 |
+
outputs = model(**inputs)
|
61 |
+
predictions = torch.argmax(outputs.logits, dim=1).cpu().tolist()
|
62 |
+
# Convert predictions to human-readable labels
|
63 |
+
predictions = [id2label[pred] for pred in predictions]
|
64 |
+
return predictions
|
65 |
+
|
66 |
+
|
67 |
+
def find_topic(start_year=2017, end_year=2021):
|
68 |
+
for year in range(start_year, end_year + 1):
|
69 |
+
print(f"Processing {year}...")
|
70 |
+
input_path = DATA_DIR / f"all_reviews_{year}.csv"
|
71 |
+
output_path = OUTPUT_DIR / f"topic_scored_reviews_{year}.csv"
|
72 |
+
|
73 |
+
df = pd.read_csv(input_path)
|
74 |
+
|
75 |
+
all_rows = []
|
76 |
+
for _, row in tqdm(df.iterrows(), total=len(df)):
|
77 |
+
review_id = row["id"]
|
78 |
+
text = row["text"]
|
79 |
+
sentences = glimpse_tokenizer(text)
|
80 |
+
if not sentences:
|
81 |
+
continue
|
82 |
+
labels = predict_topic(sentences)
|
83 |
+
for sentence, topic in zip(sentences, labels):
|
84 |
+
all_rows.append({"id": review_id, "sentence": sentence, "topic": topic})
|
85 |
+
|
86 |
+
output_df = pd.DataFrame(all_rows)
|
87 |
+
output_df.to_csv(output_path, index=False)
|
88 |
+
print(f"Saved topic-scored data to {output_path}")
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
find_topic()
|
glimpse-ui/alternative_topic/scideberta/scideberta_topic_train.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from datasets import Dataset
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
|
4 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
# Load data
|
8 |
+
dev_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_topic_dev.csv")
|
9 |
+
train_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_topic_train.csv")
|
10 |
+
test_df = pd.read_csv("./data/DISAPERE-main/SELFExtractedData/disapere_topic_test.csv")
|
11 |
+
|
12 |
+
# Convert to HuggingFace Datasets
|
13 |
+
train_ds = Dataset.from_pandas(train_df)
|
14 |
+
dev_ds = Dataset.from_pandas(dev_df)
|
15 |
+
test_ds = Dataset.from_pandas(test_df)
|
16 |
+
|
17 |
+
# Tokenize
|
18 |
+
model_name = "KISTI-AI/Scideberta-full"
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
20 |
+
|
21 |
+
def tokenize(batch):
|
22 |
+
return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=256)
|
23 |
+
|
24 |
+
train_ds = train_ds.map(tokenize, batched=True)
|
25 |
+
dev_ds = dev_ds.map(tokenize, batched=True)
|
26 |
+
test_ds = test_ds.map(tokenize, batched=True)
|
27 |
+
|
28 |
+
# Set format for PyTorch
|
29 |
+
train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
30 |
+
dev_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
31 |
+
test_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
32 |
+
|
33 |
+
print(train_df['label'].value_counts().sort_index())
|
34 |
+
|
35 |
+
|
36 |
+
# Load model
|
37 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
|
38 |
+
|
39 |
+
# Metrics
|
40 |
+
def compute_metrics(eval_pred):
|
41 |
+
logits, labels = eval_pred
|
42 |
+
preds = np.argmax(logits, axis=1)
|
43 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
|
44 |
+
acc = accuracy_score(labels, preds)
|
45 |
+
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
|
46 |
+
|
47 |
+
# Training arguments
|
48 |
+
args = TrainingArguments(
|
49 |
+
output_dir="./alternative_topic/scideberta/checkpoints",
|
50 |
+
eval_strategy="epoch",
|
51 |
+
save_strategy="epoch",
|
52 |
+
learning_rate=2e-5,
|
53 |
+
per_device_train_batch_size=8,
|
54 |
+
per_device_eval_batch_size=16,
|
55 |
+
num_train_epochs=4,
|
56 |
+
weight_decay=0.01,
|
57 |
+
load_best_model_at_end=True,
|
58 |
+
metric_for_best_model="f1"
|
59 |
+
)
|
60 |
+
|
61 |
+
# Trainer
|
62 |
+
trainer = Trainer(
|
63 |
+
model=model,
|
64 |
+
args=args,
|
65 |
+
train_dataset=train_ds,
|
66 |
+
eval_dataset=dev_ds,
|
67 |
+
tokenizer=tokenizer,
|
68 |
+
compute_metrics=compute_metrics
|
69 |
+
)
|
70 |
+
|
71 |
+
# Train
|
72 |
+
trainer.train()
|
73 |
+
|
74 |
+
# Evaluate on test
|
75 |
+
results = trainer.evaluate(test_ds)
|
76 |
+
print("Test results:", results)
|
77 |
+
|
78 |
+
# Save the model and tokenizer
|
79 |
+
model.save_pretrained("./alternative_topic/scideberta/final_model")
|
80 |
+
tokenizer.save_pretrained("./alternative_topic/scideberta/final_model")
|
glimpse-ui/data/ExtractDISAPEREData.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import pandas as pd
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
BASE_DIR = Path(__file__).resolve().parent.parent
|
7 |
+
base_path = BASE_DIR / "data" / "DISAPERE-main" / "DISAPERE" / "final_dataset"
|
8 |
+
output_path = BASE_DIR / "data" / "DISAPERE-main" / "SELFExtractedData"
|
9 |
+
|
10 |
+
###################################################################################
|
11 |
+
###################################################################################
|
12 |
+
|
13 |
+
# EXTRACTING POLARITY SENTENCES FROM DISAPERE DATASET
|
14 |
+
|
15 |
+
# def extract_polarity_sentences(json_dir):
|
16 |
+
# data = []
|
17 |
+
# for filename in os.listdir(json_dir):
|
18 |
+
# if filename.endswith(".json"):
|
19 |
+
# with open(os.path.join(json_dir, filename), "r") as f:
|
20 |
+
# thread = json.load(f)
|
21 |
+
# for sentence in thread.get("review_sentences", []):
|
22 |
+
# text = sentence.get("text", "").strip()
|
23 |
+
# polarity = sentence.get("polarity")
|
24 |
+
# if text:
|
25 |
+
# if polarity == "pol_positive":
|
26 |
+
# label = 2
|
27 |
+
# elif polarity == "pol_negative":
|
28 |
+
# label = 0
|
29 |
+
# else:
|
30 |
+
# label = 1
|
31 |
+
# data.append({"text": text, "label": label})
|
32 |
+
# return pd.DataFrame(data)
|
33 |
+
|
34 |
+
# # Extract and save each split
|
35 |
+
# for split in ["train", "dev", "test"]:
|
36 |
+
# df = extract_polarity_sentences(os.path.join(base_path, split))
|
37 |
+
# out_file = os.path.join(output_path, f"disapere_polarity_{split}.csv")
|
38 |
+
# df.to_csv(out_file, index=False)
|
39 |
+
# print(f"{split.capitalize()} saved to {out_file}: {len(df)} samples")
|
40 |
+
|
41 |
+
|
42 |
+
###################################################################################
|
43 |
+
###################################################################################
|
44 |
+
|
45 |
+
# 2. EXTRACTING TOPIC SENTENCES FROM DISAPERE DATASET
|
46 |
+
#
|
47 |
+
# === Topic Label Mapping ===
|
48 |
+
# 1: "Structuring"
|
49 |
+
# 0: "Evaluative"
|
50 |
+
# 2: "Request"
|
51 |
+
# 3: "Fact"
|
52 |
+
# 4: "Social"
|
53 |
+
# 5: "Other"
|
54 |
+
# 6: "Substance"
|
55 |
+
# 7: "Clarity"
|
56 |
+
# 8: "Soundness/Correctness"
|
57 |
+
# 9: "Originality"
|
58 |
+
# 10: "Motivation/Impact"
|
59 |
+
# 11: "Meaningful Comparison"
|
60 |
+
# 12: "Replicability"
|
61 |
+
|
62 |
+
# Final topic classes
|
63 |
+
topic_classes = [
|
64 |
+
"asp_substance",
|
65 |
+
"asp_clarity",
|
66 |
+
"asp_soundness-correctness",
|
67 |
+
"asp_originality",
|
68 |
+
"asp_impact",
|
69 |
+
"asp_comparison",
|
70 |
+
"asp_replicability",
|
71 |
+
"None", # This is used for sentences that do not match any specific topic
|
72 |
+
# "arg-structuring_summary"
|
73 |
+
]
|
74 |
+
|
75 |
+
label_map = {label: idx for idx, label in enumerate(topic_classes)}
|
76 |
+
|
77 |
+
def extract_topic_sentences(json_dir):
|
78 |
+
data = []
|
79 |
+
for filename in os.listdir(json_dir):
|
80 |
+
if filename.endswith(".json"):
|
81 |
+
with open(os.path.join(json_dir, filename), "r") as f:
|
82 |
+
thread = json.load(f)
|
83 |
+
for sentence in thread.get("review_sentences", []):
|
84 |
+
text = sentence.get("text", "").strip()
|
85 |
+
aspect = sentence.get("aspect", "")
|
86 |
+
# fine_action = sentence.get("fine_review_action", "")
|
87 |
+
|
88 |
+
# Decide label source
|
89 |
+
topic = aspect if aspect in label_map else "None"
|
90 |
+
|
91 |
+
if text and topic in label_map:
|
92 |
+
label = label_map[topic]
|
93 |
+
data.append({"text": text, "label": label})
|
94 |
+
return pd.DataFrame(data)
|
95 |
+
|
96 |
+
# Extract and save each split
|
97 |
+
for split in ["train", "dev", "test"]:
|
98 |
+
df = extract_topic_sentences(os.path.join(base_path, split))
|
99 |
+
out_file = os.path.join(output_path, f"disapere_topic_{split}.csv")
|
100 |
+
df.to_csv(out_file, index=False)
|
101 |
+
print(f"{split.capitalize()} saved to {out_file}: {len(df)} samples")
|
102 |
+
|
103 |
+
###################################################################################
|
104 |
+
###################################################################################
|
105 |
+
|
106 |
+
|
glimpse-ui/glimpse/.gitignore
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by .ignore support plugin (hsz.mobi)
|
2 |
+
### Python template
|
3 |
+
|
4 |
+
# GLIMPSE
|
5 |
+
# Ignore all the data except orignial files
|
6 |
+
data/*
|
7 |
+
!data/
|
8 |
+
summaries/
|
9 |
+
output/
|
10 |
+
slurm*
|
11 |
+
!scripts/
|
12 |
+
.gradio/
|
13 |
+
.test/
|
14 |
+
|
15 |
+
# Byte-compiled / optimized / DLL files
|
16 |
+
__pycache__/
|
17 |
+
*.py[cod]
|
18 |
+
*$py.class
|
19 |
+
|
20 |
+
# C extensions
|
21 |
+
*.so
|
22 |
+
|
23 |
+
# Distribution / packaging
|
24 |
+
.Python
|
25 |
+
env/
|
26 |
+
build/
|
27 |
+
develop-eggs/
|
28 |
+
dist/
|
29 |
+
downloads/
|
30 |
+
eggs/
|
31 |
+
.eggs/
|
32 |
+
lib/
|
33 |
+
lib64/
|
34 |
+
parts/
|
35 |
+
sdist/
|
36 |
+
var/
|
37 |
+
*.egg-info/
|
38 |
+
.installed.cfg
|
39 |
+
*.egg
|
40 |
+
|
41 |
+
# PyInstaller
|
42 |
+
# Usually these files are written by a python script from a template
|
43 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
44 |
+
*.manifest
|
45 |
+
*.spec
|
46 |
+
|
47 |
+
# Installer logs
|
48 |
+
pip-log.txt
|
49 |
+
pip-delete-this-directory.txt
|
50 |
+
|
51 |
+
# Unit test / coverage reports
|
52 |
+
htmlcov/
|
53 |
+
.tox/
|
54 |
+
.coverage
|
55 |
+
.coverage.*
|
56 |
+
.cache
|
57 |
+
nosetests.xml
|
58 |
+
coverage.xml
|
59 |
+
*,cover
|
60 |
+
.hypothesis/
|
61 |
+
|
62 |
+
# Translations
|
63 |
+
*.mo
|
64 |
+
*.pot
|
65 |
+
|
66 |
+
# Django stuff:
|
67 |
+
*.log
|
68 |
+
local_settings.py
|
69 |
+
|
70 |
+
# Flask stuff:
|
71 |
+
instance/
|
72 |
+
.webassets-cache
|
73 |
+
|
74 |
+
# Scrapy stuff:
|
75 |
+
.scrapy
|
76 |
+
|
77 |
+
# Sphinx documentation
|
78 |
+
docs/_build/
|
79 |
+
|
80 |
+
# PyBuilder
|
81 |
+
target/
|
82 |
+
|
83 |
+
# IPython Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# pyenv
|
87 |
+
.python-version
|
88 |
+
|
89 |
+
# celery beat schedule file
|
90 |
+
celerybeat-schedule
|
91 |
+
|
92 |
+
# dotenv
|
93 |
+
.env
|
94 |
+
|
95 |
+
# virtualenv
|
96 |
+
venv/
|
97 |
+
ENV/
|
98 |
+
|
99 |
+
# Spyder project settings
|
100 |
+
.spyderproject
|
101 |
+
|
102 |
+
# Rope project settings
|
103 |
+
.ropeproject
|
104 |
+
### VirtualEnv template
|
105 |
+
# Virtualenv
|
106 |
+
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
|
107 |
+
[Bb]in
|
108 |
+
[Ii]nclude
|
109 |
+
[Ll]ib
|
110 |
+
[Ll]ib64
|
111 |
+
[Ll]ocal
|
112 |
+
[Ss]cripts
|
113 |
+
pyvenv.cfg
|
114 |
+
.venv
|
115 |
+
pip-selfcheck.json
|
116 |
+
|
117 |
+
### JetBrains template
|
118 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
119 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
120 |
+
|
121 |
+
# User-specific stuff
|
122 |
+
.idea/**/workspace.xml
|
123 |
+
.idea/**/tasks.xml
|
124 |
+
.idea/**/usage.statistics.xml
|
125 |
+
.idea/**/dictionaries
|
126 |
+
.idea/**/shelf
|
127 |
+
|
128 |
+
# AWS User-specific
|
129 |
+
.idea/**/aws.xml
|
130 |
+
|
131 |
+
# Generated files
|
132 |
+
.idea/**/contentModel.xml
|
133 |
+
|
134 |
+
# Sensitive or high-churn files
|
135 |
+
.idea/**/dataSources/
|
136 |
+
.idea/**/dataSources.ids
|
137 |
+
.idea/**/dataSources.local.xml
|
138 |
+
.idea/**/sqlDataSources.xml
|
139 |
+
.idea/**/dynamic.xml
|
140 |
+
.idea/**/uiDesigner.xml
|
141 |
+
.idea/**/dbnavigator.xml
|
142 |
+
|
143 |
+
# Gradle
|
144 |
+
.idea/**/gradle.xml
|
145 |
+
.idea/**/libraries
|
146 |
+
|
147 |
+
# Gradle and Maven with auto-import
|
148 |
+
# When using Gradle or Maven with auto-import, you should exclude module files,
|
149 |
+
# since they will be recreated, and may cause churn. Uncomment if using
|
150 |
+
# auto-import.
|
151 |
+
# .idea/artifacts
|
152 |
+
# .idea/compiler.xml
|
153 |
+
# .idea/jarRepositories.xml
|
154 |
+
# .idea/modules.xml
|
155 |
+
# .idea/*.iml
|
156 |
+
# .idea/modules
|
157 |
+
# *.iml
|
158 |
+
# *.ipr
|
159 |
+
|
160 |
+
# CMake
|
161 |
+
cmake-build-*/
|
162 |
+
|
163 |
+
# Mongo Explorer plugin
|
164 |
+
.idea/**/mongoSettings.xml
|
165 |
+
|
166 |
+
# File-based project format
|
167 |
+
*.iws
|
168 |
+
|
169 |
+
# IntelliJ
|
170 |
+
out/
|
171 |
+
|
172 |
+
# mpeltonen/sbt-idea plugin
|
173 |
+
.idea_modules/
|
174 |
+
|
175 |
+
# JIRA plugin
|
176 |
+
atlassian-ide-plugin.xml
|
177 |
+
|
178 |
+
# Cursive Clojure plugin
|
179 |
+
.idea/replstate.xml
|
180 |
+
|
181 |
+
# SonarLint plugin
|
182 |
+
.idea/sonarlint/
|
183 |
+
|
184 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
185 |
+
com_crashlytics_export_strings.xml
|
186 |
+
crashlytics.properties
|
187 |
+
crashlytics-build.properties
|
188 |
+
fabric.properties
|
189 |
+
|
190 |
+
# Editor-based Rest Client
|
191 |
+
.idea/httpRequests
|
192 |
+
|
193 |
+
# Android studio 3.1+ serialized cache file
|
194 |
+
.idea/caches/build_file_checksums.ser
|
195 |
+
|
196 |
+
# idea folder, uncomment if you don't need it
|
197 |
+
# .idea
|
198 |
+
share/man/man1/isympy.1
|
199 |
+
share/man/man1/ttx.1
|
200 |
+
|
201 |
+
# IDEs
|
202 |
+
.idea/
|
203 |
+
.vscode/
|
glimpse-ui/glimpse/Readme.md
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
This is the repositotry of GLIMPSE: Pragmatically Informative Multi-Document Summarization for Scholarly Reviews
|
3 |
+
[Paper](https://arxiv.org/abs/2406.07359) | [Code](https://github.com/icannos/glimpse-mds)
|
4 |
+
|
5 |
+
|
6 |
+
### Installation
|
7 |
+
|
8 |
+
- We use python 3.10 and CUDA 12.1
|
9 |
+
``` bash
|
10 |
+
module load miniconda/3
|
11 |
+
module load cuda12
|
12 |
+
```
|
13 |
+
- First, create a virtual environment using:
|
14 |
+
``` bash
|
15 |
+
conda create -n glimpse python=3.10
|
16 |
+
```
|
17 |
+
- Second, activate the environment and install pytorch:
|
18 |
+
``` bash
|
19 |
+
conda activate glimpse
|
20 |
+
conda install pytorch==2.1.1 pytorch-cuda=12.1 -c pytorch -c nvidia
|
21 |
+
```
|
22 |
+
|
23 |
+
- Finally, all remaining required packages could be installed with the requirements file:
|
24 |
+
|
25 |
+
``` bash
|
26 |
+
pip install -r requirements
|
27 |
+
```
|
28 |
+
### Data Loading
|
29 |
+
|
30 |
+
Step 1: Start by processing the input files from data.
|
31 |
+
|
32 |
+
``` bash
|
33 |
+
python glimpse/data_loading/data_processing.py
|
34 |
+
```
|
35 |
+
|
36 |
+
### Generating Summaries and Computing RSA Scores
|
37 |
+
Step 2: Now, we generate candidate summaries and compute RSA scores for each candidate
|
38 |
+
- for extractive candidates, use the following command:
|
39 |
+
``` bash
|
40 |
+
sbatch scripts/extractive.sh Path_of_Your_Processed_Dataset_Step1.csv
|
41 |
+
```
|
42 |
+
- for abstractive candidates, use either of the following commands:
|
43 |
+
- In case the last batch is incomplete, you can add padding using `--add-padding` argument to complete it:
|
44 |
+
``` bash
|
45 |
+
sbatch scripts/abstractive.sh Path_of_Your_Processed_Dataset_Step1.csv --add-padding
|
46 |
+
```
|
47 |
+
- If you want to remove the last incomplete batch, you can run the script without the argument:
|
48 |
+
``` bash
|
49 |
+
sbatch scripts/abstractive.sh Path_of_Your_Processed_Dataset_Step1.csv
|
50 |
+
```
|
51 |
+
|
52 |
+
`rsasumm/` provides a python package with an implementation of RSA incremental decoding and RSA reranking of candidates.
|
53 |
+
`mds/` provides the experiment scripts and analysis for the MultiDocument Summarization task.
|
54 |
+
|
55 |
+
|
56 |
+
## Citation
|
57 |
+
|
58 |
+
If you use this code, please cite the following papers:
|
59 |
+
|
60 |
+
```@misc{darrin2024glimpsepragmaticallyinformativemultidocument,
|
61 |
+
title={GLIMPSE: Pragmatically Informative Multi-Document Summarization for Scholarly Reviews},
|
62 |
+
author={Maxime Darrin and Ines Arous and Pablo Piantanida and Jackie CK Cheung},
|
63 |
+
year={2024},
|
64 |
+
eprint={2406.07359},
|
65 |
+
archivePrefix={arXiv},
|
66 |
+
primaryClass={cs.CL},
|
67 |
+
url={https://arxiv.org/abs/2406.07359},
|
68 |
+
}
|
69 |
+
```
|
glimpse-ui/glimpse/examples/RSA Sum tests.ipynb
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "initial_id",
|
7 |
+
"metadata": {
|
8 |
+
"collapsed": true,
|
9 |
+
"ExecuteTime": {
|
10 |
+
"end_time": "2024-01-12T16:31:17.690349522Z",
|
11 |
+
"start_time": "2024-01-12T16:31:15.472874479Z"
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"import torch"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"cell_type": "code",
|
21 |
+
"execution_count": 2,
|
22 |
+
"outputs": [],
|
23 |
+
"source": [
|
24 |
+
"%reload_ext autoreload\n",
|
25 |
+
"%autoreload 2"
|
26 |
+
],
|
27 |
+
"metadata": {
|
28 |
+
"collapsed": false,
|
29 |
+
"ExecuteTime": {
|
30 |
+
"end_time": "2024-01-12T16:31:17.717430741Z",
|
31 |
+
"start_time": "2024-01-12T16:31:17.695066680Z"
|
32 |
+
}
|
33 |
+
},
|
34 |
+
"id": "ecefdad828c7daa3"
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "code",
|
38 |
+
"execution_count": 3,
|
39 |
+
"outputs": [
|
40 |
+
{
|
41 |
+
"name": "stderr",
|
42 |
+
"output_type": "stream",
|
43 |
+
"text": [
|
44 |
+
"Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-cnn and are newly initialized: ['model.shared.weight']\n",
|
45 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
46 |
+
]
|
47 |
+
}
|
48 |
+
],
|
49 |
+
"source": [
|
50 |
+
"\n",
|
51 |
+
"from transformers import AutoTokenizer, BartForConditionalGeneration\n",
|
52 |
+
"\n",
|
53 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"facebook/bart-large-cnn\")\n",
|
54 |
+
"model = BartForConditionalGeneration.from_pretrained(\"facebook/bart-large-cnn\")\n"
|
55 |
+
],
|
56 |
+
"metadata": {
|
57 |
+
"collapsed": false,
|
58 |
+
"ExecuteTime": {
|
59 |
+
"end_time": "2024-01-12T16:31:26.058437142Z",
|
60 |
+
"start_time": "2024-01-12T16:31:17.720106168Z"
|
61 |
+
}
|
62 |
+
},
|
63 |
+
"id": "8c32b182fbcac2b6"
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": 4,
|
68 |
+
"outputs": [],
|
69 |
+
"source": [
|
70 |
+
"from rsasumm.beam_search import RSAContextualDecoding"
|
71 |
+
],
|
72 |
+
"metadata": {
|
73 |
+
"collapsed": false,
|
74 |
+
"ExecuteTime": {
|
75 |
+
"end_time": "2024-01-12T16:31:26.097766981Z",
|
76 |
+
"start_time": "2024-01-12T16:31:26.056626187Z"
|
77 |
+
}
|
78 |
+
},
|
79 |
+
"id": "cb33d902fe736c25"
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "code",
|
83 |
+
"execution_count": 5,
|
84 |
+
"outputs": [],
|
85 |
+
"source": [
|
86 |
+
"\n",
|
87 |
+
"\n",
|
88 |
+
"texts = ['The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. I believe the authors missed Jane and al 2021. In addition, I think, there is a mistake in the math.',\n",
|
89 |
+
" 'The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. However, some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.',\n",
|
90 |
+
" 'The paper gives really interesting insights on the topic of transfer learning. It is not well presented and lack experiments. In addition, some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.'\n",
|
91 |
+
" ]\n",
|
92 |
+
"\n",
|
93 |
+
"# texts = [texts[2], texts[1], texts[0]]\n",
|
94 |
+
"\n"
|
95 |
+
],
|
96 |
+
"metadata": {
|
97 |
+
"collapsed": false,
|
98 |
+
"ExecuteTime": {
|
99 |
+
"end_time": "2024-01-12T16:31:26.127922110Z",
|
100 |
+
"start_time": "2024-01-12T16:31:26.098805312Z"
|
101 |
+
}
|
102 |
+
},
|
103 |
+
"id": "436ef1482c361159"
|
104 |
+
},
|
105 |
+
{
|
106 |
+
"cell_type": "code",
|
107 |
+
"execution_count": 6,
|
108 |
+
"outputs": [],
|
109 |
+
"source": [
|
110 |
+
"source_texts = tokenizer(texts, return_tensors=\"pt\", padding=True)\n",
|
111 |
+
"\n",
|
112 |
+
"rsa = RSAContextualDecoding(model, tokenizer, 'cpu')\n",
|
113 |
+
"\n"
|
114 |
+
],
|
115 |
+
"metadata": {
|
116 |
+
"collapsed": false,
|
117 |
+
"ExecuteTime": {
|
118 |
+
"end_time": "2024-01-12T16:31:26.169520864Z",
|
119 |
+
"start_time": "2024-01-12T16:31:26.125283164Z"
|
120 |
+
}
|
121 |
+
},
|
122 |
+
"id": "84b9943cac6cd7b2"
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "code",
|
126 |
+
"execution_count": 7,
|
127 |
+
"outputs": [],
|
128 |
+
"source": [
|
129 |
+
"output = rsa.generate(target_id=1, source_texts_ids=source_texts.input_ids, source_text_attention_mask=source_texts.attention_mask, max_length=50, top_p=0.95, do_sample=True, rationality=8.0, temperature=1.0, process_logits_before_rsa=True)"
|
130 |
+
],
|
131 |
+
"metadata": {
|
132 |
+
"collapsed": false,
|
133 |
+
"ExecuteTime": {
|
134 |
+
"end_time": "2024-01-12T16:32:14.857034731Z",
|
135 |
+
"start_time": "2024-01-12T16:31:26.164578792Z"
|
136 |
+
}
|
137 |
+
},
|
138 |
+
"id": "620e54a63dd2099c"
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"cell_type": "code",
|
142 |
+
"execution_count": 8,
|
143 |
+
"outputs": [
|
144 |
+
{
|
145 |
+
"data": {
|
146 |
+
"text/plain": "['Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.',\n 'Some parts of the paper remain unclear. I would like to see a more detailed explanation of the proposed method.']"
|
147 |
+
},
|
148 |
+
"execution_count": 8,
|
149 |
+
"metadata": {},
|
150 |
+
"output_type": "execute_result"
|
151 |
+
}
|
152 |
+
],
|
153 |
+
"source": [
|
154 |
+
"\n",
|
155 |
+
"tokenizer.batch_decode(output[0], skip_special_tokens=True)\n",
|
156 |
+
"\n"
|
157 |
+
],
|
158 |
+
"metadata": {
|
159 |
+
"collapsed": false,
|
160 |
+
"ExecuteTime": {
|
161 |
+
"end_time": "2024-01-12T16:32:14.858531480Z",
|
162 |
+
"start_time": "2024-01-12T16:32:14.856763396Z"
|
163 |
+
}
|
164 |
+
},
|
165 |
+
"id": "fb3a5a9a8f9990ee"
|
166 |
+
}
|
167 |
+
],
|
168 |
+
"metadata": {
|
169 |
+
"kernelspec": {
|
170 |
+
"display_name": "Python 3",
|
171 |
+
"language": "python",
|
172 |
+
"name": "python3"
|
173 |
+
},
|
174 |
+
"language_info": {
|
175 |
+
"codemirror_mode": {
|
176 |
+
"name": "ipython",
|
177 |
+
"version": 2
|
178 |
+
},
|
179 |
+
"file_extension": ".py",
|
180 |
+
"mimetype": "text/x-python",
|
181 |
+
"name": "python",
|
182 |
+
"nbconvert_exporter": "python",
|
183 |
+
"pygments_lexer": "ipython2",
|
184 |
+
"version": "2.7.6"
|
185 |
+
}
|
186 |
+
},
|
187 |
+
"nbformat": 4,
|
188 |
+
"nbformat_minor": 5
|
189 |
+
}
|
glimpse-ui/glimpse/examples/reviews/reviews_app.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import nltk
|
5 |
+
import numpy as np
|
6 |
+
import seaborn as sns
|
7 |
+
|
8 |
+
import sys, os.path
|
9 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
10 |
+
|
11 |
+
from rsasumm.rsa_reranker import RSAReranking
|
12 |
+
import gradio as gr
|
13 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
14 |
+
import seaborn as sns
|
15 |
+
import pandas as pd
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
|
18 |
+
MODEL = "facebook/bart-large-cnn"
|
19 |
+
|
20 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL)
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
22 |
+
|
23 |
+
|
24 |
+
latex_template = r"""
|
25 |
+
\begin{subfigure}[b]{0.48\textwidth}
|
26 |
+
\resizebox{\textwidth}{!}{
|
27 |
+
\begin{coloredbox}{darkgray}{Review 1}
|
28 |
+
[REVIEW 1]
|
29 |
+
|
30 |
+
\end{coloredbox}}
|
31 |
+
\end{subfigure}
|
32 |
+
\begin{subfigure}[b]{0.48\textwidth}
|
33 |
+
\resizebox{\textwidth}{!}{
|
34 |
+
\begin{coloredbox}{darkgray}{Review 2}
|
35 |
+
[REVIEW 2]
|
36 |
+
\end{coloredbox}}
|
37 |
+
\end{subfigure}
|
38 |
+
\begin{subfigure}[b]{0.48\textwidth}
|
39 |
+
\resizebox{\textwidth}{!}{
|
40 |
+
\begin{coloredbox}{darkgray}{Review 3}
|
41 |
+
[REVIEW 3]
|
42 |
+
\end{coloredbox}}
|
43 |
+
\end{subfigure}
|
44 |
+
"""
|
45 |
+
|
46 |
+
EXAMPLES = [
|
47 |
+
"The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. I believe the authors missed Jane and al 2021. In addition, I think, there is a mistake in the math.",
|
48 |
+
"The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. Some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.",
|
49 |
+
"The paper gives really interesting insights on the topic of transfer learning. It is not well presented and lack experiments. Some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.",
|
50 |
+
]
|
51 |
+
|
52 |
+
|
53 |
+
def make_colored_text_to_latex(scored_texts : List[Tuple[str, float]]):
|
54 |
+
"""
|
55 |
+
Make a latex string from a list of scored texts.
|
56 |
+
"""
|
57 |
+
|
58 |
+
# cast scores between 0 and 1
|
59 |
+
scores = np.array([score for _, score in scored_texts])
|
60 |
+
scores = (scores - scores.min()) / (scores.max() - scores.min())
|
61 |
+
|
62 |
+
# make color map in hex
|
63 |
+
cmap = sns.diverging_palette(250, 30, l=50, center="dark", as_cmap=True)
|
64 |
+
hex_colors = [cmap(score)[0:3] for score in scores]
|
65 |
+
# make html color string
|
66 |
+
hex_colors = [",".join([str(round(x, 2)) for x in color]) for color in hex_colors]
|
67 |
+
# make latex string
|
68 |
+
latex_string = ""
|
69 |
+
for (text, score), hex_color in zip(scored_texts, hex_colors):
|
70 |
+
latex_string += "\\textcolor[rgb]{" + str(hex_color) + "}{" + text + "} "
|
71 |
+
|
72 |
+
return latex_string
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
def summarize(text1, text2, text3, iterations, rationality=1.0):
|
78 |
+
# get sentences for each text
|
79 |
+
|
80 |
+
text1_sentences = nltk.sent_tokenize(text1)
|
81 |
+
text2_sentences = nltk.sent_tokenize(text2)
|
82 |
+
text3_sentences = nltk.sent_tokenize(text3)
|
83 |
+
|
84 |
+
|
85 |
+
# remove empty sentences
|
86 |
+
text1_sentences = [sentence for sentence in text1_sentences if sentence != ""]
|
87 |
+
text2_sentences = [sentence for sentence in text2_sentences if sentence != ""]
|
88 |
+
text3_sentences = [sentence for sentence in text3_sentences if sentence != ""]
|
89 |
+
|
90 |
+
sentences = list(set(text1_sentences + text2_sentences + text3_sentences))
|
91 |
+
|
92 |
+
rsa_reranker = RSAReranking(
|
93 |
+
model,
|
94 |
+
tokenizer,
|
95 |
+
candidates=sentences,
|
96 |
+
source_texts=[text1, text2, text3],
|
97 |
+
device="cpu",
|
98 |
+
rationality=rationality,
|
99 |
+
)
|
100 |
+
(
|
101 |
+
best_rsa,
|
102 |
+
best_base,
|
103 |
+
speaker_df,
|
104 |
+
listener_df,
|
105 |
+
initial_listener,
|
106 |
+
language_model_proba_df,
|
107 |
+
initial_consensuality_scores,
|
108 |
+
consensuality_scores,
|
109 |
+
) = rsa_reranker.rerank(t=iterations)
|
110 |
+
|
111 |
+
# apply exp to the probabilities
|
112 |
+
speaker_df = speaker_df.applymap(lambda x: math.exp(x))
|
113 |
+
|
114 |
+
text_1_summaries = speaker_df.loc[text1][text1_sentences]
|
115 |
+
text_1_summaries = text_1_summaries / text_1_summaries.sum()
|
116 |
+
|
117 |
+
text_2_summaries = speaker_df.loc[text2][text2_sentences]
|
118 |
+
text_2_summaries = text_2_summaries / text_2_summaries.sum()
|
119 |
+
|
120 |
+
text_3_summaries = speaker_df.loc[text3][text3_sentences]
|
121 |
+
text_3_summaries = text_3_summaries / text_3_summaries.sum()
|
122 |
+
|
123 |
+
# make list of tuples
|
124 |
+
text_1_summaries = [(sentence, text_1_summaries[sentence]) for sentence in text1_sentences]
|
125 |
+
text_2_summaries = [(sentence, text_2_summaries[sentence]) for sentence in text2_sentences]
|
126 |
+
text_3_summaries = [(sentence, text_3_summaries[sentence]) for sentence in text3_sentences]
|
127 |
+
|
128 |
+
# normalize consensuality scores between -1 and 1
|
129 |
+
|
130 |
+
consensuality_scores = (consensuality_scores - (consensuality_scores.max() - consensuality_scores.min()) / 2) / (consensuality_scores.max() - consensuality_scores.min()) / 2
|
131 |
+
consensuality_scores_01 = (consensuality_scores - consensuality_scores.min()) / (consensuality_scores.max() - consensuality_scores.min())
|
132 |
+
|
133 |
+
|
134 |
+
most_consensual = consensuality_scores.sort_values(ascending=True).head(3).index.tolist()
|
135 |
+
least_consensual = consensuality_scores.sort_values(ascending=False).head(3).index.tolist()
|
136 |
+
|
137 |
+
most_consensual = [(sentence, consensuality_scores[sentence]) for sentence in most_consensual]
|
138 |
+
least_consensual = [(sentence, consensuality_scores[sentence]) for sentence in least_consensual]
|
139 |
+
|
140 |
+
text_1_consensuality = consensuality_scores.loc[text1_sentences]
|
141 |
+
text_2_consensuality = consensuality_scores.loc[text2_sentences]
|
142 |
+
text_3_consensuality = consensuality_scores.loc[text3_sentences]
|
143 |
+
|
144 |
+
# rescale between -1 and 1
|
145 |
+
# text_1_consensuality = (text_1_consensuality - (text_1_consensuality.max() - text_1_consensuality.min()) / 2) / (text_1_consensuality.max() - text_1_consensuality.min()) / 2
|
146 |
+
# text_2_consensuality = (text_2_consensuality - (text_2_consensuality.max() - text_2_consensuality.min()) / 2) / (text_2_consensuality.max() - text_2_consensuality.min()) / 2
|
147 |
+
# text_3_consensuality = (text_3_consensuality - (text_3_consensuality.max() - text_3_consensuality.min()) / 2) / (text_3_consensuality.max() - text_3_consensuality.min()) / 2
|
148 |
+
|
149 |
+
text_1_consensuality = [(sentence, text_1_consensuality[sentence]) for sentence in text1_sentences]
|
150 |
+
text_2_consensuality = [(sentence, text_2_consensuality[sentence]) for sentence in text2_sentences]
|
151 |
+
text_3_consensuality = [(sentence, text_3_consensuality[sentence]) for sentence in text3_sentences]
|
152 |
+
|
153 |
+
fig1 = plt.figure(figsize=(20, 10))
|
154 |
+
ax = fig1.add_subplot(111)
|
155 |
+
sns.heatmap(
|
156 |
+
listener_df,
|
157 |
+
ax=ax,
|
158 |
+
cmap="Blues",
|
159 |
+
annot=True,
|
160 |
+
fmt=".2f",
|
161 |
+
cbar=False,
|
162 |
+
annot_kws={"size": 10},
|
163 |
+
)
|
164 |
+
ax.set_title("Listener probabilities")
|
165 |
+
ax.set_xlabel("Candidate sentences")
|
166 |
+
ax.set_ylabel("Source texts")
|
167 |
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
168 |
+
fig1.tight_layout()
|
169 |
+
|
170 |
+
fig2 = plt.figure(figsize=(20, 10))
|
171 |
+
ax = fig2.add_subplot(111)
|
172 |
+
sns.heatmap(
|
173 |
+
speaker_df,
|
174 |
+
ax=ax,
|
175 |
+
cmap="Blues",
|
176 |
+
annot=True,
|
177 |
+
fmt=".2f",
|
178 |
+
cbar=False,
|
179 |
+
annot_kws={"size": 10},
|
180 |
+
)
|
181 |
+
ax.set_title("Speaker probabilities")
|
182 |
+
ax.set_xlabel("Candidate sentences")
|
183 |
+
ax.set_ylabel("Source texts")
|
184 |
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
185 |
+
fig2.tight_layout()
|
186 |
+
|
187 |
+
latex_text_1 = make_colored_text_to_latex(text_1_summaries)
|
188 |
+
latex_text_2 = make_colored_text_to_latex(text_2_summaries)
|
189 |
+
latex_text_3 = make_colored_text_to_latex(text_3_summaries)
|
190 |
+
|
191 |
+
text_1_consensuality_ = consensuality_scores_01.loc[text1_sentences]
|
192 |
+
text_2_consensuality_ = consensuality_scores_01.loc[text2_sentences]
|
193 |
+
text_3_consensuality_ = consensuality_scores_01.loc[text3_sentences]
|
194 |
+
|
195 |
+
text_1_consensuality_ = [(sentence, text_1_consensuality_[sentence]) for sentence in text1_sentences]
|
196 |
+
text_2_consensuality_ = [(sentence, text_2_consensuality_[sentence]) for sentence in text2_sentences]
|
197 |
+
text_3_consensuality_ = [(sentence, text_3_consensuality_[sentence]) for sentence in text3_sentences]
|
198 |
+
|
199 |
+
latex_text_1_consensuality = make_colored_text_to_latex(text_1_consensuality_)
|
200 |
+
latex_text_2_consensuality = make_colored_text_to_latex(text_2_consensuality_)
|
201 |
+
latex_text_3_consensuality = make_colored_text_to_latex(text_3_consensuality_)
|
202 |
+
|
203 |
+
latex = latex_template.replace("[REVIEW 1]", latex_text_1)
|
204 |
+
latex = latex.replace("[REVIEW 2]", latex_text_2)
|
205 |
+
latex = latex.replace("[REVIEW 3]", latex_text_3)
|
206 |
+
|
207 |
+
|
208 |
+
return text_1_summaries, text_2_summaries, text_3_summaries, text_1_consensuality, text_2_consensuality, text_3_consensuality, most_consensual, least_consensual, fig1, fig2, latex
|
209 |
+
|
210 |
+
|
211 |
+
# make gradiot highlightedText component
|
212 |
+
|
213 |
+
|
214 |
+
iface = gr.Interface(
|
215 |
+
fn=summarize,
|
216 |
+
inputs=[
|
217 |
+
gr.Textbox(lines=10, value=EXAMPLES[0]),
|
218 |
+
gr.Textbox(lines=10, value=EXAMPLES[1]),
|
219 |
+
gr.Textbox(lines=10, value=EXAMPLES[2]),
|
220 |
+
gr.Number(value=1, label="Iterations"),
|
221 |
+
gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=1.0, label="Rationality"),
|
222 |
+
],
|
223 |
+
outputs=[
|
224 |
+
gr.Highlightedtext(
|
225 |
+
show_legend=True,
|
226 |
+
label="Uniqueness score for each sentence in text 1",
|
227 |
+
),
|
228 |
+
gr.Highlightedtext(
|
229 |
+
show_legend=True,
|
230 |
+
label="Uniqueness score for each sentence in text 2",
|
231 |
+
),
|
232 |
+
gr.Highlightedtext(
|
233 |
+
show_legend=True,
|
234 |
+
label="Uniqueness score for each sentence in text 3",
|
235 |
+
),
|
236 |
+
gr.Highlightedtext(
|
237 |
+
show_legend=True,
|
238 |
+
label="Consensuality score for each sentence in text 1",
|
239 |
+
|
240 |
+
),
|
241 |
+
gr.Highlightedtext(
|
242 |
+
show_legend=True,
|
243 |
+
label="Consensuality score for each sentence in text 2",
|
244 |
+
),
|
245 |
+
gr.Highlightedtext(
|
246 |
+
show_legend=True,
|
247 |
+
label="Consensuality score for each sentence in text 3",
|
248 |
+
),
|
249 |
+
gr.Highlightedtext(
|
250 |
+
show_legend=True,
|
251 |
+
label="Most consensual sentences",
|
252 |
+
|
253 |
+
),
|
254 |
+
gr.Highlightedtext(
|
255 |
+
show_legend=True,
|
256 |
+
label="Least consensual sentences",
|
257 |
+
),
|
258 |
+
gr.Plot(
|
259 |
+
label="Listener probabilities",
|
260 |
+
),
|
261 |
+
gr.Plot(
|
262 |
+
label="Speaker probabilities",
|
263 |
+
),
|
264 |
+
|
265 |
+
gr.Textbox(lines=10, label="Latex Consensuality scores"),
|
266 |
+
|
267 |
+
|
268 |
+
|
269 |
+
],
|
270 |
+
title="RSA Summarizer",
|
271 |
+
description="Summarize 3 texts using RSA",
|
272 |
+
)
|
273 |
+
|
274 |
+
iface.launch(share=True)
|
glimpse-ui/glimpse/examples/reviews/reviews_latex_generation.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import nltk
|
5 |
+
import numpy as np
|
6 |
+
import seaborn as sns
|
7 |
+
|
8 |
+
from rsasumm.rsa_reranker import RSAReranking
|
9 |
+
import gradio as gr
|
10 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
11 |
+
import seaborn as sns
|
12 |
+
import pandas as pd
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
MODEL = "facebook/bart-large-cnn"
|
16 |
+
|
17 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL)
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
19 |
+
|
20 |
+
|
21 |
+
latex_template = r"""
|
22 |
+
\begin{subfigure}[b]{0.48\textwidth}
|
23 |
+
\resizebox{\textwidth}{!}{
|
24 |
+
\begin{coloredbox}{darkgray}{Review 1}
|
25 |
+
[REVIEW 1]
|
26 |
+
|
27 |
+
\end{coloredbox}}
|
28 |
+
\end{subfigure}
|
29 |
+
\begin{subfigure}[b]{0.48\textwidth}
|
30 |
+
\resizebox{\textwidth}{!}{
|
31 |
+
\begin{coloredbox}{darkgray}{Review 2}
|
32 |
+
[REVIEW 2]
|
33 |
+
\end{coloredbox}}
|
34 |
+
\end{subfigure}
|
35 |
+
\begin{subfigure}[b]{0.48\textwidth}
|
36 |
+
\resizebox{\textwidth}{!}{
|
37 |
+
\begin{coloredbox}{darkgray}{Review 3}
|
38 |
+
[REVIEW 3]
|
39 |
+
\end{coloredbox}}
|
40 |
+
\end{subfigure}
|
41 |
+
"""
|
42 |
+
|
43 |
+
EXAMPLES = [
|
44 |
+
"The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. I believe the authors missed Jane and al 2021. In addition, I think, there is a mistake in the math.",
|
45 |
+
"The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. Some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.",
|
46 |
+
"The paper gives really interesting insights on the topic of transfer learning. It is not well presented and lack experiments. Some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.",
|
47 |
+
]
|
48 |
+
|
49 |
+
|
50 |
+
def make_colored_text_to_latex(scored_texts : List[Tuple[str, float]]):
|
51 |
+
"""
|
52 |
+
Make a latex string from a list of scored texts.
|
53 |
+
"""
|
54 |
+
|
55 |
+
# cast scores between 0 and 1
|
56 |
+
scores = np.array([score for _, score in scored_texts])
|
57 |
+
scores = (scores - scores.min()) / (scores.max() - scores.min())
|
58 |
+
|
59 |
+
# make color map in hex
|
60 |
+
cmap = sns.diverging_palette(250, 30, l=50, center="dark", as_cmap=True)
|
61 |
+
hex_colors = [cmap(score)[0:3] for score in scores]
|
62 |
+
# make html color string
|
63 |
+
hex_colors = [",".join([str(round(x, 2)) for x in color]) for color in hex_colors]
|
64 |
+
# make latex string
|
65 |
+
latex_string = ""
|
66 |
+
for (text, score), hex_color in zip(scored_texts, hex_colors):
|
67 |
+
#latex_string += "\\textcolor[rgb]{" + str(hex_color) + "}{" + text + "} "
|
68 |
+
latex_string += "\\hlc{" + str(hex_color)[1:-1] + "}{" + text + "} "
|
69 |
+
|
70 |
+
return latex_string
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
def summarize(text1, text2, text3, iterations, rationality=1.0):
|
76 |
+
# get sentences for each text
|
77 |
+
|
78 |
+
text1_sentences = nltk.sent_tokenize(text1)
|
79 |
+
text2_sentences = nltk.sent_tokenize(text2)
|
80 |
+
text3_sentences = nltk.sent_tokenize(text3)
|
81 |
+
|
82 |
+
|
83 |
+
# remove empty sentences
|
84 |
+
text1_sentences = [sentence for sentence in text1_sentences if sentence != ""]
|
85 |
+
text2_sentences = [sentence for sentence in text2_sentences if sentence != ""]
|
86 |
+
text3_sentences = [sentence for sentence in text3_sentences if sentence != ""]
|
87 |
+
|
88 |
+
sentences = list(set(text1_sentences + text2_sentences + text3_sentences))
|
89 |
+
|
90 |
+
rsa_reranker = RSAReranking(
|
91 |
+
model,
|
92 |
+
tokenizer,
|
93 |
+
candidates=sentences,
|
94 |
+
source_texts=[text1, text2, text3],
|
95 |
+
device="cpu",
|
96 |
+
rationality=rationality,
|
97 |
+
)
|
98 |
+
(
|
99 |
+
best_rsa,
|
100 |
+
best_base,
|
101 |
+
speaker_df,
|
102 |
+
listener_df,
|
103 |
+
initial_listener,
|
104 |
+
language_model_proba_df,
|
105 |
+
initial_consensuality_scores,
|
106 |
+
consensuality_scores,
|
107 |
+
) = rsa_reranker.rerank(t=iterations)
|
108 |
+
|
109 |
+
# apply exp to the probabilities
|
110 |
+
speaker_df = speaker_df.applymap(lambda x: math.exp(x))
|
111 |
+
|
112 |
+
text_1_summaries = speaker_df.loc[text1][text1_sentences]
|
113 |
+
text_1_summaries = text_1_summaries / text_1_summaries.sum()
|
114 |
+
|
115 |
+
text_2_summaries = speaker_df.loc[text2][text2_sentences]
|
116 |
+
text_2_summaries = text_2_summaries / text_2_summaries.sum()
|
117 |
+
|
118 |
+
text_3_summaries = speaker_df.loc[text3][text3_sentences]
|
119 |
+
text_3_summaries = text_3_summaries / text_3_summaries.sum()
|
120 |
+
|
121 |
+
# make list of tuples
|
122 |
+
text_1_summaries = [(sentence, text_1_summaries[sentence]) for sentence in text1_sentences]
|
123 |
+
text_2_summaries = [(sentence, text_2_summaries[sentence]) for sentence in text2_sentences]
|
124 |
+
text_3_summaries = [(sentence, text_3_summaries[sentence]) for sentence in text3_sentences]
|
125 |
+
|
126 |
+
# normalize consensuality scores between -1 and 1
|
127 |
+
|
128 |
+
consensuality_scores = (consensuality_scores - (consensuality_scores.max() - consensuality_scores.min()) / 2) / (consensuality_scores.max() - consensuality_scores.min()) / 2
|
129 |
+
consensuality_scores_01 = (consensuality_scores - consensuality_scores.min()) / (consensuality_scores.max() - consensuality_scores.min())
|
130 |
+
|
131 |
+
|
132 |
+
most_consensual = consensuality_scores.sort_values(ascending=True).head(3).index.tolist()
|
133 |
+
least_consensual = consensuality_scores.sort_values(ascending=False).head(3).index.tolist()
|
134 |
+
|
135 |
+
most_consensual = [(sentence, consensuality_scores[sentence]) for sentence in most_consensual]
|
136 |
+
least_consensual = [(sentence, consensuality_scores[sentence]) for sentence in least_consensual]
|
137 |
+
|
138 |
+
text_1_consensuality = consensuality_scores.loc[text1_sentences]
|
139 |
+
text_2_consensuality = consensuality_scores.loc[text2_sentences]
|
140 |
+
text_3_consensuality = consensuality_scores.loc[text3_sentences]
|
141 |
+
|
142 |
+
# rescale between -1 and 1
|
143 |
+
# text_1_consensuality = (text_1_consensuality - (text_1_consensuality.max() - text_1_consensuality.min()) / 2) / (text_1_consensuality.max() - text_1_consensuality.min()) / 2
|
144 |
+
# text_2_consensuality = (text_2_consensuality - (text_2_consensuality.max() - text_2_consensuality.min()) / 2) / (text_2_consensuality.max() - text_2_consensuality.min()) / 2
|
145 |
+
# text_3_consensuality = (text_3_consensuality - (text_3_consensuality.max() - text_3_consensuality.min()) / 2) / (text_3_consensuality.max() - text_3_consensuality.min()) / 2
|
146 |
+
|
147 |
+
text_1_consensuality = [(sentence, text_1_consensuality[sentence]) for sentence in text1_sentences]
|
148 |
+
text_2_consensuality = [(sentence, text_2_consensuality[sentence]) for sentence in text2_sentences]
|
149 |
+
text_3_consensuality = [(sentence, text_3_consensuality[sentence]) for sentence in text3_sentences]
|
150 |
+
|
151 |
+
fig1 = plt.figure(figsize=(20, 10))
|
152 |
+
ax = fig1.add_subplot(111)
|
153 |
+
sns.heatmap(
|
154 |
+
listener_df,
|
155 |
+
ax=ax,
|
156 |
+
cmap="Blues",
|
157 |
+
annot=True,
|
158 |
+
fmt=".2f",
|
159 |
+
cbar=False,
|
160 |
+
annot_kws={"size": 10},
|
161 |
+
)
|
162 |
+
ax.set_title("Listener probabilities")
|
163 |
+
ax.set_xlabel("Candidate sentences")
|
164 |
+
ax.set_ylabel("Source texts")
|
165 |
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
166 |
+
fig1.tight_layout()
|
167 |
+
|
168 |
+
fig2 = plt.figure(figsize=(20, 10))
|
169 |
+
ax = fig2.add_subplot(111)
|
170 |
+
sns.heatmap(
|
171 |
+
speaker_df,
|
172 |
+
ax=ax,
|
173 |
+
cmap="Blues",
|
174 |
+
annot=True,
|
175 |
+
fmt=".2f",
|
176 |
+
cbar=False,
|
177 |
+
annot_kws={"size": 10},
|
178 |
+
)
|
179 |
+
ax.set_title("Speaker probabilities")
|
180 |
+
ax.set_xlabel("Candidate sentences")
|
181 |
+
ax.set_ylabel("Source texts")
|
182 |
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
183 |
+
fig2.tight_layout()
|
184 |
+
|
185 |
+
latex_text_1 = make_colored_text_to_latex(text_1_summaries)
|
186 |
+
latex_text_2 = make_colored_text_to_latex(text_2_summaries)
|
187 |
+
latex_text_3 = make_colored_text_to_latex(text_3_summaries)
|
188 |
+
|
189 |
+
text_1_consensuality_ = consensuality_scores_01.loc[text1_sentences]
|
190 |
+
text_2_consensuality_ = consensuality_scores_01.loc[text2_sentences]
|
191 |
+
text_3_consensuality_ = consensuality_scores_01.loc[text3_sentences]
|
192 |
+
|
193 |
+
text_1_consensuality_ = [(sentence, text_1_consensuality_[sentence]) for sentence in text1_sentences]
|
194 |
+
text_2_consensuality_ = [(sentence, text_2_consensuality_[sentence]) for sentence in text2_sentences]
|
195 |
+
text_3_consensuality_ = [(sentence, text_3_consensuality_[sentence]) for sentence in text3_sentences]
|
196 |
+
|
197 |
+
latex_text_1_consensuality = make_colored_text_to_latex(text_1_consensuality_)
|
198 |
+
latex_text_2_consensuality = make_colored_text_to_latex(text_2_consensuality_)
|
199 |
+
latex_text_3_consensuality = make_colored_text_to_latex(text_3_consensuality_)
|
200 |
+
|
201 |
+
latex = latex_template.replace("[REVIEW 1]", latex_text_1)
|
202 |
+
latex = latex.replace("[REVIEW 2]", latex_text_2)
|
203 |
+
latex = latex.replace("[REVIEW 3]", latex_text_3)
|
204 |
+
|
205 |
+
|
206 |
+
return text_1_summaries, text_2_summaries, text_3_summaries, text_1_consensuality, text_2_consensuality, text_3_consensuality, most_consensual, least_consensual, fig1, fig2, latex
|
207 |
+
|
208 |
+
|
209 |
+
# make gradiot highlightedText component
|
210 |
+
|
211 |
+
|
212 |
+
iface = gr.Interface(
|
213 |
+
fn=summarize,
|
214 |
+
inputs=[
|
215 |
+
gr.Textbox(lines=10, value=EXAMPLES[0]),
|
216 |
+
gr.Textbox(lines=10, value=EXAMPLES[1]),
|
217 |
+
gr.Textbox(lines=10, value=EXAMPLES[2]),
|
218 |
+
gr.Number(value=1, label="Iterations"),
|
219 |
+
gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=1.0, label="Rationality"),
|
220 |
+
],
|
221 |
+
outputs=[
|
222 |
+
gr.Highlightedtext(
|
223 |
+
show_legend=True,
|
224 |
+
label="Uniqueness score for each sentence in text 1",
|
225 |
+
),
|
226 |
+
gr.Highlightedtext(
|
227 |
+
show_legend=True,
|
228 |
+
label="Uniqueness score for each sentence in text 2",
|
229 |
+
),
|
230 |
+
gr.Highlightedtext(
|
231 |
+
show_legend=True,
|
232 |
+
label="Uniqueness score for each sentence in text 3",
|
233 |
+
),
|
234 |
+
gr.Highlightedtext(
|
235 |
+
show_legend=True,
|
236 |
+
label="Consensuality score for each sentence in text 1",
|
237 |
+
|
238 |
+
),
|
239 |
+
gr.Highlightedtext(
|
240 |
+
show_legend=True,
|
241 |
+
label="Consensuality score for each sentence in text 2",
|
242 |
+
),
|
243 |
+
gr.Highlightedtext(
|
244 |
+
show_legend=True,
|
245 |
+
label="Consensuality score for each sentence in text 3",
|
246 |
+
),
|
247 |
+
gr.Highlightedtext(
|
248 |
+
show_legend=True,
|
249 |
+
label="Most consensual sentences",
|
250 |
+
|
251 |
+
),
|
252 |
+
gr.Highlightedtext(
|
253 |
+
show_legend=True,
|
254 |
+
label="Least consensual sentences",
|
255 |
+
),
|
256 |
+
gr.Plot(
|
257 |
+
label="Listener probabilities",
|
258 |
+
),
|
259 |
+
gr.Plot(
|
260 |
+
label="Speaker probabilities",
|
261 |
+
),
|
262 |
+
|
263 |
+
gr.Textbox(lines=10, label="Latex Consensuality scores"),
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
],
|
268 |
+
title="RSA Summarizer",
|
269 |
+
description="Summarize 3 texts using RSA",
|
270 |
+
)
|
271 |
+
|
272 |
+
iface.launch()
|
glimpse-ui/glimpse/glimpse/baselines/generate_llm_summaries.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
+
import re
|
7 |
+
import argparse
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
|
11 |
+
def parse_args():
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument("--dataset", default="")
|
14 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
15 |
+
parser.add_argument("--device", type=str, default="cuda")
|
16 |
+
parser.add_argument("--output", type=Path, default="")
|
17 |
+
|
18 |
+
args = parser.parse_args()
|
19 |
+
return args
|
20 |
+
|
21 |
+
def prepare_dataset(dataset_name, dataset_path="rsasumm/data/processed/"):
|
22 |
+
dataset_path = Path(dataset_path)
|
23 |
+
if dataset_name == "amazon":
|
24 |
+
dataset = pd.read_csv(dataset_path / "amazon_test.csv")
|
25 |
+
elif dataset_name == "space":
|
26 |
+
dataset = pd.read_csv(dataset_path / "space.csv")
|
27 |
+
elif dataset_name == "yelp":
|
28 |
+
dataset = pd.read_csv(dataset_path / "yelp_test.csv")
|
29 |
+
elif dataset_name == "reviews":
|
30 |
+
dataset = pd.read_csv(dataset_path / "test_metareviews.csv")
|
31 |
+
else:
|
32 |
+
raise ValueError(f"Unknown dataset {dataset_name}")
|
33 |
+
|
34 |
+
|
35 |
+
return dataset
|
36 |
+
|
37 |
+
|
38 |
+
# group text by sample id and concatenate text
|
39 |
+
|
40 |
+
def group_text_by_id(df: pd.DataFrame) -> pd.DataFrame:
|
41 |
+
"""
|
42 |
+
Group the text by the sample id and concatenate the text.
|
43 |
+
:param df: The dataframe
|
44 |
+
:return: The dataframe with the text grouped by the sample id
|
45 |
+
"""
|
46 |
+
texts = df.groupby("id")["text"].apply(lambda x: " ".join(x))
|
47 |
+
|
48 |
+
# retrieve first gold by id
|
49 |
+
gold = df.groupby("id")["gold"].first()
|
50 |
+
|
51 |
+
# create new dataframe
|
52 |
+
df = pd.DataFrame({"text": texts, "gold": gold}, index=texts.index)
|
53 |
+
|
54 |
+
return df
|
55 |
+
|
56 |
+
|
57 |
+
def generate_summaries(model, tokenizer, df, batch_size, device):
|
58 |
+
|
59 |
+
# df columns = id, text, gold
|
60 |
+
# make instruction:
|
61 |
+
|
62 |
+
def make_instruction(text):
|
63 |
+
return f"[INST]\n{text}\n Summarize the previous text:[/INST]\n\n"
|
64 |
+
|
65 |
+
df["instruction"] = df["text"].apply(make_instruction)
|
66 |
+
|
67 |
+
# make data loader
|
68 |
+
dataset = df[["instruction"]].values.tolist()
|
69 |
+
|
70 |
+
|
71 |
+
model = model.to(device).eval()
|
72 |
+
|
73 |
+
summaries = []
|
74 |
+
with torch.no_grad():
|
75 |
+
for batch in tqdm(dataset):
|
76 |
+
print(batch)
|
77 |
+
inputs = tokenizer.encode(batch, padding=True, truncation=True, return_tensors="pt")
|
78 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
79 |
+
outputs = model.generate(**inputs, temperature=0.7, top_p=0.7, top_k=50, max_new_tokens=500)
|
80 |
+
summaries.extend(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
81 |
+
|
82 |
+
# remove the instruction from the summaries
|
83 |
+
df["summary"] = [re.sub(r"\[INST\]\n.*\[/INST\]\n\n", "", summary) for summary in summaries]
|
84 |
+
|
85 |
+
return df
|
86 |
+
|
87 |
+
def main():
|
88 |
+
|
89 |
+
args = parse_args()
|
90 |
+
model = "togethercomputer/Llama-2-7B-32K-Instruct"
|
91 |
+
tokenizer = AutoTokenizer.from_pretrained(model)
|
92 |
+
model = AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True, torch_dtype=torch.float16)
|
93 |
+
|
94 |
+
df = prepare_dataset(args.dataset)
|
95 |
+
|
96 |
+
df = group_text_by_id(df)
|
97 |
+
|
98 |
+
df = generate_summaries(model, tokenizer, df, args.batch_size, args.device)
|
99 |
+
df['metadata/Method'] = "LLM"
|
100 |
+
df['metadata/Model'] = model
|
101 |
+
|
102 |
+
name = f"{args.dataset}-_-{model.replace('/', '-')}-_-llm_summaries.csv"
|
103 |
+
path = Path(args.output) / name
|
104 |
+
|
105 |
+
Path(args.output).mkdir(exist_ok=True, parents=True)
|
106 |
+
df.to_csv(path, index=True)
|
107 |
+
|
108 |
+
|
109 |
+
main()
|
110 |
+
|
111 |
+
|
112 |
+
|
glimpse-ui/glimpse/glimpse/baselines/sumy_baselines.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sumy.parsers.plaintext import PlaintextParser
|
2 |
+
from sumy.parsers.html import HtmlParser
|
3 |
+
from sumy.nlp.tokenizers import Tokenizer
|
4 |
+
from sumy.nlp.stemmers import Stemmer
|
5 |
+
from sumy.utils import get_stop_words
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import pandas as pd
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import nltk
|
13 |
+
|
14 |
+
|
15 |
+
def summarize(method, language, sentence_count, input_type, input_):
|
16 |
+
if method == 'LSA':
|
17 |
+
from sumy.summarizers.lsa import LsaSummarizer as Summarizer
|
18 |
+
if method == 'text-rank':
|
19 |
+
from sumy.summarizers.text_rank import TextRankSummarizer as Summarizer
|
20 |
+
if method == 'lex-rank':
|
21 |
+
from sumy.summarizers.lex_rank import LexRankSummarizer as Summarizer
|
22 |
+
if method == 'edmundson':
|
23 |
+
from sumy.summarizers.edmundson import EdmundsonSummarizer as Summarizer
|
24 |
+
if method == 'luhn':
|
25 |
+
from sumy.summarizers.luhn import LuhnSummarizer as Summarizer
|
26 |
+
if method == 'kl-sum':
|
27 |
+
from sumy.summarizers.kl import KLSummarizer as Summarizer
|
28 |
+
if method == 'random':
|
29 |
+
from sumy.summarizers.random import RandomSummarizer as Summarizer
|
30 |
+
if method == 'reduction':
|
31 |
+
from sumy.summarizers.reduction import ReductionSummarizer as Summarizer
|
32 |
+
|
33 |
+
if input_type == "URL":
|
34 |
+
parser = HtmlParser.from_url(input_, Tokenizer(language))
|
35 |
+
if input_type == "text":
|
36 |
+
parser = PlaintextParser.from_string(input_, Tokenizer(language))
|
37 |
+
|
38 |
+
stemmer = Stemmer(language)
|
39 |
+
summarizer = Summarizer(stemmer)
|
40 |
+
stop_words = get_stop_words(language)
|
41 |
+
|
42 |
+
if method == 'edmundson':
|
43 |
+
summarizer.null_words = stop_words
|
44 |
+
summarizer.bonus_words = parser.significant_words
|
45 |
+
summarizer.stigma_words = parser.stigma_words
|
46 |
+
else:
|
47 |
+
summarizer.stop_words = stop_words
|
48 |
+
|
49 |
+
summary_sentences = summarizer(parser.document, sentence_count)
|
50 |
+
summary = ' '.join([str(sentence) for sentence in summary_sentences])
|
51 |
+
|
52 |
+
return summary
|
53 |
+
|
54 |
+
|
55 |
+
# methods = ['LSA', 'text-rank', 'lex-rank', 'edmundson', 'luhn', 'kl-sum', 'random', 'reduction']
|
56 |
+
|
57 |
+
|
58 |
+
def parse_args():
|
59 |
+
parser = argparse.ArgumentParser()
|
60 |
+
parser.add_argument("--dataset", default="")
|
61 |
+
# method
|
62 |
+
parser.add_argument("--method", type=str, choices=['LSA', 'text-rank', 'lex-rank', 'edmundson', 'luhn', 'kl-sum', 'random', 'reduction'], default="LSA")
|
63 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
64 |
+
parser.add_argument("--device", type=str, default="cuda")
|
65 |
+
parser.add_argument("--output", type=Path, default="")
|
66 |
+
|
67 |
+
args = parser.parse_args()
|
68 |
+
return args
|
69 |
+
|
70 |
+
def prepare_dataset(dataset_name, dataset_path="rsasumm/data/processed/"):
|
71 |
+
dataset_path = Path(dataset_path)
|
72 |
+
if dataset_name == "amazon":
|
73 |
+
dataset = pd.read_csv(dataset_path / "amazon_test.csv")
|
74 |
+
elif dataset_name == "space":
|
75 |
+
dataset = pd.read_csv(dataset_path / "space.csv")
|
76 |
+
elif dataset_name == "yelp":
|
77 |
+
dataset = pd.read_csv(dataset_path / "yelp_test.csv")
|
78 |
+
elif dataset_name == "reviews":
|
79 |
+
dataset = pd.read_csv(dataset_path / "test_metareviews.csv")
|
80 |
+
else:
|
81 |
+
raise ValueError(f"Unknown dataset {dataset_name}")
|
82 |
+
|
83 |
+
|
84 |
+
return dataset
|
85 |
+
|
86 |
+
|
87 |
+
# group text by sample id and concatenate text
|
88 |
+
|
89 |
+
def group_text_by_id(df: pd.DataFrame) -> pd.DataFrame:
|
90 |
+
"""
|
91 |
+
Group the text by the sample id and concatenate the text.
|
92 |
+
:param df: The dataframe
|
93 |
+
:return: The dataframe with the text grouped by the sample id
|
94 |
+
"""
|
95 |
+
texts = df.groupby("id")["text"].apply(lambda x: " ".join(x))
|
96 |
+
|
97 |
+
# retrieve first gold by id
|
98 |
+
gold = df.groupby("id")["gold"].first()
|
99 |
+
|
100 |
+
# create new dataframe
|
101 |
+
df = pd.DataFrame({"text": texts, "gold": gold}, index=texts.index)
|
102 |
+
|
103 |
+
return df
|
104 |
+
|
105 |
+
|
106 |
+
def main():
|
107 |
+
args = parse_args()
|
108 |
+
for N in [1]:
|
109 |
+
dataset = prepare_dataset(args.dataset)
|
110 |
+
# dataset = group_text_by_id(dataset)
|
111 |
+
|
112 |
+
summaries = []
|
113 |
+
for text in dataset.text:
|
114 |
+
summary = summarize(args.method, "english", N, "text", text)
|
115 |
+
summaries.append(summary)
|
116 |
+
|
117 |
+
dataset['summary'] = summaries
|
118 |
+
dataset['metadata/dataset'] = args.dataset
|
119 |
+
dataset["metadata/method"] = args.method
|
120 |
+
dataset["metadata/sentence_count"] = N
|
121 |
+
|
122 |
+
name = f"{args.dataset}-_-{args.method}-_-sumy_{N}.csv"
|
123 |
+
path = Path(args.output) / name
|
124 |
+
|
125 |
+
Path(args.output).mkdir(exist_ok=True, parents=True)
|
126 |
+
dataset.to_csv(path, index=True)
|
127 |
+
|
128 |
+
|
129 |
+
main()
|
glimpse-ui/glimpse/glimpse/data_loading/Glimpse_tokenizer.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import spacy
|
3 |
+
import importlib
|
4 |
+
import nltk
|
5 |
+
|
6 |
+
############################################
|
7 |
+
### CHANGE THIS LINE TO CHOOSE TOKENIZER ###
|
8 |
+
ORIGINAL_TOKENIZER = False
|
9 |
+
############################################
|
10 |
+
|
11 |
+
try:
|
12 |
+
importlib.util.find_spec("en_core_web_sm")
|
13 |
+
nlp = spacy.load("en_core_web_sm")
|
14 |
+
except:
|
15 |
+
import spacy.cli
|
16 |
+
spacy.cli.download("en_core_web_sm")
|
17 |
+
nlp = spacy.load("en_core_web_sm")
|
18 |
+
|
19 |
+
def glimpse_tokenizer(text: str) -> list:
|
20 |
+
|
21 |
+
# If the original tokenizer is set to True, use the original tokenizer
|
22 |
+
if ORIGINAL_TOKENIZER:
|
23 |
+
return original_tokenizer(text)
|
24 |
+
|
25 |
+
# else, use the new tokenizer
|
26 |
+
else:
|
27 |
+
|
28 |
+
# More general-purpose tokenizer that handles both natural paragraph text and structured reviews.
|
29 |
+
|
30 |
+
# Normalize long dashes
|
31 |
+
text = re.sub(r"[-]{2,}", "\n", text)
|
32 |
+
|
33 |
+
# Keep line breaks meaningful (but fallback to sentence splitting)
|
34 |
+
chunks = re.split(r"\n+", text)
|
35 |
+
sentences = []
|
36 |
+
|
37 |
+
for chunk in chunks:
|
38 |
+
chunk = chunk.strip()
|
39 |
+
if not chunk:
|
40 |
+
continue
|
41 |
+
|
42 |
+
# Section headers and bullets become single “sentences”
|
43 |
+
if re.match(r"^(Summary|Strengths?|Weaknesses?|Minor)\s*:?", chunk, re.IGNORECASE):
|
44 |
+
sentences.append(chunk)
|
45 |
+
continue
|
46 |
+
|
47 |
+
if re.match(r"^(\d+(\.\d+)*\.|-)\s+.+", chunk):
|
48 |
+
sentences.append(chunk)
|
49 |
+
continue
|
50 |
+
|
51 |
+
# Otherwise, apply SpaCy sentence splitting
|
52 |
+
doc = nlp(chunk)
|
53 |
+
sentences.extend([sent.text.strip() for sent in doc.sents if sent.text.strip()])
|
54 |
+
|
55 |
+
return sentences
|
56 |
+
|
57 |
+
# reuse the original glimpse tokenizer
|
58 |
+
# def glimpse_tokenizer(text: str) -> list:
|
59 |
+
# return tokenize_sentences(text)
|
60 |
+
|
61 |
+
# Default glimpse tokenizer from the original code
|
62 |
+
def original_tokenizer(text: str) -> list:
|
63 |
+
"""
|
64 |
+
Tokenizes the input text into sentences.
|
65 |
+
|
66 |
+
@param text: The input text to be tokenized
|
67 |
+
@return: A list of tokenized sentences
|
68 |
+
"""
|
69 |
+
text = text.replace('-----', '\n')
|
70 |
+
sentences = nltk.sent_tokenize(text)
|
71 |
+
# remove empty sentences
|
72 |
+
sentences = [sentence for sentence in sentences if sentence != ""]
|
73 |
+
|
74 |
+
return sentences
|
glimpse-ui/glimpse/glimpse/data_loading/data_processing.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os
|
3 |
+
|
4 |
+
data_glimpse = "data/processed/"
|
5 |
+
if not os.path.exists(data_glimpse):
|
6 |
+
os.makedirs(data_glimpse)
|
7 |
+
|
8 |
+
for year in range (2017, 2021 + 1):
|
9 |
+
dataset = pd.read_csv(f"data/all_reviews_{year}.csv")
|
10 |
+
sub_dataset = dataset[['id','review', 'metareview']]
|
11 |
+
sub_dataset.rename(columns={"review": "text", "metareview": "gold"}, inplace=True)
|
12 |
+
|
13 |
+
sub_dataset.to_csv(f"{data_glimpse}all_reviews_{year}.csv", index=False)
|
14 |
+
|
15 |
+
|
glimpse-ui/glimpse/glimpse/data_loading/generate_abstractive_candidates.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from datasets import Dataset
|
7 |
+
from tqdm import tqdm
|
8 |
+
import datetime
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
12 |
+
|
13 |
+
GENERATION_CONFIGS = {
|
14 |
+
"top_p_sampling": {
|
15 |
+
"max_new_tokens": 200,
|
16 |
+
"do_sample": True,
|
17 |
+
"top_p": 0.95,
|
18 |
+
"temperature": 1.0,
|
19 |
+
"num_return_sequences": 8,
|
20 |
+
"num_beams" : 1,
|
21 |
+
|
22 |
+
#"num_beam_groups" : 4,
|
23 |
+
},
|
24 |
+
|
25 |
+
**{
|
26 |
+
f"sampling_topp_{str(topp).replace('.', '')}": {
|
27 |
+
"max_new_tokens": 200,
|
28 |
+
"do_sample": True,
|
29 |
+
"num_return_sequences": 8,
|
30 |
+
"top_p": 0.95,
|
31 |
+
}
|
32 |
+
for topp in [0.5, 0.8, 0.95, 0.99]
|
33 |
+
},
|
34 |
+
}
|
35 |
+
|
36 |
+
# add base.csv config to all configs
|
37 |
+
for key, value in GENERATION_CONFIGS.items():
|
38 |
+
GENERATION_CONFIGS[key] = {
|
39 |
+
# "max_length": 2048,
|
40 |
+
"min_length": 0,
|
41 |
+
"early_stopping": True,
|
42 |
+
**value,
|
43 |
+
}
|
44 |
+
|
45 |
+
|
46 |
+
def parse_args():
|
47 |
+
parser = argparse.ArgumentParser()
|
48 |
+
parser.add_argument("--model_name", type=str, default="facebook/bart-large-cnn")
|
49 |
+
parser.add_argument("--dataset_path", type=Path, default="data/processed/all_reviews_2017.csv")
|
50 |
+
parser.add_argument("--decoding_config", type=str, default="top_p_sampling", choices=GENERATION_CONFIGS.keys())
|
51 |
+
|
52 |
+
parser.add_argument("--batch_size", type=int, default=16)
|
53 |
+
parser.add_argument("--device", type=str, default="cuda")
|
54 |
+
parser.add_argument("--trimming", action=argparse.BooleanOptionalAction, default=True)
|
55 |
+
|
56 |
+
parser.add_argument("--output_dir", type=str, default="data/candidates")
|
57 |
+
|
58 |
+
# if ran in a scripted way, the output path will be printed
|
59 |
+
parser.add_argument("--scripted-run", action=argparse.BooleanOptionalAction, default=False)
|
60 |
+
|
61 |
+
# limit the number of samples to generate
|
62 |
+
parser.add_argument("--limit", type=int, default=None)
|
63 |
+
|
64 |
+
args = parser.parse_args()
|
65 |
+
|
66 |
+
return args
|
67 |
+
|
68 |
+
|
69 |
+
def prepare_dataset(dataset_path) -> Dataset:
|
70 |
+
try:
|
71 |
+
dataset = pd.read_csv(dataset_path)
|
72 |
+
except:
|
73 |
+
raise ValueError(f"Unknown dataset {dataset_path}")
|
74 |
+
|
75 |
+
# make a dataset from the dataframe
|
76 |
+
dataset = Dataset.from_pandas(dataset)
|
77 |
+
|
78 |
+
return dataset
|
79 |
+
|
80 |
+
|
81 |
+
def evaluate_summarizer(
|
82 |
+
model, tokenizer, dataset: Dataset, decoding_config, batch_size: int,
|
83 |
+
device: str, trimming: bool
|
84 |
+
) -> Dataset:
|
85 |
+
"""
|
86 |
+
@param model: The model used to generate the summaries
|
87 |
+
@param tokenizer: The tokenizer used to tokenize the text and the summary
|
88 |
+
@param dataset: A dataset with the text
|
89 |
+
@param decoding_config: Dictionary with the decoding config
|
90 |
+
@param batch_size: The batch size used to generate the summaries
|
91 |
+
@return: The same dataset with the summaries added
|
92 |
+
"""
|
93 |
+
# create a dataset with the text and the summary
|
94 |
+
|
95 |
+
# create a dataloader
|
96 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=trimming)
|
97 |
+
|
98 |
+
# generate summaries
|
99 |
+
summaries = []
|
100 |
+
print("Generating summaries...")
|
101 |
+
|
102 |
+
for batch in tqdm(dataloader):
|
103 |
+
text = batch["text"]
|
104 |
+
|
105 |
+
inputs = tokenizer(
|
106 |
+
text,
|
107 |
+
max_length=1024,
|
108 |
+
padding="max_length",
|
109 |
+
truncation=True,
|
110 |
+
return_tensors="pt",
|
111 |
+
)
|
112 |
+
|
113 |
+
# move inputs to device
|
114 |
+
inputs = {key: value.to(device) for key, value in inputs.items()}
|
115 |
+
|
116 |
+
# generate summaries
|
117 |
+
outputs = model.generate(
|
118 |
+
**inputs,
|
119 |
+
**decoding_config,
|
120 |
+
)
|
121 |
+
|
122 |
+
total_size = outputs.numel() # Total number of elements in the tensor
|
123 |
+
target_size = batch_size * outputs.shape[-1] # Target size of the last dimension
|
124 |
+
pad_size = (target_size - (total_size % target_size)) % target_size # Calculate the required padding size to make the total number of elements divisible by the target size
|
125 |
+
|
126 |
+
# Pad the tensor with zeros to make the total number of elements divisible by the target size
|
127 |
+
if not trimming and pad_size != 0: outputs = torch.nn.functional.pad(outputs, (0, 0, 0, pad_size // outputs.shape[-1]))
|
128 |
+
|
129 |
+
# output : (batch_size * num_return_sequences, max_length)
|
130 |
+
try:
|
131 |
+
outputs = outputs.reshape(batch_size, -1, outputs.shape[-1])
|
132 |
+
except Exception as e:
|
133 |
+
print(f"Error reshaping outputs: {e}")
|
134 |
+
raise ValueError(f"Cannot reshape tensor of size {outputs.numel()} into shape "
|
135 |
+
f"({batch_size}, -1, {outputs.shape[-1]}).")
|
136 |
+
|
137 |
+
# decode summaries
|
138 |
+
for b in range(batch_size):
|
139 |
+
summaries.append(
|
140 |
+
[
|
141 |
+
tokenizer.decode(
|
142 |
+
outputs[b, i],
|
143 |
+
skip_special_tokens=True,
|
144 |
+
)
|
145 |
+
for i in range(outputs.shape[1])
|
146 |
+
]
|
147 |
+
)
|
148 |
+
|
149 |
+
# if trimming the last batch, remove them from the dataset
|
150 |
+
if trimming: dataset = dataset.select(range(len(summaries)))
|
151 |
+
|
152 |
+
# add summaries to the huggingface dataset
|
153 |
+
dataset = dataset.map(lambda example: {"summary": summaries.pop(0)})
|
154 |
+
|
155 |
+
return dataset
|
156 |
+
|
157 |
+
|
158 |
+
def sanitize_model_name(model_name: str) -> str:
|
159 |
+
"""
|
160 |
+
Sanitize the model name to be used as a folder name.
|
161 |
+
@param model_name: The model name
|
162 |
+
@return: The sanitized model name
|
163 |
+
"""
|
164 |
+
return model_name.replace("/", "_")
|
165 |
+
|
166 |
+
|
167 |
+
def main():
|
168 |
+
args = parse_args()
|
169 |
+
|
170 |
+
# load the model
|
171 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
172 |
+
args.model_name
|
173 |
+
)
|
174 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
175 |
+
|
176 |
+
tokenizer.pad_token = tokenizer.unk_token
|
177 |
+
tokenizer.pad_token_id = tokenizer.unk_token_id
|
178 |
+
|
179 |
+
# move model to device
|
180 |
+
model = model.to(args.device)
|
181 |
+
|
182 |
+
# load the dataset
|
183 |
+
print("Loading dataset...")
|
184 |
+
dataset = prepare_dataset(args.dataset_path)
|
185 |
+
|
186 |
+
# limit the number of samples
|
187 |
+
if args.limit is not None:
|
188 |
+
_lim = min(args.limit, len(dataset))
|
189 |
+
dataset = dataset.select(range(_lim))
|
190 |
+
|
191 |
+
# generate summaries
|
192 |
+
dataset = evaluate_summarizer(
|
193 |
+
model,
|
194 |
+
tokenizer,
|
195 |
+
dataset,
|
196 |
+
GENERATION_CONFIGS[args.decoding_config],
|
197 |
+
args.batch_size,
|
198 |
+
args.device,
|
199 |
+
args.trimming,
|
200 |
+
)
|
201 |
+
|
202 |
+
df_dataset = dataset.to_pandas()
|
203 |
+
df_dataset = df_dataset.explode('summary')
|
204 |
+
df_dataset = df_dataset.reset_index()
|
205 |
+
# add an idx with the id of the summary for each example
|
206 |
+
df_dataset['id_candidate'] = df_dataset.groupby(['index']).cumcount()
|
207 |
+
|
208 |
+
# save the dataset
|
209 |
+
# add unique date in name
|
210 |
+
now = datetime.datetime.now()
|
211 |
+
date = now.strftime("%Y-%m-%d-%H-%M-%S")
|
212 |
+
model_name = sanitize_model_name(args.model_name)
|
213 |
+
padding_status = "trimmed" if args.trimming else "padded"
|
214 |
+
output_path = (
|
215 |
+
Path(args.output_dir)
|
216 |
+
/ f"{model_name}-_-{args.dataset_path.stem}-_-{args.decoding_config}-_-{padding_status}-_-{date}.csv"
|
217 |
+
)
|
218 |
+
|
219 |
+
# create output dir if it doesn't exist
|
220 |
+
if not output_path.parent.exists():
|
221 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
222 |
+
|
223 |
+
df_dataset.to_csv(output_path, index=False, encoding="utf-8")
|
224 |
+
|
225 |
+
# in case of scripted run, print the output path
|
226 |
+
if args.scripted_run: print(output_path)
|
227 |
+
|
228 |
+
|
229 |
+
if __name__ == "__main__":
|
230 |
+
main()
|
glimpse-ui/glimpse/glimpse/data_loading/generate_extractive_candidates.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
from datasets import Dataset
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
import nltk
|
10 |
+
|
11 |
+
import sys, os.path
|
12 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
|
13 |
+
|
14 |
+
from glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
|
15 |
+
|
16 |
+
# def tokenize_sentences(text: str) -> list:
|
17 |
+
# """
|
18 |
+
# Tokenizes the input text into sentences.
|
19 |
+
|
20 |
+
# @param text: The input text to be tokenized
|
21 |
+
# @return: A list of tokenized sentences
|
22 |
+
# """
|
23 |
+
# text = text.replace('-----', '\n')
|
24 |
+
# sentences = nltk.sent_tokenize(text)
|
25 |
+
# # remove empty sentences
|
26 |
+
# sentences = [sentence for sentence in sentences if sentence != ""]
|
27 |
+
|
28 |
+
# return sentences
|
29 |
+
|
30 |
+
def parse_args():
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
|
33 |
+
parser.add_argument("--dataset_path", type=Path, default="glimpse/data/processed/all_reviews_2017.csv")
|
34 |
+
parser.add_argument("--output_dir", type=str, default="glimpse/data/candidates")
|
35 |
+
|
36 |
+
# if ran in a scripted way, the output path will be printed
|
37 |
+
parser.add_argument("--scripted-run", action=argparse.BooleanOptionalAction, default=False)
|
38 |
+
|
39 |
+
# limit the number of samples to generate
|
40 |
+
parser.add_argument("--limit", type=int, default=None)
|
41 |
+
|
42 |
+
args = parser.parse_args()
|
43 |
+
|
44 |
+
return args
|
45 |
+
|
46 |
+
|
47 |
+
def prepare_dataset(dataset_path) -> Dataset:
|
48 |
+
|
49 |
+
try:
|
50 |
+
dataset = pd.read_csv(dataset_path)
|
51 |
+
except:
|
52 |
+
raise ValueError(f"Unknown dataset {dataset_path}")
|
53 |
+
|
54 |
+
# make a dataset from the dataframe
|
55 |
+
dataset = Dataset.from_pandas(dataset)
|
56 |
+
|
57 |
+
return dataset
|
58 |
+
|
59 |
+
|
60 |
+
def evaluate_summarizer(dataset: Dataset) -> Dataset:
|
61 |
+
"""
|
62 |
+
@param dataset: A dataset with the text
|
63 |
+
@return: The same dataset with the summaries added
|
64 |
+
"""
|
65 |
+
# create a dataset with the text and the summary
|
66 |
+
|
67 |
+
# create a dataloader
|
68 |
+
|
69 |
+
# generate summaries
|
70 |
+
summaries = []
|
71 |
+
print("Generating summaries...")
|
72 |
+
|
73 |
+
# (tqdm library for progress bar)
|
74 |
+
for sample in tqdm(dataset):
|
75 |
+
text = sample["text"]
|
76 |
+
|
77 |
+
sentences = glimpse_tokenizer(text)
|
78 |
+
|
79 |
+
summaries.append(sentences)
|
80 |
+
|
81 |
+
# add summaries to the huggingface dataset
|
82 |
+
dataset = dataset.map(lambda example: {"summary": summaries.pop(0)})
|
83 |
+
|
84 |
+
return dataset
|
85 |
+
|
86 |
+
|
87 |
+
def main():
|
88 |
+
args = parse_args()
|
89 |
+
# load the dataset
|
90 |
+
print("Loading dataset...")
|
91 |
+
dataset = prepare_dataset(args.dataset_path)
|
92 |
+
|
93 |
+
# limit the number of samples
|
94 |
+
if args.limit is not None:
|
95 |
+
_lim = min(args.limit, len(dataset))
|
96 |
+
dataset = dataset.select(range(_lim))
|
97 |
+
|
98 |
+
# generate summaries
|
99 |
+
dataset = evaluate_summarizer(
|
100 |
+
dataset,
|
101 |
+
)
|
102 |
+
|
103 |
+
df_dataset = dataset.to_pandas()
|
104 |
+
df_dataset = df_dataset.explode("summary")
|
105 |
+
df_dataset = df_dataset.reset_index()
|
106 |
+
# add an idx with the id of the summary for each example
|
107 |
+
df_dataset["id_candidate"] = df_dataset.groupby(["index"]).cumcount()
|
108 |
+
|
109 |
+
# save the dataset
|
110 |
+
# add unique date in name
|
111 |
+
now = datetime.datetime.now()
|
112 |
+
date = now.strftime("%Y-%m-%d-%H-%M-%S")
|
113 |
+
output_path = (
|
114 |
+
Path(args.output_dir)
|
115 |
+
/ f"extractive_sentences-_-{args.dataset_path.stem}-_-none-_-{date}.csv"
|
116 |
+
)
|
117 |
+
|
118 |
+
# create output dir if it doesn't exist
|
119 |
+
if not output_path.parent.exists():
|
120 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
121 |
+
|
122 |
+
df_dataset.to_csv(output_path, index=False, encoding="utf-8")
|
123 |
+
|
124 |
+
# in case of scripted run, print the output path
|
125 |
+
if args.scripted_run: print(output_path)
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
main()
|
glimpse-ui/glimpse/glimpse/evaluate/Evaluate informativeness.ipynb
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "38be7bab-8a42-49dd-8976-2755ee84edbe",
|
7 |
+
"metadata": {
|
8 |
+
"tags": []
|
9 |
+
},
|
10 |
+
"outputs": [],
|
11 |
+
"source": [
|
12 |
+
"import pandas as pd\n",
|
13 |
+
"import numpy as np\n",
|
14 |
+
"from pathlib import Path\n",
|
15 |
+
"import pickle as pk\n",
|
16 |
+
"import nltk\n",
|
17 |
+
"import seaborn as sns\n",
|
18 |
+
"import matplotlib.pyplot as plt\n",
|
19 |
+
"\n",
|
20 |
+
"\n",
|
21 |
+
"\n",
|
22 |
+
"export_summaries_path = Path('output/summaries/methods_per_text')\n",
|
23 |
+
"export_summaries_path.mkdir(parents=True, exist_ok=True)"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": null,
|
29 |
+
"id": "1c24ddc1-978e-4a24-8fdb-e48b2848edd0",
|
30 |
+
"metadata": {
|
31 |
+
"tags": []
|
32 |
+
},
|
33 |
+
"outputs": [],
|
34 |
+
"source": [
|
35 |
+
"\n",
|
36 |
+
"dfs = []\n",
|
37 |
+
"for file in export_summaries_path.glob('*.csv'):\n",
|
38 |
+
" df = pd.read_csv(file)\n",
|
39 |
+
" generation_method, dataset = file.stem.split('-_-')[:2]\n",
|
40 |
+
" \n",
|
41 |
+
" df['metadata/Generation'] = generation_method\n",
|
42 |
+
" df['metadata/Dataset'] = dataset\n",
|
43 |
+
" \n",
|
44 |
+
" dfs.append(df)\n",
|
45 |
+
" \n",
|
46 |
+
"df = pd.concat(dfs)\n",
|
47 |
+
"\n",
|
48 |
+
"df = df.drop([c for c in df.columns if \"Unnamed:\" in c], axis=1)\n",
|
49 |
+
"\n",
|
50 |
+
"del dfs\n",
|
51 |
+
"\n",
|
52 |
+
"def replace_abstractive(x):\n",
|
53 |
+
" if \"abstractive\" in x:\n",
|
54 |
+
" return \"extractive_sentences\"\n",
|
55 |
+
" else:\n",
|
56 |
+
" return x\n",
|
57 |
+
"\n",
|
58 |
+
"df['metadata/Generation'] = df['metadata/Generation'].apply(replace_abstractive)\n",
|
59 |
+
"\n"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": null,
|
65 |
+
"id": "582af03f-ac60-4644-88d7-28caf84bb552",
|
66 |
+
"metadata": {
|
67 |
+
"tags": []
|
68 |
+
},
|
69 |
+
"outputs": [],
|
70 |
+
"source": [
|
71 |
+
"df = df[~(df['Method'].str.contains('Lead')).fillna(False)]\n",
|
72 |
+
"\n"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"id": "099fda3b-137c-4e3b-bb70-933ad183d378",
|
79 |
+
"metadata": {
|
80 |
+
"tags": []
|
81 |
+
},
|
82 |
+
"outputs": [],
|
83 |
+
"source": []
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"cell_type": "code",
|
87 |
+
"execution_count": null,
|
88 |
+
"id": "7d9babd3-a409-409c-848e-db6a3f846d6f",
|
89 |
+
"metadata": {
|
90 |
+
"tags": []
|
91 |
+
},
|
92 |
+
"outputs": [],
|
93 |
+
"source": [
|
94 |
+
"ddf = df.copy()\n",
|
95 |
+
"ddf['proba_of_success'] = ddf['proba_of_success'].apply(np.exp)\n",
|
96 |
+
"\n",
|
97 |
+
"discriminativity = ddf.groupby(['metadata/Generation', 'metadata/reranking_model', 'Method'])[['proba_of_success', 'LM Perplexity']].agg(['mean']).droplevel(1, axis=1).sort_values('proba_of_success', ascending=False).reset_index()\n",
|
98 |
+
"discriminativity"
|
99 |
+
]
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "code",
|
103 |
+
"execution_count": null,
|
104 |
+
"id": "d1bffd58-0eed-4007-bcca-cd0a6dfdcd1d",
|
105 |
+
"metadata": {
|
106 |
+
"tags": []
|
107 |
+
},
|
108 |
+
"outputs": [],
|
109 |
+
"source": [
|
110 |
+
"\n",
|
111 |
+
"ddf = ddf.sort_values('proba_of_success', ascending=False)\n",
|
112 |
+
"sns.catplot(data=ddf, y='proba_of_success', x='Method', hue=\"metadata/Generation\", kind='bar', col='metadata/reranking_model')"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": null,
|
118 |
+
"id": "6684f99b-68d8-4a4f-a1f7-8943f9755bb7",
|
119 |
+
"metadata": {
|
120 |
+
"tags": []
|
121 |
+
},
|
122 |
+
"outputs": [],
|
123 |
+
"source": [
|
124 |
+
"\n",
|
125 |
+
"fig, ax = plt.subplots(1, 1)\n",
|
126 |
+
"\n",
|
127 |
+
"paretto = get_pareto_points(discriminativity[['proba_of_success', 'LM Perplexity']].values)\n",
|
128 |
+
"\n",
|
129 |
+
"ax.plot(paretto[:, 1], paretto[:, 0], c='purple', linewidth=5, linestyle=\"--\", label=\"Pareto front\")\n",
|
130 |
+
"sns.scatterplot(data=discriminativity, y='proba_of_success', x='LM Perplexity', hue=\"Method\", s=200, alpha=0.8, style='metadata/reranking_model')\n",
|
131 |
+
"plt.xlim(-70, 0)\n",
|
132 |
+
"\n",
|
133 |
+
" \n",
|
134 |
+
"\n",
|
135 |
+
"def get_pareto_points(data):\n",
|
136 |
+
" # data : [N, 2]\n",
|
137 |
+
" \n",
|
138 |
+
" optima = []\n",
|
139 |
+
" for p in data:\n",
|
140 |
+
" x, y = p\n",
|
141 |
+
" if len([p2 for p2 in data if p2[0] > p[0] and p2[1] > p[1]]) == 0:\n",
|
142 |
+
" optima.append(p)\n",
|
143 |
+
" \n",
|
144 |
+
" return np.array(optima)\n",
|
145 |
+
"\n",
|
146 |
+
"\n",
|
147 |
+
"\n",
|
148 |
+
" \n",
|
149 |
+
"\n",
|
150 |
+
" \n",
|
151 |
+
" \n",
|
152 |
+
" "
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"execution_count": null,
|
158 |
+
"id": "6f6d6a10-e13e-4d44-ace3-dae5de0f362c",
|
159 |
+
"metadata": {
|
160 |
+
"tags": []
|
161 |
+
},
|
162 |
+
"outputs": [],
|
163 |
+
"source": []
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": null,
|
168 |
+
"id": "6f2b4cc9-8584-463e-b130-6c48c85e2665",
|
169 |
+
"metadata": {
|
170 |
+
"tags": []
|
171 |
+
},
|
172 |
+
"outputs": [],
|
173 |
+
"source": []
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": null,
|
178 |
+
"id": "d49c5df4-23a4-473c-81a6-b8928ffdf8af",
|
179 |
+
"metadata": {},
|
180 |
+
"outputs": [],
|
181 |
+
"source": [
|
182 |
+
"\n",
|
183 |
+
"ddf = df.copy().drop('level_0', axis=1)\n",
|
184 |
+
"ddf['proba_of_success'] = -ddf['proba_of_success']\n",
|
185 |
+
"ddf['LM Perplexity'] = -ddf['LM Perplexity']\n",
|
186 |
+
"ddf = ddf.reset_index()\n",
|
187 |
+
"\n",
|
188 |
+
"\n",
|
189 |
+
"\n",
|
190 |
+
"\n",
|
191 |
+
"sns.displot(data=ddf, x='LM Perplexity', y='proba_of_success', hue=\"Method\", kind='kde')\n",
|
192 |
+
"# plt.xscale('log')\n",
|
193 |
+
"#plt.yscale('log')\n",
|
194 |
+
"\n",
|
195 |
+
"plt.xlim(0, 100)\n",
|
196 |
+
"plt.ylim(-2, 3)\n",
|
197 |
+
"\n"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": null,
|
203 |
+
"id": "d1d1235a-7bff-4e51-ad43-20dd5d1a0734",
|
204 |
+
"metadata": {
|
205 |
+
"tags": []
|
206 |
+
},
|
207 |
+
"outputs": [],
|
208 |
+
"source": [
|
209 |
+
"sns.scatterplot(data=df, x='LM Perplexity', y='proba_of_success', hue=\"Method\")\n"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"cell_type": "code",
|
214 |
+
"execution_count": null,
|
215 |
+
"id": "faf6ec08-feca-4106-bf51-7a3ffb438267",
|
216 |
+
"metadata": {},
|
217 |
+
"outputs": [],
|
218 |
+
"source": []
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"cell_type": "code",
|
222 |
+
"execution_count": null,
|
223 |
+
"id": "e0805bfe-660f-4d1f-a28a-cf09220b5da3",
|
224 |
+
"metadata": {},
|
225 |
+
"outputs": [],
|
226 |
+
"source": []
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "code",
|
230 |
+
"execution_count": null,
|
231 |
+
"id": "5235f917-81b5-40ca-9c62-cb63a8aaaffb",
|
232 |
+
"metadata": {},
|
233 |
+
"outputs": [],
|
234 |
+
"source": []
|
235 |
+
}
|
236 |
+
],
|
237 |
+
"metadata": {
|
238 |
+
"kernelspec": {
|
239 |
+
"display_name": "pytorch-gpu-2.0.0_py3.10.9",
|
240 |
+
"language": "python",
|
241 |
+
"name": "module-conda-env-pytorch-gpu-2.0.0_py3.10.9"
|
242 |
+
},
|
243 |
+
"language_info": {
|
244 |
+
"codemirror_mode": {
|
245 |
+
"name": "ipython",
|
246 |
+
"version": 3
|
247 |
+
},
|
248 |
+
"file_extension": ".py",
|
249 |
+
"mimetype": "text/x-python",
|
250 |
+
"name": "python",
|
251 |
+
"nbconvert_exporter": "python",
|
252 |
+
"pygments_lexer": "ipython3",
|
253 |
+
"version": "3.10.9"
|
254 |
+
}
|
255 |
+
},
|
256 |
+
"nbformat": 4,
|
257 |
+
"nbformat_minor": 5
|
258 |
+
}
|
glimpse-ui/glimpse/glimpse/evaluate/evaluate_bartbert_metrics.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
from bert_score import BERTScorer
|
7 |
+
|
8 |
+
def sanitize_model_name(model_name: str) -> str:
|
9 |
+
"""
|
10 |
+
Sanitize the model name to be used as a folder name.
|
11 |
+
@param model_name: The model name
|
12 |
+
@return: The sanitized model name
|
13 |
+
"""
|
14 |
+
return model_name.replace("/", "_")
|
15 |
+
|
16 |
+
# logging.basicConfig(stream=stdout, level=logging.)
|
17 |
+
def parse_args():
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument("--summaries", type=Path, default="")
|
20 |
+
|
21 |
+
# device
|
22 |
+
parser.add_argument("--device", type=str, default="cuda")
|
23 |
+
|
24 |
+
args = parser.parse_args()
|
25 |
+
return args
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def parse_summaries(path: Path):
|
30 |
+
"""
|
31 |
+
:return: a pandas dataframe with at least the columns 'text' and 'summary'
|
32 |
+
"""
|
33 |
+
# read csv file
|
34 |
+
|
35 |
+
df = pd.read_csv(path).dropna()
|
36 |
+
|
37 |
+
|
38 |
+
# check if the csv file has the correct columns
|
39 |
+
if not all([col in df.columns for col in ["gold", "summary"]]):
|
40 |
+
raise ValueError("The csv file must have the columns 'text' and 'summary'.")
|
41 |
+
|
42 |
+
return df
|
43 |
+
|
44 |
+
|
45 |
+
def evaluate_bartbert(df, device="cuda"):
|
46 |
+
# make a list of the tuples (text, summary)
|
47 |
+
|
48 |
+
# texts = df.text.tolist()
|
49 |
+
texts = df.gold.tolist()
|
50 |
+
summaries = df.summary.tolist()
|
51 |
+
|
52 |
+
scorer = BERTScorer(lang="en", rescale_with_baseline=True, device=device)
|
53 |
+
|
54 |
+
metrics = {'BERTScore': []}
|
55 |
+
for i in range(len(texts)):
|
56 |
+
texts[i] = texts[i].replace("\n", " ")
|
57 |
+
summaries[i] = summaries[i].replace("\n", " ")
|
58 |
+
|
59 |
+
P, R, F1 = scorer.score([summaries[i]], [texts[i]])
|
60 |
+
|
61 |
+
metrics['BERTScore'].append(F1.mean().item())
|
62 |
+
|
63 |
+
# compute the mean of the metrics
|
64 |
+
# metrics = {k: sum(v) / len(v) for k, v in metrics.items()}
|
65 |
+
|
66 |
+
return metrics
|
67 |
+
|
68 |
+
|
69 |
+
def main():
|
70 |
+
args = parse_args()
|
71 |
+
|
72 |
+
path = args.summaries
|
73 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
74 |
+
|
75 |
+
# load the model
|
76 |
+
df = parse_summaries(args.summaries)
|
77 |
+
|
78 |
+
metrics = evaluate_bartbert(df)
|
79 |
+
|
80 |
+
# make a dataframe with the metric
|
81 |
+
df = pd.DataFrame(metrics)
|
82 |
+
|
83 |
+
# Add the model name in the metrics names
|
84 |
+
df = df.add_prefix(f"common/")
|
85 |
+
|
86 |
+
# save the dataframe
|
87 |
+
|
88 |
+
# check if exists already, if it does load it and add the new columns
|
89 |
+
|
90 |
+
print(df)
|
91 |
+
|
92 |
+
if path.exists():
|
93 |
+
df_old = pd.read_csv(path, index_col=0)
|
94 |
+
|
95 |
+
# create the colums if they do not exist
|
96 |
+
for col in df.columns:
|
97 |
+
if col not in df_old.columns:
|
98 |
+
df_old[col] = float("nan")
|
99 |
+
|
100 |
+
# add entry to the dataframe
|
101 |
+
for col in df.columns:
|
102 |
+
df_old[col] = df[col]
|
103 |
+
|
104 |
+
df = df_old
|
105 |
+
|
106 |
+
df.to_csv(path)
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
main()
|
glimpse-ui/glimpse/glimpse/evaluate/evaluate_common_metrics_samples.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from rouge_score import rouge_scorer
|
6 |
+
|
7 |
+
|
8 |
+
def sanitize_model_name(model_name: str) -> str:
|
9 |
+
"""
|
10 |
+
Sanitize the model name to be used as a folder name.
|
11 |
+
@param model_name: The model name
|
12 |
+
@return: The sanitized model name
|
13 |
+
"""
|
14 |
+
return model_name.replace("/", "_")
|
15 |
+
|
16 |
+
|
17 |
+
# logging.basicConfig(stream=stdout, level=logging.)
|
18 |
+
def parse_args():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument("--summaries", type=Path, default="")
|
21 |
+
|
22 |
+
args = parser.parse_args()
|
23 |
+
return args
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
def parse_summaries(path: Path):
|
28 |
+
"""
|
29 |
+
:return: a pandas dataframe with at least the columns 'text' and 'summary'
|
30 |
+
"""
|
31 |
+
# read csv file
|
32 |
+
|
33 |
+
df = pd.read_csv(path).dropna()
|
34 |
+
|
35 |
+
# check if the csv file has the correct columns
|
36 |
+
if not all([col in df.columns for col in ["gold", "summary"]]):
|
37 |
+
raise ValueError("The csv file must have the columns 'text' and 'summary'.")
|
38 |
+
|
39 |
+
return df
|
40 |
+
|
41 |
+
|
42 |
+
def evaluate_rouge(
|
43 |
+
df,
|
44 |
+
):
|
45 |
+
# make a list of the tuples (text, summary)
|
46 |
+
|
47 |
+
texts = df.gold.tolist()
|
48 |
+
summaries = df.summary.tolist()
|
49 |
+
|
50 |
+
# rouges
|
51 |
+
metrics = {"rouge1": [], "rouge2": [], "rougeL": [], "rougeLsum": []}
|
52 |
+
|
53 |
+
rouges = rouge_scorer.RougeScorer(
|
54 |
+
["rouge1", "rouge2", "rougeL", "rougeLsum"], use_stemmer=True
|
55 |
+
)
|
56 |
+
|
57 |
+
metrics["rouge1"].extend(
|
58 |
+
[
|
59 |
+
rouges.score(summary, text)["rouge1"].fmeasure
|
60 |
+
for summary, text in zip(summaries, texts)
|
61 |
+
]
|
62 |
+
)
|
63 |
+
metrics["rouge2"].extend(
|
64 |
+
[
|
65 |
+
rouges.score(summary, text)["rouge2"].fmeasure
|
66 |
+
for summary, text in zip(summaries, texts)
|
67 |
+
]
|
68 |
+
)
|
69 |
+
metrics["rougeL"].extend(
|
70 |
+
[
|
71 |
+
rouges.score(summary, text)["rougeL"].fmeasure
|
72 |
+
for summary, text in zip(summaries, texts)
|
73 |
+
]
|
74 |
+
)
|
75 |
+
metrics["rougeLsum"].extend(
|
76 |
+
[
|
77 |
+
rouges.score(summary, text)["rougeLsum"].fmeasure
|
78 |
+
for summary, text in zip(summaries, texts)
|
79 |
+
]
|
80 |
+
)
|
81 |
+
|
82 |
+
# compute the mean of the metrics
|
83 |
+
# metrics = {k: sum(v) / len(v) for k, v in metrics.items()}
|
84 |
+
|
85 |
+
return metrics
|
86 |
+
|
87 |
+
|
88 |
+
def main():
|
89 |
+
args = parse_args()
|
90 |
+
|
91 |
+
# load the model
|
92 |
+
df = parse_summaries(args.summaries)
|
93 |
+
|
94 |
+
metrics = evaluate_rouge(df)
|
95 |
+
|
96 |
+
|
97 |
+
# # add index to the metrics
|
98 |
+
# metrics["index"] = [i for i in range(len(df))]
|
99 |
+
|
100 |
+
df = pd.DataFrame.from_dict(metrics)
|
101 |
+
df = df.add_prefix(f"common/")
|
102 |
+
|
103 |
+
# merge the metrics with the summaries
|
104 |
+
|
105 |
+
if args.summaries.exists():
|
106 |
+
df_old = parse_summaries(args.summaries)
|
107 |
+
|
108 |
+
for col in df.columns:
|
109 |
+
if col not in df_old.columns:
|
110 |
+
df_old[col] = float("nan")
|
111 |
+
|
112 |
+
# add entry to the dataframe
|
113 |
+
for col in df.columns:
|
114 |
+
df_old[col] = df[col]
|
115 |
+
|
116 |
+
df = df_old
|
117 |
+
|
118 |
+
df.to_csv(args.summaries, index=False)
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
main()
|
glimpse-ui/glimpse/glimpse/evaluate/evaluate_seahorse_metrics_samples.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.utils.data
|
8 |
+
from tqdm import tqdm
|
9 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
10 |
+
|
11 |
+
map_questionnumber_to_question = {
|
12 |
+
"question1": "SHMetric/Comprehensible",
|
13 |
+
"question2": "SHMetric/Repetition",
|
14 |
+
"question3": "SHMetric/Grammar",
|
15 |
+
"question4": "SHMetric/Attribution",
|
16 |
+
"question5": "SHMetric/Main ideas",
|
17 |
+
"question6": "SHMetric/Conciseness",
|
18 |
+
}
|
19 |
+
|
20 |
+
def sanitize_model_name(model_name: str) -> str:
|
21 |
+
"""
|
22 |
+
Sanitize the model name to be used as a folder name.
|
23 |
+
@param model_name: The model name
|
24 |
+
@return: The sanitized model name
|
25 |
+
"""
|
26 |
+
return model_name.replace("/", "_")
|
27 |
+
|
28 |
+
# logging.basicConfig(stream=stdout, level=logging.)
|
29 |
+
def parse_args():
|
30 |
+
parser = argparse.ArgumentParser()
|
31 |
+
parser.add_argument(
|
32 |
+
"--question",
|
33 |
+
type=str,
|
34 |
+
default="repetition",
|
35 |
+
)
|
36 |
+
parser.add_argument("--summaries", type=Path, default="")
|
37 |
+
parser.add_argument("--select", type=str, default="*")
|
38 |
+
parser.add_argument("--batch_size", type=int, default=16)
|
39 |
+
parser.add_argument("--device", type=str, default="cuda")
|
40 |
+
|
41 |
+
|
42 |
+
args = parser.parse_args()
|
43 |
+
return args
|
44 |
+
|
45 |
+
def parse_summaries(path: Path):
|
46 |
+
"""
|
47 |
+
:return: a pandas dataframe with at least the columns 'text' and 'summary'
|
48 |
+
"""
|
49 |
+
# read csv file
|
50 |
+
|
51 |
+
df = pd.read_csv(path).dropna()
|
52 |
+
|
53 |
+
# check if the csv file has the correct columns
|
54 |
+
if not all([col in df.columns for col in ["text", "summary"]]):
|
55 |
+
raise ValueError("The csv file must have the columns 'text' and 'summary'.")
|
56 |
+
|
57 |
+
return df
|
58 |
+
|
59 |
+
|
60 |
+
def evaluate_classification_task(model, tokenizer, question, df, batch_size):
|
61 |
+
|
62 |
+
texts = df.text.tolist()
|
63 |
+
summaries = df.summary.tolist()
|
64 |
+
|
65 |
+
template = "premise: {premise} hypothesis: {hypothesis}"
|
66 |
+
ds = [template.format(premise=text[:20*1024], hypothesis=summary) for text, summary in zip(texts, summaries)]
|
67 |
+
|
68 |
+
|
69 |
+
eval_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size)
|
70 |
+
|
71 |
+
metrics = {f"{question}/proba_1": [], f"{question}/proba_0": [], f"{question}/guess": []}
|
72 |
+
|
73 |
+
with torch.no_grad():
|
74 |
+
for batch in tqdm(eval_loader):
|
75 |
+
# tokenize the batch
|
76 |
+
inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt")
|
77 |
+
# move the inputs to the device
|
78 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
79 |
+
|
80 |
+
N_inputs = inputs["input_ids"].shape[0]
|
81 |
+
# make decoder inputs to be <pad>
|
82 |
+
decoder_input_ids = torch.full((N_inputs, 1), tokenizer.pad_token_id, dtype=torch.long, device=model.device)
|
83 |
+
|
84 |
+
outputs = model(**inputs, decoder_input_ids=decoder_input_ids)
|
85 |
+
logits = outputs.logits
|
86 |
+
# retrieve logits for the last token and the scores for 0 and 1
|
87 |
+
logits = logits[:, -1, [497, 333]]
|
88 |
+
|
89 |
+
# compute the probabilities
|
90 |
+
probs = F.softmax(logits, dim=-1)
|
91 |
+
|
92 |
+
# compute the guess
|
93 |
+
guess = probs.argmax(dim=-1)
|
94 |
+
|
95 |
+
# append the metrics
|
96 |
+
metrics[f"{question}/proba_1"].extend(probs[:, 1].tolist())
|
97 |
+
metrics[f"{question}/proba_0"].extend(probs[:, 0].tolist())
|
98 |
+
metrics[f"{question}/guess"].extend(guess.tolist())
|
99 |
+
|
100 |
+
# average the metrics
|
101 |
+
|
102 |
+
# metrics = {k: sum(v) / len(v) for k, v in metrics.items()}
|
103 |
+
|
104 |
+
return metrics
|
105 |
+
|
106 |
+
def main():
|
107 |
+
args = parse_args()
|
108 |
+
|
109 |
+
model_name = f"google/seahorse-large-q{args.question}"
|
110 |
+
question = map_questionnumber_to_question[f"question{args.question}"]
|
111 |
+
|
112 |
+
# load the model
|
113 |
+
# load in float16 to save memory
|
114 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.float16)
|
115 |
+
|
116 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
117 |
+
|
118 |
+
df = parse_summaries(args.summaries)
|
119 |
+
|
120 |
+
metrics = evaluate_classification_task(model, tokenizer, question, df, args.batch_size)
|
121 |
+
|
122 |
+
# make a dataframe with the metric
|
123 |
+
df_metrics = pd.DataFrame(metrics)
|
124 |
+
|
125 |
+
# merge the metrics with the summaries
|
126 |
+
df = parse_summaries(args.summaries)
|
127 |
+
df = pd.concat([df, df_metrics], axis=1)
|
128 |
+
|
129 |
+
path = Path(args.summaries)
|
130 |
+
|
131 |
+
if path.exists():
|
132 |
+
df_old = pd.read_csv(path, index_col=0)
|
133 |
+
|
134 |
+
# create the colums if they do not exist
|
135 |
+
for col in df.columns:
|
136 |
+
if col not in df_old.columns:
|
137 |
+
df_old[col] = float("nan")
|
138 |
+
|
139 |
+
# add entry to the dataframe
|
140 |
+
for col in df.columns:
|
141 |
+
df_old[col] = df[col]
|
142 |
+
|
143 |
+
df = df_old
|
144 |
+
|
145 |
+
# save the dataframe
|
146 |
+
df.to_csv(args.summaries)
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
main()
|
glimpse-ui/glimpse/glimpse/src/beam_rsa_decoding.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
from datasets import Dataset
|
7 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
8 |
+
|
9 |
+
from rsasumm.beam_search import RSAContextualDecoding
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
GENERATION_CONFIGS = {
|
13 |
+
"top_p_sampling": {
|
14 |
+
"max_new_tokens": 200,
|
15 |
+
"do_sample": True,
|
16 |
+
"top_p": 0.95,
|
17 |
+
"temperature": 1.0,
|
18 |
+
"num_return_sequences": 8,
|
19 |
+
"num_beams": 1,
|
20 |
+
# "num_beam_groups" : 4,
|
21 |
+
},
|
22 |
+
**{
|
23 |
+
f"sampling_topp_{str(topp).replace('.', '')}": {
|
24 |
+
"max_new_tokens": 200,
|
25 |
+
"do_sample": True,
|
26 |
+
"num_return_sequences": 8,
|
27 |
+
"top_p": 0.95,
|
28 |
+
}
|
29 |
+
for topp in [0.5, 0.8, 0.95, 0.99]
|
30 |
+
},
|
31 |
+
}
|
32 |
+
|
33 |
+
# add base.csv config to all configs
|
34 |
+
for key, value in GENERATION_CONFIGS.items():
|
35 |
+
GENERATION_CONFIGS[key] = {
|
36 |
+
# "max_length": 2048,
|
37 |
+
"min_length": 0,
|
38 |
+
"early_stopping": True,
|
39 |
+
**value,
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def parse_args():
|
44 |
+
parser = argparse.ArgumentParser()
|
45 |
+
parser.add_argument("--model_name", type=str, default="facebook/bart-large-cnn")
|
46 |
+
parser.add_argument("--dataset_name", type=str, default="amazon")
|
47 |
+
parser.add_argument("--dataset_path", type=str, default=None)
|
48 |
+
parser.add_argument(
|
49 |
+
"--decoding_config",
|
50 |
+
type=str,
|
51 |
+
default="top_p_sampling",
|
52 |
+
choices=GENERATION_CONFIGS.keys(),
|
53 |
+
)
|
54 |
+
|
55 |
+
parser.add_argument("--batch_size", type=int, default=16)
|
56 |
+
parser.add_argument("--device", type=str, default="cuda")
|
57 |
+
|
58 |
+
parser.add_argument("--output_dir", type=str, default="output")
|
59 |
+
|
60 |
+
# limit the number of samples to generate
|
61 |
+
parser.add_argument("--limit", type=int, default=None)
|
62 |
+
|
63 |
+
args = parser.parse_args()
|
64 |
+
|
65 |
+
return args
|
66 |
+
|
67 |
+
|
68 |
+
def prepare_dataset(dataset_name, dataset_path=None) -> Dataset:
|
69 |
+
dataset_path = Path(dataset_path)
|
70 |
+
if dataset_name == "amazon":
|
71 |
+
dataset = pd.read_csv(dataset_path / "amazon_test.csv")
|
72 |
+
elif dataset_name == "space":
|
73 |
+
dataset = pd.read_csv(dataset_path / "space.csv")
|
74 |
+
elif dataset_name == "yelp":
|
75 |
+
dataset = pd.read_csv(dataset_path / "yelp_test.csv")
|
76 |
+
elif dataset_name == "reviews":
|
77 |
+
dataset = pd.read_csv(dataset_path / "test_metareviews.csv")
|
78 |
+
elif dataset_name == "multi_news":
|
79 |
+
dataset = pd.read_csv(dataset_path / "multi_news.csv")
|
80 |
+
else:
|
81 |
+
raise ValueError(f"Unknown dataset {dataset_name}")
|
82 |
+
|
83 |
+
# make a dataset from the dataframe
|
84 |
+
dataset = Dataset.from_pandas(dataset)
|
85 |
+
|
86 |
+
return dataset
|
87 |
+
|
88 |
+
|
89 |
+
def evaluate_summarizer(model, tokenizer, dataset: Dataset, decoding_config) -> Dataset:
|
90 |
+
"""
|
91 |
+
@param model: The model used to generate the summaries
|
92 |
+
@param tokenizer: The tokenizer used to tokenize the text and the summary
|
93 |
+
@param dataset: A dataset with the text
|
94 |
+
@param decoding_config: Dictoionary with the decoding config
|
95 |
+
@return: The same dataset with the summaries added
|
96 |
+
"""
|
97 |
+
|
98 |
+
rsa = RSAContextualDecoding(model, tokenizer, device=model.device)
|
99 |
+
|
100 |
+
# generate summaries
|
101 |
+
summaries = []
|
102 |
+
|
103 |
+
print("Generating summaries...")
|
104 |
+
|
105 |
+
for id, batch in tqdm(dataset.to_pandas().groupby("id")):
|
106 |
+
text = batch["text"].tolist()
|
107 |
+
|
108 |
+
inputs = tokenizer(
|
109 |
+
text,
|
110 |
+
max_length=1024,
|
111 |
+
padding="max_length",
|
112 |
+
truncation=True,
|
113 |
+
return_tensors="pt",
|
114 |
+
)
|
115 |
+
batch_size = inputs["input_ids"].shape[0]
|
116 |
+
|
117 |
+
for k in tqdm(range(len(text))):
|
118 |
+
# move inputs to device
|
119 |
+
inputs = {key: value.to("cuda") for key, value in inputs.items()}
|
120 |
+
|
121 |
+
output = rsa.generate(
|
122 |
+
target_id=k,
|
123 |
+
source_texts_ids=inputs["input_ids"],
|
124 |
+
source_text_attention_mask=inputs["attention_mask"],
|
125 |
+
max_length=50,
|
126 |
+
top_p=0.95,
|
127 |
+
do_sample=True,
|
128 |
+
rationality=8.0,
|
129 |
+
temperature=1.0,
|
130 |
+
process_logits_before_rsa=True,
|
131 |
+
)
|
132 |
+
# output : (batch_size * num_return_sequences, max_length)
|
133 |
+
outputs = output[0]
|
134 |
+
summaries.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
135 |
+
|
136 |
+
# decode summaries
|
137 |
+
|
138 |
+
# add summaries to the huggingface dataset
|
139 |
+
dataset = dataset.add_column("summary", summaries)
|
140 |
+
|
141 |
+
return dataset
|
142 |
+
|
143 |
+
|
144 |
+
def sanitize_model_name(model_name: str) -> str:
|
145 |
+
"""
|
146 |
+
Sanitize the model name to be used as a folder name.
|
147 |
+
@param model_name: The model name
|
148 |
+
@return: The sanitized model name
|
149 |
+
"""
|
150 |
+
return model_name.replace("/", "_")
|
151 |
+
|
152 |
+
|
153 |
+
def main():
|
154 |
+
args = parse_args()
|
155 |
+
|
156 |
+
# load the model
|
157 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
|
158 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
159 |
+
|
160 |
+
tokenizer.pad_token = tokenizer.unk_token
|
161 |
+
tokenizer.pad_token_id = tokenizer.unk_token_id
|
162 |
+
|
163 |
+
# move model to device
|
164 |
+
model = model.to(args.device)
|
165 |
+
|
166 |
+
# load the dataset
|
167 |
+
print("Loading dataset...")
|
168 |
+
dataset = prepare_dataset(args.dataset_name, args.dataset_path)
|
169 |
+
|
170 |
+
# limit the number of samples
|
171 |
+
if args.limit is not None:
|
172 |
+
_lim = min(args.limit, len(dataset))
|
173 |
+
dataset = dataset.select(range(_lim))
|
174 |
+
|
175 |
+
# generate summaries
|
176 |
+
dataset = evaluate_summarizer(
|
177 |
+
model,
|
178 |
+
tokenizer,
|
179 |
+
dataset,
|
180 |
+
GENERATION_CONFIGS[args.decoding_config],
|
181 |
+
)
|
182 |
+
|
183 |
+
df_dataset = dataset.to_pandas()
|
184 |
+
df_dataset = df_dataset.explode("summary")
|
185 |
+
df_dataset = df_dataset.reset_index()
|
186 |
+
# add an idx with the id of the summary for each example
|
187 |
+
# df_dataset["id_candidate"] = df_dataset.groupby(["index"]).cumcount()
|
188 |
+
|
189 |
+
# save the dataset
|
190 |
+
# add unique date in name
|
191 |
+
now = datetime.datetime.now()
|
192 |
+
date = now.strftime("%Y-%m-%d-%H-%M-%S")
|
193 |
+
model_name = sanitize_model_name(args.model_name)
|
194 |
+
output_path = (
|
195 |
+
Path(args.output_dir)
|
196 |
+
/ f"{model_name}-_-{args.dataset_name}-_-{args.decoding_config}-_-{date}.csv"
|
197 |
+
)
|
198 |
+
|
199 |
+
# create output dir if it doesn't exist
|
200 |
+
if not output_path.parent.exists():
|
201 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
202 |
+
|
203 |
+
df_dataset.to_csv(output_path, index=False, encoding="utf-8")
|
204 |
+
|
205 |
+
|
206 |
+
if __name__ == "__main__":
|
207 |
+
main()
|
glimpse-ui/glimpse/glimpse/src/compute_rsa.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PegasusTokenizer
|
5 |
+
import argparse
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from pickle import dump
|
9 |
+
|
10 |
+
import sys, os.path
|
11 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
12 |
+
|
13 |
+
from rsasumm.rsa_reranker import RSAReranking
|
14 |
+
|
15 |
+
|
16 |
+
DESC = """
|
17 |
+
Compute the RSA matrices for all the set of multi-document samples and dump these along with additional information in a pickle file.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def parse_args():
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument("--model_name", type=str, default="facebook/bart-large-cnn")
|
23 |
+
parser.add_argument("--summaries", type=Path, default="glimpse/data/candidates/extractive_sentences-_-all_reviews_2017-_-none-_-2025-05-20-20-22-18.csv")
|
24 |
+
parser.add_argument("--output_dir", type=str, default="glimpse/output")
|
25 |
+
|
26 |
+
parser.add_argument("--filter", type=str, default=None)
|
27 |
+
|
28 |
+
# if ran in a scripted way, the output path will be printed
|
29 |
+
parser.add_argument("--scripted-run", action=argparse.BooleanOptionalAction, default=False)
|
30 |
+
|
31 |
+
parser.add_argument("--device", type=str, default="cuda")
|
32 |
+
|
33 |
+
return parser.parse_args()
|
34 |
+
|
35 |
+
|
36 |
+
def parse_summaries(path: Path) -> pd.DataFrame:
|
37 |
+
|
38 |
+
try:
|
39 |
+
summaries = pd.read_csv(path)
|
40 |
+
except:
|
41 |
+
raise ValueError(f"Unknown dataset {path}")
|
42 |
+
|
43 |
+
# check if the dataframe has the right columns
|
44 |
+
if not all(
|
45 |
+
col in summaries.columns for col in ["index", "id", "text", "gold", "summary", "id_candidate"]
|
46 |
+
):
|
47 |
+
raise ValueError(
|
48 |
+
"The dataframe must have columns ['index', 'id', 'text', 'gold', 'summary', 'id_candidate']"
|
49 |
+
)
|
50 |
+
|
51 |
+
return summaries
|
52 |
+
|
53 |
+
|
54 |
+
def compute_rsa(summaries: pd.DataFrame, model, tokenizer, device):
|
55 |
+
results = []
|
56 |
+
for name, group in tqdm(summaries.groupby(["id"])):
|
57 |
+
rsa_reranker = RSAReranking(
|
58 |
+
model,
|
59 |
+
tokenizer,
|
60 |
+
device=device,
|
61 |
+
candidates=group.summary.unique().tolist(),
|
62 |
+
source_texts=group.text.unique().tolist(),
|
63 |
+
rationality=1,
|
64 |
+
)
|
65 |
+
(
|
66 |
+
best_rsa,
|
67 |
+
best_base,
|
68 |
+
speaker_df,
|
69 |
+
listener_df,
|
70 |
+
initial_listener,
|
71 |
+
language_model_proba_df,
|
72 |
+
initial_consensuality_scores,
|
73 |
+
consensuality_scores,
|
74 |
+
) = rsa_reranker.rerank(t=1)
|
75 |
+
|
76 |
+
gold = group['gold'].tolist()[0]
|
77 |
+
|
78 |
+
results.append(
|
79 |
+
{
|
80 |
+
"id": name,
|
81 |
+
"best_rsa": best_rsa, # best speaker score
|
82 |
+
"best_base": best_base, # naive baseline
|
83 |
+
"speaker_df": speaker_df, # all speaker results
|
84 |
+
"listener_df": listener_df, # all listener results (chances of guessing correctly)
|
85 |
+
"initial_listener": initial_listener,
|
86 |
+
"language_model_proba_df": language_model_proba_df,
|
87 |
+
"initial_consensuality_scores": initial_consensuality_scores,
|
88 |
+
"consensuality_scores": consensuality_scores, # consensuality scores
|
89 |
+
"gold": gold,
|
90 |
+
"rationality": 1, # hyperparameter
|
91 |
+
"text_candidates" : group
|
92 |
+
}
|
93 |
+
)
|
94 |
+
|
95 |
+
return results
|
96 |
+
|
97 |
+
|
98 |
+
def main():
|
99 |
+
args = parse_args()
|
100 |
+
|
101 |
+
if args.filter is not None:
|
102 |
+
if args.filter not in args.summaries.stem:
|
103 |
+
return
|
104 |
+
|
105 |
+
# load the model and the tokenizer
|
106 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
|
107 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
108 |
+
|
109 |
+
model = model.to(args.device)
|
110 |
+
|
111 |
+
# load the summaries
|
112 |
+
summaries = parse_summaries(args.summaries)
|
113 |
+
|
114 |
+
# rerank the summaries
|
115 |
+
results = compute_rsa(summaries, model, tokenizer, args.device)
|
116 |
+
results = {"results": results}
|
117 |
+
|
118 |
+
results["metadata/reranking_model"] = args.model_name
|
119 |
+
results["metadata/rsa_iterations"] = 1
|
120 |
+
|
121 |
+
# save the summaries
|
122 |
+
# make the output directory if it does not exist
|
123 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
124 |
+
output_path = Path(args.output_dir) / f"{args.summaries.stem}-_-r3-_-rsa_reranked-{args.model_name.replace('/', '-')}.pk"
|
125 |
+
output_path_base = (
|
126 |
+
Path(args.output_dir) / f"{args.summaries.stem}-_-base_reranked.pk"
|
127 |
+
)
|
128 |
+
|
129 |
+
with open(output_path, "wb") as f:
|
130 |
+
dump(results, f)
|
131 |
+
|
132 |
+
# in case of scripted run, print the output path
|
133 |
+
if args.scripted_run: print(output_path)
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
main()
|
glimpse-ui/glimpse/glimpse/src/rsa_merge_into_single.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
from datasets import Dataset
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def parse_args():
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument("--summaries", type=Path)
|
13 |
+
|
14 |
+
parser.add_argument("--output_dir", type=str, default="output")
|
15 |
+
|
16 |
+
# limit the number of samples to generate
|
17 |
+
parser.add_argument("--limit", type=int, default=None)
|
18 |
+
|
19 |
+
args = parser.parse_args()
|
20 |
+
|
21 |
+
return args
|
22 |
+
|
23 |
+
|
24 |
+
def main():
|
25 |
+
args = parse_args()
|
26 |
+
|
27 |
+
path = Path(args.summaries)
|
28 |
+
|
29 |
+
for file in path.glob("*.csv"):
|
30 |
+
model_name, dataset, decoding_config, date, reranking_type = file.stem.split('-_-')
|
31 |
+
df = pd.read_csv(file)
|
32 |
+
df = df.drop(["Unnamed: 0.1", "Unnamed: 0", ], axis=1)
|
33 |
+
|
34 |
+
# df = df[['id', 'id_text', 'text', 'summary', 'gold']]
|
35 |
+
|
36 |
+
|
37 |
+
merged_summaries = df.groupby("id").agg({"summary": " ".join}).reset_index()
|
38 |
+
|
39 |
+
# add gold and text
|
40 |
+
|
41 |
+
merged_summaries = merged_summaries.merge(df[["id", "gold", "text"]], on="id")
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
main()
|
glimpse-ui/glimpse/glimpse/src/rsa_reranking.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
5 |
+
import argparse
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
from rsasumm.rsa_reranker import RSAReranking
|
10 |
+
|
11 |
+
def parse_args():
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument("--model_name", type=str, default="facebook/bart-large-cnn")
|
14 |
+
parser.add_argument("--summaries", type=Path, default="")
|
15 |
+
parser.add_argument("--output_dir", type=str, default="output")
|
16 |
+
|
17 |
+
parser.add_argument("--filter", type=str, default=None)
|
18 |
+
|
19 |
+
parser.add_argument("--device", type=str, default="cuda")
|
20 |
+
|
21 |
+
return parser.parse_args()
|
22 |
+
|
23 |
+
|
24 |
+
def parse_summaries(path : Path) -> pd.DataFrame:
|
25 |
+
summaries = pd.read_csv(path)
|
26 |
+
|
27 |
+
# check if the dataframe has the right columns
|
28 |
+
if not all(col in summaries.columns for col in ["id", "text", "id_candidate", "summary"]):
|
29 |
+
raise ValueError("The dataframe must have columns ['id', 'text', 'id_candidate', 'summary']")
|
30 |
+
|
31 |
+
return summaries
|
32 |
+
|
33 |
+
def reranking_rsa(summaries : pd.DataFrame, model, tokenizer, device):
|
34 |
+
|
35 |
+
best_summaries = []
|
36 |
+
best_bases = []
|
37 |
+
for name, group in tqdm(summaries.groupby(["id"])):
|
38 |
+
rsa_reranker = RSAReranking(model, tokenizer, device, group.summary.unique().tolist(), group.text.unique().tolist())
|
39 |
+
best_rsa, best_base, speaker_df, listener_df, initial_listener, language_model_proba_df = rsa_reranker.rerank(t=3)
|
40 |
+
|
41 |
+
group = group.set_index("summary")
|
42 |
+
group_lines = group.loc[best_rsa]
|
43 |
+
group_lines['speaker_proba'] = 0
|
44 |
+
group_lines['listener_proba'] = 0
|
45 |
+
group_lines['language_model_proba'] = 0
|
46 |
+
group_lines['initial_listener_proba'] = 0
|
47 |
+
|
48 |
+
group_lines = group_lines.reset_index()
|
49 |
+
|
50 |
+
for i, (idx, line) in enumerate(group_lines.iterrows()):
|
51 |
+
summary = line['summary']
|
52 |
+
text = line['text']
|
53 |
+
|
54 |
+
group_lines['speaker_proba'].loc[i] = speaker_df.loc[text, summary]
|
55 |
+
group_lines['listener_proba'].loc[i] = listener_df.loc[text, summary]
|
56 |
+
group_lines['language_model_proba'].loc[i] = language_model_proba_df.loc[text, summary]
|
57 |
+
group_lines['initial_listener_proba'].loc[i] = initial_listener.loc[text, summary]
|
58 |
+
|
59 |
+
|
60 |
+
group_lines["id"] = name
|
61 |
+
best_summaries.append(group_lines)
|
62 |
+
|
63 |
+
best_base_lines = group.loc[best_base]
|
64 |
+
best_base_lines = best_base_lines.reset_index()
|
65 |
+
|
66 |
+
best_base_lines['speaker_proba'] = 0
|
67 |
+
best_base_lines['listener_proba'] = 0
|
68 |
+
best_base_lines['language_model_proba'] = 0
|
69 |
+
best_base_lines['initial_listener_proba'] = 0
|
70 |
+
|
71 |
+
for i, (idx, line) in enumerate(best_base_lines.iterrows()):
|
72 |
+
summary = line['summary']
|
73 |
+
text = line['text']
|
74 |
+
|
75 |
+
best_base_lines['speaker_proba'].loc[i] = speaker_df.loc[text, summary]
|
76 |
+
best_base_lines['listener_proba'].loc[i] = listener_df.loc[text, summary]
|
77 |
+
best_base_lines['language_model_proba'].loc[i] = language_model_proba_df.loc[text, summary]
|
78 |
+
best_base_lines['initial_listener_proba'].loc[i] = initial_listener.loc[text, summary]
|
79 |
+
|
80 |
+
|
81 |
+
best_base_lines["id"] = name
|
82 |
+
best_bases.append(best_base_lines)
|
83 |
+
|
84 |
+
best_summaries = pd.concat(best_summaries)
|
85 |
+
best_bases = pd.concat(best_bases)
|
86 |
+
|
87 |
+
return best_summaries, best_bases
|
88 |
+
|
89 |
+
|
90 |
+
def main():
|
91 |
+
args = parse_args()
|
92 |
+
|
93 |
+
if args.filter is not None:
|
94 |
+
if args.filter not in args.summaries.stem:
|
95 |
+
return
|
96 |
+
|
97 |
+
# load the model and the tokenizer
|
98 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
100 |
+
|
101 |
+
model = model.to(args.device)
|
102 |
+
|
103 |
+
# load the summaries
|
104 |
+
summaries = parse_summaries(args.summaries)
|
105 |
+
|
106 |
+
# rerank the summaries
|
107 |
+
best_summaries, bast_base = reranking_rsa(summaries, model, tokenizer, device=args.device)
|
108 |
+
|
109 |
+
best_summaries['metadata/reranking_model'] = args.model_name
|
110 |
+
best_summaries['metadata/rsa_iterations'] = 3
|
111 |
+
|
112 |
+
bast_base['metadata/reranking_model'] = args.model_name
|
113 |
+
bast_base['metadata/rsa_iterations'] = 3
|
114 |
+
|
115 |
+
|
116 |
+
# save the summaries
|
117 |
+
# make the output directory if it does not exist
|
118 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
119 |
+
output_path = Path(args.output_dir) / f"{args.summaries.stem}-_-rsa_reranked.csv"
|
120 |
+
output_path_base = Path(args.output_dir) / f"{args.summaries.stem}-_-base_reranked.csv"
|
121 |
+
|
122 |
+
best_summaries.to_csv(output_path)
|
123 |
+
bast_base.to_csv(output_path_base)
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
main()
|
glimpse-ui/glimpse/mds/Single summaries expes.ipynb
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "64c5b118-5a32-4220-89f2-4e3ccd7a28d2",
|
7 |
+
"metadata": {
|
8 |
+
"tags": []
|
9 |
+
},
|
10 |
+
"outputs": [],
|
11 |
+
"source": [
|
12 |
+
"import matplotlib as mpl\n",
|
13 |
+
"# Use the pgf backend (must be set before pyplot imported)\n",
|
14 |
+
"mpl.use('pgf')\n",
|
15 |
+
"\n",
|
16 |
+
"import pandas as pd\n",
|
17 |
+
"import numpy as np\n",
|
18 |
+
"import matplotlib.pyplot as plt\n",
|
19 |
+
"import seaborn as sns\n",
|
20 |
+
"import re\n",
|
21 |
+
"from pathlib import Path"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"outputs": [],
|
27 |
+
"source": [],
|
28 |
+
"metadata": {
|
29 |
+
"collapsed": false
|
30 |
+
},
|
31 |
+
"id": "7893116c24574642"
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"outputs": [],
|
36 |
+
"source": [
|
37 |
+
"# use pgf backend\n",
|
38 |
+
"plt.style.use('seaborn-paper')\n"
|
39 |
+
],
|
40 |
+
"metadata": {
|
41 |
+
"collapsed": false
|
42 |
+
},
|
43 |
+
"id": "f3dc93e0b2eb9894"
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": null,
|
48 |
+
"id": "3806928d-0624-4d9f-905f-3bf41b9725f1",
|
49 |
+
"metadata": {
|
50 |
+
"tags": []
|
51 |
+
},
|
52 |
+
"outputs": [],
|
53 |
+
"source": [
|
54 |
+
"sumy_individual_path = Path('output/summaries/sumy_individual/')\n",
|
55 |
+
"ours_individual_path = Path('output/summaries/methods_reviews_individual/')\n",
|
56 |
+
"\n",
|
57 |
+
"TABLE_PATH = Path(\"../../../EMIRR/papers/rsa_multi_document/tables/\")\n",
|
58 |
+
"FIGURE = Path(\"../../../EMIRR/papers/rsa_multi_document/figures/\")\n",
|
59 |
+
"\n",
|
60 |
+
"# make sure the folder exists\n",
|
61 |
+
"TABLE_PATH.mkdir(parents=True, exist_ok=True)\n",
|
62 |
+
"FIGURE.mkdir(parents=True, exist_ok=True)"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": null,
|
68 |
+
"id": "141a8192-773a-40af-a891-620a6ab81efd",
|
69 |
+
"metadata": {
|
70 |
+
"tags": []
|
71 |
+
},
|
72 |
+
"outputs": [],
|
73 |
+
"source": [
|
74 |
+
"\n",
|
75 |
+
"dfs = []\n",
|
76 |
+
"for file in sumy_individual_path.glob('*.csv'):\n",
|
77 |
+
" df = pd.read_csv(file)\n",
|
78 |
+
" method = file.stem.split('-_-')[1]\n",
|
79 |
+
" \n",
|
80 |
+
" sumy = file.stem.split('-_-')[-1].split('_')\n",
|
81 |
+
" if len(sumy) > 1:\n",
|
82 |
+
" sentence_count = int(sumy[-1])\n",
|
83 |
+
" df['metadata/sentence_count'] = sentence_count\n",
|
84 |
+
"\n",
|
85 |
+
" # df['Method'] = method\n",
|
86 |
+
" dfs.append(df)\n",
|
87 |
+
" \n",
|
88 |
+
" \n",
|
89 |
+
"for file in ours_individual_path.glob('*.csv'):\n",
|
90 |
+
" generation_method, dataset, generation_params, date, rsa_param, rsa_ranking_model, method = file.stem.split('-_-')\n",
|
91 |
+
" \n",
|
92 |
+
" method, n = \"_\".join( method.split('_')[:-1]), method.split('_')[-1]\n",
|
93 |
+
" \n",
|
94 |
+
" if \"metadata/method\" not in df.columns:\n",
|
95 |
+
" df['metadata/method'] = method\n",
|
96 |
+
" \n",
|
97 |
+
"# reranking_model = rsa_ranking_model[len(\"rsa_reranked-\"):]\n",
|
98 |
+
" \n",
|
99 |
+
"# df['Ranking Model'] = reranking_model\n",
|
100 |
+
"# df['Method'] = method\n",
|
101 |
+
"# df['N'] = int(n) if n != \"based\" else 3 \n",
|
102 |
+
" df['Generation Method'] = generation_method\n",
|
103 |
+
" \n",
|
104 |
+
" df = pd.read_csv(file)\n",
|
105 |
+
" dfs.append(df)\n",
|
106 |
+
" \n",
|
107 |
+
"df = pd.concat(dfs)\n",
|
108 |
+
"del dfs\n",
|
109 |
+
"\n",
|
110 |
+
"df = df.drop([c for c in df.columns if \"Unnamed\" in c], axis=1)\n",
|
111 |
+
"\n"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"cell_type": "code",
|
116 |
+
"outputs": [],
|
117 |
+
"source": [
|
118 |
+
"\n",
|
119 |
+
"df['metadata/method'] = df['metadata/method'].fillna('N/A')\n",
|
120 |
+
"df = df[~(df[\"metadata/method\"].str.contains('lead'))]\n",
|
121 |
+
"df = df[~(df[\"metadata/method\"].str.contains('Lead'))]\n",
|
122 |
+
"\n"
|
123 |
+
],
|
124 |
+
"metadata": {
|
125 |
+
"collapsed": false
|
126 |
+
},
|
127 |
+
"id": "804fb4bacf2686d4",
|
128 |
+
"execution_count": null
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"cell_type": "code",
|
132 |
+
"execution_count": null,
|
133 |
+
"id": "419db7d5-b90b-47a3-9c25-f39eff337849",
|
134 |
+
"metadata": {
|
135 |
+
"tags": []
|
136 |
+
},
|
137 |
+
"outputs": [],
|
138 |
+
"source": [
|
139 |
+
"def fix_generation(x):\n",
|
140 |
+
" if x == \"abstractive_sentences\":\n",
|
141 |
+
" return \"extractive_sentences\"\n",
|
142 |
+
" else:\n",
|
143 |
+
" return x\n",
|
144 |
+
"\n",
|
145 |
+
"\n",
|
146 |
+
"df['Generation Method'] = df[\"Generation Method\"].apply(fix_generation)"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"cell_type": "code",
|
151 |
+
"execution_count": null,
|
152 |
+
"id": "4a157db9-e408-46e8-9499-2751e4cbe7e4",
|
153 |
+
"metadata": {
|
154 |
+
"tags": []
|
155 |
+
},
|
156 |
+
"outputs": [],
|
157 |
+
"source": [
|
158 |
+
"df['N'] = (df['metadata/n_sentences'].fillna(0) + df['metadata/sentence_count'].fillna(0)).apply(int)\n",
|
159 |
+
"\n",
|
160 |
+
"def fix_methods(x):\n",
|
161 |
+
"\n",
|
162 |
+
" if \"consensus\" in str(x):\n",
|
163 |
+
" return \"Agreement\"\n",
|
164 |
+
" elif \"rsa\" in str(x):\n",
|
165 |
+
" return \"Speaker+Agreement\"\n",
|
166 |
+
" else:\n",
|
167 |
+
" return x\n",
|
168 |
+
" \n",
|
169 |
+
"df['metadata/method'] = df['metadata/method'].apply(fix_methods)\n",
|
170 |
+
"\n"
|
171 |
+
]
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"cell_type": "code",
|
175 |
+
"execution_count": null,
|
176 |
+
"id": "4b926a30-f4be-4de9-b921-1eac3145e87e",
|
177 |
+
"metadata": {
|
178 |
+
"tags": []
|
179 |
+
},
|
180 |
+
"outputs": [],
|
181 |
+
"source": [
|
182 |
+
"df['metadata/sentence_count'].unique()"
|
183 |
+
]
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"cell_type": "code",
|
187 |
+
"execution_count": null,
|
188 |
+
"id": "0451f550-0522-40e9-a64a-551337ae47f6",
|
189 |
+
"metadata": {
|
190 |
+
"tags": []
|
191 |
+
},
|
192 |
+
"outputs": [],
|
193 |
+
"source": [
|
194 |
+
"\n",
|
195 |
+
" "
|
196 |
+
]
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"cell_type": "code",
|
200 |
+
"execution_count": null,
|
201 |
+
"id": "7d213f27-9b7f-4893-b2c1-251715401db5",
|
202 |
+
"metadata": {
|
203 |
+
"tags": []
|
204 |
+
},
|
205 |
+
"outputs": [],
|
206 |
+
"source": [
|
207 |
+
"\n",
|
208 |
+
"metric= 'SHMetric/Main ideas/proba_1'\n",
|
209 |
+
"\n",
|
210 |
+
"SHMetric = df.columns[df.columns.str.contains('SHMetric') & df.columns.str.contains('proba_1')].tolist()\n",
|
211 |
+
"\n",
|
212 |
+
"toplot = df.copy()\n",
|
213 |
+
"toplot['metadata/reranking_model'] = toplot['metadata/reranking_model'].fillna('N/A')\n",
|
214 |
+
"toplot['Generation Method'] = toplot['Generation Method'].fillna('N/A')\n",
|
215 |
+
"\n",
|
216 |
+
"\n",
|
217 |
+
"toplot = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\", \"metadata/reranking_model\"]).mean()\n",
|
218 |
+
"idx = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\"])[metric].idxmax()\n",
|
219 |
+
"\n",
|
220 |
+
"toplot = toplot.loc[idx].reset_index()\n",
|
221 |
+
"\n",
|
222 |
+
"avg = toplot.groupby([\"metadata/method\"]).agg(['mean', 'std'])\n",
|
223 |
+
"avg = avg[SHMetric]\n",
|
224 |
+
"\n",
|
225 |
+
"display(avg)\n",
|
226 |
+
"\n",
|
227 |
+
"# rename columns Consiness, Main ideas, Repetition\n",
|
228 |
+
"avg.columns = pd.MultiIndex.from_tuples([(f'{c[0].split(\"/\")[1]}', c[1]) for c in avg.columns])\n",
|
229 |
+
"\n",
|
230 |
+
"def map_ours(x):\n",
|
231 |
+
" if \"Agreement\" in x:\n",
|
232 |
+
" return \"Ours\"\n",
|
233 |
+
" else:\n",
|
234 |
+
" return \"Bas.\"\n",
|
235 |
+
"\n",
|
236 |
+
"\n",
|
237 |
+
"avg = avg.groupby([\"metadata/method\"]).mean()\n",
|
238 |
+
"\n",
|
239 |
+
"avg['Ours'] = avg.index.get_level_values(0).map(map_ours)\n",
|
240 |
+
"\n",
|
241 |
+
"\n",
|
242 |
+
"avg = avg.reset_index().rename(columns={'metadata/method': 'Method'})\n",
|
243 |
+
"avg = avg.set_index(['Ours', 'Method'])\n",
|
244 |
+
"avg = avg.sort_index()\n",
|
245 |
+
"\n",
|
246 |
+
"# print avg columns level 0\n",
|
247 |
+
"print(avg.columns.get_level_values(0))\n",
|
248 |
+
"\n",
|
249 |
+
"#Index(['Comprehensible', 'Comprehensible', 'Repetition', 'Repetition',\n",
|
250 |
+
" # 'Grammar', 'Grammar', 'Attribution', 'Attribution', 'Main ideas',\n",
|
251 |
+
" # 'Main ideas', 'Conciseness', 'Conciseness'],\n",
|
252 |
+
" # dtype='object')\n",
|
253 |
+
" \n",
|
254 |
+
"# rename columns with shorter names\n",
|
255 |
+
"avg.columns = pd.MultiIndex.from_tuples([\n",
|
256 |
+
" ('Compr.', 'mean'), ('Compr.', 'std'),\n",
|
257 |
+
" ('Repet.', 'mean'), ('Repet.', 'std'),\n",
|
258 |
+
" ('Gram.', 'mean'), ('Gram.', 'std'),\n",
|
259 |
+
" ('Attr.', 'mean'), ('Attr.', 'std'),\n",
|
260 |
+
" ('M. i.', 'mean'), ('M. i.', 'std'),\n",
|
261 |
+
" ('Conc.', 'mean'), ('Conc.', 'std')\n",
|
262 |
+
"])\n",
|
263 |
+
"\n",
|
264 |
+
"\n",
|
265 |
+
"style = avg.style\n",
|
266 |
+
"style = style.format(\"{:.2f}\")\n",
|
267 |
+
"\n",
|
268 |
+
"# make std column smaller and lighter in latex\n",
|
269 |
+
"idx = pd.IndexSlice\n",
|
270 |
+
"# style = style.set_properties(subset=idx[:, ['std']], **{'font-size': '10pt', 'font-weight': 'lighter'})\n",
|
271 |
+
"\n",
|
272 |
+
"# bold the best value in each mean column\n",
|
273 |
+
"style = style.highlight_max(axis=0, subset=idx[:, idx[:, 'mean']], props=\"bfseries: ;\")\n",
|
274 |
+
"\n",
|
275 |
+
"# make std columns smaller and add +/- sign\n",
|
276 |
+
"style = style.set_properties(**{'color':'[HTML]{A0A1A3}'} ,subset=(idx[:], idx[:, 'std']))\n",
|
277 |
+
"style = style.format(\"±{:.2f}\", subset=(idx[:], idx[:, 'std']))\n",
|
278 |
+
"\n",
|
279 |
+
"# drop level 1 of columns\n",
|
280 |
+
"style = style.hide_columns(level=1)\n",
|
281 |
+
"\n",
|
282 |
+
"# to latex\n",
|
283 |
+
"latex = style.to_latex(clines=\"skip-last;data\", hrules=True, multirow_align=\"l\", environment=\"table*\", caption=\"Estimated human judgment using the SEAHORSE metrics for all baselines and our templated summaries compared against each document independently. M. i. stands for Main ideas, Attr. for Attribution, Gram. for Grammar, Compr. for Comprehensible, Conc. for Conciseness, and Repet. for Repetition. The best value in each column is in bold.\")\n",
|
284 |
+
"display(style)\n",
|
285 |
+
"\n",
|
286 |
+
"# add resize box\n",
|
287 |
+
"latex = latex.replace(\"\\\\begin{tabular}\", \"\\\\resizebox{\\\\textwidth}{!}{\\\\begin{tabular}\")\n",
|
288 |
+
"latex = latex.replace(\"\\\\end{tabular}\", \"\\\\end{tabular}}\")\n",
|
289 |
+
"\n",
|
290 |
+
"\n",
|
291 |
+
"# replace \n",
|
292 |
+
"\n",
|
293 |
+
"# write to file\n",
|
294 |
+
"with open(TABLE_PATH / \"seahorse.tex\", \"w\") as f:\n",
|
295 |
+
" f.write(latex)\n",
|
296 |
+
"\n",
|
297 |
+
"\n",
|
298 |
+
"\n",
|
299 |
+
"\n",
|
300 |
+
"\n",
|
301 |
+
"\n",
|
302 |
+
"# display(avg)\n",
|
303 |
+
"# avg.set_index('Method')\""
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"cell_type": "code",
|
308 |
+
"execution_count": null,
|
309 |
+
"id": "cd4c7707-f1cd-4810-972b-1a69e0ec68ae",
|
310 |
+
"metadata": {
|
311 |
+
"tags": []
|
312 |
+
},
|
313 |
+
"outputs": [],
|
314 |
+
"source": [
|
315 |
+
"metric='SHMetric/Main ideas/proba_1'\n",
|
316 |
+
"# white grid\n",
|
317 |
+
"sns.set(style=\"whitegrid\")\n",
|
318 |
+
"avg = df.groupby([\"metadata/method\", \"id\", \"metadata/reranking_model\", \"Generation Method\"]).mean().reset_index()\n",
|
319 |
+
"avg = avg.sort_values(metric)\n",
|
320 |
+
"\n",
|
321 |
+
"# rename columns with human readable names\n",
|
322 |
+
"avg = avg.rename(columns={\n",
|
323 |
+
" 'metadata/method': 'Method',\n",
|
324 |
+
" 'metadata/reranking_model': 'Reranking Model',\n",
|
325 |
+
" 'Generation Method': 'Generation Method',\n",
|
326 |
+
" metric: 'Main Ideas'\n",
|
327 |
+
"})\n",
|
328 |
+
"\n",
|
329 |
+
"\n",
|
330 |
+
"\n",
|
331 |
+
"g = sns.catplot(data=avg, y=\"Main Ideas\", x=\"Method\", hue=\"Reranking Model\", col=\"Generation Method\", kind=\"bar\")\n",
|
332 |
+
"\n",
|
333 |
+
"\n",
|
334 |
+
"# get legend label and handle\n",
|
335 |
+
"handles, labels = g._legend_data.values(), g._legend_data.keys()\n",
|
336 |
+
"\n",
|
337 |
+
"# set legend\n",
|
338 |
+
"g._legend.remove()\n",
|
339 |
+
"g.fig.legend(handles, labels, loc='upper center', ncol=2, fontsize=25, title_fontsize=25, title=\"Reranking Model\", bbox_to_anchor=(0.4, -0.3))\n",
|
340 |
+
"\n",
|
341 |
+
"\n",
|
342 |
+
"# set title template \n",
|
343 |
+
"g.set_titles(\"{col_name}\")\n",
|
344 |
+
"\n",
|
345 |
+
"# add hline at 0.215 for the baseline, on each axis\n",
|
346 |
+
"for ax in g.axes.flat:\n",
|
347 |
+
" ax.axhline(0.215, ls='--', color='black', linewidth=5)\n",
|
348 |
+
" ax.set_xticklabels(ax.get_xticklabels(), rotation=30)\n",
|
349 |
+
" \n",
|
350 |
+
"# make label bigger\n",
|
351 |
+
"for ax in g.axes.flat:\n",
|
352 |
+
" ax.set_xlabel(\"\")\n",
|
353 |
+
" ax.set_ylabel(ax.get_ylabel(), fontsize=25, fontweight='bold')\n",
|
354 |
+
" ax.set_xticklabels(ax.get_xticklabels(), fontsize=25, fontweight='bold')\n",
|
355 |
+
" \n",
|
356 |
+
"# make title bigger\n",
|
357 |
+
"for ax in g.axes.flat:\n",
|
358 |
+
" ax.set_title(ax.get_title(), fontsize=25, fontweight='bold')\n",
|
359 |
+
" \n",
|
360 |
+
"# add annotation for the hline on the first axis\n",
|
361 |
+
"\n",
|
362 |
+
"\n",
|
363 |
+
"\n",
|
364 |
+
"\n",
|
365 |
+
"\n",
|
366 |
+
"plt.xticks(rotation=30)\n",
|
367 |
+
"\n",
|
368 |
+
"# save figure\n",
|
369 |
+
"g.savefig(FIGURE / \"seahorse_main_ideas.pdf\")\n",
|
370 |
+
"\n",
|
371 |
+
"\n",
|
372 |
+
"\n"
|
373 |
+
]
|
374 |
+
},
|
375 |
+
{
|
376 |
+
"cell_type": "code",
|
377 |
+
"execution_count": null,
|
378 |
+
"id": "f84773d4-f2b7-4db6-851f-4683d545345b",
|
379 |
+
"metadata": {
|
380 |
+
"tags": []
|
381 |
+
},
|
382 |
+
"outputs": [],
|
383 |
+
"source": [
|
384 |
+
"metric='SHMetric/Main ideas/proba_1'\n",
|
385 |
+
"\n",
|
386 |
+
"toplot = df.copy()\n",
|
387 |
+
"toplot['metadata/reranking_model'] = toplot['metadata/reranking_model'].fillna('N/A')\n",
|
388 |
+
"toplot['Generation Method'] = toplot['Generation Method'].fillna('N/A')\n",
|
389 |
+
"\n",
|
390 |
+
"toplot = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\", \"metadata/reranking_model\"]).mean()\n",
|
391 |
+
"idx = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\"])[metric].idxmax()\n",
|
392 |
+
"toplot = toplot.loc[idx].reset_index()\n",
|
393 |
+
"toplot = toplot[~toplot['metadata/method'].str.contains('Lead')]\n",
|
394 |
+
"\n",
|
395 |
+
"toplot = toplot.sort_values(metric, ascending=True)\n",
|
396 |
+
"order = toplot.groupby(\"metadata/method\").mean().sort_values(metric)\n",
|
397 |
+
"\n",
|
398 |
+
"\n",
|
399 |
+
"display(toplot.groupby(\"metadata/method\").mean().sort_values(metric)[metric])\n",
|
400 |
+
"\n",
|
401 |
+
"sns.barplot(data=toplot, y=metric, x=\"metadata/method\", order=order.index)\n",
|
402 |
+
"\n",
|
403 |
+
"plt.xticks(rotation=45)"
|
404 |
+
]
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"cell_type": "code",
|
408 |
+
"execution_count": null,
|
409 |
+
"id": "47f15195-e60d-45f8-aac0-dc88eeea9577",
|
410 |
+
"metadata": {
|
411 |
+
"tags": []
|
412 |
+
},
|
413 |
+
"outputs": [],
|
414 |
+
"source": [
|
415 |
+
"metric='SHMetric/Conciseness/proba_1'\n",
|
416 |
+
"\n",
|
417 |
+
"toplot = df.copy()\n",
|
418 |
+
"toplot['metadata/reranking_model'] = toplot['metadata/reranking_model'].fillna('N/A')\n",
|
419 |
+
"toplot['Generation Method'] = toplot['Generation Method'].fillna('N/A')\n",
|
420 |
+
"\n",
|
421 |
+
"toplot = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\", \"metadata/reranking_model\"]).mean()\n",
|
422 |
+
"idx = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\"])[metric].idxmax()\n",
|
423 |
+
"toplot = toplot.loc[idx].reset_index()\n",
|
424 |
+
"toplot = toplot[~toplot['metadata/method'].str.contains('Lead')]\n",
|
425 |
+
"\n",
|
426 |
+
"toplot = toplot.sort_values(metric, ascending=True)\n",
|
427 |
+
"order = toplot.groupby(\"metadata/method\").mean().sort_values(metric)\n",
|
428 |
+
"\n",
|
429 |
+
"\n",
|
430 |
+
"\n",
|
431 |
+
"sns.barplot(data=toplot, y=metric, x=\"metadata/method\", order=order.index)\n",
|
432 |
+
"\n",
|
433 |
+
"plt.xticks(rotation=45)"
|
434 |
+
]
|
435 |
+
},
|
436 |
+
{
|
437 |
+
"cell_type": "code",
|
438 |
+
"execution_count": null,
|
439 |
+
"id": "96b5d47b-f9d2-45bf-9284-844dedb24ce9",
|
440 |
+
"metadata": {
|
441 |
+
"tags": []
|
442 |
+
},
|
443 |
+
"outputs": [],
|
444 |
+
"source": [
|
445 |
+
"metric='SHMetric/Repetition/proba_1'\n",
|
446 |
+
"\n",
|
447 |
+
"toplot = df.copy()\n",
|
448 |
+
"toplot['metadata/reranking_model'] = toplot['metadata/reranking_model'].fillna('N/A')\n",
|
449 |
+
"toplot['Generation Method'] = toplot['Generation Method'].fillna('N/A')\n",
|
450 |
+
"\n",
|
451 |
+
"toplot = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\", \"metadata/reranking_model\"]).mean()\n",
|
452 |
+
"idx = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\"])[metric].idxmax()\n",
|
453 |
+
"toplot = toplot.loc[idx].reset_index()\n",
|
454 |
+
"toplot = toplot[~toplot['metadata/method'].str.contains('Lead')]\n",
|
455 |
+
"\n",
|
456 |
+
"toplot = toplot.sort_values(metric, ascending=True)\n",
|
457 |
+
"order = toplot.groupby(\"metadata/method\").mean().sort_values(metric)\n",
|
458 |
+
"\n",
|
459 |
+
"\n",
|
460 |
+
"\n",
|
461 |
+
"sns.barplot(data=toplot, y=metric, x=\"metadata/method\", order=order.index)\n",
|
462 |
+
"\n",
|
463 |
+
"plt.xticks(rotation=45)"
|
464 |
+
]
|
465 |
+
},
|
466 |
+
{
|
467 |
+
"cell_type": "code",
|
468 |
+
"execution_count": null,
|
469 |
+
"id": "be9a5198-40e9-4b60-8ed4-6b2ea3a9683e",
|
470 |
+
"metadata": {},
|
471 |
+
"outputs": [],
|
472 |
+
"source": [
|
473 |
+
"metric='SHMetric/Repetition/proba_1'\n",
|
474 |
+
"\n",
|
475 |
+
"toplot = df.copy()\n",
|
476 |
+
"toplot['metadata/reranking_model'] = toplot['metadata/reranking_model'].fillna('N/A')\n",
|
477 |
+
"toplot['Generation Method'] = toplot['Generation Method'].fillna('N/A')\n",
|
478 |
+
"\n",
|
479 |
+
"toplot = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\", \"metadata/reranking_model\"]).mean()\n",
|
480 |
+
"idx = toplot.groupby([\"metadata/method\", \"id\", \"Generation Method\"])[metric].idxmax()\n",
|
481 |
+
"toplot = toplot.loc[idx].reset_index()\n",
|
482 |
+
"toplot = toplot[~toplot['metadata/method'].str.contains('Lead')]\n",
|
483 |
+
"\n",
|
484 |
+
"toplot = toplot.sort_values(metric, ascending=True)\n",
|
485 |
+
"order = toplot.groupby(\"metadata/method\").mean().sort_values(metric)\n",
|
486 |
+
"\n",
|
487 |
+
"\n",
|
488 |
+
"\n",
|
489 |
+
"sns.barplot(data=toplot, y=metric, x=\"metadata/method\", order=order.index)\n",
|
490 |
+
"\n",
|
491 |
+
"plt.xticks(rotation=45)"
|
492 |
+
]
|
493 |
+
},
|
494 |
+
{
|
495 |
+
"cell_type": "code",
|
496 |
+
"execution_count": null,
|
497 |
+
"id": "6e52ec5a-3f68-4fdf-9828-205cd9b55e42",
|
498 |
+
"metadata": {
|
499 |
+
"tags": []
|
500 |
+
},
|
501 |
+
"outputs": [],
|
502 |
+
"source": [
|
503 |
+
"metric='SHMetric/Main ideas/proba_1'\n",
|
504 |
+
"\n",
|
505 |
+
"avg = df.groupby([\"metadata/method\", \"id\", \"N\"]).mean().reset_index()\n",
|
506 |
+
"avg = avg.sort_values(metric)\n",
|
507 |
+
"sns.barplot(data=avg[~avg['metadata/method'].str.contains('Lead')], y=metric, x=\"metadata/method\", hue='N')\n",
|
508 |
+
"plt.xticks(rotation=45)"
|
509 |
+
]
|
510 |
+
},
|
511 |
+
{
|
512 |
+
"cell_type": "code",
|
513 |
+
"execution_count": null,
|
514 |
+
"id": "a675a584-c80f-49ca-94e3-55d868c8b594",
|
515 |
+
"metadata": {
|
516 |
+
"tags": []
|
517 |
+
},
|
518 |
+
"outputs": [],
|
519 |
+
"source": [
|
520 |
+
"metric='SHMetric/Main ideas/proba_1'\n",
|
521 |
+
"\n",
|
522 |
+
"avg = df.groupby([\"metadata/method\", \"id\"]).mean().reset_index()\n",
|
523 |
+
"avg = avg[~avg['metadata/method'].str.contains('Lead')].sort_values(metric, )\n",
|
524 |
+
"sns.barplot(data=avg, y=metric, x=\"metadata/method\")\n",
|
525 |
+
"plt.xticks(rotation=45)"
|
526 |
+
]
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"cell_type": "code",
|
530 |
+
"execution_count": null,
|
531 |
+
"id": "79f7e781-0e34-458d-826d-67c08109cef7",
|
532 |
+
"metadata": {
|
533 |
+
"tags": []
|
534 |
+
},
|
535 |
+
"outputs": [],
|
536 |
+
"source": [
|
537 |
+
"metric='rougeL'\n",
|
538 |
+
"\n",
|
539 |
+
"avg = df.groupby([\"metadata/method\", \"id\"]).mean().reset_index()\n",
|
540 |
+
"avg = avg[~avg['metadata/method'].str.contains('Lead')].sort_values(metric, )\n",
|
541 |
+
"sns.barplot(data=avg, y=metric, x=\"metadata/method\")\n",
|
542 |
+
"plt.xticks(rotation=45)"
|
543 |
+
]
|
544 |
+
},
|
545 |
+
{
|
546 |
+
"cell_type": "code",
|
547 |
+
"execution_count": null,
|
548 |
+
"id": "8b65306c-9f5e-4d71-910a-48de9a195534",
|
549 |
+
"metadata": {
|
550 |
+
"tags": []
|
551 |
+
},
|
552 |
+
"outputs": [],
|
553 |
+
"source": [
|
554 |
+
"df.columns"
|
555 |
+
]
|
556 |
+
},
|
557 |
+
{
|
558 |
+
"cell_type": "code",
|
559 |
+
"execution_count": null,
|
560 |
+
"id": "cf6abd56-e9d1-4cea-bd01-d03046b56f3d",
|
561 |
+
"metadata": {},
|
562 |
+
"outputs": [],
|
563 |
+
"source": []
|
564 |
+
}
|
565 |
+
],
|
566 |
+
"metadata": {
|
567 |
+
"kernelspec": {
|
568 |
+
"name": "python3",
|
569 |
+
"language": "python",
|
570 |
+
"display_name": "Python 3 (ipykernel)"
|
571 |
+
},
|
572 |
+
"language_info": {
|
573 |
+
"codemirror_mode": {
|
574 |
+
"name": "ipython",
|
575 |
+
"version": 3
|
576 |
+
},
|
577 |
+
"file_extension": ".py",
|
578 |
+
"mimetype": "text/x-python",
|
579 |
+
"name": "python",
|
580 |
+
"nbconvert_exporter": "python",
|
581 |
+
"pygments_lexer": "ipython3",
|
582 |
+
"version": "3.10.9"
|
583 |
+
}
|
584 |
+
},
|
585 |
+
"nbformat": 4,
|
586 |
+
"nbformat_minor": 5
|
587 |
+
}
|
glimpse-ui/glimpse/mds/Template summaries.ipynb
ADDED
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "12404068-3244-43d6-8556-41e11489bb48",
|
7 |
+
"metadata": {
|
8 |
+
"tags": []
|
9 |
+
},
|
10 |
+
"outputs": [],
|
11 |
+
"source": [
|
12 |
+
"import pickle as pk\n",
|
13 |
+
"import pandas as pd\n",
|
14 |
+
"from pathlib import Path\n",
|
15 |
+
"import numpy as np\n",
|
16 |
+
"import seaborn as sns\n",
|
17 |
+
"\n",
|
18 |
+
"from rouge_score import rouge_scorer\n",
|
19 |
+
"\n",
|
20 |
+
"\n",
|
21 |
+
"from lexrank import LexRank\n",
|
22 |
+
"from lexrank.mappings.stopwords import STOPWORDS\n",
|
23 |
+
"import nltk \n",
|
24 |
+
"\n"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"execution_count": null,
|
30 |
+
"id": "044c82d6-23c8-4c3b-a4b5-12acbbc1cc1a",
|
31 |
+
"metadata": {
|
32 |
+
"tags": []
|
33 |
+
},
|
34 |
+
"outputs": [],
|
35 |
+
"source": []
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": null,
|
40 |
+
"id": "1f5cdf3f-5485-4b9f-a6fc-b2b8bd8aca7f",
|
41 |
+
"metadata": {
|
42 |
+
"tags": []
|
43 |
+
},
|
44 |
+
"outputs": [],
|
45 |
+
"source": [
|
46 |
+
"\n",
|
47 |
+
"\n",
|
48 |
+
"\n",
|
49 |
+
"path = Path(\"output/summaries/rsa_reranking/reviews_rsa_matrices/\")\n",
|
50 |
+
"output_path = Path(\"output/summaries/methods_reviews/\")\n",
|
51 |
+
"\n"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": null,
|
57 |
+
"id": "ec5e8ff3-6bef-42bc-8430-df93c1a4e79a",
|
58 |
+
"metadata": {
|
59 |
+
"tags": []
|
60 |
+
},
|
61 |
+
"outputs": [],
|
62 |
+
"source": []
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": null,
|
67 |
+
"id": "d45cd444-0a81-4670-bbae-213e322ea281",
|
68 |
+
"metadata": {
|
69 |
+
"tags": []
|
70 |
+
},
|
71 |
+
"outputs": [],
|
72 |
+
"source": []
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": null,
|
77 |
+
"id": "9e6c2863-9e5a-4e5a-bdb6-02e03c5f6105",
|
78 |
+
"metadata": {
|
79 |
+
"tags": []
|
80 |
+
},
|
81 |
+
"outputs": [],
|
82 |
+
"source": []
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "code",
|
86 |
+
"execution_count": null,
|
87 |
+
"id": "2ae23148-38ae-4385-99fb-db20da54334d",
|
88 |
+
"metadata": {
|
89 |
+
"tags": []
|
90 |
+
},
|
91 |
+
"outputs": [],
|
92 |
+
"source": []
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": null,
|
97 |
+
"id": "d15ce19b-a3c7-4554-878f-41acd3204878",
|
98 |
+
"metadata": {
|
99 |
+
"tags": []
|
100 |
+
},
|
101 |
+
"outputs": [],
|
102 |
+
"source": []
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": null,
|
107 |
+
"id": "1652ff75-b6c2-483a-a9af-ff3ca8616756",
|
108 |
+
"metadata": {
|
109 |
+
"tags": []
|
110 |
+
},
|
111 |
+
"outputs": [],
|
112 |
+
"source": []
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"cell_type": "code",
|
116 |
+
"execution_count": null,
|
117 |
+
"id": "983cd24c-b996-4224-8f87-ea79842c41a0",
|
118 |
+
"metadata": {
|
119 |
+
"tags": []
|
120 |
+
},
|
121 |
+
"outputs": [],
|
122 |
+
"source": []
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "markdown",
|
126 |
+
"id": "0db0e8f5-4b4a-4d55-8596-fe095aa4135f",
|
127 |
+
"metadata": {},
|
128 |
+
"source": [
|
129 |
+
"# Consensus score based summaries:"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "code",
|
134 |
+
"execution_count": null,
|
135 |
+
"id": "560c7f8d-6b8e-4b5b-ba36-0dedc509791f",
|
136 |
+
"metadata": {
|
137 |
+
"tags": []
|
138 |
+
},
|
139 |
+
"outputs": [],
|
140 |
+
"source": []
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": null,
|
145 |
+
"id": "662a5f1b-1e4d-458e-a437-b1d9c8db4552",
|
146 |
+
"metadata": {
|
147 |
+
"tags": []
|
148 |
+
},
|
149 |
+
"outputs": [],
|
150 |
+
"source": [
|
151 |
+
"def consensus_scores_based_summaries(sample, n_consensus=3, n_dissensus=3):\n",
|
152 |
+
" consensus_samples = sample['consensuality_scores'].sort_values(ascending=True).head(n_consensus).index.tolist()\n",
|
153 |
+
" disensus_samples = sample['consensuality_scores'].sort_values(ascending=False).head(n_dissensus).index.tolist()\n",
|
154 |
+
" \n",
|
155 |
+
" consensus = \".\".join(consensus_samples)\n",
|
156 |
+
" disensus = \".\".join(disensus_samples)\n",
|
157 |
+
" \n",
|
158 |
+
" return consensus + \"\\n\\n\" + disensus\n",
|
159 |
+
" \n",
|
160 |
+
" \n",
|
161 |
+
"def rsa_scores_based_summaries(sample, n_consensus=3, n_rsa_speaker=3):\n",
|
162 |
+
" consensus_samples = sample['consensuality_scores'].sort_values(ascending=True).head(n_consensus).index.tolist()\n",
|
163 |
+
" rsa = sample['best_rsa'].tolist()[:n_rsa_speaker]\n",
|
164 |
+
" \n",
|
165 |
+
" consensus = \".\".join(consensus_samples)\n",
|
166 |
+
" rsa = \".\".join(rsa)\n",
|
167 |
+
" \n",
|
168 |
+
" return consensus + \"\\n\\n\" + rsa\n",
|
169 |
+
"\n",
|
170 |
+
"scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)\n",
|
171 |
+
"\n",
|
172 |
+
"def lead(sample, N=10):\n",
|
173 |
+
" texts = sample['speaker_df'].index.tolist()\n",
|
174 |
+
" \n",
|
175 |
+
" summary = \"\\n\".join([\".\".join(t.split('.')[:N]) for t in texts])\n",
|
176 |
+
" \n",
|
177 |
+
" return summary\n",
|
178 |
+
"\n",
|
179 |
+
" \n",
|
180 |
+
" \n",
|
181 |
+
"\n",
|
182 |
+
"scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)\n",
|
183 |
+
"\n",
|
184 |
+
"\n",
|
185 |
+
"def construct_templated_summaries(data, fn, dataset=None): \n",
|
186 |
+
" records = []\n",
|
187 |
+
" for sample in data['results']:\n",
|
188 |
+
" summary = fn(sample)\n",
|
189 |
+
" text = \"\\n\\n\".join(sample['speaker_df'].index.tolist())\n",
|
190 |
+
" record = {'id' : sample['id'], 'summary': summary, 'metadata/reranking_model' : data['metadata/reranking_model'], 'metadata/rsa_iterations' : data['metadata/reranking_model'], \"text\": text}\n",
|
191 |
+
" if dataset is not None:\n",
|
192 |
+
" record['gold'] = dataset.loc[sample[\"id\"]]['gold'].tolist()[0]\n",
|
193 |
+
" if record['gold'] is not None:\n",
|
194 |
+
" rouges = scorer.score(summary, record['gold'])\n",
|
195 |
+
" record |= {r : v.fmeasure for r, v in rouges.items()}\n",
|
196 |
+
" \n",
|
197 |
+
" \n",
|
198 |
+
" \n",
|
199 |
+
" records.append(record)\n",
|
200 |
+
" \n",
|
201 |
+
" return pd.DataFrame.from_records(records)\n",
|
202 |
+
" \n",
|
203 |
+
"\n",
|
204 |
+
" \n",
|
205 |
+
" \n"
|
206 |
+
]
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"cell_type": "code",
|
210 |
+
"execution_count": null,
|
211 |
+
"id": "39c45180-a354-429c-8cde-3a7c78013cc6",
|
212 |
+
"metadata": {
|
213 |
+
"tags": []
|
214 |
+
},
|
215 |
+
"outputs": [],
|
216 |
+
"source": [
|
217 |
+
"def prepare_dataset(dataset_name, dataset_path=\"data/processed/\"):\n",
|
218 |
+
" dataset_path = Path(dataset_path)\n",
|
219 |
+
" if dataset_name == \"amazon\":\n",
|
220 |
+
" dataset = pd.read_csv(dataset_path / \"amazon_test.csv\")\n",
|
221 |
+
" elif dataset_name == \"space\":\n",
|
222 |
+
" dataset = pd.read_csv(dataset_path / \"space.csv\")\n",
|
223 |
+
" elif dataset_name == \"yelp\":\n",
|
224 |
+
" dataset = pd.read_csv(dataset_path / \"yelp_test.csv\")\n",
|
225 |
+
" elif dataset_name == \"reviews\":\n",
|
226 |
+
" dataset = pd.read_csv(dataset_path / \"test_metareviews.csv\")\n",
|
227 |
+
" else:\n",
|
228 |
+
" raise ValueError(f\"Unknown dataset {dataset_name}\")\n",
|
229 |
+
"\n",
|
230 |
+
"\n",
|
231 |
+
" return dataset\n"
|
232 |
+
]
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"cell_type": "code",
|
236 |
+
"execution_count": null,
|
237 |
+
"id": "addeda2b-71fc-4c9a-8e91-12cf70e52b1e",
|
238 |
+
"metadata": {},
|
239 |
+
"outputs": [],
|
240 |
+
"source": [
|
241 |
+
"# df = prepare_dataset('reviews')\n",
|
242 |
+
"\n",
|
243 |
+
"# for n, group in df.groupby('id'):\n",
|
244 |
+
"# for idx, row in group.iterrows():\n",
|
245 |
+
"# print(row['text'].replace('-----', \"\\n\"))\n",
|
246 |
+
"# print(\"===========\")\n",
|
247 |
+
"# break\n",
|
248 |
+
"rsa_scores_based_summaries"
|
249 |
+
]
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"cell_type": "code",
|
253 |
+
"execution_count": null,
|
254 |
+
"id": "7fe212dd-63d2-44b1-b06e-7792b9d504ac",
|
255 |
+
"metadata": {
|
256 |
+
"tags": []
|
257 |
+
},
|
258 |
+
"outputs": [],
|
259 |
+
"source": [
|
260 |
+
"\n",
|
261 |
+
"\n",
|
262 |
+
"for n in [1, 2, 3, 4, 5, 6]:\n",
|
263 |
+
" for file in path.glob(\"*.pk\"):\n",
|
264 |
+
" print(file)\n",
|
265 |
+
" with file.open('rb') as fd:\n",
|
266 |
+
" data = pk.load(fd)\n",
|
267 |
+
"\n",
|
268 |
+
" Path(output_path).mkdir(parents=True, exist_ok=True)\n",
|
269 |
+
" model_name, dataset_name, decoding_config, date = str(file.stem).split('-_-')[:4]\n",
|
270 |
+
"\n",
|
271 |
+
" dataset = prepare_dataset(dataset_name, dataset_path=\"data/processed/\")\n",
|
272 |
+
" dataset = dataset.set_index('id')\n",
|
273 |
+
" \n",
|
274 |
+
" fn = lambda sample: consensus_scores_based_summaries(sample, n_consensus=n, n_dissensus=n)\n",
|
275 |
+
"\n",
|
276 |
+
" df = construct_templated_summaries(data, fn, dataset=dataset)\n",
|
277 |
+
" \n",
|
278 |
+
" df['metadata/method'] = \"Agreement\"\n",
|
279 |
+
" df['metadata/n_sentences'] = 2*n\n",
|
280 |
+
" df['metadata/n_consensus'] = n\n",
|
281 |
+
" df['metadata/n_dissensus'] = n\n",
|
282 |
+
"\n",
|
283 |
+
" name = file.stem + \"-_-\" + f\"consensus_score_based_{n}.csv\"\n",
|
284 |
+
"\n",
|
285 |
+
" if (output_path / name).exists():\n",
|
286 |
+
" df_old = pd.read_csv(output_path / name)\n",
|
287 |
+
"\n",
|
288 |
+
" for col in df.columns:\n",
|
289 |
+
" if col not in df_old.columns:\n",
|
290 |
+
" df_old[col] = float(\"nan\")\n",
|
291 |
+
"\n",
|
292 |
+
" # add entry to the dataframe\n",
|
293 |
+
" for col in df.columns:\n",
|
294 |
+
" df_old[col] = df[col]\n",
|
295 |
+
"\n",
|
296 |
+
" df = df_old\n",
|
297 |
+
"\n",
|
298 |
+
" df.to_csv(output_path / name)\n",
|
299 |
+
" \n",
|
300 |
+
" \n",
|
301 |
+
" \n",
|
302 |
+
" "
|
303 |
+
]
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"cell_type": "code",
|
307 |
+
"execution_count": null,
|
308 |
+
"id": "ab45111b-9c9f-44ee-8cc1-613bfa32a007",
|
309 |
+
"metadata": {
|
310 |
+
"tags": []
|
311 |
+
},
|
312 |
+
"outputs": [],
|
313 |
+
"source": [
|
314 |
+
"\n",
|
315 |
+
"for n in [1, 2, 3, 4, 5, 6]:\n",
|
316 |
+
" for file in path.glob(\"*.pk\"):\n",
|
317 |
+
" with file.open('rb') as fd:\n",
|
318 |
+
" data = pk.load(fd)\n",
|
319 |
+
"\n",
|
320 |
+
" Path(output_path).mkdir(parents=True, exist_ok=True)\n",
|
321 |
+
" model_name, dataset_name, decoding_config, date = str(file.stem).split('-_-')[:4]\n",
|
322 |
+
"\n",
|
323 |
+
" dataset = prepare_dataset(dataset_name, dataset_path=\"data/processed/\")\n",
|
324 |
+
" dataset = dataset.set_index('id')\n",
|
325 |
+
"\n",
|
326 |
+
" fn = lambda sample: rsa_scores_based_summaries(sample, n_consensus=n, n_rsa_speaker=n)\n",
|
327 |
+
" df = construct_templated_summaries(data, fn, dataset=dataset)\n",
|
328 |
+
"\n",
|
329 |
+
" df['metadata/method'] = \"Speaker+Agreement\"\n",
|
330 |
+
" df['metadata/n_sentences'] = 2*n\n",
|
331 |
+
" df['metadata/n_consensus'] = n\n",
|
332 |
+
" df['metadata/n_dissensus'] = n\n",
|
333 |
+
"\n",
|
334 |
+
" name = file.stem + \"-_-\" + f\"rsa_score_based_{n}.csv\"\n",
|
335 |
+
"\n",
|
336 |
+
" if (output_path / name).exists():\n",
|
337 |
+
" df_old = pd.read_csv(output_path / name)\n",
|
338 |
+
"\n",
|
339 |
+
" for col in df.columns:\n",
|
340 |
+
" if col not in df_old.columns:\n",
|
341 |
+
" df_old[col] = float(\"nan\")\n",
|
342 |
+
"\n",
|
343 |
+
" # add entry to the dataframe\n",
|
344 |
+
" for col in df.columns:\n",
|
345 |
+
" df_old[col] = df[col]\n",
|
346 |
+
"\n",
|
347 |
+
" df = df_old\n",
|
348 |
+
"\n",
|
349 |
+
" df.to_csv(output_path / name)"
|
350 |
+
]
|
351 |
+
},
|
352 |
+
{
|
353 |
+
"cell_type": "code",
|
354 |
+
"execution_count": null,
|
355 |
+
"id": "8b57c318-5fc8-49fc-8746-128e1112e46a",
|
356 |
+
"metadata": {
|
357 |
+
"tags": []
|
358 |
+
},
|
359 |
+
"outputs": [],
|
360 |
+
"source": [
|
361 |
+
"\n",
|
362 |
+
"for n in [1, 2, 3, 4, 5, 6, 7, 8]:\n",
|
363 |
+
" for file in path.glob(\"*.pk\"):\n",
|
364 |
+
" with file.open('rb') as fd:\n",
|
365 |
+
" data = pk.load(fd)\n",
|
366 |
+
"\n",
|
367 |
+
" Path(output_path).mkdir(parents=True, exist_ok=True)\n",
|
368 |
+
" model_name, dataset_name, decoding_config, date = str(file.stem).split('-_-')[:4]\n",
|
369 |
+
"\n",
|
370 |
+
" dataset = prepare_dataset(dataset_name, dataset_path=\"data/processed/\")\n",
|
371 |
+
" dataset = dataset.set_index('id')\n",
|
372 |
+
"\n",
|
373 |
+
" fn = lambda sample: lead(sample, N=2*n)\n",
|
374 |
+
"\n",
|
375 |
+
"\n",
|
376 |
+
" df = construct_templated_summaries(data, fn, dataset=dataset)\n",
|
377 |
+
"\n",
|
378 |
+
" df['metadata/method'] = \"Lead\"\n",
|
379 |
+
" df['metadata/n_sentences'] = 2*n\n",
|
380 |
+
"\n",
|
381 |
+
" name = file.stem + \"-_-\" + f\"lead_{2*n}.csv\"\n",
|
382 |
+
"\n",
|
383 |
+
" if (output_path / name).exists():\n",
|
384 |
+
" df_old = pd.read_csv(output_path / name)\n",
|
385 |
+
"\n",
|
386 |
+
" for col in df.columns:\n",
|
387 |
+
" if col not in df_old.columns:\n",
|
388 |
+
" df_old[col] = float(\"nan\")\n",
|
389 |
+
"\n",
|
390 |
+
" # add entry to the dataframe\n",
|
391 |
+
" for col in df.columns:\n",
|
392 |
+
" df_old[col] = df[col]\n",
|
393 |
+
"\n",
|
394 |
+
" df = df_old\n",
|
395 |
+
"\n",
|
396 |
+
" df.to_csv(output_path / name)"
|
397 |
+
]
|
398 |
+
},
|
399 |
+
{
|
400 |
+
"cell_type": "code",
|
401 |
+
"execution_count": null,
|
402 |
+
"id": "af574823-667d-4722-90bc-2bb095ad3a01",
|
403 |
+
"metadata": {},
|
404 |
+
"outputs": [],
|
405 |
+
"source": []
|
406 |
+
},
|
407 |
+
{
|
408 |
+
"cell_type": "code",
|
409 |
+
"execution_count": null,
|
410 |
+
"id": "c32d1d88-c6f5-4ada-aec6-219d90cade16",
|
411 |
+
"metadata": {
|
412 |
+
"tags": []
|
413 |
+
},
|
414 |
+
"outputs": [],
|
415 |
+
"source": [
|
416 |
+
"import seaborn as sns\n",
|
417 |
+
"import matplotlib.pyplot as plt"
|
418 |
+
]
|
419 |
+
},
|
420 |
+
{
|
421 |
+
"cell_type": "code",
|
422 |
+
"execution_count": null,
|
423 |
+
"id": "868ac37e-6187-46f5-935c-111ca532b1b0",
|
424 |
+
"metadata": {
|
425 |
+
"tags": []
|
426 |
+
},
|
427 |
+
"outputs": [],
|
428 |
+
"source": [
|
429 |
+
"output_path = Path(\"output/summaries/methods_reviews/\")"
|
430 |
+
]
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"cell_type": "code",
|
434 |
+
"execution_count": null,
|
435 |
+
"id": "7263e9d0-6a43-4698-bf6f-82fff0839316",
|
436 |
+
"metadata": {
|
437 |
+
"tags": []
|
438 |
+
},
|
439 |
+
"outputs": [],
|
440 |
+
"source": [
|
441 |
+
"import subprocess\n",
|
442 |
+
"\n",
|
443 |
+
"\n",
|
444 |
+
"for file in output_path.glob(\"*.csv\"):\n",
|
445 |
+
" print(file)\n",
|
446 |
+
" cmd = [\"python\", \"mds/evaluate_bartbert_metrics.py\", \"--summaries\", file]\n",
|
447 |
+
" subprocess.run(cmd)"
|
448 |
+
]
|
449 |
+
},
|
450 |
+
{
|
451 |
+
"cell_type": "code",
|
452 |
+
"execution_count": null,
|
453 |
+
"id": "18a63537-c44e-421c-beff-50c6518115bf",
|
454 |
+
"metadata": {
|
455 |
+
"tags": []
|
456 |
+
},
|
457 |
+
"outputs": [],
|
458 |
+
"source": [
|
459 |
+
"dfs = []\n",
|
460 |
+
"for file in output_path.glob(\"*.csv\"):\n",
|
461 |
+
" model_name, dataset_name, decoding_config, date = str(file.stem).split('-_-')[:4]\n",
|
462 |
+
" method = str(file.stem).split('-_-')[-1]\n",
|
463 |
+
" \n",
|
464 |
+
" df = pd.read_csv(file)\n",
|
465 |
+
" df['metadata/Model'] = model_name\n",
|
466 |
+
" df['metadata/Dataset'] = dataset_name\n",
|
467 |
+
" df['metadata/method'] = method\n",
|
468 |
+
" \n",
|
469 |
+
" df[\"Method\"] = f\"{model_name}/{method}\"\n",
|
470 |
+
" \n",
|
471 |
+
" dfs.append(df)\n",
|
472 |
+
" \n",
|
473 |
+
"df = pd.concat(dfs)\n",
|
474 |
+
" \n",
|
475 |
+
" \n",
|
476 |
+
"df"
|
477 |
+
]
|
478 |
+
},
|
479 |
+
{
|
480 |
+
"cell_type": "code",
|
481 |
+
"execution_count": null,
|
482 |
+
"id": "eef1ec67-62f9-458f-9380-debe40bac46a",
|
483 |
+
"metadata": {
|
484 |
+
"tags": []
|
485 |
+
},
|
486 |
+
"outputs": [],
|
487 |
+
"source": [
|
488 |
+
"sns.catplot(data=df, hue='Method', y='rougeL', x='metadata/Dataset', kind='bar')"
|
489 |
+
]
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"cell_type": "code",
|
493 |
+
"execution_count": null,
|
494 |
+
"id": "aed3e026-fd89-416f-a333-c841eaf566e4",
|
495 |
+
"metadata": {},
|
496 |
+
"outputs": [],
|
497 |
+
"source": [
|
498 |
+
"sns.catplot(data=df, hue='metadata/method', y='rouge1', x='metadata/reranking_model', kind='bar', row=\"metadata/model\")"
|
499 |
+
]
|
500 |
+
},
|
501 |
+
{
|
502 |
+
"cell_type": "code",
|
503 |
+
"execution_count": null,
|
504 |
+
"id": "2e433925-6be6-4322-af07-45b4c07ff5ff",
|
505 |
+
"metadata": {},
|
506 |
+
"outputs": [],
|
507 |
+
"source": []
|
508 |
+
}
|
509 |
+
],
|
510 |
+
"metadata": {
|
511 |
+
"kernelspec": {
|
512 |
+
"display_name": "pytorch-gpu-2.0.0_py3.10.9",
|
513 |
+
"language": "python",
|
514 |
+
"name": "module-conda-env-pytorch-gpu-2.0.0_py3.10.9"
|
515 |
+
},
|
516 |
+
"language_info": {
|
517 |
+
"codemirror_mode": {
|
518 |
+
"name": "ipython",
|
519 |
+
"version": 3
|
520 |
+
},
|
521 |
+
"file_extension": ".py",
|
522 |
+
"mimetype": "text/x-python",
|
523 |
+
"name": "python",
|
524 |
+
"nbconvert_exporter": "python",
|
525 |
+
"pygments_lexer": "ipython3",
|
526 |
+
"version": "3.10.9"
|
527 |
+
}
|
528 |
+
},
|
529 |
+
"nbformat": 4,
|
530 |
+
"nbformat_minor": 5
|
531 |
+
}
|
glimpse-ui/glimpse/mds/discriminative_classification.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import argparse
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from sentence_transformers import SentenceTransformer
|
10 |
+
|
11 |
+
def xlogx(x):
|
12 |
+
if x == 0:
|
13 |
+
return 0
|
14 |
+
else:
|
15 |
+
return x * torch.log(x)
|
16 |
+
|
17 |
+
def parse_summaries(path : Path):
|
18 |
+
|
19 |
+
# Load the data
|
20 |
+
df = pd.read_csv(path)
|
21 |
+
|
22 |
+
if 'id' not in df.columns:
|
23 |
+
raise ValueError('id column not found in the summaries file')
|
24 |
+
if 'text' not in df.columns:
|
25 |
+
raise ValueError('text column not found in the summaries file')
|
26 |
+
if 'summary' not in df.columns:
|
27 |
+
raise ValueError('summary column not found in the summaries file')
|
28 |
+
|
29 |
+
return df
|
30 |
+
|
31 |
+
|
32 |
+
def embed_text_and_summaries(df : pd.DataFrame, model : SentenceTransformer) -> Tuple[torch.Tensor, torch.Tensor]:
|
33 |
+
|
34 |
+
text_embeddings = model.encode(df.text.tolist(), convert_to_tensor=True)
|
35 |
+
summary_embeddings = model.encode(df.summary.tolist(), convert_to_tensor=True)
|
36 |
+
|
37 |
+
return text_embeddings, summary_embeddings
|
38 |
+
|
39 |
+
|
40 |
+
def compute_dot_products(df : pd.DataFrame, text_embeddings : torch.Tensor, summary_embeddings : torch.Tensor):
|
41 |
+
|
42 |
+
df = df.reset_index()
|
43 |
+
df['index'] = df.index
|
44 |
+
|
45 |
+
# group by id
|
46 |
+
grouped = df.groupby('id')
|
47 |
+
|
48 |
+
# for each id gather the id of the text and the summary
|
49 |
+
ids_per_sample = grouped.index.apply(list).tolist()
|
50 |
+
|
51 |
+
# compute the dot product between the text and the summary
|
52 |
+
|
53 |
+
metrics = {'proba_of_success' : []}
|
54 |
+
for text_ids in ids_per_sample:
|
55 |
+
# shape (num_text, embedding_dim)
|
56 |
+
text_embedding = text_embeddings[text_ids]
|
57 |
+
summary_embedding = summary_embeddings[text_ids]
|
58 |
+
|
59 |
+
# shape (num_text, num_text=num_summary)
|
60 |
+
dot_product = torch.matmul(text_embedding, summary_embedding.T)
|
61 |
+
|
62 |
+
# apply log softmax
|
63 |
+
log_softmax = torch.nn.functional.log_softmax(dot_product, dim=0)
|
64 |
+
|
65 |
+
# num_text
|
66 |
+
log_proba_of_success = torch.diag(log_softmax).squeeze()
|
67 |
+
entropy = torch.xlogy(log_proba_of_success, log_proba_of_success).sum(0).squeeze()
|
68 |
+
|
69 |
+
metrics['proba_of_success'].extend(log_proba_of_success.tolist())
|
70 |
+
# metrics['entropy'].extend(entropy.tolist())
|
71 |
+
|
72 |
+
df['proba_of_success'] = metrics['proba_of_success']
|
73 |
+
# df['entropy'] = metrics['entropy']
|
74 |
+
|
75 |
+
return df
|
76 |
+
|
77 |
+
def parse_args():
|
78 |
+
parser = argparse.ArgumentParser()
|
79 |
+
parser.add_argument('--summaries', type=Path, required=True)
|
80 |
+
parser.add_argument('--model', type=str, default='paraphrase-MiniLM-L6-v2')
|
81 |
+
parser.add_argument('--output', type=Path, required=True)
|
82 |
+
parser.add_argument('--device', type=str, default='cuda')
|
83 |
+
|
84 |
+
args = parser.parse_args()
|
85 |
+
return args
|
86 |
+
|
87 |
+
def main():
|
88 |
+
|
89 |
+
args = parse_args()
|
90 |
+
|
91 |
+
# load the model
|
92 |
+
model = SentenceTransformer(args.model, device=args.device)
|
93 |
+
|
94 |
+
# load the summaries
|
95 |
+
df = parse_summaries(args.summaries)
|
96 |
+
|
97 |
+
# embedd the text and the summary
|
98 |
+
text_embeddings, summary_embeddings = embed_text_and_summaries(df, model)
|
99 |
+
|
100 |
+
# compute the dot product between the text and the summary
|
101 |
+
df = compute_dot_products(df, text_embeddings, summary_embeddings)
|
102 |
+
|
103 |
+
# create the output directory
|
104 |
+
args.output.mkdir(parents=True, exist_ok=True)
|
105 |
+
|
106 |
+
path = args.output / f"{args.summaries.stem}.csv"
|
107 |
+
|
108 |
+
# save the results
|
109 |
+
df.to_csv(path, index=False)
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == '__main__':
|
113 |
+
main()
|
glimpse-ui/glimpse/pyproject.toml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["hatchling"]
|
3 |
+
build-backend = "hatchling.build"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "rsasumm"
|
7 |
+
version = "0.0.1"
|
8 |
+
authors = [
|
9 |
+
]
|
10 |
+
description = ""
|
11 |
+
readme = "Readme.md"
|
12 |
+
requires-python = ">=3.10"
|
13 |
+
classifiers = [
|
14 |
+
"Programming Language :: Python :: 3",
|
15 |
+
"License :: OSI Approved :: MIT License",
|
16 |
+
"Operating System :: OS Independent",
|
17 |
+
]
|
18 |
+
|
19 |
+
[project.urls]
|
20 |
+
"Homepage" = ""
|
21 |
+
"Bug Tracker" = ""
|
glimpse-ui/glimpse/requirements
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
numpy==1.25.2
|
3 |
+
seaborn
|
4 |
+
matplotlib
|
5 |
+
gradio
|
6 |
+
pandas
|
7 |
+
datasets
|
8 |
+
nltk
|
9 |
+
SentencePiece
|
10 |
+
spacy
|
glimpse-ui/glimpse/rsasumm/__init__.py
ADDED
File without changes
|
glimpse-ui/glimpse/rsasumm/beam_search.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper
|
5 |
+
|
6 |
+
|
7 |
+
def compute_rsa_probas(
|
8 |
+
logits: torch.Tensor, prior: torch.Tensor, rationality: float = 1.0
|
9 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
10 |
+
"""
|
11 |
+
:param logits: (world_size, num_beam, vocab_size)
|
12 |
+
:param prior: (world_size, num_beam) for each beam the prior over the objects
|
13 |
+
:param rationality: rationality parameter, the higher the more rational ie the more the speaker will try to adapt
|
14 |
+
to the listener
|
15 |
+
:return: S1, L1: (world_size, num_beam, vocab_size).
|
16 |
+
S1[o, b, w] is the (log)probability of the word w given the object o and the current partial summary for the beam b
|
17 |
+
L1[o, b, w] is the (log)probability of the object o given the word w and the current partial summary for the beam b
|
18 |
+
"""
|
19 |
+
|
20 |
+
prod = logits + prior[..., None]
|
21 |
+
|
22 |
+
L0 = torch.nan_to_num(torch.log_softmax(prod, dim=0), nan=-float("inf"))
|
23 |
+
|
24 |
+
prod_s = logits + L0 * rationality
|
25 |
+
|
26 |
+
S1 = torch.log_softmax(prod_s, dim=-1)
|
27 |
+
S1 = torch.nan_to_num(S1, nan=-float("inf"))
|
28 |
+
|
29 |
+
prod_l = logits + L0
|
30 |
+
L1 = torch.log_softmax(prod_l, dim=0)
|
31 |
+
L1 = torch.nan_to_num(L1, nan=-float("inf"))
|
32 |
+
|
33 |
+
return S1, L1
|
34 |
+
|
35 |
+
|
36 |
+
def sample_from_probs(
|
37 |
+
logits: torch.Tensor, num_beams: torch.Tensor, do_sample: bool, K: int = 10
|
38 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
39 |
+
"""
|
40 |
+
|
41 |
+
:param logits: (num_beams, vocab_size) log proba for the next token only for the wanted object
|
42 |
+
:param num_beams: number of beam to sample. (Can be different from the shape of logits since some beams might have
|
43 |
+
finished earlier)
|
44 |
+
:param do_sample: sample or use argmax
|
45 |
+
:param K: number of samples to draw per beam to create the new population
|
46 |
+
:return: idx_beam, idx_token, tokens_scores, the indices of the sampled tokens and their scores
|
47 |
+
"""
|
48 |
+
|
49 |
+
vocab_size = logits.shape[-1]
|
50 |
+
if do_sample:
|
51 |
+
# sample from the probability distribution
|
52 |
+
logits = logits.view(num_beams * logits.shape[-1])
|
53 |
+
probs = torch.softmax(logits, dim=-1)
|
54 |
+
samples = torch.multinomial(probs, num_samples=K * num_beams)
|
55 |
+
|
56 |
+
# get the indices of the sampled tokens
|
57 |
+
idx_beam, idx_token = samples // vocab_size, samples % vocab_size
|
58 |
+
|
59 |
+
logits = logits.view(num_beams * vocab_size)
|
60 |
+
|
61 |
+
tokens_scores = logits.gather(dim=-1, index=samples).squeeze(-1)
|
62 |
+
|
63 |
+
return idx_beam, idx_token, tokens_scores
|
64 |
+
|
65 |
+
else:
|
66 |
+
# get the indices of the most probable tokens
|
67 |
+
num_beams = logits.shape[0]
|
68 |
+
vocab_size = logits.shape[-1]
|
69 |
+
|
70 |
+
logits = logits.view(num_beams * vocab_size)
|
71 |
+
scores, samples = logits.topk(2 * num_beams, dim=-1)
|
72 |
+
|
73 |
+
idx_beam, idx_token = samples // vocab_size, samples % vocab_size
|
74 |
+
|
75 |
+
tokens_scores = scores.squeeze(-1)
|
76 |
+
|
77 |
+
return idx_beam, idx_token, tokens_scores
|
78 |
+
|
79 |
+
|
80 |
+
# Beam search RSA decoding
|
81 |
+
class RSAContextualDecoding:
|
82 |
+
def __init__(self, model, tokenizer, device):
|
83 |
+
"""
|
84 |
+
|
85 |
+
:param model:
|
86 |
+
:param tokenizer:
|
87 |
+
:param device:
|
88 |
+
"""
|
89 |
+
|
90 |
+
self.model = model.to(device)
|
91 |
+
self.tokenizer = tokenizer
|
92 |
+
self.device = device
|
93 |
+
|
94 |
+
def fwd_pass(
|
95 |
+
self,
|
96 |
+
input_ids: torch.Tensor,
|
97 |
+
decoder_input_ids: torch.Tensor,
|
98 |
+
attention_mask: torch.Tensor,
|
99 |
+
decoder_attention_mask: torch.Tensor,
|
100 |
+
) -> torch.Tensor:
|
101 |
+
"""
|
102 |
+
Make a forward pass through the model to get the logits for the next tokens
|
103 |
+
:param input_ids: (world_size, num_beams, input_length)
|
104 |
+
:param decoder_input_ids: (world_size, num_beams, partial_target_length)
|
105 |
+
:param attention_mask: (world_size, num_beams, input_length)
|
106 |
+
:param decoder_attention_mask: (world_size, num_beams, partial_target_length)
|
107 |
+
:return: logits: (world_size, num_beams, vocab_size)
|
108 |
+
"""
|
109 |
+
with torch.no_grad():
|
110 |
+
world_size, num_beams = input_ids.shape[0], decoder_input_ids.shape[1]
|
111 |
+
|
112 |
+
input_ids = input_ids.view(world_size * num_beams, input_ids.shape[2]).to(self.device)
|
113 |
+
attention_mask = attention_mask.view(
|
114 |
+
world_size * num_beams, attention_mask.shape[2]
|
115 |
+
).to(self.device)
|
116 |
+
|
117 |
+
decoder_input_ids = decoder_input_ids.view(
|
118 |
+
world_size * num_beams, decoder_input_ids.shape[2]
|
119 |
+
).to(self.device)
|
120 |
+
|
121 |
+
decoder_attention_mask = decoder_attention_mask.view(
|
122 |
+
world_size * num_beams, decoder_attention_mask.shape[2]
|
123 |
+
).to(self.device)
|
124 |
+
|
125 |
+
outputs = self.model(
|
126 |
+
input_ids=input_ids,
|
127 |
+
attention_mask=attention_mask,
|
128 |
+
decoder_input_ids=decoder_input_ids,
|
129 |
+
decoder_attention_mask=decoder_attention_mask,
|
130 |
+
)
|
131 |
+
logits = outputs.logits[..., -1, :]
|
132 |
+
|
133 |
+
logits = logits.view(self.world_size, num_beams, logits.shape[-1])
|
134 |
+
|
135 |
+
# return the probability of the next token when conditioned on the source text (world_size)
|
136 |
+
# and the partial target text (num_beam)
|
137 |
+
return logits
|
138 |
+
|
139 |
+
def duplicate_and_align_input_ids(
|
140 |
+
self,
|
141 |
+
input_ids: torch.Tensor,
|
142 |
+
input_ids_mask: torch.Tensor,
|
143 |
+
decoder_input_ids: torch.Tensor,
|
144 |
+
decoder_input_ids_mask: torch.Tensor,
|
145 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
146 |
+
"""
|
147 |
+
Duplicate the input_ids and decoder_input_ids to have all pairs of input_ids[i] and decoder_input_ids[j]
|
148 |
+
It uses torch.repeat and torch.repeat_interleave to do get something like:
|
149 |
+
a 1
|
150 |
+
a 2
|
151 |
+
a 3
|
152 |
+
b 1
|
153 |
+
b 2
|
154 |
+
b 3
|
155 |
+
...
|
156 |
+
:param input_ids: (world_size, input_length)
|
157 |
+
:param decoder_input_ids: (num_beam, partial_target_length)
|
158 |
+
:return: input_ids: (world_size, num_beam, input_length)
|
159 |
+
decoder_input_ids: (world_size, num_beam, partial_target_length)
|
160 |
+
aligned such that all pairs of input_ids[i] and decoder_input_ids[j] are present
|
161 |
+
"""
|
162 |
+
|
163 |
+
num_beams = decoder_input_ids.shape[0]
|
164 |
+
|
165 |
+
input_ids = input_ids.unsqueeze(1).repeat(1, num_beams, 1)
|
166 |
+
input_ids_mask = input_ids_mask.unsqueeze(1).repeat(1, num_beams, 1)
|
167 |
+
|
168 |
+
# repeat interleave
|
169 |
+
decoder_input_ids = decoder_input_ids.repeat_interleave(self.world_size, dim=0)
|
170 |
+
decoder_input_ids_mask = decoder_input_ids_mask.repeat_interleave(
|
171 |
+
self.world_size, dim=0
|
172 |
+
)
|
173 |
+
|
174 |
+
decoder_input_ids = decoder_input_ids.view(self.world_size, num_beams, -1)
|
175 |
+
decoder_input_ids_mask = decoder_input_ids_mask.view(
|
176 |
+
self.world_size, num_beams, -1
|
177 |
+
)
|
178 |
+
|
179 |
+
# print(self.tokenizer.batch_decode(input_ids[0]))
|
180 |
+
# print(self.tokenizer.batch_decode(decoder_input_ids[0]))
|
181 |
+
|
182 |
+
return input_ids, input_ids_mask, decoder_input_ids, decoder_input_ids_mask
|
183 |
+
|
184 |
+
def compute_rsa_probas(
|
185 |
+
self,
|
186 |
+
input_ids: torch.Tensor,
|
187 |
+
attention_mask: torch.Tensor,
|
188 |
+
decoder_input_ids: torch.Tensor,
|
189 |
+
decoder_attention_mask: torch.Tensor,
|
190 |
+
do_sample: bool = True,
|
191 |
+
top_p: Optional[float] = None,
|
192 |
+
top_k: Optional[int] = None,
|
193 |
+
temperature: float = 1.0,
|
194 |
+
rationality: float = 8.0, # seems to be a good value
|
195 |
+
process_logits_before_rsa: bool = True,
|
196 |
+
beam_scores: torch.Tensor = None,
|
197 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
198 |
+
"""
|
199 |
+
|
200 |
+
:param input_ids: input_ids to the encoder/decoder model = source texts
|
201 |
+
:param attention_mask: attention_mask to the encoder/decoder model
|
202 |
+
:param decoder_input_ids: decoder ids / partial summaries
|
203 |
+
:param decoder_attention_mask: attention mask for the decoder
|
204 |
+
:param do_sample: are we planning on sampling the tokens or using argmax (to apply or not the logits processor)
|
205 |
+
:param top_p: parameters for the logits processor top p
|
206 |
+
:param top_k: parameters for the logits processor top k
|
207 |
+
:param temperature: sampling temperature
|
208 |
+
:param rationality: how rational is the speaker (higher means more rational)
|
209 |
+
:param process_logits_before_rsa: should we apply the logits processor before or after the RSA computation
|
210 |
+
:param beam_scores: (world_size, num_beams) the scores of the beams to be added to the logits
|
211 |
+
:return: S1, L1: (world_size, num_beam, vocab_size).
|
212 |
+
"""
|
213 |
+
|
214 |
+
# some sanity checks
|
215 |
+
assert (top_p is None) or (
|
216 |
+
top_k is None
|
217 |
+
), "top_p and top_k cannot be used together"
|
218 |
+
assert ((top_p is not None) and (do_sample)) or (
|
219 |
+
top_p is None
|
220 |
+
), "top_p can only be used with sampling"
|
221 |
+
assert ((top_k is not None) and (do_sample)) or (
|
222 |
+
top_k is None
|
223 |
+
), "top_k can only be used with sampling"
|
224 |
+
|
225 |
+
# duplicate the input_ids and decoder_input_ids to have all pairs of input_ids[i] and decoder_input_ids[j]
|
226 |
+
(
|
227 |
+
input_ids,
|
228 |
+
attention_mask,
|
229 |
+
decoder_input_ids,
|
230 |
+
decoder_attention_mask,
|
231 |
+
) = self.duplicate_and_align_input_ids(
|
232 |
+
input_ids,
|
233 |
+
attention_mask,
|
234 |
+
decoder_input_ids,
|
235 |
+
decoder_attention_mask,
|
236 |
+
)
|
237 |
+
|
238 |
+
logits = (
|
239 |
+
self.fwd_pass(
|
240 |
+
input_ids, decoder_input_ids, attention_mask, decoder_attention_mask
|
241 |
+
)
|
242 |
+
/ temperature # apply the temperature
|
243 |
+
)
|
244 |
+
|
245 |
+
logits = torch.nn.functional.log_softmax(logits, dim=-1)
|
246 |
+
|
247 |
+
world_size = input_ids.shape[0]
|
248 |
+
num_beams = decoder_input_ids.shape[1]
|
249 |
+
|
250 |
+
logits = logits.view(world_size * num_beams, -1)
|
251 |
+
|
252 |
+
if do_sample and process_logits_before_rsa:
|
253 |
+
if top_p is not None:
|
254 |
+
logits = TopPLogitsWarper(top_p=top_p)(input_ids=None, scores=logits)
|
255 |
+
if top_k is not None:
|
256 |
+
logits = TopKLogitsWarper(top_k=top_k)(input_ids=None, scores=logits)
|
257 |
+
|
258 |
+
logits = logits.view(world_size, num_beams, -1)
|
259 |
+
|
260 |
+
if beam_scores is not None:
|
261 |
+
logits = logits + beam_scores[None, ..., None]
|
262 |
+
|
263 |
+
# compute the RSA probabilities
|
264 |
+
S1, L1 = compute_rsa_probas(logits, self.prior, rationality=rationality)
|
265 |
+
logits = S1
|
266 |
+
|
267 |
+
if do_sample and not process_logits_before_rsa:
|
268 |
+
logits = logits.view(world_size * num_beams, -1)
|
269 |
+
if top_p is not None:
|
270 |
+
logits = TopPLogitsWarper(top_p=top_p)(input_ids=None, scores=logits)
|
271 |
+
if top_k is not None:
|
272 |
+
logits = TopKLogitsWarper(top_k=top_k)(input_ids=None, scores=logits)
|
273 |
+
|
274 |
+
logits = logits.view(world_size, num_beams, -1)
|
275 |
+
|
276 |
+
return logits, L1
|
277 |
+
|
278 |
+
def generate(
|
279 |
+
self,
|
280 |
+
target_id: int,
|
281 |
+
source_texts_ids: torch.Tensor,
|
282 |
+
source_text_attention_mask: torch.Tensor,
|
283 |
+
max_length: int = 100,
|
284 |
+
num_beams: int = 8,
|
285 |
+
do_sample=True,
|
286 |
+
top_p: Optional[float] = None,
|
287 |
+
top_k: Optional[int] = None,
|
288 |
+
temperature: float = 1.0,
|
289 |
+
rationality: float = 1.0,
|
290 |
+
process_logits_before_rsa=True,
|
291 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
292 |
+
"""
|
293 |
+
|
294 |
+
:param target_id: the id of the target object
|
295 |
+
:param source_texts_ids: (world_size, input_length) the tokenized source texts
|
296 |
+
:param source_text_attention_mask: (world_size, input_length) the attention mask for the source texts
|
297 |
+
:param max_length: the maximum length to generate
|
298 |
+
:param do_sample: are we sampling or using argmax
|
299 |
+
:param top_p: parameters for the logits processor top p
|
300 |
+
:param top_k: parameters for the logits processor top k
|
301 |
+
:param temperature: sampling temperature
|
302 |
+
:param rationality: how rational is the speaker (higher means more rational)
|
303 |
+
:param process_logits_before_rsa: should we apply the logits processor before or after the RSA computation
|
304 |
+
:return: decoder_input_ids : (num_beams, max_length) decoded sequences, beam_scores: (num_beams) the scores
|
305 |
+
of the beams
|
306 |
+
"""
|
307 |
+
|
308 |
+
self.num_beam = num_beams
|
309 |
+
self.world_size = source_texts_ids.shape[0]
|
310 |
+
|
311 |
+
self.prior = torch.ones((self.world_size, self.num_beam)).to(self.device) / self.world_size
|
312 |
+
beam_scores = torch.zeros(self.num_beam).to(self.device)
|
313 |
+
|
314 |
+
# initialize the decoder input ids
|
315 |
+
decoder_input_ids = torch.full(
|
316 |
+
(self.num_beam, 2),
|
317 |
+
0,
|
318 |
+
dtype=torch.long,
|
319 |
+
device=self.device,
|
320 |
+
)
|
321 |
+
|
322 |
+
# initialize the decoder attention mask
|
323 |
+
decoder_attention_mask = torch.ones_like(decoder_input_ids).to(self.device)
|
324 |
+
|
325 |
+
new_beams = []
|
326 |
+
finished_beams = []
|
327 |
+
|
328 |
+
# run the beam search
|
329 |
+
for t in range(max_length):
|
330 |
+
# compute the RSA probabilities
|
331 |
+
num_beams = decoder_input_ids.shape[0]
|
332 |
+
|
333 |
+
S1, L1 = self.compute_rsa_probas(
|
334 |
+
source_texts_ids,
|
335 |
+
source_text_attention_mask,
|
336 |
+
decoder_input_ids,
|
337 |
+
decoder_attention_mask,
|
338 |
+
do_sample=do_sample,
|
339 |
+
top_p=top_p,
|
340 |
+
top_k=top_k,
|
341 |
+
temperature=temperature,
|
342 |
+
rationality=rationality,
|
343 |
+
beam_scores=beam_scores,
|
344 |
+
process_logits_before_rsa=process_logits_before_rsa,
|
345 |
+
)
|
346 |
+
|
347 |
+
# sample from the probabilities
|
348 |
+
idx_beam, idx_token, tokens_scores = sample_from_probs(
|
349 |
+
S1[target_id].squeeze(), num_beams, do_sample
|
350 |
+
)
|
351 |
+
|
352 |
+
# create all the new beams
|
353 |
+
|
354 |
+
new_beams = []
|
355 |
+
|
356 |
+
for idx_t, idx_b, token_score in zip(idx_token, idx_beam, tokens_scores):
|
357 |
+
new_beams.append(
|
358 |
+
(
|
359 |
+
decoder_input_ids[idx_b].tolist() + [idx_t.item()],
|
360 |
+
beam_scores[idx_b] + token_score.item(),
|
361 |
+
L1[:, idx_b, idx_t.item()],
|
362 |
+
)
|
363 |
+
)
|
364 |
+
|
365 |
+
# sort the beams
|
366 |
+
new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
|
367 |
+
|
368 |
+
# keep only the best beams
|
369 |
+
new_beams = new_beams[: self.num_beam]
|
370 |
+
|
371 |
+
# check if the beams are finished
|
372 |
+
_new_beams = []
|
373 |
+
for beam in new_beams:
|
374 |
+
if beam[0][-1] == self.tokenizer.eos_token_id:
|
375 |
+
finished_beams.append(beam)
|
376 |
+
|
377 |
+
else:
|
378 |
+
_new_beams.append(beam)
|
379 |
+
|
380 |
+
new_beams = _new_beams
|
381 |
+
|
382 |
+
if len(new_beams) == 0:
|
383 |
+
break
|
384 |
+
|
385 |
+
# pad the beams
|
386 |
+
max_beam_len = max(len(x[0]) for x in new_beams)
|
387 |
+
new_beams = [
|
388 |
+
(
|
389 |
+
x[0] + [self.tokenizer.pad_token_id] * (max_beam_len - len(x[0])),
|
390 |
+
x[1],
|
391 |
+
x[2],
|
392 |
+
)
|
393 |
+
for x in new_beams
|
394 |
+
]
|
395 |
+
|
396 |
+
# update the beam scores
|
397 |
+
beam_scores = torch.tensor([x[1] for x in new_beams]).to(self.device)
|
398 |
+
|
399 |
+
# update the decoder input ids
|
400 |
+
decoder_input_ids: torch.Tensor = torch.tensor(
|
401 |
+
[x[0] for x in new_beams], device=self.device
|
402 |
+
)
|
403 |
+
|
404 |
+
# update the decoder attention mask based on pad tokens
|
405 |
+
decoder_attention_mask = (
|
406 |
+
decoder_input_ids != self.tokenizer.pad_token_id
|
407 |
+
).long()
|
408 |
+
|
409 |
+
self.prior = torch.stack([x[2] for x in new_beams], dim=1).to(self.device)
|
410 |
+
|
411 |
+
# self.prior = torch.ones((self.world_size, len(new_beams))) / self.world_size
|
412 |
+
|
413 |
+
results = []
|
414 |
+
|
415 |
+
# pad the beams
|
416 |
+
max_beam_len = max(len(x[0]) for x in finished_beams + new_beams)
|
417 |
+
for x in finished_beams + new_beams:
|
418 |
+
results.append(
|
419 |
+
(
|
420 |
+
x[0] + [self.tokenizer.pad_token_id] * (max_beam_len - len(x[0])),
|
421 |
+
x[1],
|
422 |
+
x[2],
|
423 |
+
)
|
424 |
+
)
|
425 |
+
|
426 |
+
decoder_input_ids = torch.tensor([x[0] for x in results], device=self.device)
|
427 |
+
|
428 |
+
beam_scores = torch.tensor([x[1] for x in results]).to(self.device)
|
429 |
+
|
430 |
+
return decoder_input_ids, beam_scores
|
glimpse-ui/glimpse/rsasumm/rsa_reranker.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import cache
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import pandas as pd
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def kl_divergence(p, q):
|
11 |
+
"""
|
12 |
+
Compute the KL divergence between two distributions
|
13 |
+
"""
|
14 |
+
return torch.nan_to_num(p * (p / q).log(), nan=0.0).sum(-1)
|
15 |
+
|
16 |
+
|
17 |
+
def jensen_shannon_divergence(p, q):
|
18 |
+
"""
|
19 |
+
Compute the Jensen-Shannon divergence between two distributions
|
20 |
+
"""
|
21 |
+
m = 0.5 * (p + q)
|
22 |
+
return 0.5 * (kl_divergence(p, m) + kl_divergence(q, m))
|
23 |
+
|
24 |
+
|
25 |
+
class RSAReranking:
|
26 |
+
"""
|
27 |
+
Rerank a list of candidates according to the RSA model.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
model,
|
33 |
+
tokenizer,
|
34 |
+
candidates: List[str],
|
35 |
+
source_texts: List[str],
|
36 |
+
batch_size: int = 32,
|
37 |
+
rationality: int = 1,
|
38 |
+
device="cuda",
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
:param model: hf model used to compute the likelihoods (supposed to be a seq2seq model), is S0 in the RSA model
|
42 |
+
:param tokenizer:
|
43 |
+
:param candidates: list of candidates summaries
|
44 |
+
:param source_texts: list of source texts
|
45 |
+
:param batch_size: batch size used to compute the likelihoods (can be high since we don't need gradients and
|
46 |
+
it's a single forward pass)
|
47 |
+
:param rationality: rationality parameter of the RSA model
|
48 |
+
:param device: device used to compute the likelihoods
|
49 |
+
"""
|
50 |
+
self.model = model
|
51 |
+
self.device = device
|
52 |
+
self.model = model.to(self.device)
|
53 |
+
self.tokenizer = tokenizer
|
54 |
+
|
55 |
+
|
56 |
+
self.candidates = candidates
|
57 |
+
self.source_texts = source_texts
|
58 |
+
|
59 |
+
self.batch_size = batch_size
|
60 |
+
self.rationality = rationality
|
61 |
+
|
62 |
+
def compute_conditionned_likelihood(
|
63 |
+
self, x: List[str], y: List[str], mean: bool = True
|
64 |
+
) -> torch.Tensor:
|
65 |
+
"""
|
66 |
+
Compute the likelihood of y given x
|
67 |
+
|
68 |
+
:param x: list of source texts len(x) = batch_size
|
69 |
+
:param y: list of candidates summaries len(y) = batch_size
|
70 |
+
:param mean: average the likelihoods over the tokens of y or take the sum
|
71 |
+
:return: tensor of shape (batch_size) containing the likelihoods of y given x
|
72 |
+
"""
|
73 |
+
|
74 |
+
# Ensure x,y are pure Python lists of strings (not pandas.Series, np.ndarray, etc.)
|
75 |
+
x = [str(item) for item in list(x)]
|
76 |
+
y = [str(item) for item in list(y)]
|
77 |
+
assert len(x) == len(y), "x and y must have the same length"
|
78 |
+
|
79 |
+
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
|
80 |
+
batch_size = len(x)
|
81 |
+
|
82 |
+
x = self.tokenizer(
|
83 |
+
x,
|
84 |
+
return_tensors="pt",
|
85 |
+
padding=True,
|
86 |
+
truncation=True,
|
87 |
+
max_length=1024,
|
88 |
+
)
|
89 |
+
y = self.tokenizer(
|
90 |
+
y,
|
91 |
+
return_tensors="pt",
|
92 |
+
padding=True,
|
93 |
+
truncation=True,
|
94 |
+
max_length=1024,
|
95 |
+
)
|
96 |
+
|
97 |
+
# Move all tensors to the correct device
|
98 |
+
x = {k: v.to(self.device) for k, v in x.items()}
|
99 |
+
y = {k: v.to(self.device) for k, v in y.items()}
|
100 |
+
|
101 |
+
# Concatenate the two inputs
|
102 |
+
# Compute the likelihood of y given x
|
103 |
+
|
104 |
+
x_ids = x["input_ids"]
|
105 |
+
y_ids = y["input_ids"]
|
106 |
+
|
107 |
+
logits = self.model(
|
108 |
+
input_ids=x_ids,
|
109 |
+
decoder_input_ids=y_ids,
|
110 |
+
attention_mask=x["attention_mask"],
|
111 |
+
decoder_attention_mask=y["attention_mask"],
|
112 |
+
).logits
|
113 |
+
|
114 |
+
# Compute the likelihood of y given x
|
115 |
+
|
116 |
+
shifted_logits = logits[..., :-1, :].contiguous()
|
117 |
+
shifted_ids = y_ids[..., 1:].contiguous()
|
118 |
+
|
119 |
+
likelihood = -loss_fn(
|
120 |
+
shifted_logits.view(-1, shifted_logits.size(-1)), shifted_ids.view(-1)
|
121 |
+
)
|
122 |
+
|
123 |
+
likelihood = likelihood.view(batch_size, -1).sum(-1)
|
124 |
+
if mean:
|
125 |
+
likelihood /= (y_ids != self.tokenizer.pad_token_id).float().sum(-1)
|
126 |
+
|
127 |
+
return likelihood
|
128 |
+
|
129 |
+
def score(self, x: List[str], y: List[str], **kwargs):
|
130 |
+
return self.compute_conditionned_likelihood(x, y, **kwargs)
|
131 |
+
|
132 |
+
def likelihood_matrix(self) -> torch.Tensor:
|
133 |
+
"""
|
134 |
+
:return: likelihood matrix : (world_size, num_candidates), likelihood[i, j] is the likelihood of
|
135 |
+
candidate j being a summary for source text i.
|
136 |
+
"""
|
137 |
+
likelihood_matrix = torch.zeros(
|
138 |
+
(len(self.source_texts), len(self.candidates))
|
139 |
+
).to(self.device)
|
140 |
+
|
141 |
+
pairs = []
|
142 |
+
for i, source_text in enumerate(self.source_texts):
|
143 |
+
for j, candidate in enumerate(self.candidates):
|
144 |
+
pairs.append((i, j, source_text, candidate))
|
145 |
+
|
146 |
+
# split the pairs into batches
|
147 |
+
batches = [
|
148 |
+
pairs[i: i + self.batch_size]
|
149 |
+
for i in range(0, len(pairs), self.batch_size)
|
150 |
+
]
|
151 |
+
|
152 |
+
for batch in tqdm(batches):
|
153 |
+
# get the source texts and candidates
|
154 |
+
source_texts = [pair[2] for pair in batch]
|
155 |
+
candidates = [pair[3] for pair in batch]
|
156 |
+
|
157 |
+
# compute the likelihoods
|
158 |
+
with torch.no_grad():
|
159 |
+
likelihoods = self.score(
|
160 |
+
source_texts, candidates, mean=True
|
161 |
+
)
|
162 |
+
|
163 |
+
# fill the matrix
|
164 |
+
for k, (i, j, _, _) in enumerate(batch):
|
165 |
+
likelihood_matrix[i, j] = likelihoods[k].detach()
|
166 |
+
|
167 |
+
return likelihood_matrix
|
168 |
+
|
169 |
+
@cache
|
170 |
+
def S(self, t):
|
171 |
+
if t == 0:
|
172 |
+
return self.initial_speaker_probas
|
173 |
+
else:
|
174 |
+
listener = self.L(t - 1)
|
175 |
+
prod = listener * self.rationality # + self.initial_speaker_probas.sum(0, keepdim=True)
|
176 |
+
return torch.log_softmax(prod, dim=-1)
|
177 |
+
|
178 |
+
@cache
|
179 |
+
def L(self, t):
|
180 |
+
speaker = self.S(t)
|
181 |
+
return torch.log_softmax(speaker, dim=-2)
|
182 |
+
|
183 |
+
def mk_listener_dataframe(self, t):
|
184 |
+
self.initial_speaker_probas = self.likelihood_matrix()
|
185 |
+
|
186 |
+
initial_listener_probas = self.L(0)
|
187 |
+
|
188 |
+
# compute consensus
|
189 |
+
uniform_distribution_over_source_texts = torch.ones_like(
|
190 |
+
initial_listener_probas
|
191 |
+
) / len(self.source_texts)
|
192 |
+
|
193 |
+
initital_consensuality_score = (
|
194 |
+
torch.exp(initial_listener_probas)
|
195 |
+
* (
|
196 |
+
initial_listener_probas - torch.log(uniform_distribution_over_source_texts)
|
197 |
+
)
|
198 |
+
).sum(0).cpu().numpy()
|
199 |
+
|
200 |
+
initital_consensuality_score = pd.Series(initital_consensuality_score, index=self.candidates)
|
201 |
+
|
202 |
+
initial_listener_probas = initial_listener_probas.cpu().numpy()
|
203 |
+
|
204 |
+
initial_listener_probas = pd.DataFrame(initial_listener_probas)
|
205 |
+
initial_listener_probas.index = self.source_texts
|
206 |
+
initial_listener_probas.columns = self.candidates
|
207 |
+
|
208 |
+
initial_speaker_probas = self.S(0).cpu().numpy()
|
209 |
+
initial_speaker_probas = pd.DataFrame(initial_speaker_probas)
|
210 |
+
initial_speaker_probas.index = self.source_texts
|
211 |
+
initial_speaker_probas.columns = self.candidates
|
212 |
+
|
213 |
+
listener_df = pd.DataFrame(self.L(t).cpu().numpy())
|
214 |
+
|
215 |
+
consensuality_scores = (
|
216 |
+
torch.exp(self.L(t))
|
217 |
+
* (self.L(t) - torch.log(uniform_distribution_over_source_texts))
|
218 |
+
).sum(0).cpu().numpy()
|
219 |
+
|
220 |
+
consensuality_scores = pd.Series(consensuality_scores, index=self.candidates)
|
221 |
+
|
222 |
+
S = self.S(t).cpu().numpy()
|
223 |
+
speaker_df = pd.DataFrame(S)
|
224 |
+
|
225 |
+
# add the source texts and candidates as index
|
226 |
+
|
227 |
+
listener_df.index = self.source_texts
|
228 |
+
speaker_df.index = self.source_texts
|
229 |
+
|
230 |
+
listener_df.columns = self.candidates
|
231 |
+
speaker_df.columns = self.candidates
|
232 |
+
|
233 |
+
return listener_df, speaker_df, initial_listener_probas, initial_speaker_probas, initital_consensuality_score, consensuality_scores
|
234 |
+
|
235 |
+
def rerank(self, t=1):
|
236 |
+
"""
|
237 |
+
return the best summary (according to rsa) for each text
|
238 |
+
"""
|
239 |
+
(
|
240 |
+
listener_df,
|
241 |
+
speaker_df,
|
242 |
+
initial_listener_proba,
|
243 |
+
initial_speaker_proba,
|
244 |
+
initital_consensuality_score,
|
245 |
+
consensuality_scores,
|
246 |
+
) = self.mk_listener_dataframe(t=t)
|
247 |
+
best_rsa = speaker_df.idxmax(axis=1).values
|
248 |
+
best_base = initial_listener_proba.idxmax(axis=1).values
|
249 |
+
|
250 |
+
return (
|
251 |
+
best_rsa,
|
252 |
+
best_base,
|
253 |
+
speaker_df,
|
254 |
+
listener_df,
|
255 |
+
initial_listener_proba,
|
256 |
+
initial_speaker_proba,
|
257 |
+
initital_consensuality_score,
|
258 |
+
consensuality_scores,
|
259 |
+
)
|
260 |
+
|
261 |
+
|
262 |
+
class RSARerankingEmbedder(RSAReranking):
|
263 |
+
def __init__(self, *args, **kwargs):
|
264 |
+
super().__init__(*args, **kwargs)
|
265 |
+
|
266 |
+
def compute_embeddings(self, x: List[str], y: List[str], **kwargs):
|
267 |
+
model_kwargs = kwargs.get("model_kwargs")
|
268 |
+
|
269 |
+
# shape: (batch_size, embedding_dim)
|
270 |
+
x_embeddings = self.model.encode(x, **model_kwargs)
|
271 |
+
y_embeddings = self.model.encode(y, **model_kwargs)
|
272 |
+
|
273 |
+
# dot product between the embeddings : shape (batch_size)
|
274 |
+
dot_products = (x_embeddings * y_embeddings).sum(-1)
|
275 |
+
|
276 |
+
return dot_products
|
277 |
+
|
278 |
+
def score(self, x: List[str], y: List[str], **kwargs):
|
279 |
+
return self.compute_embeddings(x, y, **kwargs)
|
280 |
+
|
glimpse-ui/glimpse/scripts/abstractive.sh
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition=main # Ask for unkillable job
|
3 |
+
#SBATCH --gres=gpu:1
|
4 |
+
#SBATCH --mem=10G # Ask for 10 GB of RAM
|
5 |
+
#SBATCH --time=2:00:00 # The job will run for 3 hours
|
6 |
+
#SBATCH --output=./logs/abstractive_out.txt
|
7 |
+
#SBATCH --error=./logs/abstractive_error.txt
|
8 |
+
#SBATCH -c 2
|
9 |
+
|
10 |
+
|
11 |
+
# Load the required modules
|
12 |
+
module --quiet load miniconda/3
|
13 |
+
module --quiet load cuda/12.1.1
|
14 |
+
conda activate "glimpse"
|
15 |
+
|
16 |
+
# Check if input file path is provided and valid
|
17 |
+
if [ -z "$1" ] || [ ! -f "$1" ]; then
|
18 |
+
# if no path is provided, or the path is invalid, use the default test dataset
|
19 |
+
echo "Couldn't find a valid path. Using default path: data/processed/all_reviews_2017.csv"
|
20 |
+
dataset_path="data/processed/all_reviews_2017.csv"
|
21 |
+
else
|
22 |
+
dataset_path="$1"
|
23 |
+
fi
|
24 |
+
|
25 |
+
|
26 |
+
# Generate abstractive summaries
|
27 |
+
if [[ "$@" =~ "--add-padding" ]]; then # check if padding argument is present
|
28 |
+
# add '--no-trimming' flag to the script
|
29 |
+
candidates=$(python glimpse/data_loading/generate_abstractive_candidates.py --dataset_path "$dataset_path" --scripted-run --no-trimming | tail -n 1)
|
30 |
+
else
|
31 |
+
# no additional flags
|
32 |
+
candidates=$(python glimpse/data_loading/generate_abstractive_candidates.py --dataset_path "$dataset_path" --scripted-run | tail -n 1)
|
33 |
+
fi
|
34 |
+
|
35 |
+
|
36 |
+
# Compute the RSA scores based on the generated summaries
|
37 |
+
rsa_scores=$(python glimpse/src/compute_rsa.py --summaries $candidates | tail -n 1)
|
glimpse-ui/glimpse/scripts/extractive.sh
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition=main # Ask for unkillable job
|
3 |
+
#SBATCH --gres=gpu:1
|
4 |
+
#SBATCH --mem=10G # Ask for 10 GB of RAM
|
5 |
+
#SBATCH --time=2:00:00 # The job will run for 3 hours
|
6 |
+
#SBATCH --output=./logs/abstractive_out.txt
|
7 |
+
#SBATCH --error=./logs/abstractive_error.txt
|
8 |
+
#SBATCH -c 2
|
9 |
+
|
10 |
+
|
11 |
+
# Load the required modules
|
12 |
+
module --quiet load miniconda/3
|
13 |
+
module --quiet load cuda/12.1.1
|
14 |
+
conda activate "glimpse"
|
15 |
+
|
16 |
+
|
17 |
+
# Check if input file path is provided and valid
|
18 |
+
if [ -z "$1" ] || [ ! -f "$1" ]; then
|
19 |
+
# if no path is provided, or the path is invalid, use the default test dataset
|
20 |
+
echo "Couldn't find a valid path. Using default path: data/processed/all_reviews_2017.csv"
|
21 |
+
dataset_path="data/processed/all_reviews_2017.csv"
|
22 |
+
else
|
23 |
+
dataset_path="$1"
|
24 |
+
fi
|
25 |
+
|
26 |
+
# Generate extractive summaries
|
27 |
+
candidates=$(python glimpse/data_loading/generate_extractive_candidates.py --dataset_path "$dataset_path" --scripted-run | tail -n 1)
|
28 |
+
|
29 |
+
# Compute the RSA scores based on the generated summaries
|
30 |
+
rsa_scores=$(python glimpse/src/compute_rsa.py --summaries $candidates | tail -n 1)
|
31 |
+
|
glimpse-ui/glimpse_pk_csv_converter.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import pandas as pd
|
3 |
+
from pathlib import Path
|
4 |
+
import os
|
5 |
+
import glob
|
6 |
+
import re
|
7 |
+
import json
|
8 |
+
|
9 |
+
def process_pickle_results(pickle_path: Path, output_path: Path):
|
10 |
+
# === Load Pickle File ===
|
11 |
+
with open(pickle_path, 'rb') as f:
|
12 |
+
data = pickle.load(f)
|
13 |
+
|
14 |
+
# === Extract Metadata ===
|
15 |
+
reranking_model = data.get('metadata/reranking_model')
|
16 |
+
rsa_iterations = data.get('metadata/rsa_iterations')
|
17 |
+
results = data.get('results')
|
18 |
+
|
19 |
+
# print(f"Reranking model: {reranking_model}, RSA iterations: {rsa_iterations}")
|
20 |
+
|
21 |
+
# === Validate Results ===
|
22 |
+
if not isinstance(results, list):
|
23 |
+
raise ValueError("The 'results' key is not a list. Please check the pickle file structure.")
|
24 |
+
|
25 |
+
# === Process and Flatten Results ===
|
26 |
+
csv_data = []
|
27 |
+
for index, result in enumerate(results):
|
28 |
+
# row = {
|
29 |
+
# 'index': index,
|
30 |
+
# 'id': str(result.get('id')[0]),
|
31 |
+
# 'consensuality_scores': result.get('consensuality_scores').to_dict()
|
32 |
+
# if isinstance(result.get('consensuality_scores'), pd.Series) else None,
|
33 |
+
|
34 |
+
# # Optional fields — uncomment as needed
|
35 |
+
# # 'best_base': result.get('best_base').tolist() if isinstance(result.get('best_base'), np.ndarray) else None,
|
36 |
+
# # 'best_rsa': result.get('best_rsa').tolist() if isinstance(result.get('best_rsa'), np.ndarray) else None,
|
37 |
+
# # 'speaker_df': result.get('speaker_df').to_json() if isinstance(result.get('speaker_df'), pd.DataFrame) else None,
|
38 |
+
# # 'listener_df': result.get('listener_df').to_json() if isinstance(result.get('listener_df'), pd.DataFrame) else None,
|
39 |
+
# # 'initial_listener': result.get('initial_listener').to_json() if isinstance(result.get('initial_listener'), pd.DataFrame) else None,
|
40 |
+
# # 'language_model_proba_df': result.get('language_model_proba_df').to_json() if isinstance(result.get('language_model_proba_df'), pd.DataFrame) else None,
|
41 |
+
# # 'initial_consensuality_scores': result.get('initial_consensuality_scores').to_dict() if isinstance(result.get('initial_consensuality_scores'), pd.Series) else None,
|
42 |
+
# # 'gold': result.get('gold'),
|
43 |
+
# # 'rationality': result.get('rationality'),
|
44 |
+
# # 'text_candidates': result.get('text_candidates').to_json() if isinstance(result.get('text_candidates'), pd.DataFrame) else None,
|
45 |
+
# }
|
46 |
+
|
47 |
+
|
48 |
+
row = {
|
49 |
+
'index': index,
|
50 |
+
'id': str(result.get('id')[0]),
|
51 |
+
'consensuality_scores': json.dumps(result.get('consensuality_scores').to_dict())
|
52 |
+
if isinstance(result.get('consensuality_scores'), pd.Series) else None,
|
53 |
+
}
|
54 |
+
|
55 |
+
csv_data.append(row)
|
56 |
+
|
57 |
+
# === Save to CSV ===
|
58 |
+
df = pd.DataFrame(csv_data)
|
59 |
+
df.to_csv(output_path, index=False)
|
60 |
+
print(f"Results saved to '{output_path}'.")
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
|
65 |
+
BASE_DIR = Path(__file__).resolve().parent
|
66 |
+
|
67 |
+
# Set the path to the pickle file and the output CSV file
|
68 |
+
# ==== Uncomment the appropriate line below to set the pickle file path ====
|
69 |
+
# pickle_file = BASE_DIR / "glimpse" / "output" / "extractive_sentences-_-all_reviews_2017-_-none-_-2025-05-20-20-22-18-_-r3-_-rsa_reranked-google-pegasus-arxiv.pk"
|
70 |
+
|
71 |
+
# ==== Find the latest file in the directory and use it instead ====
|
72 |
+
# This assumes the pickle files are stored in the 'glimpse/output' directory
|
73 |
+
# list_of_files = glob.glob('./glimpse/output/*.pk')
|
74 |
+
# pickle_file = max(list_of_files, key=os.path.getctime)
|
75 |
+
# print (f"Using pickle file: {pickle_file}")
|
76 |
+
|
77 |
+
# output_file = BASE_DIR / "data" / "GLIMPSE_results_from_pk.csv"
|
78 |
+
|
79 |
+
# process_pickle_results(pickle_file, output_file)
|
80 |
+
|
81 |
+
output_dir = BASE_DIR / "data"
|
82 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
83 |
+
|
84 |
+
pickle_files = sorted(glob.glob('./glimpse/output/*.pk'), key=os.path.getctime)
|
85 |
+
|
86 |
+
for pickle_file in pickle_files:
|
87 |
+
year_match = re.search(r'(\d{4})', os.path.basename(pickle_file))
|
88 |
+
year_tag = year_match.group(1) if year_match else 'unknown_year'
|
89 |
+
output_file = output_dir / f"GLIMPSE_results_{year_tag}.csv"
|
90 |
+
|
91 |
+
print(f"Using pickle file: {pickle_file}, saving as {output_file}")
|
92 |
+
process_pickle_results(Path(pickle_file), output_file)
|
glimpse-ui/interface/Demo.py
ADDED
@@ -0,0 +1,800 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import sys, os.path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
7 |
+
|
8 |
+
from glimpse.rsasumm.rsa_reranker import RSAReranking
|
9 |
+
import gradio as gr
|
10 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
from scored_reviews_builder import load_scored_reviews
|
14 |
+
from glimpse.glimpse.data_loading.Glimpse_tokenizer import glimpse_tokenizer
|
15 |
+
# from scibert.scibert_polarity.scibert_polarity import predict_polarity
|
16 |
+
|
17 |
+
# Load scored reviews
|
18 |
+
years, all_scored_reviews_df = load_scored_reviews()
|
19 |
+
|
20 |
+
# -----------------------------------
|
21 |
+
# Pre-processed Tab
|
22 |
+
# -----------------------------------
|
23 |
+
|
24 |
+
def get_preprocessed_scores(year):
|
25 |
+
scored_reviews = all_scored_reviews_df[all_scored_reviews_df["year"] == year]["scored_dict"].iloc[0]
|
26 |
+
return scored_reviews
|
27 |
+
|
28 |
+
|
29 |
+
# -----------------------------------
|
30 |
+
# Interactive Tab
|
31 |
+
# -----------------------------------
|
32 |
+
|
33 |
+
# RSA_model = "facebook/bart-large-cnn"
|
34 |
+
RSA_model = "sshleifer/distilbart-cnn-12-3"
|
35 |
+
|
36 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(RSA_model)
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained(RSA_model)
|
38 |
+
|
39 |
+
# Define the manual color map for topics
|
40 |
+
topic_color_map = {
|
41 |
+
"Substance": "#cce0ff", # lighter blue
|
42 |
+
"Clarity": "#e6ee9c", # lighter yellow-green
|
43 |
+
"Soundness/Correctness": "#ffcccc", # lighter red
|
44 |
+
"Originality": "#d1c4e9", # lighter purple
|
45 |
+
"Motivation/Impact": "#b2ebf2", # lighter teal
|
46 |
+
"Meaningful Comparison": "#fff9c4", # lighter yellow
|
47 |
+
"Replicability": "#c8e6c9", # lighter green
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
# GLIMPSE Home/Description Page
|
52 |
+
glimpse_description = """
|
53 |
+
# ReView: A Tool for Visualizing and Analyzing Scientific Reviews
|
54 |
+
|
55 |
+
## Overview
|
56 |
+
ReView is a visualization tool designed to assist **area chairs** and **researchers** in efficiently analyzing scholarly reviews. The interface offers two main ways to explore scholarly reviews:
|
57 |
+
- Pre-Processed Reviews: Explore real peer reviews from ICLR (2017–2021) with structured visualizations of sentiment, topics, and reviewer agreement.
|
58 |
+
- Interactive Tab: Enter your own reviews and view them analyzed in real time using the same NLP-powered highlighting options.
|
59 |
+
|
60 |
+
All reviews are shown in their original, unaltered form, with visual overlays to help identify key insights such as disagreements, sentiment and common themes—reducing cognitive load and scrolling effort.
|
61 |
+
|
62 |
+
---
|
63 |
+
## **Key Features**
|
64 |
+
- *Traceability and Transparency:* The tool preserves the original text of each review and overlays highlights for key aspects (e.g., sentiment, topic, agreement), allowing area chairs to trace back every insight to its source without modifying or summarizing the content.
|
65 |
+
- *Structured Overview*: All reviews are displayed in one interface and with radio buttons, one can navigate from one highlighting option to the other.
|
66 |
+
- *Interactive*: The tool allows users to input their own reviews and, within seconds, view them annotated with highlighted aspects
|
67 |
+
---
|
68 |
+
## **Highlighting Options**
|
69 |
+
- *Agreement:* Identifies both shared and conflicting points across reviews, helping to surface consensus and disagreement.
|
70 |
+
- *Polarity:* Highlights positive and negative sentiments within the reviews to reveal tone and stance.
|
71 |
+
- *Topic:* Organizes the review sentences by their discussed topics, ensuring coverage of diverse reviewer perspectives and improving clarity.
|
72 |
+
|
73 |
+
---
|
74 |
+
|
75 |
+
### How to Use ReView
|
76 |
+
|
77 |
+
ReView offers two main ways to explore peer reviews: using pre-processed reviews or by entering your own.
|
78 |
+
|
79 |
+
#### 🗂️ Pre-Processed Reviews Tab
|
80 |
+
|
81 |
+
Use this tab to explore reviews from ICLR (2017–2021):
|
82 |
+
|
83 |
+
1. **Select a conference year** from the dropdown menu on the right.
|
84 |
+
2. **Navigate between submissions** using the *Next* and *Previous* buttons on the left.
|
85 |
+
3. **Choose a highlighting view** using the radio buttons:
|
86 |
+
- **Original**: Displays unmodified review text.
|
87 |
+
- **Agreement**: Highlights consensus points in **red** and disagreements in **purple**.
|
88 |
+
- **Polarity**: Highlights **positive** sentiment in **green** and **negative** sentiment in **red**.
|
89 |
+
- **Topic**: Highlights comments by discussion topic using color-coded labels.
|
90 |
+
|
91 |
+
#### ✍️ Interactive Tab
|
92 |
+
|
93 |
+
Use this tab to analyze your own review text:
|
94 |
+
|
95 |
+
1. **Enter up to three reviews** in the input fields labeled *Review 1*, *Review 2*, and *Review 3*.
|
96 |
+
2. **Click "Process"** to analyze the input (average processing time: ~42 seconds).
|
97 |
+
3. **Explore the results** using the same highlighting options as above (Agreement, Polarity, Topic).
|
98 |
+
"""
|
99 |
+
|
100 |
+
|
101 |
+
EXAMPLES = [
|
102 |
+
"The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. I believe the authors missed Jane and al 2021. In addition, I think, there is a mistake in the math.",
|
103 |
+
"The paper gives really interesting insights on the topic of transfer learning. It is well presented and the experiment are extensive. Some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.",
|
104 |
+
"The paper gives really interesting insights on the topic of transfer learning. It is not well presented and lack experiments. Some parts remain really unclear and I would like to see a more detailed explanation of the proposed method.",
|
105 |
+
]
|
106 |
+
|
107 |
+
# Function to summarize the input texts using the RSAReranking model in interactive mode
|
108 |
+
def summarize(text1, text2, text3, focus, mode, rationality=1.0, iterations=1):
|
109 |
+
|
110 |
+
# print(focus, mode, rationality, iterations)
|
111 |
+
|
112 |
+
# get sentences for each text
|
113 |
+
text2_sentences = glimpse_tokenizer(text2)
|
114 |
+
text1_sentences = glimpse_tokenizer(text1)
|
115 |
+
text3_sentences = glimpse_tokenizer(text3)
|
116 |
+
|
117 |
+
|
118 |
+
# remove empty sentences
|
119 |
+
text1_sentences = [sentence for sentence in text1_sentences if sentence != ""]
|
120 |
+
text2_sentences = [sentence for sentence in text2_sentences if sentence != ""]
|
121 |
+
text3_sentences = [sentence for sentence in text3_sentences if sentence != ""]
|
122 |
+
|
123 |
+
sentences = list(set(text1_sentences + text2_sentences + text3_sentences))
|
124 |
+
|
125 |
+
# Load polarity model and tokenizer (SciBERT)
|
126 |
+
polarity_model_path = "scibert/scibert_polarity/final_model"
|
127 |
+
polarity_tokenizer = AutoTokenizer.from_pretrained(polarity_model_path)
|
128 |
+
polarity_model = AutoModelForSequenceClassification.from_pretrained(polarity_model_path)
|
129 |
+
polarity_model.eval()
|
130 |
+
polarity_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
131 |
+
polarity_model.to(polarity_device)
|
132 |
+
|
133 |
+
def predict_polarity(sent_list):
|
134 |
+
inputs = polarity_tokenizer(
|
135 |
+
sent_list, return_tensors="pt", padding=True, truncation=True, max_length=512
|
136 |
+
).to(polarity_device)
|
137 |
+
with torch.no_grad():
|
138 |
+
logits = polarity_model(**inputs).logits
|
139 |
+
preds = torch.argmax(logits, dim=1).cpu().tolist()
|
140 |
+
emoji_map = {0: "➖", 1: None, 2: "➕"}
|
141 |
+
return dict(zip(sent_list, [emoji_map[p] for p in preds]))
|
142 |
+
|
143 |
+
|
144 |
+
# Run polarity prediction
|
145 |
+
polarity_map = predict_polarity(sentences)
|
146 |
+
|
147 |
+
|
148 |
+
# Load topic model and tokenizer (SciBERT)
|
149 |
+
topic_model_path = "scibert/scibert_topic/final_model"
|
150 |
+
topic_tokenizer = AutoTokenizer.from_pretrained(topic_model_path)
|
151 |
+
topic_model = AutoModelForSequenceClassification.from_pretrained(topic_model_path)
|
152 |
+
topic_model.eval()
|
153 |
+
topic_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
154 |
+
topic_model.to(topic_device)
|
155 |
+
|
156 |
+
def predict_topic(sent_list):
|
157 |
+
inputs = topic_tokenizer(
|
158 |
+
sent_list, return_tensors="pt", padding=True, truncation=True, max_length=512
|
159 |
+
).to(topic_device)
|
160 |
+
with torch.no_grad():
|
161 |
+
logits = topic_model(**inputs).logits
|
162 |
+
preds = torch.argmax(logits, dim=1).cpu().tolist()
|
163 |
+
|
164 |
+
# Topic ID to label and emoji
|
165 |
+
id2label = {
|
166 |
+
0: "Substance",
|
167 |
+
1: "Clarity",
|
168 |
+
2: "Correctness",
|
169 |
+
3: "Originality",
|
170 |
+
4: "Impact",
|
171 |
+
5: "Comparison",
|
172 |
+
6: "Replicability",
|
173 |
+
7: None # This is used for sentences that do not match any specific topic,
|
174 |
+
}
|
175 |
+
return dict(zip(sent_list, [id2label[p] for p in preds]))
|
176 |
+
|
177 |
+
# Run topic prediction
|
178 |
+
topic_map = predict_topic(sentences)
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
rsa_reranker = RSAReranking(
|
183 |
+
model,
|
184 |
+
tokenizer,
|
185 |
+
candidates=sentences,
|
186 |
+
source_texts=[text1, text2, text3],
|
187 |
+
device="cpu",
|
188 |
+
rationality=rationality,
|
189 |
+
)
|
190 |
+
(
|
191 |
+
best_rsa,
|
192 |
+
best_base,
|
193 |
+
speaker_df,
|
194 |
+
listener_df,
|
195 |
+
initial_listener,
|
196 |
+
language_model_proba_df,
|
197 |
+
initial_consensuality_scores,
|
198 |
+
consensuality_scores,
|
199 |
+
) = rsa_reranker.rerank(t=iterations)
|
200 |
+
|
201 |
+
# apply exp to the probabilities
|
202 |
+
speaker_df = speaker_df.applymap(lambda x: math.exp(x))
|
203 |
+
|
204 |
+
text_1_summaries = speaker_df.loc[text1][text1_sentences]
|
205 |
+
text_1_summaries = text_1_summaries / text_1_summaries.sum()
|
206 |
+
|
207 |
+
text_2_summaries = speaker_df.loc[text2][text2_sentences]
|
208 |
+
text_2_summaries = text_2_summaries / text_2_summaries.sum()
|
209 |
+
|
210 |
+
text_3_summaries = speaker_df.loc[text3][text3_sentences]
|
211 |
+
text_3_summaries = text_3_summaries / text_3_summaries.sum()
|
212 |
+
|
213 |
+
# make list of tuples
|
214 |
+
text_1_summaries = [(sentence, text_1_summaries[sentence]) for sentence in text1_sentences]
|
215 |
+
text_2_summaries = [(sentence, text_2_summaries[sentence]) for sentence in text2_sentences]
|
216 |
+
text_3_summaries = [(sentence, text_3_summaries[sentence]) for sentence in text3_sentences]
|
217 |
+
|
218 |
+
# normalize consensuality scores between -1 and 1
|
219 |
+
consensuality_scores = (consensuality_scores - (consensuality_scores.max() - consensuality_scores.min()) / 2) / (consensuality_scores.max() - consensuality_scores.min()) / 2
|
220 |
+
|
221 |
+
# get most and least consensual sentences
|
222 |
+
# most consensual --> most common; least consensual --> most unique
|
223 |
+
most_consensual = consensuality_scores.sort_values(ascending=True).head(3).index.tolist()
|
224 |
+
least_consensual = consensuality_scores.sort_values(ascending=False).head(3).index.tolist()
|
225 |
+
|
226 |
+
# Convert lists to strings
|
227 |
+
most_consensual = " ".join(most_consensual)
|
228 |
+
least_consensual = " ".join(least_consensual)
|
229 |
+
|
230 |
+
text_1_consensuality = consensuality_scores.loc[text1_sentences]
|
231 |
+
text_2_consensuality = consensuality_scores.loc[text2_sentences]
|
232 |
+
text_3_consensuality = consensuality_scores.loc[text3_sentences]
|
233 |
+
|
234 |
+
text_1_consensuality = [(sentence, text_1_consensuality[sentence]) for sentence in text1_sentences]
|
235 |
+
text_2_consensuality = [(sentence, text_2_consensuality[sentence]) for sentence in text2_sentences]
|
236 |
+
text_3_consensuality = [(sentence, text_3_consensuality[sentence]) for sentence in text3_sentences]
|
237 |
+
|
238 |
+
|
239 |
+
def highlight_reviews(text_sentences, consensuality_scores, threshold_common=0.0, threshold_unique=0.0):
|
240 |
+
highlighted = []
|
241 |
+
for sentence in text_sentences:
|
242 |
+
# print(f"Processing sentence: {sentence}", "score:", consensuality_scores.loc[sentence])
|
243 |
+
score = consensuality_scores.loc[sentence]
|
244 |
+
score = score*2 if score > 0 else score # amplify unique scores for better visibility
|
245 |
+
|
246 |
+
# common sentences --> positive consensuality scores
|
247 |
+
# unique sentences --> negative consensuality scores
|
248 |
+
|
249 |
+
score *= -1 # invert the score for highlighting
|
250 |
+
|
251 |
+
highlighted.append((sentence, score))
|
252 |
+
return highlighted
|
253 |
+
|
254 |
+
# Apply highlighting to each review
|
255 |
+
text_1_agreement = highlight_reviews(text1_sentences, consensuality_scores)
|
256 |
+
text_2_agreement = highlight_reviews(text2_sentences, consensuality_scores)
|
257 |
+
text_3_agreement = highlight_reviews(text3_sentences, consensuality_scores)
|
258 |
+
|
259 |
+
# Add polarity outputs
|
260 |
+
text_1_polarity = [(s, polarity_map[s]) for s in text1_sentences]
|
261 |
+
text_2_polarity = [(s, polarity_map[s]) for s in text2_sentences]
|
262 |
+
text_3_polarity = [(s, polarity_map[s]) for s in text3_sentences]
|
263 |
+
|
264 |
+
# Add topic outputs
|
265 |
+
text_1_topic = [(s, topic_map[s]) for s in text1_sentences]
|
266 |
+
text_2_topic = [(s, topic_map[s]) for s in text2_sentences]
|
267 |
+
text_3_topic = [(s, topic_map[s]) for s in text3_sentences]
|
268 |
+
|
269 |
+
# print(type(text_1_consensuality))
|
270 |
+
return (
|
271 |
+
# text_1_summaries, text_2_summaries, text_3_summaries,
|
272 |
+
# text_1_consensuality, text_2_consensuality, text_3_consensuality,
|
273 |
+
text_1_agreement, text_2_agreement, text_3_agreement,
|
274 |
+
most_consensual, least_consensual,
|
275 |
+
text_1_polarity, text_2_polarity, text_3_polarity,
|
276 |
+
text_1_topic, text_2_topic, text_3_topic,
|
277 |
+
)
|
278 |
+
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
with gr.Blocks(title="ReView") as demo:
|
283 |
+
# gr.Markdown("# ReView Interface")
|
284 |
+
|
285 |
+
with gr.Tab("Introduction"):
|
286 |
+
gr.Markdown(glimpse_description)
|
287 |
+
|
288 |
+
# -----------------------------------
|
289 |
+
# Pre-processed Tab
|
290 |
+
# -----------------------------------
|
291 |
+
with gr.Tab("Pre-processed Reviews"):
|
292 |
+
# Initialize state for this session.
|
293 |
+
initial_year = 2017
|
294 |
+
initial_scored_reviews = get_preprocessed_scores(initial_year)
|
295 |
+
initial_review_ids = list(initial_scored_reviews.keys())
|
296 |
+
initial_review = initial_scored_reviews[initial_review_ids[0]]
|
297 |
+
number_of_displayed_reviews = len(initial_scored_reviews[initial_review_ids[0]])
|
298 |
+
initial_state = {
|
299 |
+
"year_choice": initial_year,
|
300 |
+
"scored_reviews_for_year": initial_scored_reviews,
|
301 |
+
"review_ids": initial_review_ids,
|
302 |
+
"current_review_index": 0,
|
303 |
+
"current_review": initial_review,
|
304 |
+
"number_of_displayed_reviews": number_of_displayed_reviews,
|
305 |
+
}
|
306 |
+
state = gr.State(initial_state)
|
307 |
+
|
308 |
+
def update_review_display(state, score_type):
|
309 |
+
|
310 |
+
review_ids = state["review_ids"]
|
311 |
+
current_index = state["current_review_index"]
|
312 |
+
current_review = state["scored_reviews_for_year"][review_ids[current_index]]
|
313 |
+
|
314 |
+
show_polarity = score_type == "Polarity"
|
315 |
+
show_consensuality = score_type == "Agreement"
|
316 |
+
show_topic = score_type == "Topic"
|
317 |
+
|
318 |
+
|
319 |
+
if show_polarity:
|
320 |
+
color_map = {"➕": "#d4fcd6", "➖": "#fcd6d6"}
|
321 |
+
legend = False
|
322 |
+
elif show_topic:
|
323 |
+
color_map = topic_color_map # No color map for topics
|
324 |
+
legend = False
|
325 |
+
elif show_consensuality:
|
326 |
+
color_map = None # Continuous scale, no predefined colors
|
327 |
+
legend = True
|
328 |
+
else:
|
329 |
+
color_map = {} # Default to empty map
|
330 |
+
legend = False
|
331 |
+
|
332 |
+
new_review_id = (
|
333 |
+
f"### Submission Link:\n\n{review_ids[current_index]}<br>"
|
334 |
+
f"(Showing {current_index + 1} of {len(state['review_ids'])} reviews)"
|
335 |
+
)
|
336 |
+
|
337 |
+
number_of_displayed_reviews = len(current_review)
|
338 |
+
review_updates = []
|
339 |
+
consensuality_dict = {}
|
340 |
+
|
341 |
+
for i in range(8):
|
342 |
+
if i < number_of_displayed_reviews:
|
343 |
+
review_item = list(current_review[i].items())
|
344 |
+
|
345 |
+
if show_polarity:
|
346 |
+
highlighted = []
|
347 |
+
for sentence, metadata in review_item:
|
348 |
+
polarity = metadata.get("polarity", None)
|
349 |
+
if polarity >= 0.995:
|
350 |
+
label = "➕" # positive
|
351 |
+
elif polarity <= -0.99:
|
352 |
+
label = "➖" # negative
|
353 |
+
else:
|
354 |
+
label = None # ignore neutral (1)
|
355 |
+
highlighted.append((sentence, label))
|
356 |
+
elif show_consensuality:
|
357 |
+
highlighted = []
|
358 |
+
for sentence, metadata in review_item:
|
359 |
+
score = metadata.get("consensuality", 0.0)
|
360 |
+
score = score * 2 - 1 # Normalize to [-1, 1]
|
361 |
+
score = score/2.5 if score > 0 else score # Amplify unique scores for better visibility
|
362 |
+
score *= -1 # Invert the score for highlighting
|
363 |
+
|
364 |
+
consensuality_dict[sentence] = score
|
365 |
+
highlighted.append((sentence, score))
|
366 |
+
|
367 |
+
elif show_topic:
|
368 |
+
highlighted = []
|
369 |
+
for sentence, metadata in review_item:
|
370 |
+
topic = metadata.get("topic", None)
|
371 |
+
if topic != "NONE":
|
372 |
+
highlighted.append((sentence, topic))
|
373 |
+
else:
|
374 |
+
highlighted.append((sentence, None))
|
375 |
+
else:
|
376 |
+
highlighted = [
|
377 |
+
(sentence, None)
|
378 |
+
for sentence, metadata in review_item
|
379 |
+
]
|
380 |
+
|
381 |
+
review_updates.append(
|
382 |
+
gr.update(
|
383 |
+
visible=True,
|
384 |
+
value=highlighted,
|
385 |
+
color_map=color_map,
|
386 |
+
show_legend=legend,
|
387 |
+
key=f"updated_{score_type}_{i}"
|
388 |
+
)
|
389 |
+
)
|
390 |
+
else:
|
391 |
+
review_updates.append(
|
392 |
+
gr.update(
|
393 |
+
visible=False,
|
394 |
+
value=[],
|
395 |
+
show_legend=False,
|
396 |
+
color_map=color_map,
|
397 |
+
key=f"updated_{score_type}_{i}"
|
398 |
+
)
|
399 |
+
)
|
400 |
+
|
401 |
+
# Set most consensual / unique sentences
|
402 |
+
if show_consensuality and consensuality_dict:
|
403 |
+
scores = pd.Series(consensuality_dict)
|
404 |
+
most_unique = scores.sort_values(ascending=True).head(3).index.tolist()
|
405 |
+
most_common = scores.sort_values(ascending=False).head(3).index.tolist()
|
406 |
+
most_common_text = "\n".join(most_common)
|
407 |
+
most_unique_text = "\n".join(most_unique)
|
408 |
+
|
409 |
+
most_common_visibility = gr.update(visible=True, value=most_common_text)
|
410 |
+
most_unique_visibility = gr.update(visible=True, value=most_unique_text)
|
411 |
+
else:
|
412 |
+
# Debugging statements to check visibility settings
|
413 |
+
# print("Hiding most common and unique sentences")
|
414 |
+
|
415 |
+
most_common_visibility = gr.update(visible=False, value=[])
|
416 |
+
most_unique_visibility = gr.update(visible=False, value=[])
|
417 |
+
|
418 |
+
# update topic color map
|
419 |
+
if show_topic:
|
420 |
+
topic_color_map_visibility = gr.update(
|
421 |
+
visible=True,
|
422 |
+
color_map=topic_color_map,
|
423 |
+
value=[
|
424 |
+
("", "Substance"),
|
425 |
+
("", "Clarity"),
|
426 |
+
("", "Soundness/Correctness"),
|
427 |
+
("", "Originality"),
|
428 |
+
("", "Motivation/Impact"),
|
429 |
+
("", "Meaningful Comparison"),
|
430 |
+
("", "Replicability"),
|
431 |
+
]
|
432 |
+
)
|
433 |
+
else:
|
434 |
+
topic_color_map_visibility = gr.update(visible=False, value=[])
|
435 |
+
|
436 |
+
return (
|
437 |
+
new_review_id,
|
438 |
+
*review_updates,
|
439 |
+
most_common_visibility,
|
440 |
+
most_unique_visibility,
|
441 |
+
topic_color_map_visibility,
|
442 |
+
state
|
443 |
+
)
|
444 |
+
|
445 |
+
|
446 |
+
|
447 |
+
# Precompute the initial outputs so something is shown on load.
|
448 |
+
init_display = update_review_display(initial_state, score_type="Original")
|
449 |
+
# init_display returns: (review_id, review1, review2, review3, review4, review5, review6, review7, review8, state)
|
450 |
+
|
451 |
+
with gr.Row():
|
452 |
+
|
453 |
+
with gr.Column(scale=1):
|
454 |
+
review_id = gr.Markdown(value=init_display[0], container=True)
|
455 |
+
with gr.Row():
|
456 |
+
previous_button = gr.Button("Previous", variant="secondary", interactive=True)
|
457 |
+
next_button = gr.Button("Next", variant="primary", interactive=True)
|
458 |
+
|
459 |
+
|
460 |
+
with gr.Column(scale=1):
|
461 |
+
# Input controls.
|
462 |
+
year = gr.Dropdown(choices=years, label="Select Year", interactive=True, value=initial_year)
|
463 |
+
score_type = gr.Radio(
|
464 |
+
choices=["Original", "Agreement", "Polarity", "Topic"],
|
465 |
+
label="Score Type to Display",
|
466 |
+
value="Original",
|
467 |
+
interactive=True
|
468 |
+
)
|
469 |
+
|
470 |
+
# Output display.
|
471 |
+
with gr.Row():
|
472 |
+
most_common_sentences = gr.Textbox(
|
473 |
+
lines=8,
|
474 |
+
label="Most Common Opinions",
|
475 |
+
visible=False,
|
476 |
+
value=[]
|
477 |
+
)
|
478 |
+
most_unique_sentences = gr.Textbox(
|
479 |
+
lines=8,
|
480 |
+
label="Most Divergent Opinions",
|
481 |
+
visible=False,
|
482 |
+
value=[]
|
483 |
+
)
|
484 |
+
|
485 |
+
# Add a new textbox for topic labels and colors
|
486 |
+
topic_text_box = gr.HighlightedText(
|
487 |
+
label="Topic Labels (Color-Coded)",
|
488 |
+
visible=False,
|
489 |
+
value=[],
|
490 |
+
show_legend=True,
|
491 |
+
)
|
492 |
+
|
493 |
+
review1 = gr.HighlightedText(
|
494 |
+
show_legend=False,
|
495 |
+
label="Review 1",
|
496 |
+
visible= number_of_displayed_reviews >= 1,
|
497 |
+
key="initial_review1",
|
498 |
+
# color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
|
499 |
+
)
|
500 |
+
review2 = gr.HighlightedText(
|
501 |
+
show_legend=False,
|
502 |
+
label="Review 2",
|
503 |
+
visible= number_of_displayed_reviews >= 2,
|
504 |
+
key="initial_review2"
|
505 |
+
# color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
|
506 |
+
)
|
507 |
+
review3 = gr.HighlightedText(
|
508 |
+
show_legend=False,
|
509 |
+
label="Review 3",
|
510 |
+
visible= number_of_displayed_reviews >= 3,
|
511 |
+
key="initial_review3"
|
512 |
+
# color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
|
513 |
+
)
|
514 |
+
review4 = gr.HighlightedText(
|
515 |
+
show_legend=False,
|
516 |
+
label="Review 4",
|
517 |
+
visible= number_of_displayed_reviews >= 4,
|
518 |
+
key="initial_review4"
|
519 |
+
# color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
|
520 |
+
)
|
521 |
+
review5 = gr.HighlightedText(
|
522 |
+
show_legend=False,
|
523 |
+
label="Review 5",
|
524 |
+
visible= number_of_displayed_reviews >= 5,
|
525 |
+
key="initial_review5"
|
526 |
+
# color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
|
527 |
+
)
|
528 |
+
review6 = gr.HighlightedText(
|
529 |
+
show_legend=False,
|
530 |
+
label="Review 6",
|
531 |
+
visible= number_of_displayed_reviews >= 6,
|
532 |
+
key="initial_review6"
|
533 |
+
# color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
|
534 |
+
)
|
535 |
+
review7 = gr.HighlightedText(
|
536 |
+
show_legend=False,
|
537 |
+
label="Review 7",
|
538 |
+
visible= number_of_displayed_reviews >= 7,
|
539 |
+
key="initial_review7"
|
540 |
+
# color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
|
541 |
+
)
|
542 |
+
review8 = gr.HighlightedText(
|
543 |
+
show_legend=False,
|
544 |
+
label="Review 8",
|
545 |
+
visible= number_of_displayed_reviews >= 8,
|
546 |
+
key="initial_review8"
|
547 |
+
# color_map={"Positive": "#d4fcd6", "Negative": "#fcd6d6"}
|
548 |
+
)
|
549 |
+
|
550 |
+
# Callback functions that update state.
|
551 |
+
def year_change(year, state, score_type):
|
552 |
+
state["year_choice"] = year
|
553 |
+
state["scored_reviews_for_year"] = get_preprocessed_scores(year)
|
554 |
+
state["review_ids"] = list(state["scored_reviews_for_year"].keys())
|
555 |
+
state["current_review_index"] = 0
|
556 |
+
state["current_review"] = state["scored_reviews_for_year"][state["review_ids"][0]]
|
557 |
+
return update_review_display(state, score_type)
|
558 |
+
|
559 |
+
def next_review(state, score_type):
|
560 |
+
state["current_review_index"] = (state["current_review_index"] + 1) % len(state["review_ids"])
|
561 |
+
state["current_review"] = state["scored_reviews_for_year"][state["review_ids"][state["current_review_index"]]]
|
562 |
+
return update_review_display(state, score_type)
|
563 |
+
|
564 |
+
def previous_review(state, score_type):
|
565 |
+
state["current_review_index"] = (state["current_review_index"] - 1) % len(state["review_ids"])
|
566 |
+
state["current_review"] = state["scored_reviews_for_year"][state["review_ids"][state["current_review_index"]]]
|
567 |
+
return update_review_display(state, score_type)
|
568 |
+
|
569 |
+
# Hook up the callbacks with the session state.
|
570 |
+
year.change(
|
571 |
+
fn=year_change,
|
572 |
+
inputs=[year, state, score_type],
|
573 |
+
outputs=[review_id, review1, review2, review3, review4, review5, review6, review7, review8, most_common_sentences, most_unique_sentences, topic_text_box, state]
|
574 |
+
)
|
575 |
+
score_type.change(
|
576 |
+
fn=update_review_display,
|
577 |
+
inputs=[state, score_type],
|
578 |
+
outputs=[review_id, review1, review2, review3, review4, review5, review6, review7, review8, most_common_sentences, most_unique_sentences, topic_text_box, state]
|
579 |
+
)
|
580 |
+
next_button.click(
|
581 |
+
fn=next_review,
|
582 |
+
inputs=[state, score_type],
|
583 |
+
outputs=[review_id, review1, review2, review3, review4, review5, review6, review7, review8, most_common_sentences, most_unique_sentences, topic_text_box, state]
|
584 |
+
)
|
585 |
+
previous_button.click(
|
586 |
+
fn=previous_review,
|
587 |
+
inputs=[state, score_type],
|
588 |
+
outputs=[review_id, review1, review2, review3, review4, review5, review6, review7, review8, most_common_sentences, most_unique_sentences, topic_text_box, state]
|
589 |
+
)
|
590 |
+
|
591 |
+
|
592 |
+
|
593 |
+
|
594 |
+
# -----------------------------------
|
595 |
+
# Interactive Tab
|
596 |
+
# -----------------------------------
|
597 |
+
with gr.Tab("Interactive", interactive=True):
|
598 |
+
with gr.Row():
|
599 |
+
with gr.Column():
|
600 |
+
|
601 |
+
gr.Markdown("## Input Reviews")
|
602 |
+
|
603 |
+
# review_count = gr.Slider(minimum=1, maximum=3, step=1, value=3, label="Number of Reviews", interactive=True)
|
604 |
+
|
605 |
+
review1_textbox = gr.Textbox(lines=5, value=EXAMPLES[0], label="Review 1", interactive=True)
|
606 |
+
review2_textbox = gr.Textbox(lines=5, value=EXAMPLES[1], label="Review 2", interactive=True)
|
607 |
+
review3_textbox = gr.Textbox(lines=5, value=EXAMPLES[2], label="Review 3", interactive=True)
|
608 |
+
|
609 |
+
with gr.Row():
|
610 |
+
submit_button = gr.Button("Process", variant="primary", interactive=True)
|
611 |
+
clear_button = gr.Button("Clear", variant="secondary", interactive=True)
|
612 |
+
gr.Markdown("**Note**: *Once your inputs are processed, you can see the different result by <ins>**only changing the parameters**</ins>, and without the need to re-process.*", container=True)
|
613 |
+
|
614 |
+
|
615 |
+
|
616 |
+
with gr.Column():
|
617 |
+
|
618 |
+
gr.Markdown("## Results")
|
619 |
+
|
620 |
+
mode_radio = gr.Radio(
|
621 |
+
choices=[("In-line Highlighting", "highlight"), ("Generate Summaries", "summary")],
|
622 |
+
value="highlight",
|
623 |
+
label="Output Mode:",
|
624 |
+
interactive=False,
|
625 |
+
visible=False # Initially hidden, will be shown based on mode selection
|
626 |
+
)
|
627 |
+
focus_radio = gr.Radio(
|
628 |
+
choices=[("Agreement", "unique"), "Polarity", "Topic",],
|
629 |
+
value="unique",
|
630 |
+
label="Focus on:",
|
631 |
+
interactive=True
|
632 |
+
)
|
633 |
+
generation_method_radio = gr.Radio(
|
634 |
+
choices=[("Extractive", "extractive")], #TODO: add ("Abstractive", "abstractive") and abstractive generation
|
635 |
+
value="extractive",
|
636 |
+
label="Generation Method:",
|
637 |
+
interactive=True,
|
638 |
+
visible=False
|
639 |
+
)
|
640 |
+
|
641 |
+
# Fixed rationality (3.0) and iterations (2) to be consistent with the compute_rsa.py script
|
642 |
+
#iterations_slider = gr.Slider(minimum=1, maximum=10, step=1, value=2, label="Iterations", interactive=False, visible=False)
|
643 |
+
# rationality_slider = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=2.0, label="Rationality", interactive=False, visible=False)
|
644 |
+
|
645 |
+
with gr.Row():
|
646 |
+
unique_sentences = gr.Textbox(
|
647 |
+
lines=6, label="Most Divergent Opinions", visible=True, value=None, container=True
|
648 |
+
)
|
649 |
+
common_sentences = gr.Textbox(
|
650 |
+
lines=6, label="Most Common Opinions", visible=True, value=None, container=True
|
651 |
+
)
|
652 |
+
|
653 |
+
uniqueness_score_text1 = gr.HighlightedText(
|
654 |
+
show_legend=True, label="Agreement in Review 1", visible=True, value=None,
|
655 |
+
)
|
656 |
+
uniqueness_score_text2 = gr.HighlightedText(
|
657 |
+
show_legend=True, label="Agreement in Review 2", visible=True, value=None,
|
658 |
+
)
|
659 |
+
uniqueness_score_text3 = gr.HighlightedText(
|
660 |
+
show_legend=True, label="Agreement in Review 3", visible=True, value=None,
|
661 |
+
)
|
662 |
+
|
663 |
+
|
664 |
+
polarity_score_text1 = gr.HighlightedText(
|
665 |
+
show_legend=True, label="Polarity in Review 1", visible=False, value=None,
|
666 |
+
color_map={"➕": "#d4fcd6", "➖": "#fcd6d6" }
|
667 |
+
)
|
668 |
+
polarity_score_text2 = gr.HighlightedText(
|
669 |
+
show_legend=True, label="Polarity in Review 2", visible=False, value=None,
|
670 |
+
color_map={"➕": "#d4fcd6", "➖": "#fcd6d6" }
|
671 |
+
)
|
672 |
+
polarity_score_text3 = gr.HighlightedText(
|
673 |
+
show_legend=True, label="Polarity in Review 3", visible=False, value=None,
|
674 |
+
color_map={"➕": "#d4fcd6", "➖": "#fcd6d6" }
|
675 |
+
)
|
676 |
+
|
677 |
+
aspect_score_text1 = gr.HighlightedText(
|
678 |
+
show_legend=False, label="Topic in Review 1", visible=False, value=None,
|
679 |
+
color_map = topic_color_map
|
680 |
+
)
|
681 |
+
aspect_score_text2 = gr.HighlightedText(
|
682 |
+
show_legend=False, label="Topic in Review 2", visible=False, value=None,
|
683 |
+
color_map = topic_color_map
|
684 |
+
)
|
685 |
+
aspect_score_text3 = gr.HighlightedText(
|
686 |
+
show_legend=False, label="Topic in Review 3", visible=False, value=None,
|
687 |
+
color_map = topic_color_map
|
688 |
+
)
|
689 |
+
|
690 |
+
|
691 |
+
|
692 |
+
|
693 |
+
# Connect summarize function to submit button
|
694 |
+
submit_button.click(
|
695 |
+
fn=summarize,
|
696 |
+
inputs=[
|
697 |
+
review1_textbox, review2_textbox, review3_textbox,
|
698 |
+
focus_radio, mode_radio
|
699 |
+
],
|
700 |
+
outputs=[
|
701 |
+
uniqueness_score_text1, uniqueness_score_text2, uniqueness_score_text3,
|
702 |
+
common_sentences, unique_sentences,
|
703 |
+
polarity_score_text1, polarity_score_text2, polarity_score_text3,
|
704 |
+
aspect_score_text1, aspect_score_text2, aspect_score_text3
|
705 |
+
|
706 |
+
]
|
707 |
+
)
|
708 |
+
|
709 |
+
# Define clear button behavior
|
710 |
+
clear_button.click(
|
711 |
+
fn=lambda: (None, None, None, None, None, None, None, None, None, None, None), # clear all fields
|
712 |
+
inputs=[],
|
713 |
+
outputs=[
|
714 |
+
review1_textbox, review2_textbox, review3_textbox,
|
715 |
+
uniqueness_score_text1, uniqueness_score_text2, uniqueness_score_text3,
|
716 |
+
common_sentences, unique_sentences
|
717 |
+
]
|
718 |
+
)
|
719 |
+
|
720 |
+
# Update visibility of generation_method_radio based on mode_radio value
|
721 |
+
# def toggle_generation_method(mode):
|
722 |
+
# if mode == "summary":
|
723 |
+
# return gr.update(visible=True), gr.update(visible=False) # show generation method radio, hide focus radio
|
724 |
+
# else:
|
725 |
+
# return gr.update(visible=False), gr.update(visible=True) # show focus radio, hide generation method radio
|
726 |
+
|
727 |
+
# mode_radio.change(
|
728 |
+
# fn=toggle_generation_method,
|
729 |
+
# inputs=mode_radio,
|
730 |
+
# outputs=[generation_method_radio, focus_radio]
|
731 |
+
# )
|
732 |
+
|
733 |
+
# Update visibility of output textboxes based on mode_radio and focus_radio values
|
734 |
+
def toggle_output_textboxes(mode, focus):
|
735 |
+
if mode == "highlight" and focus == "unique":
|
736 |
+
return (
|
737 |
+
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), # in-line uniqueness highlights
|
738 |
+
gr.update(visible=True), gr.update(visible=True), # summary highlights
|
739 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), # polarity highlights
|
740 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) # aspect highlights
|
741 |
+
)
|
742 |
+
|
743 |
+
elif focus == "Polarity":
|
744 |
+
return (
|
745 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), # in-line uniqueness highlights
|
746 |
+
gr.update(visible=False), gr.update(visible=False), # summary highlights
|
747 |
+
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), # polarity highlights
|
748 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) # aspect highlights
|
749 |
+
)
|
750 |
+
|
751 |
+
elif focus == "Topic":
|
752 |
+
return (
|
753 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), # in-line uniqueness highlights
|
754 |
+
gr.update(visible=False), gr.update(visible=False), # summary highlights
|
755 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), # polarity highlights
|
756 |
+
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) # aspect highlights
|
757 |
+
)
|
758 |
+
|
759 |
+
focus_radio.change(
|
760 |
+
fn=toggle_output_textboxes,
|
761 |
+
inputs=[mode_radio, focus_radio],
|
762 |
+
outputs=[
|
763 |
+
uniqueness_score_text1, uniqueness_score_text2, uniqueness_score_text3,
|
764 |
+
common_sentences, unique_sentences,
|
765 |
+
polarity_score_text1, polarity_score_text2, polarity_score_text3,
|
766 |
+
aspect_score_text1, aspect_score_text2, aspect_score_text3
|
767 |
+
]
|
768 |
+
)
|
769 |
+
# mode_radio.change(
|
770 |
+
# fn=toggle_output_textboxes,
|
771 |
+
# inputs=[mode_radio, focus_radio],
|
772 |
+
# outputs=[
|
773 |
+
# uniqueness_score_text1, uniqueness_score_text2, uniqueness_score_text3,
|
774 |
+
# consensuality_score_text1, consensuality_score_text2, consensuality_score_text3,
|
775 |
+
# most_consensual_sentences, most_unique_sentences
|
776 |
+
# ]
|
777 |
+
# )
|
778 |
+
|
779 |
+
# TODO: Configure the slider for the number of review boxes
|
780 |
+
|
781 |
+
# def toggle_reviews(number_of_displayed_reviews):
|
782 |
+
# number_of_displayed_reviews = int(number_of_displayed_reviews)
|
783 |
+
# updates = []
|
784 |
+
# # for review(i), set visible True if its index is <= n, otherwise False.
|
785 |
+
# for i in range(1, 4): updates.append(gr.update(visible=(i <= number_of_displayed_reviews)))
|
786 |
+
# return tuple(updates)
|
787 |
+
|
788 |
+
# review_count.change(
|
789 |
+
# fn=toggle_reviews,
|
790 |
+
# inputs=[review_count],
|
791 |
+
# outputs=[review1_textbox, review2_textbox, review3_textbox]
|
792 |
+
# )
|
793 |
+
|
794 |
+
demo.load(
|
795 |
+
fn=update_review_display,
|
796 |
+
inputs=[state, score_type],
|
797 |
+
outputs=[review_id, review1, review2, review3, review4, review5, review6, review7, review8, most_common_sentences, most_unique_sentences, topic_text_box, state]
|
798 |
+
)
|
799 |
+
|
800 |
+
demo.launch(share=False)
|
glimpse-ui/scibert/scibert_polarity/final_model/config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "allenai/scibert_scivocab_uncased",
|
3 |
+
"architectures": [
|
4 |
+
"BertForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"classifier_dropout": null,
|
8 |
+
"hidden_act": "gelu",
|
9 |
+
"hidden_dropout_prob": 0.1,
|
10 |
+
"hidden_size": 768,
|
11 |
+
"id2label": {
|
12 |
+
"0": "LABEL_0",
|
13 |
+
"1": "LABEL_1",
|
14 |
+
"2": "LABEL_2"
|
15 |
+
},
|
16 |
+
"initializer_range": 0.02,
|
17 |
+
"intermediate_size": 3072,
|
18 |
+
"label2id": {
|
19 |
+
"LABEL_0": 0,
|
20 |
+
"LABEL_1": 1,
|
21 |
+
"LABEL_2": 2
|
22 |
+
},
|
23 |
+
"layer_norm_eps": 1e-12,
|
24 |
+
"max_position_embeddings": 512,
|
25 |
+
"model_type": "bert",
|
26 |
+
"num_attention_heads": 12,
|
27 |
+
"num_hidden_layers": 12,
|
28 |
+
"pad_token_id": 0,
|
29 |
+
"position_embedding_type": "absolute",
|
30 |
+
"torch_dtype": "float32",
|
31 |
+
"transformers_version": "4.46.1",
|
32 |
+
"type_vocab_size": 2,
|
33 |
+
"use_cache": true,
|
34 |
+
"vocab_size": 31090
|
35 |
+
}
|
glimpse-ui/scibert/scibert_polarity/final_model/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e259f6ae81187152e0aa80b9a478aec607c802eb270114eef2b64c8bac806d43
|
3 |
+
size 439706620
|
glimpse-ui/scibert/scibert_polarity/final_model/special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|