Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Joshua Lochner
		
	commited on
		
		
					Commit 
							
							·
						
						25f1183
	
1
								Parent(s):
							
							320a2ba
								
Use multiclass classifier to filter predictions
Browse files
    	
        out/runs/Jan18_13-34-23_DESKTOP-I39NJG7/1642505668.7632372/events.out.tfevents.1642505668.DESKTOP-I39NJG7.27016.1
    ADDED
    
    | Binary file (5.12 kB). View file | 
|  | 
    	
        out/runs/Jan18_13-34-23_DESKTOP-I39NJG7/events.out.tfevents.1642505668.DESKTOP-I39NJG7.27016.0
    ADDED
    
    | Binary file (3.51 kB). View file | 
|  | 
    	
        src/predict.py
    CHANGED
    
    | @@ -1,3 +1,4 @@ | |
|  | |
| 1 | 
             
            from utils import re_findall
         | 
| 2 | 
             
            from shared import OutputArguments
         | 
| 3 | 
             
            from typing import Optional
         | 
| @@ -25,6 +26,7 @@ import logging | |
| 25 |  | 
| 26 | 
             
            import re
         | 
| 27 |  | 
|  | |
| 28 | 
             
            def seconds_to_time(seconds, remove_leading_zeroes=False):
         | 
| 29 | 
             
                fractional = round(seconds % 1, 3)
         | 
| 30 | 
             
                fractional = '' if fractional == 0 else str(fractional)[1:]
         | 
| @@ -35,6 +37,7 @@ def seconds_to_time(seconds, remove_leading_zeroes=False): | |
| 35 | 
             
                    hms = re.sub(r'^0(?:0:0?)?', '', hms)
         | 
| 36 | 
             
                return f"{'-' if seconds < 0 else ''}{hms}{fractional}"
         | 
| 37 |  | 
|  | |
| 38 | 
             
            @dataclass
         | 
| 39 | 
             
            class TrainingOutputArguments:
         | 
| 40 |  | 
| @@ -68,13 +71,15 @@ class PredictArguments(TrainingOutputArguments): | |
| 68 | 
             
                )
         | 
| 69 |  | 
| 70 |  | 
| 71 | 
            -
             | 
|  | |
|  | |
| 72 |  | 
| 73 | 
             
            MATCH_WINDOW = 25       # Increase for accuracy, but takes longer: O(n^3)
         | 
| 74 | 
             
            MERGE_TIME_WITHIN = 8   # Merge predictions if they are within x seconds
         | 
| 75 |  | 
| 76 |  | 
| 77 | 
            -
            @dataclass
         | 
| 78 | 
             
            class ClassifierArguments:
         | 
| 79 | 
             
                classifier_dir: Optional[str] = field(
         | 
| 80 | 
             
                    default='classifiers',
         | 
| @@ -101,7 +106,7 @@ class ClassifierArguments: | |
| 101 | 
             
                    default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
         | 
| 102 |  | 
| 103 |  | 
| 104 | 
            -
            def  | 
| 105 | 
             
                """Use classifier to filter predictions"""
         | 
| 106 | 
             
                if not predictions:
         | 
| 107 | 
             
                    return predictions
         | 
| @@ -114,14 +119,34 @@ def filter_predictions(predictions, classifier_args):  # classifier, vectorizer, | |
| 114 | 
             
                ])
         | 
| 115 | 
             
                probabilities = classifier.predict_proba(transformed_segments)
         | 
| 116 |  | 
|  | |
|  | |
| 117 | 
             
                filtered_predictions = []
         | 
| 118 | 
            -
                for prediction,  | 
| 119 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 120 |  | 
| 121 | 
            -
                     | 
| 122 | 
            -
             | 
| 123 | 
            -
                    #  | 
| 124 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 125 |  | 
| 126 | 
             
                return filtered_predictions
         | 
| 127 |  | 
| @@ -140,7 +165,6 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie | |
| 140 | 
             
                )
         | 
| 141 |  | 
| 142 | 
             
                predictions = segments_to_predictions(segments, model, tokenizer)
         | 
| 143 | 
            -
             | 
| 144 | 
             
                # Add words back to time_ranges
         | 
| 145 | 
             
                for prediction in predictions:
         | 
| 146 | 
             
                    # Stores words in the range
         | 
| @@ -148,8 +172,8 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie | |
| 148 | 
             
                        words, prediction['start'], prediction['end'])
         | 
| 149 |  | 
| 150 | 
             
                # TODO add back
         | 
| 151 | 
            -
                 | 
| 152 | 
            -
             | 
| 153 |  | 
| 154 | 
             
                return predictions
         | 
| 155 |  | 
| @@ -171,6 +195,9 @@ def greedy_match(list, sublist): | |
| 171 | 
             
                return best_i, best_j, best_k
         | 
| 172 |  | 
| 173 |  | 
|  | |
|  | |
|  | |
| 174 | 
             
            def predict_sponsor_text(text, model, tokenizer):
         | 
| 175 | 
             
                """Given a body of text, predict the words which are part of the sponsor"""
         | 
