Trainer¶
A flexible, modular training loop designed for GNN models in the GridFM framework. Handles training, validation, early stopping, learning rate scheduling, and plugin callbacks.
Trainer
¶
A flexible training loop for GridFM models with optional validation, learning rate scheduling, and plugin callbacks for logging or custom behavior.
Attributes:
Name | Type | Description |
---|---|---|
model |
Module
|
The PyTorch model to train. |
optimizer |
Optimizer
|
The optimizer used for updating model parameters. |
device |
The device to train on (CPU or CUDA). |
|
loss_fn |
Module
|
Loss function that returns a loss dictionary. |
early_stopper |
EarlyStopper
|
Callback for early stopping based on validation loss. |
train_dataloader |
DataLoader
|
Dataloader for training data. |
val_dataloader |
DataLoader
|
Dataloader for validation data. |
lr_scheduler |
optional
|
Learning rate scheduler. |
plugins |
List[TrainerPlugin]
|
List of plugin callbacks. |
Source code in gridfm_graphkit/training/trainer.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
|
train(start_epoch=0, epochs=1, prev_step=-1)
¶Main training loop.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
start_epoch
|
int
|
Epoch to start training from. |
0
|
epochs
|
int
|
Total number of epochs to train. |
1
|
prev_step
|
int
|
Previous training step (for logging continuity). |
-1
|
Source code in gridfm_graphkit/training/trainer.py
144 145 146 147 148 149 150 151 152 153 154 155 156 |
|
Usage Example¶
from gridfm_graphkit.training.trainer import Trainer
from gridfm_graphkit.training.callbacks import EarlyStopper
from gridfm_graphkit.training.plugins import MLflowLoggerPlugin
trainer = Trainer(
model=model,
optimizer=optimizer,
device=device,
loss_fn=loss_function,
early_stopper=EarlyStopper(save_path),
train_dataloader=train_loader,
val_dataloader=val_loader,
lr_scheduler=scheduler,
plugins=[MLflowLoggerPlugin()]
)
trainer.train(start_epoch=0, epochs=100)