anomalib.models.components.base

Base classes for all anomaly components.

Submodules

Package Contents

Classes

AnomalyModule

AnomalyModule to train, validate, predict and test images.

DynamicBufferModule

Torch module that allows loading variables from the state dict even in the case of shape mismatch.

class anomalib.models.components.base.AnomalyModule[source]

Bases: pytorch_lightning.LightningModule, abc.ABC

AnomalyModule to train, validate, predict and test images.

Acts as a base class for all the Anomaly Modules in the library.

forward(batch)

Forward-pass input tensor to the module.

Parameters

batch (Tensor) – Input Tensor

Returns

Output tensor from the model.

Return type

Tensor

abstract validation_step(batch, batch_idx) dict

To be implemented in the subclasses.

predict_step(batch: Any, batch_idx: int, _dataloader_idx: Optional[int] = None) Any

Step function called during predict().

By default, it calls forward(). Override to add any processing logic.

Parameters
  • batch (Tensor) – Current batch

  • batch_idx (int) – Index of current batch

  • _dataloader_idx (int) – Index of the current dataloader

Returns

Predicted output

test_step(batch, _)

Calls validation_step for anomaly map/score calculation.

Parameters
  • batch (Tensor) – Input batch

  • _ – Index of the batch.

Returns

Dictionary containing images, features, true labels and masks. These are required in validation_epoch_end for feature concatenation.

validation_step_end(val_step_outputs)

Called at the end of each validation step.

test_step_end(test_step_outputs)

Called at the end of each test step.

validation_epoch_end(outputs)

Compute threshold and performance metrics.

Parameters

outputs – Batch of outputs from the validation step

test_epoch_end(outputs)

Compute and save anomaly scores of the test set.

Parameters

outputs – Batch of outputs from the validation step

_compute_adaptive_threshold(outputs)
static _collect_outputs(image_metric, pixel_metric, outputs)
static _post_process(outputs)

Compute labels based on model predictions.

static _outputs_to_cpu(output)
_log_metrics()

Log computed performance metrics.

_load_normalization_class(state_dict: OrderedDict[str, torch.Tensor])

Assigns the normalization method to use.

load_state_dict(state_dict: OrderedDict[str, torch.Tensor], strict: bool = True)

Load state dict from checkpoint.

Ensures that normalization and thresholding attributes is properly setup before model is loaded.

class anomalib.models.components.base.DynamicBufferModule[source]

Bases: abc.ABC, torch.nn.Module

Torch module that allows loading variables from the state dict even in the case of shape mismatch.

get_tensor_attribute(attribute_name: str) torch.Tensor

Get attribute of the tensor given the name.

Parameters

attribute_name (str) – Name of the tensor

Raises

ValueErrorattribute_name is not a torch Tensor

Returns

Tensor attribute

Return type

Tensor

_load_from_state_dict(state_dict: dict, prefix: str, *args)

Resizes the local buffers to match those stored in the state dict.

Overrides method from parent class.

Parameters
  • state_dict (dict) – State dictionary containing weights

  • prefix (str) – Prefix of the weight file.

  • *args