ncalab.models.basicNCA.abstractNCA
Classes
Abstract base class for NCA models. |
Module Contents
- class ncalab.models.basicNCA.abstractNCA.AbstractNCAModel(device: torch.device, num_image_channels: int, num_hidden_channels: int, num_output_channels: int, plot_function: ncalab.visualization.Visual | None = None, validation_metric: str | None = None, fire_rate: float = 0.5, hidden_size: int = 128, use_alive_mask: bool = False, immutable_image_channels: bool = True, num_learned_filters: int = 0, filter_padding: Literal['zero', 'reflect', 'replicate', 'circular'] = 'reflect', use_laplace: bool = False, kernel_size: int = 3, pad_noise: bool = False, use_temporal_encoding: bool = False, rule_type: type[ncalab.models.basicNCA.abstractNCArule.AbstractNCARule] = MLPNCARule, rule_args=None, training_timesteps: int | Tuple[int, int] = 100, inference_timesteps: int | Tuple[int, int] = 100)
Bases:
torch.nn.Module,abc.ABCAbstract base class for NCA models.
BasicNCAModel is a composition of an NCA backbone model (called “rule”), and an (optional) head module for downstream tasks.
- Parameters:
device – Pytorch device descriptor.
num_image_channels – Number of channels reserved for input image.
num_hidden_channels – Number of hidden channels (communication channels).
num_output_channels – Number of output channels.
validation_metric
fire_rate – Fire rate for stochastic weight update. Defaults to 0.5.
hidden_size – Number of neurons in hidden layer. Defaults to 128.
use_alive_mask – Whether to use alive masking (channel 3) during training. Defaults to False.
immutable_image_channels – If image channels should be fixed during inference, which is the case for most segmentation or classification problems. Defaults to True.
num_learned_filters – Number of learned filters. If zero, use two sobel filters instead. Defaults to 2.
filter_padding – Padding type to use. Might affect reliance on spatial cues. Defaults to “circular”.
use_laplace – Whether to use Laplace filter (only if num_learned_filters == 0)
kernel_size – Filter kernel size (only for learned filters)
pad_noise – Whether to pad input image tensor with noise in hidden / output channels
use_temporal_encoding
rule_type
rule_args
training_timesteps
inference_timesteps
- device
- num_image_channels
- num_output_channels
- num_channels
- fire_rate = 0.5
- use_alive_mask = False
- immutable_image_channels = True
- num_learned_filters = 0
- use_laplace = False
- kernel_size = 3
- filter_padding = 'reflect'
- pad_noise = False
- use_temporal_encoding = False
- plot_function = None
- validation_metric = None
- training_timesteps = 100
- inference_timesteps = 100
- perception
- input_vector_size
- rule_type
- rule_args = None
- rule
- head: ncalab.models.basicNCA.abstractNCAhead.AbstractNCAHead | None = None
- metrics: Dict[str, torchmetrics.Metric]
- _define_rule() ncalab.models.basicNCA.abstractNCArule.AbstractNCARule
- prepare_input(x: torch.Tensor) torch.Tensor
Preprocess input. Intended to be overwritten by subclass, if preprocessing is necessary.
- Parameters:
[torch.Tensor] (x) – Input tensor to preprocess.
- Returns:
Processed tensor.
- _alive(x)
- _update(x: torch.Tensor, step: int) torch.Tensor
Compute residual cell update.
- Parameters:
[torch.Tensor] (x) – Input tensor, BCWH
[int] (step) – Current timestep, required for computing temporal encoding.
- Returns:
Residual cell update, BCWH.
- _forward_step(x: torch.Tensor, step: int)
- forward(x: torch.Tensor, steps: int = 1) ncalab.prediction.Prediction
- Parameters:
[torch.Tensor] (x) – Input image, padded along the channel dimension, BCWH.
[int] (steps) – Time steps in forward pass.
- Returns [Prediction]:
Prediction object.
- _post_forward_step(x: torch.Tensor) torch.Tensor
- loss(pred: ncalab.prediction.Prediction, label: torch.Tensor) Dict[str, torch.Tensor]
Compute loss. Needs to be overloaded by any subclass. Please note that the returned dict needs to hold “total” key in which the total loss is stored, which is typically a weighted sum of other losses. The total loss is backpropagated, whereas the other losses are sent to tensorboard.
- Parameters:
[torch.Tensor] (label) – Input image, BCWH.
[torch.Tensor] – Ground truth, BCWH.
- Returns:
Dictionary of identifiers mapped to computed losses.
- finetune(freeze_head: bool = False)
Prepare model for fine tuning by freezing everything except the final layer, and setting to “train” mode.
- Param:
freeze_head
- predict(image: torch.Tensor, steps: int | Tuple[int, int] | None = None) ncalab.prediction.Prediction
Make an NCA prediction, performing multiple forward passes to yield a final result.
- Parameters:
image (torch.Tensor) – Input image, BCWH.
steps (Optional[int]) – Time steps
- Returns:
Prediction object.
- Return type:
- record(image: torch.Tensor, steps: int | Tuple[int, int] | None = None) List[ncalab.prediction.Prediction]
Record predictions for all time steps and return the resulting sequence of predictions.
- Parameters:
image (torch.Tensor) – Input image, BCWH.
- Returns:
List of Prediction objects.
- Return type:
List[Prediction]
- validate(dataloader: torch.utils.data.DataLoader, steps: int | None = None) Tuple[Dict[str, float], List[ncalab.prediction.Prediction]]
Make a prediction on an image of the validation set and return metrics computed with respect to a labelled validation image.
- Parameters:
[torch.utils.data.DataLoader] (dataloader) – Dataloader for validation images
[int] (steps) – Inference steps
- Returns [Tuple[float, List[Prediction]]]:
Validation metric, predicted image BCWH
- _to_dict() Dict[str, Any]
- to_dict() Dict[str, Any]
- classmethod from_dict(d: Dict[str, Any])
- num_trainable_parameters() int
Returns the number of trainable model parameters.
- Returns:
Number of trainable parameters.
- Return type:
int
- save(path: str | os.PathLike)
- static load(model: AbstractNCAModel, path: str | os.PathLike) AbstractNCAModel
- post_prediction(prediction: ncalab.prediction.Prediction) ncalab.prediction.Prediction