| 176 | 
             
                input_ids = tokenizer(
         | 
| @@ -189,7 +216,7 @@ def predict_sponsor_matches(text, model, tokenizer): | |
| 189 | 
             
                if CustomTokens.NO_SEGMENT.value in sponsorship_text:
         | 
| 190 | 
             
                    return []
         | 
| 191 |  | 
| 192 | 
            -
                return re_findall( | 
| 193 |  | 
| 194 |  | 
| 195 | 
             
            def segments_to_predictions(segments, model, tokenizer):
         | 
| @@ -237,12 +264,11 @@ def segments_to_predictions(segments, model, tokenizer): | |
| 237 | 
             
                    start_time = range['start']
         | 
| 238 | 
             
                    end_time = range['end']
         | 
| 239 |  | 
| 240 | 
            -
                    if prev_prediction is not None and  | 
| 241 | 
            -
             | 
| 242 | 
            -
             | 
| 243 | 
            -
             | 
| 244 | 
            -
                        #  | 
| 245 | 
            -
                        # so we extend last prediction range
         | 
| 246 | 
             
                        final_predicted_time_ranges[-1]['end'] = end_time
         | 
| 247 |  | 
| 248 | 
             
                    else:  # No overlap, is a new prediction
         | 
| @@ -279,7 +305,7 @@ def main(): | |
| 279 |  | 
| 280 | 
             
                predict_args.video_id = predict_args.video_id.strip()
         | 
| 281 | 
             
                predictions = predict(predict_args.video_id, model, tokenizer,
         | 
| 282 | 
            -
                                      segmentation_args | 
| 283 |  | 
| 284 | 
             
                video_url = f'https://www.youtube.com/watch?v={predict_args.video_id}'
         | 
| 285 | 
             
                if not predictions:
         | 
| @@ -292,7 +318,7 @@ def main(): | |
| 292 | 
             
                    print('Text: "',
         | 
| 293 | 
             
                          ' '.join([w['text'] for w in prediction['words']]), '"', sep='')
         | 
| 294 | 
             
                    print('Time:', seconds_to_time(
         | 
| 295 | 
            -
                        prediction['start']), ' | 
| 296 | 
             
                    print('Probability:', prediction.get('probability'))
         | 
| 297 | 
             
                    print('Category:', prediction.get('category'))
         | 
| 298 | 
             
                    print()
         | 
|  | |
| 1 | 
            +
            from shared import START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE
         | 
| 2 | 
             
            from utils import re_findall
         | 
| 3 | 
             
            from shared import OutputArguments
         | 
| 4 | 
             
            from typing import Optional
         | 
|  | |
| 26 |  | 
| 27 | 
             
            import re
         | 
| 28 |  | 
| 29 | 
            +
             | 
| 30 | 
             
            def seconds_to_time(seconds, remove_leading_zeroes=False):
         | 
| 31 | 
             
                fractional = round(seconds % 1, 3)
         | 
| 32 | 
             
                fractional = '' if fractional == 0 else str(fractional)[1:]
         | 
|  | |
| 37 | 
             
                    hms = re.sub(r'^0(?:0:0?)?', '', hms)
         | 
| 38 | 
             
                return f"{'-' if seconds < 0 else ''}{hms}{fractional}"
         | 
| 39 |  | 
| 40 | 
            +
             | 
| 41 | 
             
            @dataclass
         | 
| 42 | 
             
            class TrainingOutputArguments:
         | 
| 43 |  | 
|  | |
| 71 | 
             
                )
         | 
| 72 |  | 
| 73 |  | 
| 74 | 
            +
            _SEGMENT_START = START_SEGMENT_TEMPLATE.format(r'(?P<category>\w+)')
         | 
| 75 | 
            +
            _SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
         | 
| 76 | 
            +
            SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
         | 
| 77 |  | 
| 78 | 
             
            MATCH_WINDOW = 25       # Increase for accuracy, but takes longer: O(n^3)
         | 
| 79 | 
             
            MERGE_TIME_WITHIN = 8   # Merge predictions if they are within x seconds
         | 
| 80 |  | 
| 81 |  | 
| 82 | 
            +
            @dataclass(frozen=True, eq=True)
         | 
| 83 | 
             
            class ClassifierArguments:
         | 
| 84 | 
             
                classifier_dir: Optional[str] = field(
         | 
| 85 | 
             
                    default='classifiers',
         | 
|  | |
| 106 | 
             
                    default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
         | 
| 107 |  | 
| 108 |  | 
| 109 | 
            +
            def add_predictions(predictions, classifier_args):  # classifier, vectorizer,
         | 
| 110 | 
             
                """Use classifier to filter predictions"""
         | 
| 111 | 
             
                if not predictions:
         | 
| 112 | 
             
                    return predictions
         | 
|  | |
| 119 | 
             
                ])
         | 
| 120 | 
             
                probabilities = classifier.predict_proba(transformed_segments)
         | 
| 121 |  | 
| 122 | 
            +
                # Transformer sometimes says segment is of another category, so we
         | 
| 123 | 
            +
                # update category and probabilities if classifier is confident it is another category
         | 
| 124 | 
             
                filtered_predictions = []
         | 
| 125 | 
            +
                for prediction, probabilities in zip(predictions, probabilities):
         | 
| 126 | 
            +
                    predicted_probabilities = {k: v for k,
         | 
| 127 | 
            +
                                               v in zip(CATEGORIES, probabilities)}
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # Get best category + probability
         | 
| 130 | 
            +
                    classifier_category = max(
         | 
| 131 | 
            +
                        predicted_probabilities, key=predicted_probabilities.get)
         | 
| 132 | 
            +
                    classifier_probability = predicted_probabilities[classifier_category]
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    if classifier_category is None and classifier_probability > classifier_args.min_probability:
         | 
| 135 | 
            +
                        continue  # Ignore
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    if classifier_category is not None and classifier_probability > 0.5:  # TODO make param
         | 
| 138 | 
            +
                        # Confident enough to overrule, so we update category
         | 
| 139 | 
            +
                        prediction['category'] = classifier_category
         | 
| 140 |  | 
| 141 | 
            +
                    prediction['probability'] = predicted_probabilities[prediction['category']]
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    # TODO add probabilities, but remove None and normalise rest
         | 
| 144 | 
            +
                    prediction['probabilities'] = predicted_probabilities
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    # if prediction['probability'] < classifier_args.min_probability:
         | 
| 147 | 
            +
                    #     continue
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    filtered_predictions.append(prediction)
         | 
| 150 |  | 
| 151 | 
             
                return filtered_predictions
         | 
| 152 |  | 
|  | |
| 165 | 
             
                )
         | 
| 166 |  | 
| 167 | 
             
                predictions = segments_to_predictions(segments, model, tokenizer)
         | 
|  | |
| 168 | 
             
                # Add words back to time_ranges
         | 
| 169 | 
             
                for prediction in predictions:
         | 
| 170 | 
             
                    # Stores words in the range
         | 
|  | |
| 172 | 
             
                        words, prediction['start'], prediction['end'])
         | 
