Spaces:
Running
Running
batch surprisal computation, now GPU friendly
Browse files- syntaxgym.py +59 -48
syntaxgym.py
CHANGED
@@ -174,68 +174,82 @@ class SyntaxGym(evaluate.EvaluationModule):
|
|
174 |
|
175 |
tokenizer, tokenizer_kwargs = prepare_tokenizer(model, batch_size, add_start_token)
|
176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
results = {}
|
178 |
result_keys = ["prediction_results", "region_totals"]
|
179 |
-
|
180 |
-
|
181 |
-
result_single = self._compute_single(item, tokenizer, tokenizer_kwargs,
|
182 |
-
model, device)
|
183 |
|
184 |
suite_name = item["suite_name"]
|
185 |
if suite_name not in results:
|
186 |
results[suite_name] = SyntaxGymMetricSuiteResult(suite_name, [], [])
|
187 |
for k in result_keys:
|
188 |
-
getattr(results[suite_name], k).append(
|
189 |
|
190 |
return results
|
191 |
|
192 |
-
def
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
# input_ids: B * T
|
199 |
-
input_ids = tokenized["input_ids"]
|
200 |
-
assert input_ids.ndim == 2
|
201 |
-
|
202 |
-
# Compute sentence level surprisals.
|
203 |
-
with torch.no_grad():
|
204 |
-
# Pre-softmax predictive distribution B * T * V
|
205 |
-
logits = model(input_ids).logits
|
206 |
-
surprisals = -logits.log_softmax(dim=2) / np.log(2)
|
207 |
-
|
208 |
-
# surprisals: B * T * V
|
209 |
-
assert surprisals.ndim == 3
|
210 |
-
|
211 |
-
# Get surprisals of expected words.
|
212 |
-
surps_shifted = surprisals[:, :-1, :]
|
213 |
-
expected_ids = input_ids[:, 1:]
|
214 |
-
|
215 |
-
# reindexed surprisals: B * (T - 1)
|
216 |
-
surprisals = torch.gather(surps_shifted, 2, expected_ids.unsqueeze(2)) \
|
217 |
-
.squeeze(2)
|
218 |
|
219 |
#### aggregate
|
220 |
-
condition_names = item["conditions"]["condition_name"]
|
221 |
region_totals = {condition_name: defaultdict(float)
|
222 |
-
for condition_name in
|
223 |
region2tokens = self.compute_region_token_mapping(
|
224 |
-
item,
|
225 |
|
226 |
-
for i, (
|
227 |
-
for region_number, region_tokens in region2tokens[
|
228 |
for token in region_tokens:
|
229 |
if token == 0:
|
230 |
# surprisal not defined. pass.
|
231 |
continue
|
232 |
-
elif token <=
|
233 |
-
region_totals[
|
234 |
else:
|
235 |
# TODO don't think this is an issue, just should clean
|
236 |
# up the aggregation output
|
237 |
-
assert token ==
|
238 |
-
"%s %s" % (token,
|
239 |
|
240 |
region_totals = {(condition_name, region_number): float(total)
|
241 |
for condition_name, totals in region_totals.items()
|
@@ -275,23 +289,20 @@ class SyntaxGym(evaluate.EvaluationModule):
|
|
275 |
|
276 |
return ret
|
277 |
|
278 |
-
def compute_region_token_mapping(self, item,
|
279 |
offset_mapping: List[Tuple[int, int]]
|
280 |
) -> Dict[str, Dict[int, List[int]]]:
|
281 |
-
# input_ids: B * T
|
282 |
# offset_mapping: B * T * 2
|
283 |
-
# assumes batch is sorted according to item's condition_name order
|
284 |
|
285 |
-
|
286 |
-
region2tokens = {cond: defaultdict(list) for cond in condition_names}
|
287 |
|
288 |
max_long = torch.iinfo(torch.int64).max
|
289 |
|
290 |
-
for i_cond,
|
291 |
region_edges = self.get_region_edges(item, i_cond)
|
292 |
|
293 |
t_cursor, r_cursor = 0, 0
|
294 |
-
while t_cursor <
|
295 |
# token = i_tokens[t_cursor]
|
296 |
token_char_start, token_char_end = i_offsets[t_cursor]
|
297 |
|
@@ -310,7 +321,7 @@ class SyntaxGym(evaluate.EvaluationModule):
|
|
310 |
r_cursor += 1
|
311 |
continue
|
312 |
|
313 |
-
region2tokens[
|
314 |
t_cursor += 1
|
315 |
|
316 |
return region2tokens
|
|
|
174 |
|
175 |
tokenizer, tokenizer_kwargs = prepare_tokenizer(model, batch_size, add_start_token)
|
176 |
|
177 |
+
# Flatten sentences, enforcing that sentences are always ordered by the same condition.
|
178 |
+
condition_order = dataset[0]["conditions"]["condition_name"]
|
179 |
+
all_sentences = []
|
180 |
+
for item in dataset:
|
181 |
+
for condition_name in condition_order:
|
182 |
+
# Get idx of condition for this item.
|
183 |
+
condition_idx = item["conditions"]["condition_name"].index(condition_name)
|
184 |
+
all_sentences.append(item["conditions"]["content"][condition_idx])
|
185 |
+
|
186 |
+
# Tokenize sentences and split into batches.
|
187 |
+
all_tokenized_sentences = tokenizer(all_sentences, return_tensors="pt",
|
188 |
+
return_offsets_mapping=True,
|
189 |
+
**tokenizer_kwargs).to(device)
|
190 |
+
tokenized_batches = torch.split(all_tokenized_sentences["input_ids"], batch_size)
|
191 |
+
|
192 |
+
# Compute surprisal per-batch and combine into a single surprisal tensor.
|
193 |
+
n_sentences, n_timesteps = all_tokenized_sentences["input_ids"].shape
|
194 |
+
surprisals = torch.zeros(n_sentences, n_timesteps - 1).float().to(device)
|
195 |
+
for i, batch in enumerate(datasets.logging.tqdm(tokenized_batches)) :
|
196 |
+
batch = batch.to(device)
|
197 |
+
with torch.no_grad():
|
198 |
+
# logits are B * T * V
|
199 |
+
b_logits = model(batch)["logits"]
|
200 |
+
b_surprisals = -b_logits.log_softmax(dim=2) / np.log(2)
|
201 |
+
|
202 |
+
# Get surprisals of ground-truth words.
|
203 |
+
gt_idxs = batch[:, 1:]
|
204 |
+
# Reindexed surprisals: B * (T - 1)
|
205 |
+
b_surprisals_gt = torch.gather(b_surprisals[:, :-1, :], 2, gt_idxs.unsqueeze(2)).squeeze(2)
|
206 |
+
|
207 |
+
surprisals[i * batch_size : (i + 1) * batch_size] = b_surprisals_gt
|
208 |
+
|
209 |
+
# Reshape to intuitive axes n_items * n_conditions * ...
|
210 |
+
surprisals = surprisals.reshape((len(dataset), len(condition_order), -1))
|
211 |
+
offset_mapping = all_tokenized_sentences["offset_mapping"] \
|
212 |
+
.reshape((len(dataset), len(condition_order), -1, 2))
|
213 |
+
|
214 |
+
# Now evaluate per-item.
|
215 |
results = {}
|
216 |
result_keys = ["prediction_results", "region_totals"]
|
217 |
+
for item, item_surprisals, item_offset_mapping in zip(datasets.logging.tqdm(dataset), surprisals, offset_mapping):
|
218 |
+
result_i = self._compute_item(item, item_surprisals, item_offset_mapping, condition_order)
|
|
|
|
|
219 |
|
220 |
suite_name = item["suite_name"]
|
221 |
if suite_name not in results:
|
222 |
results[suite_name] = SyntaxGymMetricSuiteResult(suite_name, [], [])
|
223 |
for k in result_keys:
|
224 |
+
getattr(results[suite_name], k).append(result_i[k])
|
225 |
|
226 |
return results
|
227 |
|
228 |
+
def _compute_item(self, item, item_surprisals, offset_mapping, condition_order):
|
229 |
+
"""
|
230 |
+
Aggregate token-level surprisals to region-level surprisals for the given item,
|
231 |
+
and evaluate the item's predictions.
|
232 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
#### aggregate
|
|
|
235 |
region_totals = {condition_name: defaultdict(float)
|
236 |
+
for condition_name in condition_order}
|
237 |
region2tokens = self.compute_region_token_mapping(
|
238 |
+
item, condition_order, offset_mapping)
|
239 |
|
240 |
+
for i, (cond_i, surprisals_i) in enumerate(zip(condition_order, item_surprisals)):
|
241 |
+
for region_number, region_tokens in region2tokens[cond_i].items():
|
242 |
for token in region_tokens:
|
243 |
if token == 0:
|
244 |
# surprisal not defined. pass.
|
245 |
continue
|
246 |
+
elif token <= item_surprisals.shape[1]:
|
247 |
+
region_totals[cond_i][region_number] += surprisals_i[token - 1]
|
248 |
else:
|
249 |
# TODO don't think this is an issue, just should clean
|
250 |
# up the aggregation output
|
251 |
+
assert token == surprisals_i.shape[1], \
|
252 |
+
"%s %s" % (token, surprisals_i.shape[1])
|
253 |
|
254 |
region_totals = {(condition_name, region_number): float(total)
|
255 |
for condition_name, totals in region_totals.items()
|
|
|
289 |
|
290 |
return ret
|
291 |
|
292 |
+
def compute_region_token_mapping(self, item, condition_order,
|
293 |
offset_mapping: List[Tuple[int, int]]
|
294 |
) -> Dict[str, Dict[int, List[int]]]:
|
|
|
295 |
# offset_mapping: B * T * 2
|
|
|
296 |
|
297 |
+
region2tokens = {cond: defaultdict(list) for cond in condition_order}
|
|
|
298 |
|
299 |
max_long = torch.iinfo(torch.int64).max
|
300 |
|
301 |
+
for i_cond, i_offsets in enumerate(offset_mapping):
|
302 |
region_edges = self.get_region_edges(item, i_cond)
|
303 |
|
304 |
t_cursor, r_cursor = 0, 0
|
305 |
+
while t_cursor < i_offsets.shape[0]:
|
306 |
# token = i_tokens[t_cursor]
|
307 |
token_char_start, token_char_end = i_offsets[t_cursor]
|
308 |
|
|
|
321 |
r_cursor += 1
|
322 |
continue
|
323 |
|
324 |
+
region2tokens[condition_order[i_cond]][r_cursor + 1].append(t_cursor)
|
325 |
t_cursor += 1
|
326 |
|
327 |
return region2tokens
|