Power Grid datasets
GridDatasetMem
¶
Bases: InMemoryDataset
A PyTorch Geometric InMemoryDataset
for power grid data stored in tabular CSV format.
This dataset class reads node and edge data from CSV files, applies normalization using
user-specified Normalizer
instances, and builds graph data objects with edge weights and
positional encodings.
- Reads raw node and edge CSV files (
pf_node.csv
,pf_edge.csv
). - Applies the normalization method specified on both node and edge features
- Stores normalization statistics in the
processed
directory for reuse. - Constructs
torch_geometric.data.Data
objects with edge weights and positional encodings (via random walk embeddings).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
root
|
str
|
Root directory where the dataset is stored. |
required |
norm_method
|
str
|
Identifier for normalization method (e.g., "minmax", "standard"). |
required |
node_normalizer
|
Normalizer
|
Normalizer used for node features. |
required |
edge_normalizer
|
Normalizer
|
Normalizer used for edge features. |
required |
pe_dim
|
int
|
Length of the random walk used for positional encoding. |
required |
mask_dim
|
int
|
Number of features per-node that could be masked. Usually Pd, Qd, Pg, Qg, Vm, Va |
6
|
transform
|
callable
|
Transformation applied at runtime. |
None
|
pre_transform
|
callable
|
Transformation applied before saving to disk. |
None
|
pre_filter
|
callable
|
Filter to determine which graphs to keep. |
None
|
Source code in gridfm_graphkit/datasets/powergrid.py
|
|
change_transform(new_transform)
¶Temporarily switch to a new transform function, used when evaluating different tasks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
new_transform
|
Callable
|
The new transform to use. |
required |
Source code in gridfm_graphkit/datasets/powergrid.py
174 175 176 177 178 179 180 181 182 |
|
reset_transform()
¶Reverts the transform to the original one set during initialization, usually called after the evaluation step.
Source code in gridfm_graphkit/datasets/powergrid.py
184 185 186 187 188 189 190 191 192 |
|
Usage Example¶
from gridfm_graphkit.datasets.data_normalization import IdentityNormalizer
from gridfm_graphkit.datasets.powergrid import GridDatasetMem
dataset = GridDatasetMem(
root="./data",
norm_method="identity",
node_normalizer=IdentityNormalizer(),
edge_normalizer=IdentityNormalizer(),
pe_dim=10,
mask_dim=6,
)