ncalab.models.basicNCA.abstractNCA

Classes

AbstractNCAModel

Abstract base class for NCA models.

Module Contents

class ncalab.models.basicNCA.abstractNCA.AbstractNCAModel(device: torch.device, num_image_channels: int, num_hidden_channels: int, num_output_channels: int, plot_function: ncalab.visualization.Visual | None = None, validation_metric: str | None = None, fire_rate: float = 0.5, hidden_size: int = 128, use_alive_mask: bool = False, immutable_image_channels: bool = True, num_learned_filters: int = 0, filter_padding: Literal['zero', 'reflect', 'replicate', 'circular'] = 'reflect', use_laplace: bool = False, kernel_size: int = 3, pad_noise: bool = False, use_temporal_encoding: bool = False, rule_type: type[ncalab.models.basicNCA.abstractNCArule.AbstractNCARule] = MLPNCARule, rule_args=None, training_timesteps: int | Tuple[int, int] = 100, inference_timesteps: int | Tuple[int, int] = 100)

Bases: torch.nn.Module, abc.ABC

Abstract base class for NCA models.

BasicNCAModel is a composition of an NCA backbone model (called “rule”), and an (optional) head module for downstream tasks.

Parameters:
  • device – Pytorch device descriptor.

  • num_image_channels – Number of channels reserved for input image.

  • num_hidden_channels – Number of hidden channels (communication channels).

  • num_output_channels – Number of output channels.

  • validation_metric

  • fire_rate – Fire rate for stochastic weight update. Defaults to 0.5.

  • hidden_size – Number of neurons in hidden layer. Defaults to 128.

  • use_alive_mask – Whether to use alive masking (channel 3) during training. Defaults to False.

  • immutable_image_channels – If image channels should be fixed during inference, which is the case for most segmentation or classification problems. Defaults to True.

  • num_learned_filters – Number of learned filters. If zero, use two sobel filters instead. Defaults to 2.

  • filter_padding – Padding type to use. Might affect reliance on spatial cues. Defaults to “circular”.

  • use_laplace – Whether to use Laplace filter (only if num_learned_filters == 0)

  • kernel_size – Filter kernel size (only for learned filters)

  • pad_noise – Whether to pad input image tensor with noise in hidden / output channels

  • use_temporal_encoding

  • rule_type

  • rule_args

  • training_timesteps

  • inference_timesteps

device
num_image_channels
num_hidden_channels
num_output_channels
num_channels
fire_rate = 0.5
hidden_size = 128
use_alive_mask = False
immutable_image_channels = True
num_learned_filters = 0
use_laplace = False
kernel_size = 3
filter_padding = 'reflect'
pad_noise = False
use_temporal_encoding = False
plot_function = None
validation_metric = None
training_timesteps = 100
inference_timesteps = 100
perception
input_vector_size
rule_type
rule_args = None
rule
head: ncalab.models.basicNCA.abstractNCAhead.AbstractNCAHead | None = None
metrics: Dict[str, torchmetrics.Metric]
_define_rule() ncalab.models.basicNCA.abstractNCArule.AbstractNCARule
prepare_input(x: torch.Tensor) torch.Tensor

Preprocess input. Intended to be overwritten by subclass, if preprocessing is necessary.

Parameters:

[torch.Tensor] (x) – Input tensor to preprocess.

Returns:

Processed tensor.

_alive(x)
_update(x: torch.Tensor, step: int) torch.Tensor

Compute residual cell update.

Parameters:
  • [torch.Tensor] (x) – Input tensor, BCWH

  • [int] (step) – Current timestep, required for computing temporal encoding.

Returns:

Residual cell update, BCWH.

_forward_step(x: torch.Tensor, step: int)
forward(x: torch.Tensor, steps: int = 1) ncalab.prediction.Prediction
Parameters:
  • [torch.Tensor] (x) – Input image, padded along the channel dimension, BCWH.

  • [int] (steps) – Time steps in forward pass.

Returns [Prediction]:

Prediction object.

_post_forward_step(x: torch.Tensor) torch.Tensor
loss(pred: ncalab.prediction.Prediction, label: torch.Tensor) Dict[str, torch.Tensor]

Compute loss. Needs to be overloaded by any subclass. Please note that the returned dict needs to hold “total” key in which the total loss is stored, which is typically a weighted sum of other losses. The total loss is backpropagated, whereas the other losses are sent to tensorboard.

Parameters:
  • [torch.Tensor] (label) – Input image, BCWH.

  • [torch.Tensor] – Ground truth, BCWH.

Returns:

Dictionary of identifiers mapped to computed losses.

finetune(freeze_head: bool = False)

Prepare model for fine tuning by freezing everything except the final layer, and setting to “train” mode.

Param:

freeze_head

predict(image: torch.Tensor, steps: int | Tuple[int, int] | None = None) ncalab.prediction.Prediction

Make an NCA prediction, performing multiple forward passes to yield a final result.

Parameters:
  • image (torch.Tensor) – Input image, BCWH.

  • steps (Optional[int]) – Time steps

Returns:

Prediction object.

Return type:

Prediction

record(image: torch.Tensor, steps: int | Tuple[int, int] | None = None) List[ncalab.prediction.Prediction]

Record predictions for all time steps and return the resulting sequence of predictions.

Parameters:

image (torch.Tensor) – Input image, BCWH.

Returns:

List of Prediction objects.

Return type:

List[Prediction]

validate(dataloader: torch.utils.data.DataLoader, steps: int | None = None) Tuple[Dict[str, float], List[ncalab.prediction.Prediction]]

Make a prediction on an image of the validation set and return metrics computed with respect to a labelled validation image.

Parameters:
  • [torch.utils.data.DataLoader] (dataloader) – Dataloader for validation images

  • [int] (steps) – Inference steps

Returns [Tuple[float, List[Prediction]]]:

Validation metric, predicted image BCWH

_to_dict() Dict[str, Any]
to_dict() Dict[str, Any]
classmethod from_dict(d: Dict[str, Any])
num_trainable_parameters() int

Returns the number of trainable model parameters.

Returns:

Number of trainable parameters.

Return type:

int

save(path: str | os.PathLike)
static load(model: AbstractNCAModel, path: str | os.PathLike) AbstractNCAModel
post_prediction(prediction: ncalab.prediction.Prediction) ncalab.prediction.Prediction