-
Notifications
You must be signed in to change notification settings - Fork 0
Training Pipeline
Alexander R Izquierdo edited this page Dec 21, 2024
·
1 revision
This document outlines the training process flow and methodology implementation details.
The DDPM trainer (methods/ddpm_trainer.py
) implements:
-
Velocity Prediction
- v-prediction parameterization
- Enhanced high-frequency details
- Reduced color shifting
-
SNR Weighting
- Signal-to-noise ratio gamma: 5.0
- Better detail preservation
- Improved stability
-
Zero Terminal SNR
- Infinite noise approximation
- High-frequency artifact reduction
- Enhanced contrast
The flow matching trainer (methods/flow_matching_trainer.py
) offers:
-
Optimal Transport
- Direct velocity field learning
- Faster convergence rates
- Lower memory overhead
-
Time Sampling
- Logit-normal distribution
- Better coverage of time steps
- Reduced training variance
graph TD
A[Dataset Load] --> B[Cache Check]
B -->|Cache Hit| C[Load Cached]
B -->|Cache Miss| D[Process New]
D --> E[Cache Results]
C --> F[Batch Formation]
E --> F
graph TD
A[Forward Pass] --> B[Loss Calculation]
B --> C[Gradient Accumulation]
C -->|Steps Complete| D[Optimizer Step]
D --> E[Memory Cleanup]
E --> F[Metrics Logging]
graph TD
A[Generate Samples] --> B[Quality Metrics]
B --> C[Log Images]
C --> D[Archive Results]
- Dynamic tensor offloading
- Gradient checkpointing
- Activation caching
- Async memory transfers
- Mixed precision training
- Optimal batch formation
- Worker thread management
- Pipeline prefetching
- Extend the base trainer:
from .base import BaseTrainer
class CustomTrainer(BaseTrainer):
def training_step(self, batch):
# Implement custom training logic
- Register in configuration:
@dataclass
class CustomConfig:
"""Custom training configuration."""
enabled: bool = False
# Add custom parameters
- Define in trainer:
def custom_loss(self, pred, target):
# Implement custom loss
- Update training step:
loss = self.custom_loss(predictions, targets)
- Loss curves
- Learning rate progression
- Memory utilization
- Sample quality metrics
- Training throughput
- FID scores
- CLIP scores
- Perceptual metrics
- User feedback integration
- Gradient flow monitoring
- Layer-wise updates
- Update magnitude tracking
- Convergence indicators
- Loss smoothing techniques
- Gradient clipping strategies
- Learning rate adaptation
- Batch size scaling
- Noise schedule adaptation
- Learning rate modulation
- Batch size adjustment
- Memory optimization timing
- Hypothesis validation
- Parameter sensitivity
- Performance correlation
- Quality metrics evolution
- Training speed analysis
- Memory efficiency
- Quality benchmarks
- Stability metrics
Next: See Research Guidelines for experimental protocols and Quality Metrics for evaluation methods.