anomalib.models.components.base.anomaly_module

Base Anomaly Module for Training Task.

Module Contents

Classes

AnomalyModule

AnomalyModule to train, validate, predict and test images.

Attributes

anomalib.models.components.base.anomaly_module.logger[source]
class anomalib.models.components.base.anomaly_module.AnomalyModule[source]

Bases: pytorch_lightning.LightningModule, abc.ABC

AnomalyModule to train, validate, predict and test images.

Acts as a base class for all the Anomaly Modules in the library.

forward(batch)[source]

Forward-pass input tensor to the module.

Parameters

batch (Tensor) – Input Tensor

Returns

Output tensor from the model.

Return type

Tensor

abstract validation_step(batch, batch_idx) dict[source]

To be implemented in the subclasses.

predict_step(batch: Any, batch_idx: int, _dataloader_idx: Optional[int] = None) Any[source]

Step function called during predict().

By default, it calls forward(). Override to add any processing logic.

Parameters
  • batch (Tensor) – Current batch

  • batch_idx (int) – Index of current batch

  • _dataloader_idx (int) – Index of the current dataloader

Returns

Predicted output

test_step(batch, _)[source]

Calls validation_step for anomaly map/score calculation.

Parameters
  • batch (Tensor) – Input batch

  • _ – Index of the batch.

Returns

Dictionary containing images, features, true labels and masks. These are required in validation_epoch_end for feature concatenation.

validation_step_end(val_step_outputs)[source]

Called at the end of each validation step.

test_step_end(test_step_outputs)[source]

Called at the end of each test step.

validation_epoch_end(outputs)[source]

Compute threshold and performance metrics.

Parameters

outputs – Batch of outputs from the validation step

test_epoch_end(outputs)[source]

Compute and save anomaly scores of the test set.

Parameters

outputs – Batch of outputs from the validation step

_compute_adaptive_threshold(outputs)[source]
_collect_outputs(image_metric, pixel_metric, outputs)[source]
_post_process(outputs)[source]

Compute labels based on model predictions.

_outputs_to_cpu(output)[source]
_log_metrics()[source]

Log computed performance metrics.