demo_detoxi / model_wrapper /model_wrapper.py
Cricles's picture
Upload 5 files
5d567f2 verified
raw
history blame
489 Bytes
from bert_wrapper import BertWrapper
from fasttext_wrapper import FasttextWrapper
from frida_wrapper import FridaWrapper
from typing import Any
class ModelWrapper(object):
def __init__(self) -> None:
self.models_dict: dict[str, Any] = {
"fasttext": FasttextWrapper(),
"ru-BERT": BertWrapper(),
"FRIDA": FridaWrapper(),
}
def __call__(self, text: str, model_name: str) -> str:
return self.models_dict[model_name](text)