File size: 1,760 Bytes
3190e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from dataclasses import dataclass
from typing import List

@dataclass
class PipelineConfig:
    """
    Config for the pipeline
    """
    # Validation settings
    min_length: int = 1
    max_length: int = 512
    min_tokens: int = 1
    max_tokens: int = 128

    allowed_speakers: List[str] = None
    required_fields: List[str] = None

    # Text augmentation settings
    augmentation_factor: int = 4
    augmentation_techniques: List[str] = None
    
    max_turns_per_dialogue: int = 6
    max_variations_per_turn: int = 3
    max_sampled_variations: int = 2
    max_complexity_threshold: int = 100
    complexity_reduction_turns: int = 4

    # Quality thresholds
    semantic_similarity_threshold: float = 0.45
    grammar_error_threshold: int = 2
    rouge1_f1_threshold: float = 0.30
    rouge2_f1_threshold: float = 0.15
    
    # Response coherence thresholds
    min_response_coherence: float = 0.3
    context_similarity_weight: float = 0.35
    response_coherence_weight: float = 0.65

    # Performance settings
    batch_size: int = 32
    use_cache: bool = True
    debug: bool = False
    
    context_window_size: int = 4

    def __post_init__(self):
        if self.allowed_speakers is None:
            self.allowed_speakers = ['user', 'assistant']
        if self.required_fields is None:
            self.required_fields = ['dialogue_id', 'turns']
        if self.augmentation_techniques is None:
            self.augmentation_techniques = ['paraphrase', 'back_translation']
            
        # Validate weights sum to 1.0
        if abs((self.context_similarity_weight + self.response_coherence_weight) - 1.0) > 1e-6:
            raise ValueError("Context similarity and response coherence weights must sum to 1.0")