Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
fb87012
1
Parent(s):
02e576a
Improve caching and downloading of classifier for predictions
Browse files- src/evaluate.py +4 -3
- src/predict.py +20 -8
src/evaluate.py
CHANGED
|
@@ -205,7 +205,7 @@ def main():
|
|
| 205 |
|
| 206 |
evaluation_args, dataset_args, segmentation_args, classifier_args, _ = hf_parser.parse_args_into_dataclasses()
|
| 207 |
|
| 208 |
-
model, tokenizer = get_model_tokenizer(evaluation_args.model_path)
|
| 209 |
|
| 210 |
# # TODO find better way of evaluating videos not trained on
|
| 211 |
# dataset = load_dataset('json', data_files=os.path.join(
|
|
@@ -313,8 +313,9 @@ def main():
|
|
| 313 |
[w['text'] for w in missed_segment['words']]), '"', sep='')
|
| 314 |
print('\t\tCategory:',
|
| 315 |
missed_segment.get('category'))
|
| 316 |
-
|
| 317 |
-
|
|
|
|
| 318 |
|
| 319 |
segments_to_submit.append({
|
| 320 |
'segment': [missed_segment['start'], missed_segment['end']],
|
|
|
|
| 205 |
|
| 206 |
evaluation_args, dataset_args, segmentation_args, classifier_args, _ = hf_parser.parse_args_into_dataclasses()
|
| 207 |
|
| 208 |
+
model, tokenizer = get_model_tokenizer(evaluation_args.model_path, evaluation_args.cache_dir)
|
| 209 |
|
| 210 |
# # TODO find better way of evaluating videos not trained on
|
| 211 |
# dataset = load_dataset('json', data_files=os.path.join(
|
|
|
|
| 313 |
[w['text'] for w in missed_segment['words']]), '"', sep='')
|
| 314 |
print('\t\tCategory:',
|
| 315 |
missed_segment.get('category'))
|
| 316 |
+
if 'probability' in missed_segment:
|
| 317 |
+
print('\t\tProbability:',
|
| 318 |
+
missed_segment['probability'])
|
| 319 |
|
| 320 |
segments_to_submit.append({
|
| 321 |
'segment': [missed_segment['start'], missed_segment['end']],
|
src/predict.py
CHANGED
|
@@ -11,8 +11,8 @@ from segment import (
|
|
| 11 |
SegmentationArguments
|
| 12 |
)
|
| 13 |
import preprocess
|
| 14 |
-
from errors import TranscriptError, ModelLoadError
|
| 15 |
-
from model import get_classifier_vectorizer, get_model_tokenizer
|
| 16 |
from transformers import HfArgumentParser
|
| 17 |
from transformers.trainer_utils import get_last_checkpoint
|
| 18 |
from dataclasses import dataclass, field
|
|
@@ -29,6 +29,7 @@ class TrainingOutputArguments:
|
|
| 29 |
'help': 'Path to pretrained model used for prediction'
|
| 30 |
}
|
| 31 |
)
|
|
|
|
| 32 |
|
| 33 |
output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
|
| 34 |
'output_dir']
|
|
@@ -43,7 +44,8 @@ class TrainingOutputArguments:
|
|
| 43 |
self.model_path = last_checkpoint
|
| 44 |
return
|
| 45 |
|
| 46 |
-
raise ModelLoadError(
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
@dataclass
|
|
@@ -65,6 +67,13 @@ MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
|
|
| 65 |
|
| 66 |
@dataclass(frozen=True, eq=True)
|
| 67 |
class ClassifierArguments:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
classifier_dir: Optional[str] = field(
|
| 69 |
default='classifiers',
|
| 70 |
metadata={
|
|
@@ -90,7 +99,6 @@ class ClassifierArguments:
|
|
| 90 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
| 91 |
|
| 92 |
|
| 93 |
-
# classifier, vectorizer,
|
| 94 |
def filter_and_add_probabilities(predictions, classifier_args):
|
| 95 |
"""Use classifier to filter predictions"""
|
| 96 |
if not predictions:
|
|
@@ -160,8 +168,11 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
|
|
| 160 |
|
| 161 |
# TODO add back
|
| 162 |
if classifier_args is not None:
|
| 163 |
-
|
| 164 |
-
predictions
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
return predictions
|
| 167 |
|
|
@@ -290,7 +301,7 @@ def main():
|
|
| 290 |
print('No video ID supplied. Use `--video_id`.')
|
| 291 |
return
|
| 292 |
|
| 293 |
-
model, tokenizer = get_model_tokenizer(predict_args.model_path)
|
| 294 |
|
| 295 |
predict_args.video_id = predict_args.video_id.strip()
|
| 296 |
predictions = predict(predict_args.video_id, model, tokenizer,
|
|
@@ -308,8 +319,9 @@ def main():
|
|
| 308 |
' '.join([w['text'] for w in prediction['words']]), '"', sep='')
|
| 309 |
print('Time:', seconds_to_time(
|
| 310 |
prediction['start']), '\u2192', seconds_to_time(prediction['end']))
|
| 311 |
-
print('Probability:', prediction.get('probability'))
|
| 312 |
print('Category:', prediction.get('category'))
|
|
|
|
|
|
|
| 313 |
print()
|
| 314 |
|
| 315 |
|
|
|
|
| 11 |
SegmentationArguments
|
| 12 |
)
|
| 13 |
import preprocess
|
| 14 |
+
from errors import TranscriptError, ModelLoadError, ClassifierLoadError
|
| 15 |
+
from model import ModelArguments, get_classifier_vectorizer, get_model_tokenizer
|
| 16 |
from transformers import HfArgumentParser
|
| 17 |
from transformers.trainer_utils import get_last_checkpoint
|
| 18 |
from dataclasses import dataclass, field
|
|
|
|
| 29 |
'help': 'Path to pretrained model used for prediction'
|
| 30 |
}
|
| 31 |
)
|
| 32 |
+
cache_dir: Optional[str] = ModelArguments.__dataclass_fields__['cache_dir']
|
| 33 |
|
| 34 |
output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
|
| 35 |
'output_dir']
|
|
|
|
| 44 |
self.model_path = last_checkpoint
|
| 45 |
return
|
| 46 |
|
| 47 |
+
raise ModelLoadError(
|
| 48 |
+
'Unable to find model, explicitly set `--model_path`')
|
| 49 |
|
| 50 |
|
| 51 |
@dataclass
|
|
|
|
| 67 |
|
| 68 |
@dataclass(frozen=True, eq=True)
|
| 69 |
class ClassifierArguments:
|
| 70 |
+
classifier_model: Optional[str] = field(
|
| 71 |
+
default='Xenova/sponsorblock-classifier',
|
| 72 |
+
metadata={
|
| 73 |
+
'help': 'Use a pretrained classifier'
|
| 74 |
+
}
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
classifier_dir: Optional[str] = field(
|
| 78 |
default='classifiers',
|
| 79 |
metadata={
|
|
|
|
| 99 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
| 100 |
|
| 101 |
|
|
|
|
| 102 |
def filter_and_add_probabilities(predictions, classifier_args):
|
| 103 |
"""Use classifier to filter predictions"""
|
| 104 |
if not predictions:
|
|
|
|
| 168 |
|
| 169 |
# TODO add back
|
| 170 |
if classifier_args is not None:
|
| 171 |
+
try:
|
| 172 |
+
predictions = filter_and_add_probabilities(
|
| 173 |
+
predictions, classifier_args)
|
| 174 |
+
except ClassifierLoadError:
|
| 175 |
+
print('Unable to load classifer')
|
| 176 |
|
| 177 |
return predictions
|
| 178 |
|
|
|
|
| 301 |
print('No video ID supplied. Use `--video_id`.')
|
| 302 |
return
|
| 303 |
|
| 304 |
+
model, tokenizer = get_model_tokenizer(predict_args.model_path, predict_args.cache_dir)
|
| 305 |
|
| 306 |
predict_args.video_id = predict_args.video_id.strip()
|
| 307 |
predictions = predict(predict_args.video_id, model, tokenizer,
|
|
|
|
| 319 |
' '.join([w['text'] for w in prediction['words']]), '"', sep='')
|
| 320 |
print('Time:', seconds_to_time(
|
| 321 |
prediction['start']), '\u2192', seconds_to_time(prediction['end']))
|
|
|
|
| 322 |
print('Category:', prediction.get('category'))
|
| 323 |
+
if 'probability' in prediction:
|
| 324 |
+
print('Probability:', prediction['probability'])
|
| 325 |
print()
|
| 326 |
|
| 327 |
|