csc525_retrieval_based_chatbot / pipeline_config.py
JoeArmani
updates through 4th iteration
300fe5d
raw
history blame
1.76 kB
from dataclasses import dataclass
from typing import List
@dataclass
class PipelineConfig:
"""
Config for the pipeline
"""
# Validation settings
min_length: int = 1
max_length: int = 512
min_tokens: int = 1
max_tokens: int = 128
allowed_speakers: List[str] = None
required_fields: List[str] = None
# Text augmentation settings
augmentation_factor: int = 4
augmentation_techniques: List[str] = None
max_turns_per_dialogue: int = 6
max_variations_per_turn: int = 3
max_sampled_variations: int = 2
max_complexity_threshold: int = 100
complexity_reduction_turns: int = 4
# Quality thresholds
semantic_similarity_threshold: float = 0.45
grammar_error_threshold: int = 2
rouge1_f1_threshold: float = 0.30
rouge2_f1_threshold: float = 0.15
# Response coherence thresholds
min_response_coherence: float = 0.3
context_similarity_weight: float = 0.35
response_coherence_weight: float = 0.65
# Performance settings
batch_size: int = 32
use_cache: bool = True
debug: bool = False
context_window_size: int = 4
def __post_init__(self):
if self.allowed_speakers is None:
self.allowed_speakers = ['user', 'assistant']
if self.required_fields is None:
self.required_fields = ['dialogue_id', 'turns']
if self.augmentation_techniques is None:
self.augmentation_techniques = ['paraphrase', 'back_translation']
# Validate weights sum to 1.0
if abs((self.context_similarity_weight + self.response_coherence_weight) - 1.0) > 1e-6:
raise ValueError("Context similarity and response coherence weights must sum to 1.0")