| 173 |  | 
| 174 | 
             
                # TODO add back
         | 
| 175 | 
            +
                if classifier_args is not None:
         | 
| 176 | 
            +
                    predictions = add_predictions(predictions, classifier_args)
         | 
| 177 |  | 
| 178 | 
             
                return predictions
         | 
| 179 |  | 
|  | |
| 195 | 
             
                return best_i, best_j, best_k
         | 
| 196 |  | 
| 197 |  | 
| 198 | 
            +
            CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
         | 
| 199 | 
            +
             | 
| 200 | 
            +
             | 
| 201 | 
             
            def predict_sponsor_text(text, model, tokenizer):
         | 
| 202 | 
             
                """Given a body of text, predict the words which are part of the sponsor"""
         | 
| 203 | 
             
                input_ids = tokenizer(
         | 
|  | |
| 216 | 
             
                if CustomTokens.NO_SEGMENT.value in sponsorship_text:
         | 
| 217 | 
             
                    return []
         | 
| 218 |  | 
| 219 | 
            +
                return re_findall(SEGMENT_MATCH_RE, sponsorship_text)
         | 
| 220 |  | 
| 221 |  | 
| 222 | 
             
            def segments_to_predictions(segments, model, tokenizer):
         | 
|  | |
| 264 | 
             
                    start_time = range['start']
         | 
| 265 | 
             
                    end_time = range['end']
         | 
| 266 |  | 
| 267 | 
            +
                    if prev_prediction is not None and \
         | 
| 268 | 
            +
                            (start_time <= prev_prediction['end'] <= end_time or    # Merge overlapping segments
         | 
| 269 | 
            +
                                (range['category'] == prev_prediction['category']   # Merge disconnected segments if same category and within threshold
         | 
| 270 | 
            +
                                    and start_time - prev_prediction['end'] <= MERGE_TIME_WITHIN)):
         | 
| 271 | 
            +
                        # Extend last prediction range
         | 
|  | |
| 272 | 
             
                        final_predicted_time_ranges[-1]['end'] = end_time
         | 
| 273 |  | 
| 274 | 
             
                    else:  # No overlap, is a new prediction
         | 
|  | |
| 305 |  | 
| 306 | 
             
                predict_args.video_id = predict_args.video_id.strip()
         | 
| 307 | 
             
                predictions = predict(predict_args.video_id, model, tokenizer,
         | 
| 308 | 
            +
                                      segmentation_args, classifier_args=classifier_args)
         | 
| 309 |  | 
| 310 | 
             
                video_url = f'https://www.youtube.com/watch?v={predict_args.video_id}'
         | 
| 311 | 
             
                if not predictions:
         | 
|  | |
| 318 | 
             
                    print('Text: "',
         | 
| 319 | 
             
                          ' '.join([w['text'] for w in prediction['words']]), '"', sep='')
         | 
| 320 | 
             
                    print('Time:', seconds_to_time(
         | 
| 321 | 
            +
                        prediction['start']), '\u2192', seconds_to_time(prediction['end']))
         | 
| 322 | 
             
                    print('Probability:', prediction.get('probability'))
         | 
| 323 | 
             
                    print('Category:', prediction.get('category'))
         | 
| 324 | 
             
                    print()
         |