Classes to replace LRFinder and patches to Learner

Modifications to existing callback LRFinder are needed in order to run lr_find using multiple TPU cores. An equivalent xla_lr_find method is patched to Learner so it can run on multiple TPU cores.

class SkipValidationCallback[source]

SkipValidationCallback(after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None) :: Callback

Basic class handling tweaks of the training loop by changing a Learner in various events

PerDeviceLoader.close[source]

PerDeviceLoader.close()

close data loader queues on xla parallel loader

class SyncedCancelCallback[source]

SyncedCancelCallback(after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None) :: Callback

A Callback to cancel training in sync (closing data loaders queues across all ranks)

class XLALRFinder[source]

XLALRFinder(start_lr=1e-07, end_lr=10, num_it=100, stop_div=True) :: ParamScheduler

Training with exponentially growing learning rate

Learner.get_suggested_lrs[source]

Learner.get_suggested_lrs(num_it)

compute Suggested LRs

# import pickle
# from fastai.learner import Recorder
# from fastcore.basics import patch   
# @patch
# def reload_lr_find_attrs(self:Recorder, fn='_plt_loss.pkl'):
#     if isinstance(fn,str):
#         fn = Path(fn)

#     if not fn.is_file():
#         return
       
#     with open(fn,'rb') as f:
#         d = pickle.load(f)
#         self.lrs,self.losses = d['lrs'],d['losses']
#         self.values, self.iters = d['values'], d['iters']
#         if 'hps' in d:
#             self.hps = d['hps']
#     # delete file after
#     if fn.is_file():
#         fn.unlink()

xla_run_lr_find[source]

xla_run_lr_find(rank, learner_args, add_args, lr_find_args, ctrl_args)

Learner.xla_lr_find[source]

Learner.xla_lr_find(num_cores=8, start_method='fork', start_lr=1e-07, end_lr=10, num_it=100, stop_div=True, show_plot=True, suggestions=True)

Test out xla_lr_find

from fastai.vision.all import *
path = untar_data(URLs.MNIST_TINY)
# path = untar_data(URLs.MNIST)
data = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(),
    # splitter=GrandparentSplitter(train_name='training', valid_name='testing'),
    item_tfms=Resize(28),
    batch_tfms=[]
)
dls = data.dataloaders(path, bs=8)
# dls = data.dataloaders(path, bs=64)
learner = cnn_learner(dls, resnet18, metrics=accuracy, concat_pool=False)
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth

learner.unfreeze()
%%time
learner.xla_lr_find()
start fit
CPU times: user 121 ms, sys: 106 ms, total: 227 ms
Wall time: 1min 9s
SuggestedLRs(lr_min=7.585775847473997e-08, lr_steep=6.309573450380412e-07)
learner.xla_fit(1)
start fit
epoch train_loss valid_loss accuracy time
0 0.151200 0.250101 0.910511 00:17
suggest_lr = learner.xla_lr_find(stop_div=False)
start fit
print(suggest_lr)
SuggestedLRs(lr_min=7.585775847473997e-08, lr_steep=6.309573450380412e-07)
 
 
learner.xla_fit(1)
start fit
epoch train_loss valid_loss accuracy time
0 0.103489 1.654110 0.616477 00:16
learner.unfreeze()
suggest_lr = learner.xla_lr_find(stop_div=False)
start fit
print(suggest_lr)
SuggestedLRs(lr_min=7.585775847473997e-08, lr_steep=6.309573450380412e-07)
 
 
learner.xla_fit_one_cycle(3)
start fit
epoch train_loss valid_loss accuracy time
0 0.051903 0.543764 0.752841 00:13
1 0.067877 1.030375 0.690341 00:07
2 0.071983 0.573838 0.833807 00:06
one_param(learner.model).device
device(type='cpu')
learner.dls.device
device(type='cpu')
learner.xla_lr_find()
start fit
 
learner.xla_fit(1, lr=1e-2)
start fit
epoch train_loss valid_loss accuracy time
0 0.417607 77015.023438 0.498580 00:17