File size: 6,779 Bytes
83dd2a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import hydra
from omegaconf import DictConfig, OmegaConf
import torch
from torch.utils.data import DataLoader
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from datasets import load_dataset
import mlflow
import wandb
from pathlib import Path
import logging
from typing import Dict, Any
import numpy as np
from tqdm import tqdm

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class FormIQTrainer:
    def __init__(self, config: DictConfig):
        """Initialize the trainer with configuration."""
        self.config = config
        self.device = torch.device(config.model.device)
        
        # Initialize model and processor
        self.processor = LayoutLMv3Processor.from_pretrained(config.model.name)
        self.model = LayoutLMv3ForTokenClassification.from_pretrained(
            config.model.name,
            num_labels=config.model.num_labels
        )
        self.model.to(self.device)
        
        # Initialize optimizer
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=config.training.learning_rate,
            weight_decay=config.training.weight_decay
        )
        
        # Setup logging
        self.setup_logging()
        
    def setup_logging(self):
        """Setup MLflow and W&B logging."""
        if self.config.logging.mlflow.enabled:
            mlflow.set_tracking_uri(self.config.logging.mlflow.tracking_uri)
            mlflow.set_experiment(self.config.logging.mlflow.experiment_name)
            
        if self.config.logging.wandb.enabled:
            wandb.init(
                project=self.config.logging.wandb.project,
                entity=self.config.logging.wandb.entity,
                config=OmegaConf.to_container(self.config, resolve=True)
            )
    
    def prepare_dataset(self):
        """Prepare the dataset for training."""
        # TODO: Implement dataset preparation
        # This is a placeholder implementation
        return None, None
    
    def train_epoch(self, train_loader: DataLoader) -> Dict[str, float]:
        """Train for one epoch.
        
        Args:
            train_loader: DataLoader for training data
            
        Returns:
            Dictionary containing training metrics
        """
        self.model.train()
        total_loss = 0
        correct_predictions = 0
        total_predictions = 0
        
        progress_bar = tqdm(train_loader, desc="Training")
        for batch in progress_bar:
            # Move batch to device
            batch = {k: v.to(self.device) for k, v in batch.items()}
            
            # Forward pass
            outputs = self.model(**batch)
            loss = outputs.loss
            
            # Backward pass
            loss.backward()
            
            # Update weights
            self.optimizer.step()
            self.optimizer.zero_grad()
            
            # Update metrics
            total_loss += loss.item()
            predictions = outputs.logits.argmax(-1)
            correct_predictions += (predictions == batch["labels"]).sum().item()
            total_predictions += batch["labels"].numel()
            
            # Update progress bar
            progress_bar.set_postfix({
                "loss": loss.item(),
                "accuracy": correct_predictions / total_predictions
            })
        
        # Calculate epoch metrics
        metrics = {
            "train_loss": total_loss / len(train_loader),
            "train_accuracy": correct_predictions / total_predictions
        }
        
        return metrics
    
    def evaluate(self, eval_loader: DataLoader) -> Dict[str, float]:
        """Evaluate the model.
        
        Args:
            eval_loader: DataLoader for evaluation data
            
        Returns:
            Dictionary containing evaluation metrics
        """
        self.model.eval()
        total_loss = 0
        correct_predictions = 0
        total_predictions = 0
        
        with torch.no_grad():
            for batch in tqdm(eval_loader, desc="Evaluating"):
                # Move batch to device
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                # Forward pass
                outputs = self.model(**batch)
                loss = outputs.loss
                
                # Update metrics
                total_loss += loss.item()
                predictions = outputs.logits.argmax(-1)
                correct_predictions += (predictions == batch["labels"]).sum().item()
                total_predictions += batch["labels"].numel()
        
        # Calculate evaluation metrics
        metrics = {
            "eval_loss": total_loss / len(eval_loader),
            "eval_accuracy": correct_predictions / total_predictions
        }
        
        return metrics
    
    def train(self):
        """Train the model."""
        # Prepare datasets
        train_loader, eval_loader = self.prepare_dataset()
        
        # Training loop
        best_eval_loss = float('inf')
        for epoch in range(self.config.training.num_epochs):
            logger.info(f"Epoch {epoch + 1}/{self.config.training.num_epochs}")
            
            # Train
            train_metrics = self.train_epoch(train_loader)
            
            # Evaluate
            eval_metrics = self.evaluate(eval_loader)
            
            # Log metrics
            metrics = {**train_metrics, **eval_metrics}
            if self.config.logging.mlflow.enabled:
                mlflow.log_metrics(metrics, step=epoch)
            if self.config.logging.wandb.enabled:
                wandb.log(metrics, step=epoch)
            
            # Save best model
            if eval_metrics["eval_loss"] < best_eval_loss:
                best_eval_loss = eval_metrics["eval_loss"]
                self.save_model("best_model")
            
            # Save checkpoint
            self.save_model(f"checkpoint_epoch_{epoch + 1}")
    
    def save_model(self, name: str):
        """Save the model.
        
        Args:
            name: Name of the saved model
        """
        save_path = Path(self.config.model.save_dir) / name
        save_path.mkdir(parents=True, exist_ok=True)
        
        self.model.save_pretrained(save_path)
        self.processor.save_pretrained(save_path)
        
        if self.config.logging.mlflow.enabled:
            mlflow.log_artifacts(str(save_path), f"models/{name}")

@hydra.main(config_path="../config", config_name="config")
def main(config: DictConfig):
    """Main training function."""
    trainer = FormIQTrainer(config)
    trainer.train()

if __name__ == "__main__":
    main()