File size: 7,081 Bytes
960b1a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92da7ef
960b1a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92da7ef
960b1a0
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# ---------------------------
# Настройки корпусов данных
# ---------------------------

[datasets.meld]
base_dir = "E:/MELD"
csv_path = "{base_dir}/meld_{split}_labels.csv"
wav_dir  = "{base_dir}/wavs/{split}"

[datasets.resd]
base_dir = "E:/RESD"
csv_path = "{base_dir}/resd_{split}_labels.csv"
wav_dir  = "{base_dir}/wavs/{split}"

[synthetic_data]
use_synthetic_data = false
synthetic_path = "E:/MELD_S"
synthetic_ratio = 0.005

# ---------------------------
# Список модальностей и эмоций
# ---------------------------
modalities = ["audio"]
# emotion_columns = ["neutral", "happy", "sad", "anger", "surprise", "disgust", "fear"]
emotion_columns = ["anger", "disgust", "fear", "happy", "neutral", "sad", "surprise"]

# ---------------------------
# DataLoader параметры
# ---------------------------
[dataloader]
num_workers = 0
shuffle = true
prepare_only = false

# ---------------------------
# Аудио
# ---------------------------
[audio]
sample_rate = 16000           # Целевая частота дискретизации
wav_length = 4                # Целевая длина (в секундах) для аудио
save_merged_audio = true
merged_audio_base_path = "saved_merges"
merged_audio_suffix = "_merged"
force_remerge = false

# ---------------------------
# Whisper и текст
# ---------------------------
[text]
# Если "csv", то мы стараемся брать текст из CSV, если там есть
# (поле text_column). Если нет - тогда Whisper (если нужно).
source = "csv"
text_column = "text"
whisper_model = "base"

# Указываем, где запускать Whisper: "cuda" (GPU) или "cpu"
whisper_device = "cpu"

# Если для dev/test в CSV нет текста, нужно ли всё же вызывать Whisper?
use_whisper_for_nontrain_if_no_text = true

# ---------------------------
# Общие параметры тренировки
# ---------------------------
[train.general]
random_seed = 42         # фиксируем random seed для воспроизводимости (0 = каждый раз разный)
subset_size = 100         # ограничение на количество примеров (0 = использовать весь датасет)
merge_probability = 0    # процент склеивания коротких файлов
batch_size = 8         # размер батча
num_epochs = 75           # число эпох тренировки
max_patience = 10        # максимальное число эпох без улучшений (для Early Stopping)
save_best_model = false
save_prepared_data = true # сохранять извлеченные признаки (эмбеддинги)
save_feature_path = './features/' # путь для сохранения эмбеддингов
search_type = "none" # стратегия поиска: "greedy", "exhaustive" или "none"
path_to_df_ls = 'Phi-4-mini-instruct_emotions_union.csv'  # путь к датафрейму со смягченными метками - Qwen3-4B_emotions_union или Phi-4-mini-instruct_emotions_union
smoothing_probability = 0.0 # процент использования смягченных меток

# ---------------------------
# Параметры модели
# ---------------------------
[train.model]
model_name = "BiFormer"    # название модели (BiGraphFormer, BiFormer, BiGatedGraphFormer, BiGatedFormer, BiMamba, PredictionsFusion, BiFormerWithProb, BiMambaWithProb, BiGraphFormerWithProb, BiGatedGraphFormerWithProb)
hidden_dim = 256            # размер скрытого состояния
hidden_dim_gated = 128      # скрытое состояние для gated механизмов
num_transformer_heads = 16   # количество attention голов в трансформере
num_graph_heads = 2         # количество голов в граф-механизме
tr_layer_number = 5         # количество слоев в трансформере
mamba_d_state = 16          # размер состояния в Mamba
mamba_ker_size = 6          # размер кернела в Mamba
mamba_layer_number = 5      # количество слоев Mamba
positional_encoding = false # использовать ли позиционное кодирование
dropout = 0.15              # dropout между слоями
out_features = 256          # размер финальных признаков перед классификацией
mode = 'mean'               # способ агрегации признаков (например, "mean", "max", и т.д.)

# ---------------------------
# Параметры оптимизатора
# ---------------------------
[train.optimizer]
optimizer = "adam"        # тип оптимизатора: "adam", "adamw", "lion", "sgd", "rmsprop"
lr = 1e-4                 # начальная скорость обучения
weight_decay = 0.0        # weight decay для регуляризации
momentum = 0.9            # momentum (используется только в SGD)

# ---------------------------
# Параметры шедулера
# ---------------------------
[train.scheduler]
scheduler_type = "plateau" # тип шедулера: "none", "plateau", "cosine", "onecycle" ил  и HuggingFace-стиль ("huggingface_linear", "huggingface_cosine" "huggingface_cosine_with_restarts" и т.д.)
warmup_ratio = 0.1         # отношение количества warmup-итераций к общему числу шагов (0.1 = 10%)

[embeddings]
# audio_model = "amiriparian/ExHuBERT"  # Hugging Face имя модели для аудио
audio_model = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"  # Hugging Face имя модели для аудио
audio_classifier_checkpoint = "best_audio_model_2.pt"
text_classifier_checkpoint = "best_text_model.pth"
text_model = "jinaai/jina-embeddings-v3"  # Hugging Face имя модели для текста
audio_embedding_dim = 256  # размерность аудио-эмбеддинга
text_embedding_dim = 1024   # размерность текст-эмбеддинга
emb_normalize = false  # нормализовать ли вектор L2-нормой
max_tokens = 95         # ограничение на длину текста (токенов) при токенизации
device = "cpu"          # "cuda" или "cpu", куда грузить модель

# audio_pooling = "mean"        # "mean", "cls", "max", "min", "last", "attention"
# text_pooling = "cls"          # "mean", "cls", "max", "min", "last", "sum", "attention"

[textgen]
# model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"  # deepseek-ai/deepseek-llm-1.3b-base или любая другая модель
# model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"  # deepseek-ai/deepseek-llm-1.3b-base или любая другая модель
max_new_tokens = 50
temperature = 1.0
top_p = 0.95