Skip to content

LitGridDataModule

Bases: LightningDataModule

PyTorch Lightning DataModule for power grid datasets.

This datamodule handles loading, preprocessing, splitting, and batching of power grid graph datasets (GridDatasetDisk) for training, validation, testing, and prediction. It ensures reproducibility through fixed seeds.

Parameters:

Name Type Description Default
args NestedNamespace

Experiment configuration.

required
data_dir str

Root directory for datasets. Defaults to "./data".

'./data'

Attributes:

Name Type Description
batch_size int

Batch size for all dataloaders. From args.training.batch_size

node_normalizers list

List of node feature normalizers, one per dataset.

edge_normalizers list

List of edge feature normalizers, one per dataset.

datasets list

Original datasets for each network.

train_datasets list

Train splits for each network.

val_datasets list

Validation splits for each network.

test_datasets list

Test splits for each network.

train_dataset_multi ConcatDataset

Concatenated train datasets for multi-network training.

val_dataset_multi ConcatDataset

Concatenated validation datasets for multi-network validation.

_is_setup_done bool

Tracks whether setup has been executed to avoid repeated processing.

Methods:

Name Description
setup

Load and preprocess datasets, split into train/val/test, and store normalizers. Handles distributed preprocessing safely.

train_dataloader

Returns a DataLoader for concatenated training datasets.

val_dataloader

Returns a DataLoader for concatenated validation datasets.

test_dataloader

Returns a list of DataLoaders, one per test dataset.

predict_dataloader

Returns a list of DataLoaders, one per test dataset for prediction.

Notes
  • Preprocessing is only performed on rank 0 in distributed settings.
  • Subsets and splits are deterministic based on the provided random seed.
  • Normalizers are loaded for each network independently.
  • Test and predict dataloaders are returned as lists, one per dataset.
Example
from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule
from gridfm_graphkit.io.param_handler import NestedNamespace
import yaml

with open("config/config.yaml") as f:
    base_config = yaml.safe_load(f)
args = NestedNamespace(**base_config)

datamodule = LitGridDataModule(args, data_dir="./data")

