anomalib.models.ganomaly.loss¶
Loss function for the GANomaly Model Implementation.
Module Contents¶
Classes¶
Generator loss for the GANomaly model. |
|
Discriminator loss for the GANomaly model. |
- class anomalib.models.ganomaly.loss.GeneratorLoss(wadv=1, wcon=50, wenc=1)[source]¶
Bases:
torch.nn.ModuleGenerator loss for the GANomaly model.
- Parameters
wadv (int, optional) – Weight for adversarial loss. Defaults to 1.
wcon (int, optional) – Image regeneration weight. Defaults to 50.
wenc (int, optional) – Latent vector encoder weight. Defaults to 1.
- forward(self, latent_i: torch.Tensor, latent_o: torch.Tensor, images: torch.Tensor, fake: torch.Tensor, pred_real: torch.Tensor, pred_fake: torch.Tensor) torch.Tensor[source]¶
Compute the loss for a batch.
- Parameters
latent_i (Tensor) – Latent features of the first encoder.
latent_o (Tensor) – Latent features of the second encoder.
images (Tensor) – Real image that served as input of the generator.
fake (Tensor) – Generated image.
pred_real (Tensor) – Discriminator predictions for the real image.
pred_fake (Tensor) – Discriminator predictions for the fake image.
- Returns
The computed generator loss.
- Return type
Tensor