ncalab.models.basicNCA.abstractNCA ================================== .. py:module:: ncalab.models.basicNCA.abstractNCA Classes ------- .. autoapisummary:: ncalab.models.basicNCA.abstractNCA.AbstractNCAModel Module Contents --------------- .. py:class:: AbstractNCAModel(device: torch.device, num_image_channels: int, num_hidden_channels: int, num_output_channels: int, plot_function: Optional[ncalab.visualization.Visual] = None, validation_metric: Optional[str] = None, fire_rate: float = 0.5, hidden_size: int = 128, use_alive_mask: bool = False, immutable_image_channels: bool = True, num_learned_filters: int = 0, filter_padding: Literal['zero', 'reflect', 'replicate', 'circular'] = 'reflect', use_laplace: bool = False, kernel_size: int = 3, pad_noise: bool = False, use_temporal_encoding: bool = False, rule_type: type[ncalab.models.basicNCA.abstractNCArule.AbstractNCARule] = MLPNCARule, rule_args=None, training_timesteps: int | Tuple[int, int] = 100, inference_timesteps: int | Tuple[int, int] = 100) Bases: :py:obj:`torch.nn.Module`, :py:obj:`abc.ABC` Abstract base class for NCA models. BasicNCAModel is a composition of an NCA backbone model (called "rule"), and an (optional) head module for downstream tasks. :param device: Pytorch device descriptor. :param num_image_channels: Number of channels reserved for input image. :param num_hidden_channels: Number of hidden channels (communication channels). :param num_output_channels: Number of output channels. :param validation_metric: :param fire_rate: Fire rate for stochastic weight update. Defaults to 0.5. :param hidden_size: Number of neurons in hidden layer. Defaults to 128. :param use_alive_mask: Whether to use alive masking (channel 3) during training. Defaults to False. :param immutable_image_channels: If image channels should be fixed during inference, which is the case for most segmentation or classification problems. Defaults to True. :param num_learned_filters: Number of learned filters. If zero, use two sobel filters instead. Defaults to 2. :param filter_padding: Padding type to use. Might affect reliance on spatial cues. Defaults to "circular". :param use_laplace: Whether to use Laplace filter (only if num_learned_filters == 0) :param kernel_size: Filter kernel size (only for learned filters) :param pad_noise: Whether to pad input image tensor with noise in hidden / output channels :param use_temporal_encoding: :param rule_type: :param rule_args: :param training_timesteps: :param inference_timesteps: .. py:attribute:: device .. py:attribute:: num_image_channels .. py:attribute:: num_hidden_channels .. py:attribute:: num_output_channels .. py:attribute:: num_channels .. py:attribute:: fire_rate :value: 0.5 .. py:attribute:: hidden_size :value: 128 .. py:attribute:: use_alive_mask :value: False .. py:attribute:: immutable_image_channels :value: True .. py:attribute:: num_learned_filters :value: 0 .. py:attribute:: use_laplace :value: False .. py:attribute:: kernel_size :value: 3 .. py:attribute:: filter_padding :value: 'reflect' .. py:attribute:: pad_noise :value: False .. py:attribute:: use_temporal_encoding :value: False .. py:attribute:: plot_function :value: None .. py:attribute:: validation_metric :value: None .. py:attribute:: training_timesteps :value: 100 .. py:attribute:: inference_timesteps :value: 100 .. py:attribute:: perception .. py:attribute:: input_vector_size .. py:attribute:: rule_type .. py:attribute:: rule_args :value: None .. py:attribute:: rule .. py:attribute:: head :type: ncalab.models.basicNCA.abstractNCAhead.AbstractNCAHead | None :value: None .. py:attribute:: metrics :type: Dict[str, torchmetrics.Metric] .. py:method:: _define_rule() -> ncalab.models.basicNCA.abstractNCArule.AbstractNCARule .. py:method:: prepare_input(x: torch.Tensor) -> torch.Tensor Preprocess input. Intended to be overwritten by subclass, if preprocessing is necessary. :param x [torch.Tensor]: Input tensor to preprocess. :returns: Processed tensor. .. py:method:: _alive(x) .. py:method:: _update(x: torch.Tensor, step: int) -> torch.Tensor Compute residual cell update. :param x [torch.Tensor]: Input tensor, BCWH :param step [int]: Current timestep, required for computing temporal encoding. :returns: Residual cell update, BCWH. .. py:method:: _forward_step(x: torch.Tensor, step: int) .. py:method:: forward(x: torch.Tensor, steps: int = 1) -> ncalab.prediction.Prediction :param x [torch.Tensor]: Input image, padded along the channel dimension, BCWH. :param steps [int]: Time steps in forward pass. :returns [Prediction]: Prediction object. .. py:method:: _post_forward_step(x: torch.Tensor) -> torch.Tensor .. py:method:: loss(pred: ncalab.prediction.Prediction, label: torch.Tensor) -> Dict[str, torch.Tensor] Compute loss. Needs to be overloaded by any subclass. Please note that the returned dict needs to hold "total" key in which the total loss is stored, which is typically a weighted sum of other losses. The total loss is backpropagated, whereas the other losses are sent to tensorboard. :param image [torch.Tensor]: Input image, BCWH. :param label [torch.Tensor]: Ground truth, BCWH. :returns: Dictionary of identifiers mapped to computed losses. .. py:method:: finetune(freeze_head: bool = False) Prepare model for fine tuning by freezing everything except the final layer, and setting to "train" mode. :param: freeze_head .. py:method:: predict(image: torch.Tensor, steps: Optional[int | Tuple[int, int]] = None) -> ncalab.prediction.Prediction Make an NCA prediction, performing multiple forward passes to yield a final result. :param image: Input image, BCWH. :type image: torch.Tensor :param steps: Time steps :type steps: Optional[int] :returns: Prediction object. :rtype: Prediction .. py:method:: record(image: torch.Tensor, steps: Optional[int | Tuple[int, int]] = None) -> List[ncalab.prediction.Prediction] Record predictions for all time steps and return the resulting sequence of predictions. :param image: Input image, BCWH. :type image: torch.Tensor :returns: List of Prediction objects. :rtype: List[Prediction] .. py:method:: validate(dataloader: torch.utils.data.DataLoader, steps: Optional[int] = None) -> Tuple[Dict[str, float], List[ncalab.prediction.Prediction]] Make a prediction on an image of the validation set and return metrics computed with respect to a labelled validation image. :param dataloader [torch.utils.data.DataLoader]: Dataloader for validation images :param steps [int]: Inference steps :returns [Tuple[float, List[Prediction]]]: Validation metric, predicted image BCWH .. py:method:: _to_dict() -> Dict[str, Any] .. py:method:: to_dict() -> Dict[str, Any] .. py:method:: from_dict(d: Dict[str, Any]) :classmethod: .. py:method:: num_trainable_parameters() -> int Returns the number of trainable model parameters. :return: Number of trainable parameters. :rtype: int .. py:method:: save(path: str | os.PathLike) .. py:method:: load(model: AbstractNCAModel, path: str | os.PathLike) -> AbstractNCAModel :staticmethod: .. py:method:: post_prediction(prediction: ncalab.prediction.Prediction) -> ncalab.prediction.Prediction