Commit
Β·
51dabd6
1
Parent(s):
aa426fb
refactoring of code
Browse files- .env.example +4 -0
- README.md +3 -0
- base_model/main.py β main.py +10 -7
- poetry.lock +121 -2
- pyproject.toml +18 -0
- src/es_retriever.py +9 -0
- base_model/evaluate.py β src/evaluation.py +10 -9
- base_model/retriever.py β src/fais_retriever.py +15 -18
- {base_model β src}/reader.py +0 -0
- src/utils/log.py +31 -0
- {base_model β src/utils}/string_utils.py +0 -0
.env.example
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ELASTIC_USERNAME=elastic
|
| 2 |
+
ELASTIC_PASSWORD=<password>
|
| 3 |
+
|
| 4 |
+
LOG_LEVEL=INFO
|
README.md
CHANGED
|
@@ -73,3 +73,6 @@ poetry run python main.py
|
|
| 73 |
> shows that MT systems perform worse when they are asked to translate sentences
|
| 74 |
> that describe people with non-stereotypical gender roles, like "The doctor
|
| 75 |
> asked the nurse to help her in the > operation".
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
> shows that MT systems perform worse when they are asked to translate sentences
|
| 74 |
> that describe people with non-stereotypical gender roles, like "The doctor
|
| 75 |
> asked the nurse to help her in the > operation".
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
## Setting up elastic search.
|
base_model/main.py β main.py
RENAMED
|
@@ -1,20 +1,23 @@
|
|
| 1 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
if __name__ == '__main__':
|
| 5 |
# Initialize retriever
|
| 6 |
-
r =
|
| 7 |
|
| 8 |
# Retrieve example
|
| 9 |
scores, result = r.retrieve(
|
| 10 |
"What is the perplexity of a language model?")
|
| 11 |
|
| 12 |
for i, score in enumerate(scores):
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
print() # Newline
|
| 16 |
|
| 17 |
# Compute overall performance
|
| 18 |
exact_match, f1_score = r.evaluate()
|
| 19 |
-
|
| 20 |
-
|
|
|
|
| 1 |
+
from src.fais_retriever import FAISRetriever
|
| 2 |
+
from src.utils.log import get_logger
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
logger = get_logger()
|
| 6 |
|
| 7 |
|
| 8 |
if __name__ == '__main__':
|
| 9 |
# Initialize retriever
|
| 10 |
+
r = FAISRetriever()
|
| 11 |
|
| 12 |
# Retrieve example
|
| 13 |
scores, result = r.retrieve(
|
| 14 |
"What is the perplexity of a language model?")
|
| 15 |
|
| 16 |
for i, score in enumerate(scores):
|
| 17 |
+
logger.info(f"Result {i+1} (score: {score:.02f}):")
|
| 18 |
+
logger.info(result['text'][i])
|
|
|
|
| 19 |
|
| 20 |
# Compute overall performance
|
| 21 |
exact_match, f1_score = r.evaluate()
|
| 22 |
+
logger.info(f"Exact match: {exact_match:.02f}\n"
|
| 23 |
+
f"F1-score: {f1_score:.02f}")
|
poetry.lock
CHANGED
|
@@ -149,6 +149,36 @@ python-versions = ">=2.7, !=3.0.*"
|
|
| 149 |
[package.extras]
|
| 150 |
graph = ["objgraph (>=1.7.2)"]
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
[[package]]
|
| 153 |
name = "faiss-cpu"
|
| 154 |
version = "1.7.2"
|
|
@@ -291,6 +321,32 @@ python-versions = "*"
|
|
| 291 |
[package.dependencies]
|
| 292 |
dill = ">=0.3.4"
|
| 293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
[[package]]
|
| 295 |
name = "numpy"
|
| 296 |
version = "1.22.3"
|
|
@@ -380,6 +436,17 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
|
| 380 |
[package.dependencies]
|
| 381 |
six = ">=1.5"
|
| 382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
[[package]]
|
| 384 |
name = "pytz"
|
| 385 |
version = "2021.3"
|
|
@@ -480,6 +547,14 @@ category = "dev"
|
|
| 480 |
optional = false
|
| 481 |
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
|
| 482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
[[package]]
|
| 484 |
name = "torch"
|
| 485 |
version = "1.11.0"
|
|
@@ -610,7 +685,7 @@ multidict = ">=4.0"
|
|
| 610 |
[metadata]
|
| 611 |
lock-version = "1.1"
|
| 612 |
python-versions = "^3.8"
|
| 613 |
-
content-hash = "
|
| 614 |
|
| 615 |
[metadata.files]
|
| 616 |
aiohttp = [
|
|
@@ -727,6 +802,14 @@ dill = [
|
|
| 727 |
{file = "dill-0.3.4-py2.py3-none-any.whl", hash = "sha256:7e40e4a70304fd9ceab3535d36e58791d9c4a776b38ec7f7ec9afc8d3dca4d4f"},
|
| 728 |
{file = "dill-0.3.4.zip", hash = "sha256:9f9734205146b2b353ab3fec9af0070237b6ddae78452af83d2fca84d739e675"},
|
| 729 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
faiss-cpu = [
|
| 731 |
{file = "faiss-cpu-1.7.2.tar.gz", hash = "sha256:f7ea89de997f55764e3710afaf0a457b2529252f99ee63510d4d9348d5b419dd"},
|
| 732 |
{file = "faiss_cpu-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b7461f989d757917a3e6dc81eb171d0b563eb98d23ebaf7fc6684d0093ba267e"},
|
|
@@ -918,12 +1001,40 @@ multiprocess = [
|
|
| 918 |
{file = "multiprocess-0.70.12.2-py39-none-any.whl", hash = "sha256:6f812a1d3f198b7cacd63983f60e2dc1338bd4450893f90c435067b5a3127e6f"},
|
| 919 |
{file = "multiprocess-0.70.12.2.zip", hash = "sha256:206bb9b97b73f87fec1ed15a19f8762950256aa84225450abc7150d02855a083"},
|
| 920 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 921 |
numpy = [
|
| 922 |
{file = "numpy-1.22.3-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:92bfa69cfbdf7dfc3040978ad09a48091143cffb778ec3b03fa170c494118d75"},
|
| 923 |
{file = "numpy-1.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8251ed96f38b47b4295b1ae51631de7ffa8260b5b087808ef09a39a9d66c97ab"},
|
| 924 |
{file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a3aecd3b997bf452a2dedb11f4e79bc5bfd21a1d4cc760e703c31d57c84b3e"},
|
| 925 |
{file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3bae1a2ed00e90b3ba5f7bd0a7c7999b55d609e0c54ceb2b076a25e345fa9f4"},
|
| 926 |
-
{file = "numpy-1.22.3-cp310-cp310-win32.whl", hash = "sha256:f950f8845b480cffe522913d35567e29dd381b0dc7e4ce6a4a9f9156417d2430"},
|
| 927 |
{file = "numpy-1.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:08d9b008d0156c70dc392bb3ab3abb6e7a711383c3247b410b39962263576cd4"},
|
| 928 |
{file = "numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:201b4d0552831f7250a08d3b38de0d989d6f6e4658b709a02a73c524ccc6ffce"},
|
| 929 |
{file = "numpy-1.22.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8c1f39caad2c896bc0018f699882b345b2a63708008be29b1f355ebf6f933fe"},
|
|
@@ -1015,6 +1126,10 @@ python-dateutil = [
|
|
| 1015 |
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
|
| 1016 |
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
|
| 1017 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1018 |
pytz = [
|
| 1019 |
{file = "pytz-2021.3-py2.py3-none-any.whl", hash = "sha256:3672058bc3453457b622aab7a1c3bfd5ab0bdae451512f6cf25f64ed37f5b87c"},
|
| 1020 |
{file = "pytz-2021.3.tar.gz", hash = "sha256:acad2d8b20a1af07d4e4c9d2e9285c5ed9104354062f275f3fcd88dcef4f1326"},
|
|
@@ -1189,6 +1304,10 @@ toml = [
|
|
| 1189 |
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
| 1190 |
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
| 1191 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1192 |
torch = [
|
| 1193 |
{file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
|
| 1194 |
{file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
|
|
|
|
| 149 |
[package.extras]
|
| 150 |
graph = ["objgraph (>=1.7.2)"]
|
| 151 |
|
| 152 |
+
[[package]]
|
| 153 |
+
name = "elastic-transport"
|
| 154 |
+
version = "8.1.0"
|
| 155 |
+
description = "Transport classes and utilities shared among Python Elastic client libraries"
|
| 156 |
+
category = "main"
|
| 157 |
+
optional = false
|
| 158 |
+
python-versions = ">=3.6"
|
| 159 |
+
|
| 160 |
+
[package.dependencies]
|
| 161 |
+
certifi = "*"
|
| 162 |
+
urllib3 = ">=1.26.2,<2"
|
| 163 |
+
|
| 164 |
+
[package.extras]
|
| 165 |
+
develop = ["pytest", "pytest-cov", "pytest-mock", "pytest-asyncio", "mock", "requests", "aiohttp"]
|
| 166 |
+
|
| 167 |
+
[[package]]
|
| 168 |
+
name = "elasticsearch"
|
| 169 |
+
version = "8.1.0"
|
| 170 |
+
description = "Python client for Elasticsearch"
|
| 171 |
+
category = "main"
|
| 172 |
+
optional = false
|
| 173 |
+
python-versions = ">=3.6, <4"
|
| 174 |
+
|
| 175 |
+
[package.dependencies]
|
| 176 |
+
elastic-transport = ">=8,<9"
|
| 177 |
+
|
| 178 |
+
[package.extras]
|
| 179 |
+
async = ["aiohttp (>=3,<4)"]
|
| 180 |
+
requests = ["requests (>=2.4.0,<3.0.0)"]
|
| 181 |
+
|
| 182 |
[[package]]
|
| 183 |
name = "faiss-cpu"
|
| 184 |
version = "1.7.2"
|
|
|
|
| 321 |
[package.dependencies]
|
| 322 |
dill = ">=0.3.4"
|
| 323 |
|
| 324 |
+
[[package]]
|
| 325 |
+
name = "mypy"
|
| 326 |
+
version = "0.941"
|
| 327 |
+
description = "Optional static typing for Python"
|
| 328 |
+
category = "dev"
|
| 329 |
+
optional = false
|
| 330 |
+
python-versions = ">=3.6"
|
| 331 |
+
|
| 332 |
+
[package.dependencies]
|
| 333 |
+
mypy-extensions = ">=0.4.3"
|
| 334 |
+
tomli = ">=1.1.0"
|
| 335 |
+
typing-extensions = ">=3.10"
|
| 336 |
+
|
| 337 |
+
[package.extras]
|
| 338 |
+
dmypy = ["psutil (>=4.0)"]
|
| 339 |
+
python2 = ["typed-ast (>=1.4.0,<2)"]
|
| 340 |
+
reports = ["lxml"]
|
| 341 |
+
|
| 342 |
+
[[package]]
|
| 343 |
+
name = "mypy-extensions"
|
| 344 |
+
version = "0.4.3"
|
| 345 |
+
description = "Experimental type system extensions for programs checked with the mypy typechecker."
|
| 346 |
+
category = "dev"
|
| 347 |
+
optional = false
|
| 348 |
+
python-versions = "*"
|
| 349 |
+
|
| 350 |
[[package]]
|
| 351 |
name = "numpy"
|
| 352 |
version = "1.22.3"
|
|
|
|
| 436 |
[package.dependencies]
|
| 437 |
six = ">=1.5"
|
| 438 |
|
| 439 |
+
[[package]]
|
| 440 |
+
name = "python-dotenv"
|
| 441 |
+
version = "0.19.2"
|
| 442 |
+
description = "Read key-value pairs from a .env file and set them as environment variables"
|
| 443 |
+
category = "main"
|
| 444 |
+
optional = false
|
| 445 |
+
python-versions = ">=3.5"
|
| 446 |
+
|
| 447 |
+
[package.extras]
|
| 448 |
+
cli = ["click (>=5.0)"]
|
| 449 |
+
|
| 450 |
[[package]]
|
| 451 |
name = "pytz"
|
| 452 |
version = "2021.3"
|
|
|
|
| 547 |
optional = false
|
| 548 |
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
|
| 549 |
|
| 550 |
+
[[package]]
|
| 551 |
+
name = "tomli"
|
| 552 |
+
version = "2.0.1"
|
| 553 |
+
description = "A lil' TOML parser"
|
| 554 |
+
category = "dev"
|
| 555 |
+
optional = false
|
| 556 |
+
python-versions = ">=3.7"
|
| 557 |
+
|
| 558 |
[[package]]
|
| 559 |
name = "torch"
|
| 560 |
version = "1.11.0"
|
|
|
|
| 685 |
[metadata]
|
| 686 |
lock-version = "1.1"
|
| 687 |
python-versions = "^3.8"
|
| 688 |
+
content-hash = "7fadbb5aabac268ecd27c257e2c8f651d26896e78c9cc0ea7e61a8b6ec61c84c"
|
| 689 |
|
| 690 |
[metadata.files]
|
| 691 |
aiohttp = [
|
|
|
|
| 802 |
{file = "dill-0.3.4-py2.py3-none-any.whl", hash = "sha256:7e40e4a70304fd9ceab3535d36e58791d9c4a776b38ec7f7ec9afc8d3dca4d4f"},
|
| 803 |
{file = "dill-0.3.4.zip", hash = "sha256:9f9734205146b2b353ab3fec9af0070237b6ddae78452af83d2fca84d739e675"},
|
| 804 |
]
|
| 805 |
+
elastic-transport = [
|
| 806 |
+
{file = "elastic-transport-8.1.0.tar.gz", hash = "sha256:769ee4c7b28d270cdbce71359973b88129ac312b13be95b4f7479e35c49d9455"},
|
| 807 |
+
{file = "elastic_transport-8.1.0-py3-none-any.whl", hash = "sha256:0bb2ae3d13348e9e4587ca1f17cd813a528a7cc07f879505f56d69c81823b660"},
|
| 808 |
+
]
|
| 809 |
+
elasticsearch = [
|
| 810 |
+
{file = "elasticsearch-8.1.0-py3-none-any.whl", hash = "sha256:11e36565dfdf649b7911c2d3cb1f15b99267acfb7f82e94e7613c0323a9936e9"},
|
| 811 |
+
{file = "elasticsearch-8.1.0.tar.gz", hash = "sha256:648d1c707a632279535356d2762cbc63ae728c4633211fe160f43f87a3e1cdcd"},
|
| 812 |
+
]
|
| 813 |
faiss-cpu = [
|
| 814 |
{file = "faiss-cpu-1.7.2.tar.gz", hash = "sha256:f7ea89de997f55764e3710afaf0a457b2529252f99ee63510d4d9348d5b419dd"},
|
| 815 |
{file = "faiss_cpu-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b7461f989d757917a3e6dc81eb171d0b563eb98d23ebaf7fc6684d0093ba267e"},
|
|
|
|
| 1001 |
{file = "multiprocess-0.70.12.2-py39-none-any.whl", hash = "sha256:6f812a1d3f198b7cacd63983f60e2dc1338bd4450893f90c435067b5a3127e6f"},
|
| 1002 |
{file = "multiprocess-0.70.12.2.zip", hash = "sha256:206bb9b97b73f87fec1ed15a19f8762950256aa84225450abc7150d02855a083"},
|
| 1003 |
]
|
| 1004 |
+
mypy = [
|
| 1005 |
+
{file = "mypy-0.941-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:98f61aad0bb54f797b17da5b82f419e6ce214de0aa7e92211ebee9e40eb04276"},
|
| 1006 |
+
{file = "mypy-0.941-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6a8e1f63357851444940351e98fb3252956a15f2cabe3d698316d7a2d1f1f896"},
|
| 1007 |
+
{file = "mypy-0.941-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b30d29251dff4c59b2e5a1fa1bab91ff3e117b4658cb90f76d97702b7a2ae699"},
|
| 1008 |
+
{file = "mypy-0.941-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8eaf55fdf99242a1c8c792247c455565447353914023878beadb79600aac4a2a"},
|
| 1009 |
+
{file = "mypy-0.941-cp310-cp310-win_amd64.whl", hash = "sha256:080097eee5393fd740f32c63f9343580aaa0fb1cda0128fd859dfcf081321c3d"},
|
| 1010 |
+
{file = "mypy-0.941-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f79137d012ff3227866222049af534f25354c07a0d6b9a171dba9f1d6a1fdef4"},
|
| 1011 |
+
{file = "mypy-0.941-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8e5974583a77d630a5868eee18f85ac3093caf76e018c510aeb802b9973304ce"},
|
| 1012 |
+
{file = "mypy-0.941-cp36-cp36m-win_amd64.whl", hash = "sha256:0dd441fbacf48e19dc0c5c42fafa72b8e1a0ba0a39309c1af9c84b9397d9b15a"},
|
| 1013 |
+
{file = "mypy-0.941-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:0d3bcbe146247997e03bf030122000998b076b3ac6925b0b6563f46d1ce39b50"},
|
| 1014 |
+
{file = "mypy-0.941-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3bada0cf7b6965627954b3a128903a87cac79a79ccd83b6104912e723ef16c7b"},
|
| 1015 |
+
{file = "mypy-0.941-cp37-cp37m-win_amd64.whl", hash = "sha256:eea10982b798ff0ccc3b9e7e42628f932f552c5845066970e67cd6858655d52c"},
|
| 1016 |
+
{file = "mypy-0.941-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:108f3c7e14a038cf097d2444fa0155462362c6316e3ecb2d70f6dd99cd36084d"},
|
| 1017 |
+
{file = "mypy-0.941-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d61b73c01fc1de799226963f2639af831307fe1556b04b7c25e2b6c267a3bc76"},
|
| 1018 |
+
{file = "mypy-0.941-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:42c216a33d2bdba08098acaf5bae65b0c8196afeb535ef4b870919a788a27259"},
|
| 1019 |
+
{file = "mypy-0.941-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:fc5ecff5a3bbfbe20091b1cad82815507f5ae9c380a3a9bf40f740c70ce30a9b"},
|
| 1020 |
+
{file = "mypy-0.941-cp38-cp38-win_amd64.whl", hash = "sha256:bf446223b2e0e4f0a4792938e8d885e8a896834aded5f51be5c3c69566495540"},
|
| 1021 |
+
{file = "mypy-0.941-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:745071762f32f65e77de6df699366d707fad6c132a660d1342077cbf671ef589"},
|
| 1022 |
+
{file = "mypy-0.941-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:465a6ce9ca6268cadfbc27a2a94ddf0412568a6b27640ced229270be4f5d394d"},
|
| 1023 |
+
{file = "mypy-0.941-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d051ce0946521eba48e19b25f27f98e5ce4dbc91fff296de76240c46b4464df0"},
|
| 1024 |
+
{file = "mypy-0.941-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:818cfc51c25a5dbfd0705f3ac1919fff6971eb0c02e6f1a1f6a017a42405a7c0"},
|
| 1025 |
+
{file = "mypy-0.941-cp39-cp39-win_amd64.whl", hash = "sha256:b2ce2788df0c066c2ff4ba7190fa84f18937527c477247e926abeb9b1168b8cc"},
|
| 1026 |
+
{file = "mypy-0.941-py3-none-any.whl", hash = "sha256:3cf77f138efb31727ee7197bc824c9d6d7039204ed96756cc0f9ca7d8e8fc2a4"},
|
| 1027 |
+
{file = "mypy-0.941.tar.gz", hash = "sha256:cbcc691d8b507d54cb2b8521f0a2a3d4daa477f62fe77f0abba41e5febb377b7"},
|
| 1028 |
+
]
|
| 1029 |
+
mypy-extensions = [
|
| 1030 |
+
{file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"},
|
| 1031 |
+
{file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"},
|
| 1032 |
+
]
|
| 1033 |
numpy = [
|
| 1034 |
{file = "numpy-1.22.3-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:92bfa69cfbdf7dfc3040978ad09a48091143cffb778ec3b03fa170c494118d75"},
|
| 1035 |
{file = "numpy-1.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8251ed96f38b47b4295b1ae51631de7ffa8260b5b087808ef09a39a9d66c97ab"},
|
| 1036 |
{file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a3aecd3b997bf452a2dedb11f4e79bc5bfd21a1d4cc760e703c31d57c84b3e"},
|
| 1037 |
{file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3bae1a2ed00e90b3ba5f7bd0a7c7999b55d609e0c54ceb2b076a25e345fa9f4"},
|
|
|
|
| 1038 |
{file = "numpy-1.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:08d9b008d0156c70dc392bb3ab3abb6e7a711383c3247b410b39962263576cd4"},
|
| 1039 |
{file = "numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:201b4d0552831f7250a08d3b38de0d989d6f6e4658b709a02a73c524ccc6ffce"},
|
| 1040 |
{file = "numpy-1.22.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8c1f39caad2c896bc0018f699882b345b2a63708008be29b1f355ebf6f933fe"},
|
|
|
|
| 1126 |
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
|
| 1127 |
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
|
| 1128 |
]
|
| 1129 |
+
python-dotenv = [
|
| 1130 |
+
{file = "python-dotenv-0.19.2.tar.gz", hash = "sha256:a5de49a31e953b45ff2d2fd434bbc2670e8db5273606c1e737cc6b93eff3655f"},
|
| 1131 |
+
{file = "python_dotenv-0.19.2-py2.py3-none-any.whl", hash = "sha256:32b2bdc1873fd3a3c346da1c6db83d0053c3c62f28f1f38516070c4c8971b1d3"},
|
| 1132 |
+
]
|
| 1133 |
pytz = [
|
| 1134 |
{file = "pytz-2021.3-py2.py3-none-any.whl", hash = "sha256:3672058bc3453457b622aab7a1c3bfd5ab0bdae451512f6cf25f64ed37f5b87c"},
|
| 1135 |
{file = "pytz-2021.3.tar.gz", hash = "sha256:acad2d8b20a1af07d4e4c9d2e9285c5ed9104354062f275f3fcd88dcef4f1326"},
|
|
|
|
| 1304 |
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
| 1305 |
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
| 1306 |
]
|
| 1307 |
+
tomli = [
|
| 1308 |
+
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
| 1309 |
+
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
| 1310 |
+
]
|
| 1311 |
torch = [
|
| 1312 |
{file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
|
| 1313 |
{file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
|
pyproject.toml
CHANGED
|
@@ -11,10 +11,28 @@ transformers = "^4.17.0"
|
|
| 11 |
torch = "^1.11.0"
|
| 12 |
datasets = "^1.18.4"
|
| 13 |
faiss-cpu = "^1.7.2"
|
|
|
|
|
|
|
| 14 |
|
| 15 |
[tool.poetry.dev-dependencies]
|
| 16 |
flake8 = "^4.0.1"
|
| 17 |
autopep8 = "^1.6.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
[build-system]
|
| 20 |
requires = ["poetry-core>=1.0.0"]
|
|
|
|
| 11 |
torch = "^1.11.0"
|
| 12 |
datasets = "^1.18.4"
|
| 13 |
faiss-cpu = "^1.7.2"
|
| 14 |
+
python-dotenv = "^0.19.2"
|
| 15 |
+
elasticsearch = "^8.1.0"
|
| 16 |
|
| 17 |
[tool.poetry.dev-dependencies]
|
| 18 |
flake8 = "^4.0.1"
|
| 19 |
autopep8 = "^1.6.0"
|
| 20 |
+
mypy = "^0.941"
|
| 21 |
+
|
| 22 |
+
[tool.mypy]
|
| 23 |
+
no_implicit_optional=true
|
| 24 |
+
|
| 25 |
+
[[tool.mypy.overrides]]
|
| 26 |
+
module = [
|
| 27 |
+
"transformers",
|
| 28 |
+
"datasets",
|
| 29 |
+
]
|
| 30 |
+
ignore_missing_imports = true
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
[tool.isort]
|
| 34 |
+
profile = "black"
|
| 35 |
+
|
| 36 |
|
| 37 |
[build-system]
|
| 38 |
requires = ["poetry-core>=1.0.0"]
|
src/es_retriever.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ESRetriever:
|
| 2 |
+
def __init__(self, dataset_name: str = "GroNLP/ik-nlp-22_slp"):
|
| 3 |
+
self.dataset_name = dataset_name
|
| 4 |
+
|
| 5 |
+
def _setup_data(self):
|
| 6 |
+
pass
|
| 7 |
+
|
| 8 |
+
def retrieve(self, query: str, k: int):
|
| 9 |
+
pass
|
base_model/evaluate.py β src/evaluation.py
RENAMED
|
@@ -1,15 +1,16 @@
|
|
| 1 |
from typing import Callable, List
|
| 2 |
|
| 3 |
-
from
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
-
def
|
| 7 |
for fun in preprocessing_functions:
|
| 8 |
inp = fun(inp)
|
| 9 |
return inp
|
| 10 |
|
| 11 |
|
| 12 |
-
def
|
| 13 |
"""Preprocesses the sentence string by normalizing.
|
| 14 |
|
| 15 |
Args:
|
|
@@ -21,10 +22,10 @@ def normalize_text_default(inp: str) -> str:
|
|
| 21 |
|
| 22 |
steps = [remove_articles, white_space_fix, remove_punc, lower]
|
| 23 |
|
| 24 |
-
return
|
| 25 |
|
| 26 |
|
| 27 |
-
def
|
| 28 |
"""Computes exact match for sentences.
|
| 29 |
|
| 30 |
Args:
|
|
@@ -34,10 +35,10 @@ def compute_exact_match(prediction: str, answer: str) -> int:
|
|
| 34 |
Returns:
|
| 35 |
int: 1 for exact match, 0 for not
|
| 36 |
"""
|
| 37 |
-
return int(
|
| 38 |
|
| 39 |
|
| 40 |
-
def
|
| 41 |
"""Computes F1-score on token overlap for sentences.
|
| 42 |
|
| 43 |
Args:
|
|
@@ -47,8 +48,8 @@ def compute_f1(prediction: str, answer: str) -> float:
|
|
| 47 |
Returns:
|
| 48 |
boolean: the f1 score
|
| 49 |
"""
|
| 50 |
-
pred_tokens =
|
| 51 |
-
answer_tokens =
|
| 52 |
|
| 53 |
if len(pred_tokens) == 0 or len(answer_tokens) == 0:
|
| 54 |
return int(pred_tokens == answer_tokens)
|
|
|
|
| 1 |
from typing import Callable, List
|
| 2 |
|
| 3 |
+
from src.utils.string_utils import (lower, remove_articles, remove_punc,
|
| 4 |
+
white_space_fix)
|
| 5 |
|
| 6 |
|
| 7 |
+
def _normalize_text(inp: str, preprocessing_functions: List[Callable[[str], str]]):
|
| 8 |
for fun in preprocessing_functions:
|
| 9 |
inp = fun(inp)
|
| 10 |
return inp
|
| 11 |
|
| 12 |
|
| 13 |
+
def _normalize_text_default(inp: str) -> str:
|
| 14 |
"""Preprocesses the sentence string by normalizing.
|
| 15 |
|
| 16 |
Args:
|
|
|
|
| 22 |
|
| 23 |
steps = [remove_articles, white_space_fix, remove_punc, lower]
|
| 24 |
|
| 25 |
+
return _normalize_text(inp, steps)
|
| 26 |
|
| 27 |
|
| 28 |
+
def exact_match(prediction: str, answer: str) -> int:
|
| 29 |
"""Computes exact match for sentences.
|
| 30 |
|
| 31 |
Args:
|
|
|
|
| 35 |
Returns:
|
| 36 |
int: 1 for exact match, 0 for not
|
| 37 |
"""
|
| 38 |
+
return int(_normalize_text_default(prediction) == _normalize_text_default(answer))
|
| 39 |
|
| 40 |
|
| 41 |
+
def f1(prediction: str, answer: str) -> float:
|
| 42 |
"""Computes F1-score on token overlap for sentences.
|
| 43 |
|
| 44 |
Args:
|
|
|
|
| 48 |
Returns:
|
| 49 |
boolean: the f1 score
|
| 50 |
"""
|
| 51 |
+
pred_tokens = _normalize_text_default(prediction).split()
|
| 52 |
+
answer_tokens = _normalize_text_default(answer).split()
|
| 53 |
|
| 54 |
if len(pred_tokens) == 0 or len(answer_tokens) == 0:
|
| 55 |
return int(pred_tokens == answer_tokens)
|
base_model/retriever.py β src/fais_retriever.py
RENAMED
|
@@ -1,23 +1,19 @@
|
|
| 1 |
-
from transformers import (
|
| 2 |
-
DPRContextEncoder,
|
| 3 |
-
DPRContextEncoderTokenizer,
|
| 4 |
-
DPRQuestionEncoder,
|
| 5 |
-
DPRQuestionEncoderTokenizer,
|
| 6 |
-
)
|
| 7 |
-
from datasets import load_dataset
|
| 8 |
-
import torch
|
| 9 |
-
import os.path
|
| 10 |
-
|
| 11 |
-
import evaluate
|
| 12 |
-
|
| 13 |
# Hacky fix for FAISS error on macOS
|
| 14 |
# See https://stackoverflow.com/a/63374568/4545692
|
| 15 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
| 18 |
|
| 19 |
|
| 20 |
-
class
|
| 21 |
"""A class used to retrieve relevant documents based on some query.
|
| 22 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
| 23 |
"""
|
|
@@ -67,12 +63,13 @@ class Retriever:
|
|
| 67 |
embeddings.
|
| 68 |
"""
|
| 69 |
# Load dataset
|
| 70 |
-
ds = load_dataset(dataset_name, name="paragraphs")[
|
|
|
|
| 71 |
print(ds)
|
| 72 |
|
| 73 |
if os.path.exists(embedding_path):
|
| 74 |
# If we already have FAISS embeddings, load them from disk
|
| 75 |
-
ds.load_faiss_index('embeddings', embedding_path)
|
| 76 |
return ds
|
| 77 |
else:
|
| 78 |
# If there are no FAISS embeddings, generate them
|
|
@@ -85,7 +82,7 @@ class Retriever:
|
|
| 85 |
return {"embeddings": enc}
|
| 86 |
|
| 87 |
# Add FAISS embeddings
|
| 88 |
-
ds_with_embeddings = ds.map(embed)
|
| 89 |
|
| 90 |
ds_with_embeddings.add_faiss_index(column="embeddings")
|
| 91 |
|
|
@@ -141,9 +138,9 @@ class Retriever:
|
|
| 141 |
scores += score[0]
|
| 142 |
predictions.append(result['text'][0])
|
| 143 |
|
| 144 |
-
exact_matches = [
|
| 145 |
predictions[i], answers[i]) for i in range(len(answers))]
|
| 146 |
-
f1_scores = [
|
| 147 |
predictions[i], answers[i]) for i in range(len(answers))]
|
| 148 |
|
| 149 |
return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Hacky fix for FAISS error on macOS
|
| 2 |
# See https://stackoverflow.com/a/63374568/4545692
|
| 3 |
import os
|
| 4 |
+
import os.path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
from transformers import (DPRContextEncoder, DPRContextEncoderTokenizer,
|
| 9 |
+
DPRQuestionEncoder, DPRQuestionEncoderTokenizer)
|
| 10 |
+
|
| 11 |
+
from src.evaluation import exact_match, f1
|
| 12 |
|
| 13 |
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
| 14 |
|
| 15 |
|
| 16 |
+
class FAISRetriever:
|
| 17 |
"""A class used to retrieve relevant documents based on some query.
|
| 18 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
| 19 |
"""
|
|
|
|
| 63 |
embeddings.
|
| 64 |
"""
|
| 65 |
# Load dataset
|
| 66 |
+
ds = load_dataset(dataset_name, name="paragraphs")[
|
| 67 |
+
"train"] # type: ignore
|
| 68 |
print(ds)
|
| 69 |
|
| 70 |
if os.path.exists(embedding_path):
|
| 71 |
# If we already have FAISS embeddings, load them from disk
|
| 72 |
+
ds.load_faiss_index('embeddings', embedding_path) # type: ignore
|
| 73 |
return ds
|
| 74 |
else:
|
| 75 |
# If there are no FAISS embeddings, generate them
|
|
|
|
| 82 |
return {"embeddings": enc}
|
| 83 |
|
| 84 |
# Add FAISS embeddings
|
| 85 |
+
ds_with_embeddings = ds.map(embed) # type: ignore
|
| 86 |
|
| 87 |
ds_with_embeddings.add_faiss_index(column="embeddings")
|
| 88 |
|
|
|
|
| 138 |
scores += score[0]
|
| 139 |
predictions.append(result['text'][0])
|
| 140 |
|
| 141 |
+
exact_matches = [exact_match(
|
| 142 |
predictions[i], answers[i]) for i in range(len(answers))]
|
| 143 |
+
f1_scores = [f1(
|
| 144 |
predictions[i], answers[i]) for i in range(len(answers))]
|
| 145 |
|
| 146 |
return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
|
{base_model β src}/reader.py
RENAMED
|
File without changes
|
src/utils/log.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_logger():
|
| 10 |
+
# creates a default logger for the project
|
| 11 |
+
logger = logging.getLogger("Flashcards")
|
| 12 |
+
|
| 13 |
+
log_level = os.getenv("LOG_LEVEL", "INFO")
|
| 14 |
+
logger.setLevel(log_level)
|
| 15 |
+
|
| 16 |
+
# Log format
|
| 17 |
+
formatter = logging.Formatter(
|
| 18 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 19 |
+
|
| 20 |
+
# file handler
|
| 21 |
+
fh = logging.FileHandler("logs.log")
|
| 22 |
+
fh.setFormatter(formatter)
|
| 23 |
+
|
| 24 |
+
# stout
|
| 25 |
+
ch = logging.StreamHandler()
|
| 26 |
+
ch.setFormatter(formatter)
|
| 27 |
+
|
| 28 |
+
logger.addHandler(fh)
|
| 29 |
+
logger.addHandler(ch)
|
| 30 |
+
|
| 31 |
+
return logger
|
{base_model β src/utils}/string_utils.py
RENAMED
|
File without changes
|