Classes to replace LRFinder and patches to Learner
Modifications to existing callback LRFinder
are needed in order to run lr_find
using multiple TPU cores. An equivalent xla_lr_find
method is patched to Learner
so it can run on multiple TPU cores.
# import pickle
# from fastai.learner import Recorder
# from fastcore.basics import patch
# @patch
# def reload_lr_find_attrs(self:Recorder, fn='_plt_loss.pkl'):
# if isinstance(fn,str):
# fn = Path(fn)
# if not fn.is_file():
# return
# with open(fn,'rb') as f:
# d = pickle.load(f)
# self.lrs,self.losses = d['lrs'],d['losses']
# self.values, self.iters = d['values'], d['iters']
# if 'hps' in d:
# self.hps = d['hps']
# # delete file after
# if fn.is_file():
# fn.unlink()
from fastai.vision.all import *
path = untar_data(URLs.MNIST_TINY)
# path = untar_data(URLs.MNIST)
data = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
get_y=parent_label,
splitter=GrandparentSplitter(),
# splitter=GrandparentSplitter(train_name='training', valid_name='testing'),
item_tfms=Resize(28),
batch_tfms=[]
)
dls = data.dataloaders(path, bs=8)
# dls = data.dataloaders(path, bs=64)
learner = cnn_learner(dls, resnet18, metrics=accuracy, concat_pool=False)
learner.unfreeze()
%%time
learner.xla_lr_find()
learner.xla_fit(1)
suggest_lr = learner.xla_lr_find(stop_div=False)
print(suggest_lr)
learner.xla_fit(1)
learner.unfreeze()
suggest_lr = learner.xla_lr_find(stop_div=False)
print(suggest_lr)
learner.xla_fit_one_cycle(3)
one_param(learner.model).device
learner.dls.device
learner.xla_lr_find()
learner.xla_fit(1, lr=1e-2)