Skip to content

Feature Reconstruction Task

Bases: LightningModule

PyTorch Lightning task for node feature reconstruction on power grid graphs.

This task wraps a GridFM model inside a LightningModule and defines the full training, validation, testing, and prediction logic. It is designed to reconstruct masked node features from graph-structured input data, using datasets and normalizers provided by gridfm-graphkit.

Parameters:

Name Type Description Default
args NestedNamespace

Experiment configuration. Expected fields include training.batch_size, optimizer.*, etc.

required
node_normalizers list

One normalizer per dataset to (de)normalize node features.

required
edge_normalizers list

One normalizer per dataset to (de)normalize edge features.

required

Attributes:

Name Type Description
model Module

model loaded via load_model.

loss_fn callable

Loss function resolved from configuration.

batch_size int

Training batch size. From args.training.batch_size

node_normalizers list

Dataset-wise node feature normalizers.

edge_normalizers list

Dataset-wise edge feature normalizers.

Methods:

Name Description
forward

Forward pass with optional feature masking.

training_step

One training step: computes loss, logs metrics, returns loss.

validation_step

One validation step: computes losses and logs metrics.

test_step

Evaluate on test data, compute per-node-type MSEs, and log per-dataset metrics.

predict_step

Run inference and return denormalized outputs + node masks.

configure_optimizers

Setup Adam optimizer and ReduceLROnPlateau scheduler.

on_fit_start

Save normalization statistics at the beginning of training.

on_test_end

Collect test metrics across datasets and export summary CSV reports.

Notes
  • Node types are distinguished using the global constants (PQ, PV, REF).
  • The datamodule must provide batch.mask for masking node features.
  • Test metrics include per-node-type RMSE for [Pd, Qd, Pg, Qg, Vm, Va].
  • Reports are saved under <mlflow_artifacts>/test/<dataset>.csv.
