Skip to content

Trainer

A flexible, modular training loop designed for GNN models in the GridFM framework. Handles training, validation, early stopping, learning rate scheduling, and plugin callbacks.


Trainer

A flexible training loop for GridFM models with optional validation, learning rate scheduling, and plugin callbacks for logging or custom behavior.

Attributes:

Name Type Description
model Module

The PyTorch model to train.

optimizer Optimizer

The optimizer used for updating model parameters.

device

The device to train on (CPU or CUDA).

loss_fn Module

Loss function that returns a loss dictionary.

early_stopper EarlyStopper

Callback for early stopping based on validation loss.

train_dataloader DataLoader

Dataloader for training data.

val_dataloader DataLoader

Dataloader for validation data.

lr_scheduler optional

Learning rate scheduler.

plugins List[TrainerPlugin]

List of plugin callbacks.

Source code in gridfm_graphkit/training/trainer.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class Trainer:
    """
    A flexible training loop for GridFM models with optional validation, learning rate scheduling,
    and plugin callbacks for logging or custom behavior.

    Attributes:
        model (nn.Module): The PyTorch model to train.
        optimizer (Optimizer): The optimizer used for updating model parameters.
        device: The device to train on (CPU or CUDA).
        loss_fn (nn.Module): Loss function that returns a loss dictionary.
        early_stopper (EarlyStopper): Callback for early stopping based on validation loss.
        train_dataloader (DataLoader): Dataloader for training data.
        val_dataloader (DataLoader, optional): Dataloader for validation data.
        lr_scheduler (optional): Learning rate scheduler.
        plugins (List[TrainerPlugin]): List of plugin callbacks.
    """

    def __init__(
        self,
        model: nn.Module,
        optimizer: Optimizer,
        device,
        loss_fn: nn.Module,
        early_stopper: EarlyStopper,
        train_dataloader: DataLoader,
        val_dataloader: DataLoader,
        lr_scheduler=None,
        plugins: List[TrainerPlugin] = [],
    ):
        self.model = model
        self.optimizer = optimizer
        self.device = device
        self.early_stopper = early_stopper
        self.loss_fn = loss_fn
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.lr_scheduler = lr_scheduler
        self.plugins = plugins

    def __one_step(
        self,
        input: torch.Tensor,
        edge_index: torch.Tensor,
        label: torch.Tensor,
        edge_attr: torch.Tensor,
        mask: torch.Tensor = None,
        batch: torch.Tensor = None,
        pe: torch.Tensor = None,
        val: bool = False,
    ):
        # expand the learnable mask to the input shape
        mask_value_expanded = self.model.mask_value.expand(input.shape[0], -1)
        # The line below will overwrite the last mask values, which is fine as long as the features which are masked do not change between batches
        # set the learnable mask to the inout where it should be masked
        input[:, : mask.shape[1]][mask] = mask_value_expanded[mask]
        output = self.model(input, pe, edge_index, edge_attr, batch)

        loss_dict = self.loss_fn(output, label, edge_index, edge_attr, mask)

        if not val:
            self.optimizer.zero_grad()
            loss_dict["loss"].backward()
            self.optimizer.step()

        return loss_dict

    def __one_epoch(self, epoch: int, prev_step: int):
        self.model.train()

        highest_step = prev_step
        for step, batch in enumerate(self.train_dataloader):
            step = prev_step + step + 1
            highest_step = step
            batch = batch.to(self.device)

            mask = getattr(batch, "mask", None)

            loss_dict = self.__one_step(
                batch.x,
                batch.edge_index,
                batch.y,
                batch.edge_attr,
                mask,
                batch.batch,
                batch.pe,
            )
            current_lr = self.optimizer.param_groups[0]["lr"]
            metrics = {}
            metrics["Training Loss"] = loss_dict["loss"].item()
            metrics["Learning Rate"] = current_lr

            if self.model.learn_mask:
                metrics["Mask Gradient Norm"] = self.model.mask_value.grad.norm().item()

            for plugin in self.plugins:
                plugin.step(epoch, step, metrics=metrics)

        self.model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in self.val_dataloader:
                batch = batch.to(self.device)
                mask = getattr(batch, "mask", None)
                metrics = self.__one_step(
                    batch.x,
                    batch.edge_index,
                    batch.y,
                    batch.edge_attr,
                    mask,
                    batch.batch,
                    batch.pe,
                    True,
                )
                val_loss += metrics["loss"].item()
                metrics["Validation Loss"] = metrics.pop("loss").item()

                for plugin in self.plugins:
                    plugin.step(epoch, step, metrics=metrics)
        val_loss /= len(self.val_dataloader)
        if self.lr_scheduler is not None:
            self.lr_scheduler.step(val_loss)
        for plugin in self.plugins:
            plugin.step(
                epoch,
                step=highest_step,
                end_of_epoch=True,
                model=self.model,
                optimizer=self.optimizer,
                scheduler=self.lr_scheduler,
            )
        return val_loss

    def train(self, start_epoch: int = 0, epochs: int = 1, prev_step: int = -1):
        """
        Main training loop.

        Args:
            start_epoch (int): Epoch to start training from.
            epochs (int): Total number of epochs to train.
            prev_step (int): Previous training step (for logging continuity).
        """
        for epoch in tqdm(range(start_epoch, start_epoch + epochs), desc="Epochs"):
            val_loss = self.__one_epoch(epoch, prev_step)
            if self.early_stopper.early_stop(val_loss, self.model):
                break
train(start_epoch=0, epochs=1, prev_step=-1)

Main training loop.

Parameters:

Name Type Description Default
start_epoch int

Epoch to start training from.

0
epochs int

Total number of epochs to train.

1
prev_step int

Previous training step (for logging continuity).

-1
Source code in gridfm_graphkit/training/trainer.py
144
145
146
147
148
149
150
151
152
153
154
155
156
def train(self, start_epoch: int = 0, epochs: int = 1, prev_step: int = -1):
    """
    Main training loop.

    Args:
        start_epoch (int): Epoch to start training from.
        epochs (int): Total number of epochs to train.
        prev_step (int): Previous training step (for logging continuity).
    """
    for epoch in tqdm(range(start_epoch, start_epoch + epochs), desc="Epochs"):
        val_loss = self.__one_epoch(epoch, prev_step)
        if self.early_stopper.early_stop(val_loss, self.model):
            break

Usage Example

from gridfm_graphkit.training.trainer import Trainer
from gridfm_graphkit.training.callbacks import EarlyStopper
from gridfm_graphkit.training.plugins import MLflowLoggerPlugin

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    device=device,
    loss_fn=loss_function,
    early_stopper=EarlyStopper(save_path),
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    lr_scheduler=scheduler,
    plugins=[MLflowLoggerPlugin()]
)

trainer.train(start_epoch=0, epochs=100)