Trainer Plugins¶
Reusable plugins that hook into the training loop to perform tasks like logging or checkpointing.
TrainerPlugin
¶
Base class for training plugins.
A TrainerPlugin
is invoked during the training process either at regular step intervals,
at the end of each epoch, or both. It can be extended to perform actions like logging,
checkpointing, or validation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
steps
|
int
|
Interval (in steps) to run the plugin. If |
None
|
Source code in gridfm_graphkit/training/plugins.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
|
run(step, end_of_epoch)
¶Determines whether to execute the plugin at the current step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step
|
int
|
The current step number. |
required |
end_of_epoch
|
bool
|
Whether this is the end of the epoch. |
required |
Returns:
Name | Type | Description |
---|---|---|
bool |
bool
|
True if the plugin should run; False otherwise. |
Source code in gridfm_graphkit/training/plugins.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
|
step(epoch, step, metrics={}, end_of_epoch=False, **kwargs)
abstractmethod
¶This method is called on every step of training, or with step=None at the end of each epoch. Implementations can use the passed in parameters for validation, checkpointing, logging, etc.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epoch
|
int
|
The current epoch number. |
required |
step
|
int
|
The current step within the epoch. |
required |
metrics
|
dict
|
Dictionary of training metrics (e.g., loss). |
{}
|
end_of_epoch
|
bool
|
Indicates if this call is at the end of an epoch. |
False
|
**kwargs
|
Any
|
Additional parameters such as model, optimizer, scheduler. |
{}
|
Source code in gridfm_graphkit/training/plugins.py
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
|
MLflowLoggerPlugin
¶
Bases: TrainerPlugin
Plugin to log training metrics to MLflow.
Logs metrics dynamically during training at defined step intervals and/or at the end of each epoch. Also logs initial training parameters once.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
steps
|
int
|
Interval in steps to log metrics. |
None
|
params
|
dict
|
Parameters to log to MLflow at the start. |
None
|
Source code in gridfm_graphkit/training/plugins.py
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
|
step(epoch, step, metrics={}, end_of_epoch=False, **kwargs)
¶Logs metrics to MLflow dynamically at each specified step and at the end of each epoch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epoch
|
int
|
The current epoch number. |
required |
step
|
int
|
The current step within the epoch. |
required |
metrics
|
Dict
|
Dictionary of metrics to log, e.g., {'train_loss': value}. |
{}
|
end_of_epoch
|
bool
|
Flag indicating whether this is the end of the epoch. |
False
|
Source code in gridfm_graphkit/training/plugins.py
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
|
CheckpointerPlugin
¶
Bases: TrainerPlugin
Plugin to periodically save model checkpoints.
Stores the model, optimizer, and scheduler states to a given directory at specified step intervals or at the end of each epoch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
checkpoint_dir
|
str
|
Directory where checkpoints will be saved. |
required |
steps
|
int
|
Interval in steps for checkpointing. |
None
|
Source code in gridfm_graphkit/training/plugins.py
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
|
step(epoch, step, metrics={}, end_of_epoch=False, model=None, optimizer=None, scheduler=None)
¶Saves a checkpoint if the conditions to run the plugin are met.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epoch
|
int
|
Current epoch number. |
required |
step
|
int
|
Current training step. |
required |
metrics
|
dict
|
Optional metrics dictionary (unused here). |
{}
|
end_of_epoch
|
bool
|
Whether this is the end of the epoch. |
False
|
model
|
Module
|
Model to be checkpointed. |
None
|
optimizer
|
Optimizer
|
Optimizer to save. |
None
|
scheduler
|
LRScheduler
|
Scheduler to save. |
None
|
Source code in gridfm_graphkit/training/plugins.py
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
|
MetricsTrackerPlugin
¶
Bases: TrainerPlugin
Logs metrics at the end of each epoch. Currently only returning the validation loss.
Source code in gridfm_graphkit/training/plugins.py
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
|