Example
model = FeatureReconstructionTask(args, node_normalizers, edge_normalizers)
output = model(batch.x, batch.pe, batch.edge_index, batch.edge_attr, batch.batch)
Source code in gridfm_graphkit/tasks/feature_reconstruction_task.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
class FeatureReconstructionTask(L.LightningModule):
    """
    PyTorch Lightning task for node feature reconstruction on power grid graphs.

    This task wraps a GridFM model inside a LightningModule and defines the full
    training, validation, testing, and prediction logic. It is designed to
    reconstruct masked node features from graph-structured input data, using
    datasets and normalizers provided by `gridfm-graphkit`.

    Args:
        args (NestedNamespace): Experiment configuration. Expected fields include `training.batch_size`, `optimizer.*`, etc.
        node_normalizers (list): One normalizer per dataset to (de)normalize node features.
        edge_normalizers (list): One normalizer per dataset to (de)normalize edge features.

    Attributes:
        model (torch.nn.Module): model loaded via `load_model`.
        loss_fn (callable): Loss function resolved from configuration.
        batch_size (int): Training batch size. From ``args.training.batch_size``
        node_normalizers (list): Dataset-wise node feature normalizers.
        edge_normalizers (list): Dataset-wise edge feature normalizers.

    Methods:
        forward(x, pe, edge_index, edge_attr, batch, mask=None):
            Forward pass with optional feature masking.
        training_step(batch):
            One training step: computes loss, logs metrics, returns loss.
        validation_step(batch, batch_idx):
            One validation step: computes losses and logs metrics.
        test_step(batch, batch_idx, dataloader_idx=0):
            Evaluate on test data, compute per-node-type MSEs, and log per-dataset metrics.
        predict_step(batch, batch_idx, dataloader_idx=0):
            Run inference and return denormalized outputs + node masks.
        configure_optimizers():
            Setup Adam optimizer and ReduceLROnPlateau scheduler.
        on_fit_start():
            Save normalization statistics at the beginning of training.
        on_test_end():
            Collect test metrics across datasets and export summary CSV reports.

    Notes:
        - Node types are distinguished using the global constants (`PQ`, `PV`, `REF`).
        - The datamodule must provide `batch.mask` for masking node features.
        - Test metrics include per-node-type RMSE for [Pd, Qd, Pg, Qg, Vm, Va].
        - Reports are saved under `<mlflow_artifacts>/test/<dataset>.csv`.

    Example:
        ```python
        model = FeatureReconstructionTask(args, node_normalizers, edge_normalizers)
        output = model(batch.x, batch.pe, batch.edge_index, batch.edge_attr, batch.batch)
        ```
    """

    def __init__(self, args, node_normalizers, edge_normalizers):
        super().__init__()
        self.model = load_model(args=args)
        self.args = args
        self.loss_fn = get_loss_function(args)
        self.batch_size = int(args.training.batch_size)
        self.node_normalizers = node_normalizers
        self.edge_normalizers = edge_normalizers
        self.save_hyperparameters()

    def forward(self, x, pe, edge_index, edge_attr, batch, mask=None):
        if mask is not None:
            mask_value_expanded = self.model.mask_value.expand(x.shape[0], -1)
            x[:, : mask.shape[1]][mask] = mask_value_expanded[mask]
        return self.model(x, pe, edge_index, edge_attr, batch)

    @rank_zero_only
    def on_fit_start(self):
        # Determine save path
        if isinstance(self.logger, MLFlowLogger):
            log_dir = os.path.join(
                self.logger.save_dir,
                self.logger.experiment_id,
                self.logger.run_id,
                "artifacts",
                "stats",
            )
        else:
            log_dir = os.path.join(self.logger.save_dir, "stats")

        os.makedirs(log_dir, exist_ok=True)
        log_stats_path = os.path.join(log_dir, "normalization_stats.txt")

        # Collect normalization stats
        with open(log_stats_path, "w") as log_file:
            for i, normalizer in enumerate(self.node_normalizers):
                log_file.write(
                    f"Node Normalizer {self.args.data.networks[i]} stats:\n{normalizer.get_stats()}\n\n",
                )

            for i, normalizer in enumerate(self.edge_normalizers):
                log_file.write(
                    f"Edge Normalizer {self.args.data.networks[i]} stats:\n{normalizer.get_stats()}\n\n",
                )

    def shared_step(self, batch):
        output = self.forward(
            x=batch.x,
            pe=batch.pe,
            edge_index=batch.edge_index,
            edge_attr=batch.edge_attr,
            batch=batch.batch,
            mask=batch.mask,
        )

        loss_dict = self.loss_fn(
            output,
            batch.y,
            batch.edge_index,
            batch.edge_attr,
            batch.mask,
        )
        return output, loss_dict

    def training_step(self, batch):
        _, loss_dict = self.shared_step(batch)
        current_lr = self.optimizer.param_groups[0]["lr"]
        metrics = {}
        metrics["Training Loss"] = loss_dict["loss"].detach()
        metrics["Learning Rate"] = current_lr
        for metric, value in metrics.items():
            self.log(
                metric,
                value,
                batch_size=batch.num_graphs,
                sync_dist=True,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                on_step=False,
            )

        return loss_dict["loss"]

    def validation_step(self, batch, batch_idx):
        _, loss_dict = self.shared_step(batch)
        loss_dict["loss"] = loss_dict["loss"].detach()
        for metric, value in loss_dict.items():
            metric_name = f"Validation {metric}"
            self.log(
                metric_name,
                value,
                batch_size=batch.num_graphs,
                sync_dist=True,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                on_step=False,
            )

        return loss_dict["loss"]

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        output, loss_dict = self.shared_step(batch)

        dataset_name = self.args.data.networks[dataloader_idx]

        output_denorm = self.node_normalizers[dataloader_idx].inverse_transform(output)
        target_denorm = self.node_normalizers[dataloader_idx].inverse_transform(batch.y)

        mask_PQ = batch.x[:, PQ] == 1
        mask_PV = batch.x[:, PV] == 1
        mask_REF = batch.x[:, REF] == 1

        mse_PQ = F.mse_loss(
            output_denorm[mask_PQ],
            target_denorm[mask_PQ],
            reduction="none",
        )
        mse_PV = F.mse_loss(
            output_denorm[mask_PV],
            target_denorm[mask_PV],
            reduction="none",
        )
        mse_REF = F.mse_loss(
            output_denorm[mask_REF],
            target_denorm[mask_REF],
            reduction="none",
        )

        mse_PQ = mse_PQ.mean(dim=0)
        mse_PV = mse_PV.mean(dim=0)
        mse_REF = mse_REF.mean(dim=0)

        loss_dict["MSE PQ nodes - PD"] = mse_PQ[PD]
        loss_dict["MSE PV nodes - PD"] = mse_PV[PD]
        loss_dict["MSE REF nodes - PD"] = mse_REF[PD]

        loss_dict["MSE PQ nodes - QD"] = mse_PQ[QD]
        loss_dict["MSE PV nodes - QD"] = mse_PV[QD]
        loss_dict["MSE REF nodes - QD"] = mse_REF[QD]

        loss_dict["MSE PQ nodes - PG"] = mse_PQ[PG]
        loss_dict["MSE PV nodes - PG"] = mse_PV[PG]
        loss_dict["MSE REF nodes - PG"] = mse_REF[PG]

        loss_dict["MSE PQ nodes - QG"] = mse_PQ[QG]
        loss_dict["MSE PV nodes - QG"] = mse_PV[QG]
        loss_dict["MSE REF nodes - QG"] = mse_REF[QG]

        loss_dict["MSE PQ nodes - VM"] = mse_PQ[VM]
        loss_dict["MSE PV nodes - VM"] = mse_PV[VM]
        loss_dict["MSE REF nodes - VM"] = mse_REF[VM]

        loss_dict["MSE PQ nodes - VA"] = mse_PQ[VA]
        loss_dict["MSE PV nodes - VA"] = mse_PV[VA]
        loss_dict["MSE REF nodes - VA"] = mse_REF[VA]

        loss_dict["Test loss"] = loss_dict.pop("loss").detach()
        for metric, value in loss_dict.items():
            metric_name = f"{dataset_name}/{metric}"
            if "p.u." in metric:
                # Denormalize metrics expressed in p.u.
                value *= self.node_normalizers[dataloader_idx].baseMVA
                metric_name = metric_name.replace("in p.u.", "").strip()
            self.log(
                metric_name,
                value,
                batch_size=batch.num_graphs,
                add_dataloader_idx=False,
                sync_dist=True,
                logger=False,
            )
        return

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        output, _ = self.shared_step(batch)
        output_denorm = self.node_normalizers[dataloader_idx].inverse_transform(output)

        # Masks for node types
        mask_PQ = (batch.x[:, PQ] == 1).cpu()
        mask_PV = (batch.x[:, PV] == 1).cpu()
        mask_REF = (batch.x[:, REF] == 1).cpu()

        # Count buses and generate per-node scenario_id
        bus_counts = batch.batch.unique(return_counts=True)[1]
        scenario_ids = batch.scenario_id  # shape: [num_graphs]
        scenario_per_node = torch.cat(
            [
                torch.full((count,), sid, dtype=torch.int32)
                for count, sid in zip(bus_counts, scenario_ids)
            ],
        )

        bus_numbers = np.concatenate([np.arange(count.item()) for count in bus_counts])

        return {
            "output": output_denorm.cpu().numpy(),
            "mask_PQ": mask_PQ,
            "mask_PV": mask_PV,
            "mask_REF": mask_REF,
            "scenario_id": scenario_per_node,
            "bus_number": bus_numbers,
        }

    @rank_zero_only
    def on_test_end(self):
        if isinstance(self.logger, MLFlowLogger):
            artifact_dir = os.path.join(
                self.logger.save_dir,
                self.logger.experiment_id,
                self.logger.run_id,
                "artifacts",
            )
        else:
            artifact_dir = self.logger.save_dir

        final_metrics = self.trainer.callback_metrics
        grouped_metrics = {}

        for full_key, value in final_metrics.items():
            try:
                value = value.item()
            except AttributeError:
                pass

            if "/" in full_key:
                dataset_name, metric = full_key.split("/", 1)
                if dataset_name not in grouped_metrics:
                    grouped_metrics[dataset_name] = {}
                grouped_metrics[dataset_name][metric] = value

        for dataset, metrics in grouped_metrics.items():
            rmse_PQ = [
                metrics.get(f"MSE PQ nodes - {label}", float("nan")) ** 0.5
                for label in ["PD", "QD", "PG", "QG", "VM", "VA"]
            ]
            rmse_PV = [
                metrics.get(f"MSE PV nodes - {label}", float("nan")) ** 0.5
                for label in ["PD", "QD", "PG", "QG", "VM", "VA"]
            ]
            rmse_REF = [
                metrics.get(f"MSE REF nodes - {label}", float("nan")) ** 0.5
                for label in ["PD", "QD", "PG", "QG", "VM", "VA"]
            ]

            avg_active_res = metrics.get("Active Power Loss", " ")
            avg_reactive_res = metrics.get("Reactive Power Loss", " ")

            data = {
                "Metric": [
                    "RMSE-PQ",
                    "RMSE-PV",
                    "RMSE-REF",
                    "Avg. active res. (MW)",
                    "Avg. reactive res. (MVar)",
                ],
                "Pd (MW)": [
                    rmse_PQ[0],
                    rmse_PV[0],
                    rmse_REF[0],
                    avg_active_res,
                    avg_reactive_res,
                ],
                "Qd (MVar)": [rmse_PQ[1], rmse_PV[1], rmse_REF[1], " ", " "],
                "Pg (MW)": [rmse_PQ[2], rmse_PV[2], rmse_REF[2], " ", " "],
                "Qg (MVar)": [rmse_PQ[3], rmse_PV[3], rmse_REF[3], " ", " "],
                "Vm (p.u.)": [rmse_PQ[4], rmse_PV[4], rmse_REF[4], " ", " "],
                "Va (degree)": [rmse_PQ[5], rmse_PV[5], rmse_REF[5], " ", " "],
            }

            df = pd.DataFrame(data)

            test_dir = os.path.join(artifact_dir, "test")
            os.makedirs(test_dir, exist_ok=True)
            csv_path = os.path.join(test_dir, f"{dataset}.csv")
            df.to_csv(csv_path, index=False)

    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.args.optimizer.learning_rate,
            betas=(self.args.optimizer.beta1, self.args.optimizer.beta2),
        )

        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            mode="min",
            factor=self.args.optimizer.lr_decay,
            patience=self.args.optimizer.lr_patience,
        )
        config_optim = {
            "optimizer": self.optimizer,
            "lr_scheduler": {
                "scheduler": self.scheduler,
                "monitor": "Validation loss",
                "reduce_on_plateau": True,
            },
        }
        return config_optim