ncalab.models.basicNCA.basicNCArule
Classes
NCA rule module based on a two-layer Multi-Layer-Perceptron (MLP). |
Module Contents
- class ncalab.models.basicNCA.basicNCArule.BasicNCARule(device: torch.device, input_size: int, hidden_size: int, output_size: int, nonlinearity: type[torch.nn.Module] = nn.ReLU)
Bases:
torch.nn.ModuleNCA rule module based on a two-layer Multi-Layer-Perceptron (MLP).
- Parameters:
nn (_type_) – _description_
_summary_
- Parameters:
device (torch.device) – _description_
input_size (int) – _description_
hidden_size (int) – _description_
output_size (int) – _description_
nonlinearity (type[nn.Module], optional) – _description_, defaults to nn.ReLU
- nonlinearity
- input_size
- output_size
- device
- _build_network()
- _initialize_network()
Initialize network weights of the MLP.
We assume that the default initialization of the first layer is good enough. Since the final layer is purely linear and unbiased, we initalize with 0.
- forward(x: torch.Tensor) torch.Tensor
- Parameters:
x (torch.Tensor) – BCWH perception vector
- Returns:
BCWH residual update
- Return type:
torch.Tensor
- freeze(freeze_last: bool = False)
Freeze the first layer of the NCA rule network and, optionally, the final layer.
- Parameters:
freeze_last (bool, optional) – _description_, defaults to False