Spaces:
Sleeping
Sleeping
| import abc | |
| import logging | |
| from typing import cast, Callable | |
| from sklearn.cluster import MiniBatchKMeans | |
| from feature_retrieval.index import NumpyArray | |
| logger = logging.getLogger(__name__) | |
| class IFeatureMatrixTransform: | |
| """Interface for transform encoded voice feature from (n_features,vector_dim) to (m_features,vector_dim)""" | |
| def transform(self, matrix: NumpyArray) -> NumpyArray: | |
| """transform given feature matrix from (n_features,vector_dim) to (m_features,vector_dim)""" | |
| raise NotImplementedError | |
| class DummyFeatureTransform(IFeatureMatrixTransform): | |
| """do nothing""" | |
| def transform(self, matrix: NumpyArray) -> NumpyArray: | |
| return matrix | |
| class MinibatchKmeansFeatureTransform(IFeatureMatrixTransform): | |
| """replaces number of examples with k-means centroids using minibatch algorythm""" | |
| def __init__(self, n_clusters: int, n_parallel: int) -> None: | |
| self._n_clusters = n_clusters | |
| self._n_parallel = n_parallel | |
| def _batch_size(self) -> int: | |
| return self._n_parallel * 256 | |
| def transform(self, matrix: NumpyArray) -> NumpyArray: | |
| """transform given feature matrix from (n_features,vector_dim) to (n_clusters,vector_dim)""" | |
| cluster = MiniBatchKMeans( | |
| n_clusters=self._n_clusters, | |
| verbose=True, | |
| batch_size=self._batch_size, | |
| compute_labels=False, | |
| init="k-means++", | |
| ) | |
| return cast(NumpyArray, cluster.fit(matrix).cluster_centers_) | |
| class OnConditionFeatureTransform(IFeatureMatrixTransform): | |
| """call given transform if condition is True else call otherwise transform""" | |
| def __init__( | |
| self, | |
| condition: Callable[[NumpyArray], bool], | |
| on_condition: IFeatureMatrixTransform, | |
| otherwise: IFeatureMatrixTransform, | |
| ) -> None: | |
| self._condition = condition | |
| self._on_condition = on_condition | |
| self._otherwise = otherwise | |
| def transform(self, matrix: NumpyArray) -> NumpyArray: | |
| if self._condition(matrix): | |
| transform_name = self._on_condition.__class__.__name__ | |
| logger.info(f"pass condition. Transform by rule {transform_name}") | |
| return self._on_condition.transform(matrix) | |
| transform_name = self._otherwise.__class__.__name__ | |
| logger.info(f"condition is not passed. Transform by rule {transform_name}") | |
| return self._otherwise.transform(matrix) | |