applying profiling to the fastai learner callback functions

FastAI Training Event Lifecycle Methods

after_create

before_fit
      before_epoch
            before_train
                before_batch
                    after_pred
                    after_loss
                    before_backward
                    after_backward
                    after_cancel_step
                    after_step
                    after_cancel_batch
                after_batch 
            after_cancel_train                    
            after_train
            before_validate
                before_batch
                    after_pred
                    after_loss
                after_cancel_batch
                after_batch
            after_cancel_validate
            after_validate
      after_epoch
after_cancel_fit       
after_fit

class MyProfileCallback[source]

MyProfileCallback(reset=False) :: Callback

Callback to profile training lifecycle event performance

Learner.to_my_profile[source]

Learner.to_my_profile(reset=False)

Add my_profile callback to learner

Example Usage

from fastai.vision.all import *
path = untar_data(URLs.MNIST_TINY)
Path.BASE_PATH = path
datablock = DataBlock(
    blocks=(ImageBlock,CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(),
    item_tfms=Resize(28),
    batch_tfms=[]
)
dls = datablock.dataloaders(path)
learner = cnn_learner(dls,resnet18,metrics=accuracy)
learner.summary()
Sequential (Input shape: 64)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     64 x 64 x 14 x 14   
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
____________________________________________________________________________
                     64 x 128 x 4 x 4    
Conv2d                                    73728      False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Conv2d                                    8192       False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
____________________________________________________________________________
                     64 x 256 x 2 x 2    
Conv2d                                    294912     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    32768      False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
____________________________________________________________________________
                     64 x 512 x 1 x 1    
Conv2d                                    1179648    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Conv2d                                    131072     False     
BatchNorm2d                               1024       True      
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
AdaptiveAvgPool2d                                              
AdaptiveMaxPool2d                                              
Flatten                                                        
BatchNorm1d                               2048       True      
Dropout                                                        
____________________________________________________________________________
                     64 x 512            
Linear                                    524288     True      
ReLU                                                           
BatchNorm1d                               1024       True      
Dropout                                                        
____________________________________________________________________________
                     64 x 2              
Linear                                    1024       True      
____________________________________________________________________________

Total params: 11,704,896
Total trainable params: 537,984
Total non-trainable params: 11,166,912

Optimizer used: <function Adam at 0x134d4d9e0>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #2

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback
learner.to_my_profile()
<fastai.learner.Learner at 0x135ecafd0>
learner.my_profile
MyProfileCallback
learner.my_profile.print_stats()
fit has no data
   epoch has no data
      train has no data
         train_batch has no data
            train_pred has no data
            train_loss has no data
            train_backward has no data
            train_step has no data
            train_zero_grad has no data
      valid has no data
         valid_batch has no data
            valid_pred has no data
            valid_loss has no data
learner.fit(1)
epoch train_loss valid_loss accuracy time
0 0.693655 0.486362 0.749642 00:14
learner.my_profile.print_stats()
fit  called 1 times. max: 14.826 avg: 14.826
   epoch  called 1 times. max: 14.826 avg: 14.826
      train  called 1 times. max: 12.539 avg: 12.539
         train_batch  called 11 times. max: 1.147 avg: 1.093
            train_pred  called 11 times. max: 0.253 avg: 0.219
            train_loss  called 11 times. max: 0.001 avg: 0.001
            train_backward  called 11 times. max: 0.900 avg: 0.861
            train_step  called 11 times. max: 0.014 avg: 0.010
            train_zero_grad  called 11 times. max: 0.002 avg: 0.002
      valid  called 1 times. max: 2.283 avg: 2.283
         valid_batch  called 11 times. max: 0.203 avg: 0.181
            valid_pred  called 11 times. max: 0.202 avg: 0.180
            valid_loss  called 11 times. max: 0.002 avg: 0.001
fit_stats = learner.my_profile.get_stats();fit_stats
[('fit', 0, [14.826272010803223]),
 ('epoch', 1, [14.825506210327148]),
 ('train', 2, [12.53893232345581]),
 ('train_batch',
  3,
  [1.147028923034668,
   1.0965969562530518,
   1.0539379119873047,
   1.0700407028198242,
   1.0998239517211914,
   1.0905580520629883,
   1.0969460010528564,
   1.0751848220825195,
   1.1051452159881592,
   1.0613350868225098,
   1.1278557777404785]),
 ('train_pred',
  4,
  [0.25260400772094727,
   0.2151319980621338,
   0.21577811241149902,
   0.21297788619995117,
   0.2168900966644287,
   0.21621465682983398,
   0.21819210052490234,
   0.2154397964477539,
   0.21802592277526855,
   0.21338295936584473,
   0.21511292457580566]),
 ('train_loss',
  4,
  [0.0011301040649414062,
   0.0007872581481933594,
   0.000743865966796875,
   0.0007627010345458984,
   0.0007507801055908203,
   0.0007741451263427734,
   0.0007429122924804688,
   0.0007698535919189453,
   0.0007410049438476562,
   0.0007848739624023438,
   0.0007369518280029297]),
 ('train_backward',
  4,
  [0.8776719570159912,
   0.8692278861999512,
   0.8258969783782959,
   0.8449299335479736,
   0.8709321022033691,
   0.862293004989624,
   0.8666291236877441,
   0.8475267887115479,
   0.8750889301300049,
   0.8348186016082764,
   0.9004881381988525]),
 ('train_step',
  4,
  [0.013615131378173828,
   0.009528875350952148,
   0.009634017944335938,
   0.009547948837280273,
   0.009372234344482422,
   0.009392976760864258,
   0.009460210800170898,
   0.00952601432800293,
   0.009387016296386719,
   0.009891986846923828,
   0.009598970413208008]),
 ('train_zero_grad',
  4,
  [0.0018649101257324219,
   0.001825094223022461,
   0.0017778873443603516,
   0.0017271041870117188,
   0.0017819404602050781,
   0.0017867088317871094,
   0.0018239021301269531,
   0.0018229484558105469,
   0.0018050670623779297,
   0.0023589134216308594,
   0.0018210411071777344]),
 ('valid', 2, [2.282799005508423]),
 ('valid_batch',
  3,
  [0.20348501205444336,
   0.1881699562072754,
   0.1850287914276123,
   0.1786811351776123,
   0.1722097396850586,
   0.17976093292236328,
   0.1717219352722168,
   0.17983412742614746,
   0.18445897102355957,
   0.17809414863586426,
   0.16802597045898438]),
 ('valid_pred',
  4,
  [0.20157313346862793,
   0.18747615814208984,
   0.184248685836792,
   0.17804694175720215,
   0.17160391807556152,
   0.1789689064025879,
   0.171112060546875,
   0.1790611743927002,
   0.18384408950805664,
   0.17748594284057617,
   0.16724205017089844]),
 ('valid_loss',
  4,
  [0.0017740726470947266,
   0.0005881786346435547,
   0.000659942626953125,
   0.0005400180816650391,
   0.000514984130859375,
   0.0006740093231201172,
   0.0005199909210205078,
   0.0006546974182128906,
   0.0005199909210205078,
   0.0005161762237548828,
   0.0006620883941650391])]
learner.my_profile.print_stats('train_batch')
         train_batch  called 11 times. max: 1.147 avg: 1.093
train_batch_stats = learner.my_profile.get_stats('train_batch'); train_batch_stats
('train_batch',
 3,
 [1.147028923034668,
  1.0965969562530518,
  1.0539379119873047,
  1.0700407028198242,
  1.0998239517211914,
  1.0905580520629883,
  1.0969460010528564,
  1.0751848220825195,
  1.1051452159881592,
  1.0613350868225098,
  1.1278557777404785])
learner.my_profile.clear_stats()
learner.my_profile.print_stats()
fit has no data
   epoch has no data
      train has no data
         train_batch has no data
            train_pred has no data
            train_loss has no data
            train_backward has no data
            train_step has no data
            train_zero_grad has no data
      valid has no data
         valid_batch has no data
            valid_pred has no data
            valid_loss has no data
learner.my_profile.print_stats('train')
      train has no data
learner.fine_tune(1)
epoch train_loss valid_loss accuracy time
0 0.339887 0.245247 0.908441 00:14
epoch train_loss valid_loss accuracy time
0 0.273823 0.196766 0.919886 00:21
learner.my_profile.print_stats()
fit  called 2 times. max: 21.981 avg: 18.306
   epoch  called 2 times. max: 21.980 avg: 18.305
      train  called 2 times. max: 19.679 avg: 15.999
         train_batch  called 22 times. max: 2.007 avg: 1.426
            train_pred  called 22 times. max: 0.259 avg: 0.219
            train_loss  called 22 times. max: 0.001 avg: 0.001
            train_backward  called 22 times. max: 1.592 avg: 1.150
            train_step  called 22 times. max: 0.148 avg: 0.051
            train_zero_grad  called 22 times. max: 0.007 avg: 0.004
      valid  called 2 times. max: 2.306 avg: 2.301
         valid_batch  called 22 times. max: 0.211 avg: 0.182
            valid_pred  called 22 times. max: 0.209 avg: 0.181
            valid_loss  called 22 times. max: 0.002 avg: 0.001
learner.my_profile.reset = True
learner.fine_tune(1)
epoch train_loss valid_loss accuracy time
0 0.160970 0.161350 0.944206 00:15
epoch train_loss valid_loss accuracy time
0 0.120808 0.129148 0.958512 00:23
learner.my_profile.print_stats()
fit  called 1 times. max: 23.155 avg: 23.155
   epoch  called 1 times. max: 23.154 avg: 23.154
      train  called 1 times. max: 20.823 avg: 20.823
         train_batch  called 11 times. max: 1.939 avg: 1.864
            train_pred  called 11 times. max: 0.247 avg: 0.216
            train_loss  called 11 times. max: 0.001 avg: 0.001
            train_backward  called 11 times. max: 1.590 avg: 1.546
            train_step  called 11 times. max: 0.147 avg: 0.093
            train_zero_grad  called 11 times. max: 0.008 avg: 0.007
      valid  called 1 times. max: 2.326 avg: 2.326
         valid_batch  called 11 times. max: 0.214 avg: 0.183
            valid_pred  called 11 times. max: 0.212 avg: 0.182
            valid_loss  called 11 times. max: 0.002 avg: 0.001
learner.my_profile.reset