[docs]classGanomaly(AnomalyModule):"""PL Lightning Module for the GANomaly Algorithm. Args: batch_size (int): Batch size. input_size (Tuple[int,int]): Input dimension. n_features (int): Number of features layers in the CNNs. latent_vec_size (int): Size of autoencoder latent vector. extra_layers (int, optional): Number of extra layers for encoder/decoder. Defaults to 0. add_final_conv_layer (bool, optional): Add convolution layer at the end. Defaults to True. wadv (int, optional): Weight for adversarial loss. Defaults to 1. wcon (int, optional): Image regeneration weight. Defaults to 50. wenc (int, optional): Latent vector encoder weight. Defaults to 1. """def__init__(self,batch_size:int,input_size:Tuple[int,int],n_features:int,latent_vec_size:int,extra_layers:int=0,add_final_conv_layer:bool=True,wadv:int=1,wcon:int=50,wenc:int=1,lr:float=0.0002,beta1:float=0.5,beta2:float=0.999,):super().__init__()self.model:GanomalyModel=GanomalyModel(input_size=input_size,num_input_channels=3,n_features=n_features,latent_vec_size=latent_vec_size,extra_layers=extra_layers,add_final_conv_layer=add_final_conv_layer,)self.real_label=torch.ones(size=(batch_size,),dtype=torch.float32)self.fake_label=torch.zeros(size=(batch_size,),dtype=torch.float32)self.min_scores:Tensor=torch.tensor(float("inf"),dtype=torch.float32)# pylint: disable=not-callableself.max_scores:Tensor=torch.tensor(float("-inf"),dtype=torch.float32)# pylint: disable=not-callableself.generator_loss=GeneratorLoss(wadv,wcon,wenc)self.discriminator_loss=DiscriminatorLoss()# TODO: LR should be part of optimizer in config.yaml! Since ganomaly has custom# optimizer this is to be addressed later.self.learning_rate=lrself.beta1=beta1self.beta2=beta2
[docs]defconfigure_optimizers(self)->List[optim.Optimizer]:"""Configures optimizers for each decoder. Note: This method is used for the existing CLI. When PL CLI is introduced, configure optimizers method will be deprecated, and optimizers will be configured from either config.yaml file or from CLI. Returns: Optimizer: Adam optimizer for each decoder """optimizer_d=optim.Adam(self.model.discriminator.parameters(),lr=self.learning_rate,betas=(self.beta1,self.beta2),)optimizer_g=optim.Adam(self.model.generator.parameters(),lr=self.learning_rate,betas=(self.beta1,self.beta2),)return[optimizer_d,optimizer_g]
[docs]deftraining_step(self,batch,_,optimizer_idx):# pylint: disable=arguments-differ"""Training step. Args: batch (Dict): Input batch containing images. optimizer_idx (int): Optimizer which is being called for current training step. Returns: Dict[str, Tensor]: Loss """# forward passpadded,fake,latent_i,latent_o=self.model(batch["image"])pred_real,_=self.model.discriminator(padded)ifoptimizer_idx==0:# Discriminatorpred_fake,_=self.model.discriminator(fake.detach())loss=self.discriminator_loss(pred_real,pred_fake)else:# Generatorpred_fake,_=self.model.discriminator(fake)loss=self.generator_loss(latent_i,latent_o,padded,fake,pred_real,pred_fake)return{"loss":loss}
[docs]defon_validation_start(self)->None:"""Reset min and max values for current validation epoch."""self._reset_min_max()returnsuper().on_validation_start()
[docs]defvalidation_step(self,batch,_)->Dict[str,Tensor]:# type: ignore # pylint: disable=arguments-differ"""Update min and max scores from the current step. Args: batch (Dict[str, Tensor]): Predicted difference between z and z_hat. Returns: Dict[str, Tensor]: batch """batch["pred_scores"]=self.model(batch["image"])self.max_scores=max(self.max_scores,torch.max(batch["pred_scores"]))self.min_scores=min(self.min_scores,torch.min(batch["pred_scores"]))returnbatch
[docs]defvalidation_epoch_end(self,outputs):"""Normalize outputs based on min/max values."""logger.info("Normalizing validation outputs based on min/max values.")forpredictioninoutputs:prediction["pred_scores"]=self._normalize(prediction["pred_scores"])super().validation_epoch_end(outputs)returnoutputs
[docs]defon_test_start(self)->None:"""Reset min max values before test batch starts."""self._reset_min_max()returnsuper().on_test_start()
[docs]deftest_step(self,batch,_):"""Update min and max scores from the current step."""super().test_step(batch,_)self.max_scores=max(self.max_scores,torch.max(batch["pred_scores"]))self.min_scores=min(self.min_scores,torch.min(batch["pred_scores"]))returnbatch
[docs]deftest_epoch_end(self,outputs):"""Normalize outputs based on min/max values."""logger.info("Normalizing test outputs based on min/max values.")forpredictioninoutputs:prediction["pred_scores"]=self._normalize(prediction["pred_scores"])super().test_epoch_end(outputs)returnoutputs
[docs]def_normalize(self,scores:Tensor)->Tensor:"""Normalize the scores based on min/max of entire dataset. Args: scores (Tensor): Un-normalized scores. Returns: Tensor: Normalized scores. """scores=(scores-self.min_scores.to(scores.device))/(self.max_scores.to(scores.device)-self.min_scores.to(scores.device))returnscores
[docs]classGanomalyLightning(Ganomaly):"""PL Lightning Module for the GANomaly Algorithm. Args: hparams (Union[DictConfig, ListConfig]): Model params """def__init__(self,hparams:Union[DictConfig,ListConfig])->None:super().__init__(batch_size=hparams.dataset.train_batch_size,input_size=hparams.model.input_size,n_features=hparams.model.n_features,latent_vec_size=hparams.model.latent_vec_size,extra_layers=hparams.model.extra_layers,add_final_conv_layer=hparams.model.add_final_conv,wadv=hparams.model.wadv,wcon=hparams.model.wcon,wenc=hparams.model.wenc,lr=hparams.model.lr,beta1=hparams.model.beta1,beta2=hparams.model.beta2,)self.hparams:Union[DictConfig,ListConfig]# type: ignoreself.save_hyperparameters(hparams)
[docs]defconfigure_callbacks(self):"""Configure model-specific callbacks. Note: This method is used for the existing CLI. When PL CLI is introduced, configure callback method will be deprecated, and callbacks will be configured from either config.yaml file or from CLI. """early_stopping=EarlyStopping(monitor=self.hparams.model.early_stopping.metric,patience=self.hparams.model.early_stopping.patience,mode=self.hparams.model.early_stopping.mode,)return[early_stopping]