|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Dict |
|
|
|
@dataclass |
|
class ChatbotConfig: |
|
""" |
|
All config params for the chatbot |
|
""" |
|
max_context_length: int = 512 |
|
embedding_dim: int = 384 |
|
learning_rate: float = 0.0005 |
|
min_text_length: int = 3 |
|
max_context_turns: int = 24 |
|
pretrained_model: str = 'sentence-transformers/all-MiniLM-L6-v2' |
|
cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2' |
|
summarizer_model: str = 't5-small' |
|
embedding_batch_size: int = 64 |
|
search_batch_size: int = 64 |
|
max_batch_size: int = 64 |
|
neg_samples: int = 10 |
|
max_retries: int = 3 |
|
nlist: int = 100 |
|
|
|
def to_dict(self) -> Dict: |
|
"""Convert config to dictionary.""" |
|
return {k: (str(v) if isinstance(v, Path) else v) |
|
for k, v in self.__dict__.items()} |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: Dict) -> 'ChatbotConfig': |
|
"""Create config from dictionary.""" |
|
return cls(**{k: v for k, v in config_dict.items() |
|
if k in cls.__dataclass_fields__}) |