Lightning Implementation of the CFA Model.
CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly Localization.
This implementation uses PyTorch Lightning for training and inference.
- class'wide_resnet50_2', gamma_c=1, gamma_d=1, num_nearest_neighbors=3, num_hard_negative_features=3, radius=1e-05, pre_processor=True, post_processor=True, evaluator=True, visualizer=True)#
CFA Lightning Module.
The CFA model performs anomaly detection and localization using coupled hypersphere-based feature adaptation.
- Parameters:
backbone (str) – Name of the backbone CNN network. Defaults to
.gamma_c (int, optional) – Centroid loss weight parameter. Defaults to
.gamma_d (int, optional) – Distance loss weight parameter. Defaults to
.num_nearest_neighbors (int) – Number of nearest neighbors to consider. Defaults to
.num_hard_negative_features (int) – Number of hard negative features to use. Defaults to
.radius (float) – Radius of the hypersphere for soft boundary search. Defaults to
.pre_processor (PreProcessor | bool, optional) – Pre-processor instance or boolean flag. Defaults to
.post_processor (PostProcessor | bool, optional) – Post-processor instance or boolean flag. Defaults to
.evaluator (Evaluator | bool, optional) – Evaluator instance or boolean flag. Defaults to
.visualizer (Visualizer | bool, optional) – Visualizer instance or boolean flag. Defaults to
- static backward(loss, *args, **kwargs)#
Perform backward pass.
- Parameters:
loss (torch.Tensor) – Computed loss value.
*args – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Return type:
due to computational graph requirements. See CVS-122673 for more details.
- configure_optimizers()#
Configure the optimizer.
- Returns:
- AdamW optimizer configured with:
Learning rate:
Weight decay:
- Return type:
- property learning_type: LearningType#
Get the learning type.
- Returns:
Indicates this is a one-class classification model.
- Return type:
- on_train_start()#
Initialize the centroid for memory bank computation.
This method is called at the start of training to compute the initial centroid using the training data.
- Return type:
- training_step(batch, *args, **kwargs)#
Perform a training step.
- Parameters:
batch (Batch) – Input batch containing images and metadata.
*args – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns:
Dictionary containing the loss value.
- Return type:
- validation_step(batch, *args, **kwargs)#
Perform a validation step.
- Parameters:
batch (Batch) – Input batch containing images and metadata.
*args – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns:
Batch object updated with model predictions.
- Return type:
Torch Implementation of the CFA Model.
CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly Localization.
This module provides the PyTorch implementation of the CFA model for anomaly detection and localization. The model learns discriminative features by adapting them to coupled hyperspheres in the feature space.
- The model consists of:
A backbone CNN feature extractor
A descriptor network that generates target-oriented features
A memory bank that stores prototypical normal features
An anomaly map generator for localization
>>> import torch
>>> from import CfaModel
>>> # Initialize model
>>> model = CfaModel(
... backbone="resnet18",
... gamma_c=1,
... gamma_d=1,
... num_nearest_neighbors=3,
... num_hard_negative_features=3,
... radius=0.5
... )
>>> # Forward pass
>>> x = torch.randn(32, 3, 256, 256)
>>> predictions = model(x)
- class, gamma_c, gamma_d, num_nearest_neighbors, num_hard_negative_features, radius)#
Torch implementation of the CFA Model.
The model learns discriminative features by adapting them to coupled hyperspheres in the feature space. It uses a teacher-student architecture where the teacher network extracts features from normal samples to guide the student network.
- Parameters:
backbone (str) – Name of the backbone CNN network.
gamma_c (int) – Weight for centroid loss.
gamma_d (int) – Weight for distance loss.
num_nearest_neighbors (int) – Number of nearest neighbors for score computation.
num_hard_negative_features (int) – Number of hard negative features to use.
radius (float) – Initial radius of the hypersphere decision boundary.
>>> model = CfaModel( ... backbone="resnet18", ... gamma_c=1, ... gamma_d=1, ... num_nearest_neighbors=3, ... num_hard_negative_features=3, ... radius=0.5 ... )
- compute_distance(target_oriented_features)#
Compute distances between features and memory bank centroids.
- Parameters:
target_oriented_features (torch.Tensor) – Features from the descriptor network.
- Returns:
Distance tensor.
- Return type:
>>> model = CfaModel(...) >>> features = torch.randn(32, 256, 32, 32) # B x C x H x W >>> distances = model.compute_distance(features)
- forward(input_tensor)#
Forward pass through the model.
- Parameters:
input_tensor (torch.Tensor) – Input image tensor.
- Raises:
ValueError – When the memory bank is not initialized.
- Returns:
- During training, returns distance
tensor. During inference, returns anomaly predictions.
- Return type:
>>> model = CfaModel(...) >>> x = torch.randn(32, 3, 256, 256) >>> predictions = model(x)
- get_scale(input_size)#
Get the scale of the feature maps.
- Parameters:
input_size (tuple[int, int] | torch.Size) – Input image dimensions (height, width).
- Returns:
Feature map dimensions.
- Return type:
>>> model = CfaModel(...) >>> scale = model.get_scale((256, 256))
- initialize_centroid(data_loader)#
Initialize the centroid of the memory bank.
Computes the average feature representation of normal samples to initialize the memory bank centroids.
- Parameters:
data_loader (DataLoader) – DataLoader containing normal training samples.
- Return type:
>>> from import DataLoader >>> model = CfaModel(...) >>> train_loader = DataLoader(...) >>> model.initialize_centroid(train_loader)
Loss function for the CFA (Coupled-hypersphere-based Feature Adaptation) model.
This module implements the loss function used to train the CFA model for anomaly detection. The loss consists of two components:
Attraction loss that pulls normal samples inside a hypersphere
Repulsion loss that pushes anomalous samples outside the hypersphere
>>> import torch
>>> from import CfaLoss
>>> # Initialize loss function
>>> loss_fn = CfaLoss(
... num_nearest_neighbors=3,
... num_hard_negative_features=3,
... radius=0.5
... )
>>> # Compute loss on distance tensor
>>> distance = torch.randn(2, 1024, 1) # batch x pixels x 1
>>> loss = loss_fn(distance)
- class, num_hard_negative_features, radius)#
Loss function for the CFA model.
The loss encourages normal samples to lie within a hypersphere while pushing anomalous samples outside. It uses k-nearest neighbors to identify the closest samples and hard negative mining to find challenging anomalous examples.
- Parameters:
num_nearest_neighbors (int) – Number of nearest neighbors to consider for the attraction loss component.
num_hard_negative_features (int) – Number of hard negative features to use for the repulsion loss component.
radius (float) – Initial radius of the hypersphere that defines the decision boundary between normal and anomalous samples.
>>> loss_fn = CfaLoss( ... num_nearest_neighbors=3, ... num_hard_negative_features=3, ... radius=0.5 ... ) >>> distance = torch.randn(2, 1024, 1) # batch x pixels x 1 >>> loss = loss_fn(distance)
- forward(distance)#
Compute the CFA loss given distance features.
- The loss has two components:
Attraction loss (l_att): Encourages normal samples to lie within the hypersphere by penalizing distances greater than radius.
Repulsion loss (l_rep): Pushes anomalous samples outside the hypersphere by penalizing distances less than radius + margin.
- Parameters:
distance (torch.Tensor) – Distance tensor of shape
(batch_size, num_pixels, 1)
computed using target-oriented features.- Returns:
- Scalar loss value combining attraction and repulsion
- Return type:
Anomaly Map Generator for the CFA model implementation.
This module provides functionality to generate anomaly heatmaps from distance features computed by the CFA model.
>>> import torch
>>> from import AnomalyMapGenerator
>>> # Initialize generator
>>> generator = AnomalyMapGenerator(num_nearest_neighbors=3)
>>> # Generate anomaly map
>>> distance = torch.randn(1, 1024, 1) # batch x pixels x 1
>>> scale = (32, 32) # height x width
>>> anomaly_map = generator(distance=distance, scale=scale)
- class, sigma=4)#
Generate anomaly heatmaps from distance features.
The generator computes anomaly scores based on k-nearest neighbor distances and applies Gaussian smoothing to produce the final heatmap.
- Parameters:
>>> import torch >>> generator = AnomalyMapGenerator(num_nearest_neighbors=3) >>> distance = torch.randn(1, 1024, 1) # batch x pixels x 1 >>> scale = (32, 32) # height x width >>> anomaly_map = generator(distance=distance, scale=scale)
- compute_anomaly_map(score, image_size=None)#
Generate smoothed anomaly map from scores.
- Parameters:
score (torch.Tensor) – Anomaly scores of shape
(batch_size, 1, height, width)
.image_size (tuple[int, int] | torch.Size | None, optional) – Target size for upsampling the anomaly map. Defaults to
- Returns:
- Smoothed anomaly map of shape
(batch_size, 1, height, width)
- Return type:
- compute_score(distance, scale)#
Compute anomaly scores from distance features.
- Parameters:
distance (torch.Tensor) – Distance tensor of shape
(batch_size, num_pixels, 1)
.scale (tuple[int, int]) – Height and width of the feature map used to reshape the scores.
- Returns:
- Anomaly scores of shape
(batch_size, 1, height, width)
- Return type:
- forward(**kwargs)#
Generate anomaly map from input features.
The method expects
as required inputs, with optionalimage_size
for upsampling.- Parameters:
**kwargs –
Keyword arguments containing: - distance (torch.Tensor): Distance features - scale (tuple[int, int]): Feature map scale - image_size (tuple[int, int] | torch.Size, optional):
Target size for upsampling
- Raises:
ValueError – If required arguments are missing.
- Returns:
- Anomaly heatmap of shape
(batch_size, 1, height, width)
- Return type: