Source code for anomalib.utils.callbacks.nncf.callback
"""Callbacks for NNCF optimization."""# Copyright (C) 2022 Intel Corporation## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions# and limitations under the License.importosfromtypingimportAny,Dict,Optionalimportpytorch_lightningasplfromnncfimportNNCFConfigfromnncf.api.compressionimportCompressionAlgorithmControllerfromnncf.torchimportregister_default_init_argsfrompytorch_lightningimportCallbackfrompytorch_lightning.utilities.cliimportCALLBACK_REGISTRYfromanomalib.utils.callbacks.nncf.utilsimportInitLoader,wrap_nncf_model@CALLBACK_REGISTRY
[docs]classNNCFCallback(Callback):"""Callback for NNCF compression. Assumes that the pl module contains a 'model' attribute, which is the PyTorch module that must be compressed. Args: config (Dict): NNCF Configuration export_dir (Str): Path where the export `onnx` and the OpenVINO `xml` and `bin` IR are saved. If None model will not be exported. """def__init__(self,config:Dict,export_dir:str=None):self.export_dir=export_dirself.config=NNCFConfig(config)self.nncf_ctrl:Optional[CompressionAlgorithmController]=None# pylint: disable=unused-argument
[docs]defsetup(self,trainer:pl.Trainer,pl_module:pl.LightningModule,stage:Optional[str]=None)->None:"""Call when fit or test begins. Takes the pytorch model and wraps it using the compression controller so that it is ready for nncf fine-tuning. """ifself.nncf_ctrlisnotNone:returninit_loader=InitLoader(trainer.datamodule.train_dataloader())# type: ignoreconfig=register_default_init_args(self.config,init_loader)self.nncf_ctrl,pl_module.model=wrap_nncf_model(model=pl_module.model,config=config,dataloader=trainer.datamodule.train_dataloader()# type: ignore
)
[docs]defon_train_batch_start(self,trainer:pl.Trainer,_pl_module:pl.LightningModule,_batch:Any,_batch_idx:int,_unused:Optional[int]=0,)->None:"""Call when the train batch begins. Prepare compression method to continue training the model in the next step. """ifself.nncf_ctrl:self.nncf_ctrl.scheduler.step()
[docs]defon_train_epoch_start(self,_trainer:pl.Trainer,_pl_module:pl.LightningModule)->None:"""Call when the train epoch starts. Prepare compression method to continue training the model in the next epoch. """ifself.nncf_ctrl:self.nncf_ctrl.scheduler.epoch_step()
[docs]defon_train_end(self,_trainer:pl.Trainer,_pl_module:pl.LightningModule)->None:"""Call when the train ends. Exports onnx model and if compression controller is not None, uses the onnx model to generate the OpenVINO IR. """ifself.export_dirisNoneorself.nncf_ctrlisNone:returnos.makedirs(self.export_dir,exist_ok=True)onnx_path=os.path.join(self.export_dir,"model_nncf.onnx")self.nncf_ctrl.export_model(onnx_path)optimize_command="mo --input_model "+onnx_path+" --output_dir "+self.export_diros.system(optimize_command)