Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Replabel transforms for use with flashlight's ASG criterion. | |
| """ | |
| def replabel_symbol(i): | |
| """ | |
| Replabel symbols used in flashlight, currently just "1", "2", ... | |
| This prevents training with numeral tokens, so this might change in the future | |
| """ | |
| return str(i) | |
| def pack_replabels(tokens, dictionary, max_reps): | |
| """ | |
| Pack a token sequence so that repeated symbols are replaced by replabels | |
| """ | |
| if len(tokens) == 0 or max_reps <= 0: | |
| return tokens | |
| replabel_value_to_idx = [0] * (max_reps + 1) | |
| for i in range(1, max_reps + 1): | |
| replabel_value_to_idx[i] = dictionary.index(replabel_symbol(i)) | |
| result = [] | |
| prev_token = -1 | |
| num_reps = 0 | |
| for token in tokens: | |
| if token == prev_token and num_reps < max_reps: | |
| num_reps += 1 | |
| else: | |
| if num_reps > 0: | |
| result.append(replabel_value_to_idx[num_reps]) | |
| num_reps = 0 | |
| result.append(token) | |
| prev_token = token | |
| if num_reps > 0: | |
| result.append(replabel_value_to_idx[num_reps]) | |
| return result | |
| def unpack_replabels(tokens, dictionary, max_reps): | |
| """ | |
| Unpack a token sequence so that replabels are replaced by repeated symbols | |
| """ | |
| if len(tokens) == 0 or max_reps <= 0: | |
| return tokens | |
| replabel_idx_to_value = {} | |
| for i in range(1, max_reps + 1): | |
| replabel_idx_to_value[dictionary.index(replabel_symbol(i))] = i | |
| result = [] | |
| prev_token = -1 | |
| for token in tokens: | |
| try: | |
| for _ in range(replabel_idx_to_value[token]): | |
| result.append(prev_token) | |
| prev_token = -1 | |
| except KeyError: | |
| result.append(token) | |
| prev_token = token | |
| return result | |