Spaces:
Running
Running
| import preprocess | |
| from shared import CustomTokens | |
| from dataclasses import dataclass, field | |
| class SegmentationArguments: | |
| pause_threshold: int = field(default=2, metadata={ | |
| 'help': 'When the time between words is greater than pause threshold, force into a new segment'}) | |
| # WORDS TO ALWAYS HAVE ON THEIR OWN | |
| # always_split_re = re.compile(r'\[\w+\]') | |
| # e.g., [Laughter], [Applause], [Music] | |
| always_split = [ | |
| CustomTokens.MUSIC.value, | |
| CustomTokens.APPLAUSE.value, | |
| CustomTokens.LAUGHTER.value | |
| ] | |
| def get_overlapping_chunks_of_tokens(tokens, size, overlap): | |
| for i in range(0, len(tokens), size-overlap+1): | |
| yield tokens[i:i+size] | |
| # Generate up to max_tokens - SAFETY_TOKENS | |
| SAFETY_TOKENS = 12 | |
| # TODO play around with this? | |
| OVERLAP_TOKEN_PERCENTAGE = 0.5 # 0.25 | |
| def add_labels_to_words(words, sponsor_segments): | |
| # TODO binary search | |
| for word in words: | |
| word['category'] = None | |
| for sponsor_segment in sponsor_segments: | |
| if sponsor_segment['start'] <= word['start'] <= sponsor_segment['end']: | |
| word['category'] = sponsor_segment['category'] | |
| # TODO use extract_segment with mapping function? | |
| # TODO remove sponsor segments that contain mostly empty space? | |
| return words | |
| def generate_labelled_segments(words, tokenizer, segmentation_args, sponsor_segments): | |
| segments = generate_segments(words, tokenizer, segmentation_args) | |
| labelled_segments = list( | |
| map(lambda x: add_labels_to_words(x, sponsor_segments), segments)) | |
| return labelled_segments | |
| def word_start(word): | |
| return word['start'] | |
| def word_end(word): | |
| return word.get('end', word['start']) | |
| def generate_segments(words, tokenizer, segmentation_args): | |
| first_pass_segments = [] | |
| for index, word in enumerate(words): | |
| # Get length of tokenized word | |
| cleaned = preprocess.clean_text(word['text']) | |
| word['num_tokens'] = len( | |
| tokenizer(cleaned, add_special_tokens=False, truncation=True).input_ids) | |
| add_new_segment = index == 0 | |
| if not add_new_segment: | |
| if word['text'] in always_split or words[index-1]['text'] in always_split: | |
| add_new_segment = True | |
| # Pause too small, do not split | |
| elif word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold: | |
| add_new_segment = True | |
| if add_new_segment: # New segment | |
| first_pass_segments.append([word]) | |
| else: # Add to current segment | |
| first_pass_segments[-1].append(word) | |
| max_q_size = tokenizer.model_max_length - SAFETY_TOKENS | |
| buffer_size = OVERLAP_TOKEN_PERCENTAGE*max_q_size # tokenizer.model_max_length | |
| # In second pass, we split those segments if too big | |
| second_pass_segments = [] | |
| for segment in first_pass_segments: | |
| current_segment_num_tokens = 0 | |
| current_segment = [] | |
| for word in segment: | |
| new_seg = current_segment_num_tokens + word['num_tokens'] >= max_q_size | |
| if new_seg: | |
| # Adding this token would make it have too many tokens | |
| # We save this batch and create new | |
| second_pass_segments.append(current_segment.copy()) | |
| # Add tokens to current segment | |
| current_segment.append(word) | |
| current_segment_num_tokens += word['num_tokens'] | |
| if new_seg: | |
| # Just created a new segment, so we remove until we only have buffer_size tokens | |
| while current_segment_num_tokens > buffer_size and current_segment: | |
| first_word = current_segment.pop(0) | |
| current_segment_num_tokens -= first_word['num_tokens'] | |
| if current_segment: # Add remaining segment | |
| second_pass_segments.append(current_segment.copy()) | |
| # Cleaning up, delete 'num_tokens' from each word | |
| for segment in second_pass_segments: | |
| for word in segment: | |
| word.pop('num_tokens', None) | |
| return second_pass_segments | |
| def extract_segment(words, start, end, map_function=None): | |
| """Extracts all words with time in [start, end]""" | |
| a = binary_search(words, 0, len(words), start, True) | |
| b = min(binary_search(words, 0, len(words), end , False) + 1, len(words)) | |
| to_transform = map_function is not None and callable(map_function) | |
| return [ | |
| map_function(words[i]) if to_transform else words[i] for i in range(a, b) | |
| ] | |
| # Binary search to get first index of word whose start/end time is greater/less than some value | |
| def binary_search(words, start_index, end_index, time, below): | |
| if start_index >= end_index: | |
| return end_index | |
| middle_index = (start_index + end_index ) // 2 | |
| middle_time = word_start(words[middle_index]) if below else word_end(words[middle_index]) | |
| if time <= middle_time: | |
| return binary_search(words, start_index, middle_index, time, below) | |
| else: | |
| return binary_search(words, middle_index + 1, end_index, time, below) | |