Spaces:
Build error
Build error
don't mix up condition indexing in suites where items have different order of content by condition (e.g. number_prep in syntaxgym2020)
Browse files- syntaxgym.py +5 -4
syntaxgym.py
CHANGED
|
@@ -265,12 +265,13 @@ class SyntaxGym(evaluate.EvaluationModule):
|
|
| 265 |
}
|
| 266 |
return results
|
| 267 |
|
| 268 |
-
def get_region_edges(self, item,
|
| 269 |
"""
|
| 270 |
Get left edge of each region as a character index.
|
| 271 |
"""
|
| 272 |
# NB this is coupled with `condition_to_string` logic of course
|
| 273 |
|
|
|
|
| 274 |
regions = item["conditions"]["regions"][condition_idx]
|
| 275 |
|
| 276 |
idx = 0
|
|
@@ -298,8 +299,8 @@ class SyntaxGym(evaluate.EvaluationModule):
|
|
| 298 |
|
| 299 |
max_long = torch.iinfo(torch.int64).max
|
| 300 |
|
| 301 |
-
for
|
| 302 |
-
region_edges = self.get_region_edges(item,
|
| 303 |
|
| 304 |
t_cursor, r_cursor = 0, 0
|
| 305 |
while t_cursor < i_offsets.shape[0]:
|
|
@@ -321,7 +322,7 @@ class SyntaxGym(evaluate.EvaluationModule):
|
|
| 321 |
r_cursor += 1
|
| 322 |
continue
|
| 323 |
|
| 324 |
-
region2tokens[
|
| 325 |
t_cursor += 1
|
| 326 |
|
| 327 |
return region2tokens
|
|
|
|
| 265 |
}
|
| 266 |
return results
|
| 267 |
|
| 268 |
+
def get_region_edges(self, item, condition_name):
|
| 269 |
"""
|
| 270 |
Get left edge of each region as a character index.
|
| 271 |
"""
|
| 272 |
# NB this is coupled with `condition_to_string` logic of course
|
| 273 |
|
| 274 |
+
condition_idx = item["conditions"]["condition_name"].index(condition_name)
|
| 275 |
regions = item["conditions"]["regions"][condition_idx]
|
| 276 |
|
| 277 |
idx = 0
|
|
|
|
| 299 |
|
| 300 |
max_long = torch.iinfo(torch.int64).max
|
| 301 |
|
| 302 |
+
for cond_name, i_offsets in zip(condition_order, offset_mapping):
|
| 303 |
+
region_edges = self.get_region_edges(item, cond_name)
|
| 304 |
|
| 305 |
t_cursor, r_cursor = 0, 0
|
| 306 |
while t_cursor < i_offsets.shape[0]:
|
|
|
|
| 322 |
r_cursor += 1
|
| 323 |
continue
|
| 324 |
|
| 325 |
+
region2tokens[cond_name][r_cursor + 1].append(t_cursor)
|
| 326 |
t_cursor += 1
|
| 327 |
|
| 328 |
return region2tokens
|