datamodule.setup("fit")
train_loader = datamodule.train_dataloader()
Source code in gridfm_graphkit/datasets/powergrid_datamodule.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
class LitGridDataModule(L.LightningDataModule):
    """
    PyTorch Lightning DataModule for power grid datasets.

    This datamodule handles loading, preprocessing, splitting, and batching
    of power grid graph datasets (`GridDatasetDisk`) for training, validation,
    testing, and prediction. It ensures reproducibility through fixed seeds.

    Args:
        args (NestedNamespace): Experiment configuration.
        data_dir (str, optional): Root directory for datasets. Defaults to "./data".

    Attributes:
        batch_size (int): Batch size for all dataloaders. From ``args.training.batch_size``
        node_normalizers (list): List of node feature normalizers, one per dataset.
        edge_normalizers (list): List of edge feature normalizers, one per dataset.
        datasets (list): Original datasets for each network.
        train_datasets (list): Train splits for each network.
        val_datasets (list): Validation splits for each network.
        test_datasets (list): Test splits for each network.
        train_dataset_multi (ConcatDataset): Concatenated train datasets for multi-network training.
        val_dataset_multi (ConcatDataset): Concatenated validation datasets for multi-network validation.
        _is_setup_done (bool): Tracks whether `setup` has been executed to avoid repeated processing.

    Methods:
        setup(stage):
            Load and preprocess datasets, split into train/val/test, and store normalizers.
            Handles distributed preprocessing safely.
        train_dataloader():
            Returns a DataLoader for concatenated training datasets.
        val_dataloader():
            Returns a DataLoader for concatenated validation datasets.
        test_dataloader():
            Returns a list of DataLoaders, one per test dataset.
        predict_dataloader():
            Returns a list of DataLoaders, one per test dataset for prediction.

    Notes:
        - Preprocessing is only performed on rank 0 in distributed settings.
        - Subsets and splits are deterministic based on the provided random seed.
        - Normalizers are loaded for each network independently.
        - Test and predict dataloaders are returned as lists, one per dataset.

    Example:
        ```python
        from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule
        from gridfm_graphkit.io.param_handler import NestedNamespace
        import yaml

        with open("config/config.yaml") as f:
            base_config = yaml.safe_load(f)
        args = NestedNamespace(**base_config)

        datamodule = LitGridDataModule(args, data_dir="./data")

        datamodule.setup("fit")
        train_loader = datamodule.train_dataloader()
        ```
    """

    def __init__(self, args: NestedNamespace, data_dir: str = "./data"):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = int(args.training.batch_size)
        self.args = args
        self.node_normalizers = []
        self.edge_normalizers = []
        self.datasets = []
        self.train_datasets = []
        self.val_datasets = []
        self.test_datasets = []
        self._is_setup_done = False

    def setup(self, stage: str):
        if self._is_setup_done:
            print(f"Setup already done for stage={stage}, skipping...")
            return

        for i, network in enumerate(self.args.data.networks):
            node_normalizer, edge_normalizer = load_normalizer(args=self.args)
            self.node_normalizers.append(node_normalizer)
            self.edge_normalizers.append(edge_normalizer)

            # Create torch dataset and split
            data_path_network = os.path.join(self.data_dir, network)

            # Run preprocessing only on rank 0
            if dist.is_available() and dist.is_initialized() and dist.get_rank() == 0:
                print(f"Pre-processing of {network} dataset on rank 0")
                _ = GridDatasetDisk(  # just to trigger processing
                    root=data_path_network,
                    norm_method=self.args.data.normalization,
                    node_normalizer=node_normalizer,
                    edge_normalizer=edge_normalizer,
                    pe_dim=self.args.model.pe_dim,
                    mask_dim=self.args.data.mask_dim,
                    transform=get_transform(args=self.args),
                )

            # All ranks wait here until processing is done
            if torch.distributed.is_available() and torch.distributed.is_initialized():
                torch.distributed.barrier()

            dataset = GridDatasetDisk(
                root=data_path_network,
                norm_method=self.args.data.normalization,
                node_normalizer=node_normalizer,
                edge_normalizer=edge_normalizer,
                pe_dim=self.args.model.pe_dim,
                mask_dim=self.args.data.mask_dim,
                transform=get_transform(args=self.args),
            )
            self.datasets.append(dataset)

            num_scenarios = self.args.data.scenarios[i]
            if num_scenarios > len(dataset):
                warnings.warn(
                    f"Requested number of scenarios ({num_scenarios}) exceeds dataset size ({len(dataset)}). "
                    "Using the full dataset instead.",
                )
                num_scenarios = len(dataset)

            # Create a subset
            all_indices = list(range(len(dataset)))
            # Random seed set before every shuffle for reproducibility in case the power grid datasets are analyzed in a different order
            random.seed(self.args.seed)
            random.shuffle(all_indices)
            subset_indices = all_indices[:num_scenarios]
            dataset = Subset(dataset, subset_indices)

            # Random seed set before every split, same as above
            np.random.seed(self.args.seed)
            train_dataset, val_dataset, test_dataset = split_dataset(
                dataset,
                self.data_dir,
                self.args.data.val_ratio,
                self.args.data.test_ratio,
            )

            self.train_datasets.append(train_dataset)
            self.val_datasets.append(val_dataset)
            self.test_datasets.append(test_dataset)

        self.train_dataset_multi = ConcatDataset(self.train_datasets)
        self.val_dataset_multi = ConcatDataset(self.val_datasets)
        self._is_setup_done = True

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset_multi,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.args.data.workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset_multi,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.args.data.workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return [
            DataLoader(
                i,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.args.data.workers,
                pin_memory=True,
            )
            for i in self.test_datasets
        ]

    def predict_dataloader(self):
        return [
            DataLoader(
                i,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.args.data.workers,
                pin_memory=True,
            )
            for i in self.test_datasets
        ]