GANomaly#
GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training.
GANomaly is an anomaly detection model that uses a conditional GAN architecture to learn the normal data distribution. The model consists of a generator network that learns to reconstruct normal images, and a discriminator that helps ensure the reconstructions are realistic.
Example
>>> from anomalib.data import MVTec
>>> from anomalib.models import Ganomaly
>>> from anomalib.engine import Engine
>>> datamodule = MVTec()
>>> model = Ganomaly()
>>> engine = Engine()
>>> engine.fit(model, datamodule=datamodule)
>>> predictions = engine.predict(model, datamodule=datamodule)
- Paper:
Title: GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training URL: https://arxiv.org/abs/1805.06725
See also
anomalib.models.image.ganomaly.torch_model.GanomalyModel
:PyTorch implementation of the GANomaly model architecture.
anomalib.models.image.ganomaly.loss.GeneratorLoss
:Loss function for the generator network.
anomalib.models.image.ganomaly.loss.DiscriminatorLoss
:Loss function for the discriminator network.
- class anomalib.models.image.ganomaly.lightning_model.Ganomaly(batch_size=32, n_features=64, latent_vec_size=100, extra_layers=0, add_final_conv_layer=True, wadv=1, wcon=50, wenc=1, lr=0.0002, beta1=0.5, beta2=0.999, pre_processor=True, post_processor=True, evaluator=True, visualizer=True)#
Bases:
AnomalibModule
PL Lightning Module for the GANomaly Algorithm.
The GANomaly model consists of a generator and discriminator network. The generator learns to reconstruct normal images while the discriminator helps ensure the reconstructions are realistic. Anomalies are detected by measuring the reconstruction error and latent space differences.
- Parameters:
batch_size (int) – Number of samples in each batch. Defaults to
32
.n_features (int) – Number of feature channels in CNN layers. Defaults to
64
.latent_vec_size (int) – Dimension of the latent space vectors. Defaults to
100
.extra_layers (int, optional) – Number of extra layers in encoder/decoder. Defaults to
0
.add_final_conv_layer (bool, optional) – Add convolution layer at the end. Defaults to
True
.wadv (int, optional) – Weight for adversarial loss component. Defaults to
1
.wcon (int, optional) – Weight for image reconstruction loss component. Defaults to
50
.wenc (int, optional) – Weight for latent vector encoding loss component. Defaults to
1
.lr (float, optional) – Learning rate for optimizers. Defaults to
0.0002
.beta1 (float, optional) – Beta1 parameter for Adam optimizers. Defaults to
0.5
.beta2 (float, optional) – Beta2 parameter for Adam optimizers. Defaults to
0.999
.pre_processor (PreProcessor | bool, optional) – Pre-processor to transform inputs before passing to model. Defaults to
True
.post_processor (PostProcessor | bool, optional) – Post-processor to generate predictions from model outputs. Defaults to
True
.evaluator (Evaluator | bool, optional) – Evaluator to compute metrics. Defaults to
True
.visualizer (Visualizer | bool, optional) – Visualizer to display results. Defaults to
True
.
Example
>>> from anomalib.models import Ganomaly >>> model = Ganomaly( ... batch_size=32, ... n_features=64, ... latent_vec_size=100, ... wadv=1, ... wcon=50, ... wenc=1, ... )
See also
anomalib.models.image.ganomaly.torch_model.GanomalyModel
:PyTorch implementation of the GANomaly model architecture.
anomalib.models.image.ganomaly.loss.GeneratorLoss
:Loss function for the generator network.
anomalib.models.image.ganomaly.loss.DiscriminatorLoss
:Loss function for the discriminator network.
- configure_optimizers()#
Configure optimizers for each decoder.
- Returns:
Adam optimizer for each decoder
- Return type:
Optimizer
- property learning_type: LearningType#
Return the learning type of the model.
- Returns:
Learning type of the model.
- Return type:
LearningType
- on_test_batch_end(outputs, batch, batch_idx, dataloader_idx=0)#
Normalize outputs based on min/max values.
- Return type:
- on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx=0)#
Normalize outputs based on min/max values.
- Return type:
- test_step(batch, batch_idx, *args, **kwargs)#
Update min and max scores from the current step.
- Return type:
- training_step(batch, batch_idx)#
Perform the training step.
Torch models defining encoder, decoder, generator and discriminator networks.
The GANomaly model consists of several key components:
Encoder: Compresses input images into latent vectors
Decoder: Reconstructs images from latent vectors
Generator: Combines encoder-decoder-encoder for image generation
Discriminator: Distinguishes real from generated images
The architecture follows an encoder-decoder-encoder pattern where: - First encoder compresses input image to latent space - Decoder reconstructs the image from latent vector - Second encoder re-encodes reconstructed image - Anomaly score is based on difference between latent vectors
Example
>>> from anomalib.models.image.ganomaly.torch_model import GanomalyModel
>>> model = GanomalyModel(
... input_size=(256, 256),
... num_input_channels=3,
... n_features=64,
... latent_vec_size=100,
... extra_layers=0,
... add_final_conv_layer=True
... )
>>> input_tensor = torch.randn(32, 3, 256, 256)
>>> output = model(input_tensor)
- Code adapted from:
Title: GANomaly - PyTorch Implementation Authors: Samet Akcay URL: samet-akcay/ganomaly License: MIT
See also
anomalib.models.image.ganomaly.lightning_model.Ganomaly
:Lightning implementation of the GANomaly model
anomalib.models.image.ganomaly.loss.GeneratorLoss
:Loss function for the generator network
anomalib.models.image.ganomaly.loss.DiscriminatorLoss
:Loss function for the discriminator network
- class anomalib.models.image.ganomaly.torch_model.GanomalyModel(input_size, num_input_channels, n_features, latent_vec_size, extra_layers=0, add_final_conv_layer=True)#
Bases:
Module
GANomaly model for anomaly detection.
Complete model combining Generator and Discriminator networks.
- Parameters:
input_size (tuple[int, int]) – Input image size (height, width)
num_input_channels (int) – Number of input image channels
n_features (int) – Number of feature maps in convolution layers
latent_vec_size (int) – Size of latent vector between encoder-decoder
extra_layers (int, optional) – Number of extra intermediate layers. Defaults to
0
.add_final_conv_layer (bool, optional) – Add final convolution to encoders. Defaults to
True
.
Example
>>> model = GanomalyModel( ... input_size=(256, 256), ... num_input_channels=3, ... n_features=64, ... latent_vec_size=100 ... ) >>> input_tensor = torch.randn(32, 3, 256, 256) >>> output = model(input_tensor)
References
- Title: GANomaly: Semi-Supervised Anomaly Detection via Adversarial
Training
Authors: Samet Akcay, Amir Atapour-Abarghouei, Toby P. Breckon
- forward(batch)#
Forward pass through GANomaly model.
- Parameters:
batch (torch.Tensor) – Batch of input images
- Return type:
- Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | InferenceBatch:
- If training:
Padded input batch
Generated images
First encoder’s latent vectors
Second encoder’s latent vectors
- If inference:
Batch containing anomaly scores
Loss functions for the GANomaly model implementation.
The GANomaly model uses two loss functions:
Generator Loss: Combines adversarial loss, reconstruction loss and encoding loss
Discriminator Loss: Binary cross entropy loss for real/fake image discrimination
Example
>>> from anomalib.models.image.ganomaly.loss import GeneratorLoss
>>> generator_loss = GeneratorLoss(wadv=1, wcon=50, wenc=1)
>>> loss = generator_loss(latent_i, latent_o, images, fake, pred_real, pred_fake)
>>> from anomalib.models.image.ganomaly.loss import DiscriminatorLoss
>>> discriminator_loss = DiscriminatorLoss()
>>> loss = discriminator_loss(pred_real, pred_fake)
See also
anomalib.models.image.ganomaly.torch_model.GanomalyModel
:PyTorch implementation of the GANomaly model architecture.
- class anomalib.models.image.ganomaly.loss.DiscriminatorLoss#
Bases:
Module
Discriminator loss for the GANomaly model.
Uses binary cross entropy to train the discriminator to distinguish between real and generated images.
Example
>>> discriminator_loss = DiscriminatorLoss() >>> loss = discriminator_loss( ... pred_real=torch.randn(32, 1), ... pred_fake=torch.randn(32, 1) ... )
- forward(pred_real, pred_fake)#
Compute the discriminator loss for predicted batch.
- Parameters:
pred_real (torch.Tensor) – Discriminator predictions for real images.
pred_fake (torch.Tensor) – Discriminator predictions for fake images.
- Returns:
Average discriminator loss.
- Return type:
Example
>>> loss = discriminator_loss(pred_real, pred_fake)
- class anomalib.models.image.ganomaly.loss.GeneratorLoss(wadv=1, wcon=50, wenc=1)#
Bases:
Module
Generator loss for the GANomaly model.
Combines three components: 1. Adversarial loss: Helps generate realistic images 2. Contextual loss: Ensures generated images match input 3. Encoding loss: Enforces consistency in latent space
- Parameters:
Example
>>> generator_loss = GeneratorLoss(wadv=1, wcon=50, wenc=1) >>> loss = generator_loss( ... latent_i=torch.randn(32, 100), ... latent_o=torch.randn(32, 100), ... images=torch.randn(32, 3, 256, 256), ... fake=torch.randn(32, 3, 256, 256), ... pred_real=torch.randn(32, 1), ... pred_fake=torch.randn(32, 1) ... )
- forward(latent_i, latent_o, images, fake, pred_real, pred_fake)#
Compute the generator loss for a batch.
- Parameters:
latent_i (torch.Tensor) – Latent features from the first encoder.
latent_o (torch.Tensor) – Latent features from the second encoder.
images (torch.Tensor) – Real images that served as generator input.
fake (torch.Tensor) – Generated/fake images.
pred_real (torch.Tensor) – Discriminator predictions for real images.
pred_fake (torch.Tensor) – Discriminator predictions for fake images.
- Returns:
Combined weighted generator loss.
- Return type:
Example
>>> loss = generator_loss(latent_i, latent_o, images, fake, ... pred_real, pred_fake)