|
from dataclasses import dataclass |
|
from typing import List |
|
|
|
@dataclass |
|
class PipelineConfig: |
|
""" |
|
Config for the pipeline |
|
""" |
|
|
|
min_length: int = 1 |
|
max_length: int = 512 |
|
min_tokens: int = 1 |
|
max_tokens: int = 128 |
|
|
|
allowed_speakers: List[str] = None |
|
required_fields: List[str] = None |
|
|
|
|
|
augmentation_factor: int = 4 |
|
augmentation_techniques: List[str] = None |
|
|
|
max_turns_per_dialogue: int = 6 |
|
max_variations_per_turn: int = 3 |
|
max_sampled_variations: int = 2 |
|
max_complexity_threshold: int = 100 |
|
complexity_reduction_turns: int = 4 |
|
|
|
|
|
semantic_similarity_threshold: float = 0.45 |
|
grammar_error_threshold: int = 2 |
|
rouge1_f1_threshold: float = 0.30 |
|
rouge2_f1_threshold: float = 0.15 |
|
|
|
|
|
min_response_coherence: float = 0.3 |
|
context_similarity_weight: float = 0.35 |
|
response_coherence_weight: float = 0.65 |
|
|
|
|
|
batch_size: int = 32 |
|
use_cache: bool = True |
|
debug: bool = False |
|
|
|
context_window_size: int = 4 |
|
|
|
def __post_init__(self): |
|
if self.allowed_speakers is None: |
|
self.allowed_speakers = ['user', 'assistant'] |
|
if self.required_fields is None: |
|
self.required_fields = ['dialogue_id', 'turns'] |
|
if self.augmentation_techniques is None: |
|
self.augmentation_techniques = ['paraphrase', 'back_translation'] |
|
|
|
|
|
if abs((self.context_similarity_weight + self.response_coherence_weight) - 1.0) > 1e-6: |
|
raise ValueError("Context similarity and response coherence weights must sum to 1.0") |
|
|