Single TPU Core Extensions
Use this module if running on a single TPU Core as the main process.
from fastai_xla_extensions.core import *
XLAOptimProxy
is a class which has overridden the step
method to call the Pytorch-XLA function xm.optimizer_step
which synchronizes the XLA graph.
All other calls to XLAOptimProxy
just forward it to the internal self.opt
instance.
DeviceMoverTransform
is a simple transform that moves the batch input from the CPU to the XLA device.
This is in lieu of the normal mechanism of the DataLoader implementation where the dls.device is set to the XLA device before the start of any batch transformations in the dataloaders.
Unfortunately, the AffineCoordTfm which is used for data augmentation (all the batch Zoom, Warp, Rotate augmentations) cause a problem when run on the TPU due to some affine operations not currently implemented in the Pytorch XLA) which triggers a lowering of the XLA Tensors to the CPU to perform the affine operation and causes a massive slowdown, even much slower than just doing the affine transform in the CPU in the first place.
The solution is then to postpone the moving of the input batch to TPU after the affine transformation, by setting the dls.device to None, which is done in the before_fit method of the XLAOptCallback.
These functions are for the purpose of modifying the batch transforms pipeline to add a device mover transform that moves the batch input sample to the TPU since this step has been disabled (by setting the dls.device to None) so that all batch transforms prior to the device mover transform are by default executed on the CPU.
These internal functions just check the types of Transform
in the after_batch
pipeline.
This inserts a batch transform for a dataloader at the index location idx
.
This will add a device mover transform to the batch transforms if any of them trigger a lowering from the TPU to CPU. Currently identified transforms that cause this are the AffineCoordTfm
and RandomResizeCropGPU
transforms.
If none of the transforms are present, the dls.device is set to XLA so that when the TrainEvalCallback.before_fit
is called, the model is also moved to the TPU.
This callback replaces the learner's opt
with an instance of XLAOptimProxy
that proxies the original opt
during the beginning of the fit
method and restores the original opt
after the fit
.
It also sets the dataloaders.device
and the learn.model
to use a TPU core using the device returned by the xm.xla_device()
method.
Learner.to_xla
makes sure the model and dataloader has been moved to the xla device prior to creating the optimizer by setting the opt to None which will force a call to create_opt in the fit methods after already moving the model to the TPU device in this method.
from fastai.vision.all import *
from fastai.callback.training import GradientAccumulation
Also, import the fastai_xla_extensions.core
package as
from fastai_xla_extensions.core import *
Load data
path = untar_data(URLs.MNIST_TINY)
Path.BASE_PATH = path
Create datablock
datablock = DataBlock(
blocks=(ImageBlock,CategoryBlock),
get_items=get_image_files,
get_y=parent_label,
splitter=GrandparentSplitter(),
item_tfms=Resize(28),
batch_tfms=aug_transforms(do_flip=False, min_scale=0.8) # trigger usage of RandomResizedCropGPU
# batch_tfms=[]
)
Set dataloader to load the batches to the cpu
dls = datablock.dataloaders(path)
dls.device
Note that at this point, the dataloaders is still using the CPU device.
dls.train.after_batch.fs
Also, note that using the batch transforms (batch_tfms
) adds the Warp
(actually any zoom, warp, resize, rotate transforms which are subclasses of AffineCoordTfm
) and RandomResizedCropGPU
transform which unfortunately aren't supported by XLA yet and trigger a "lowering" (i.e. move the input tensor back to CPU and perform the operation there and move it back to TPU afterwards) which causes a massive performance slowdown.
As a workaround, these batch transform operations are done first on the CPU.
See the discussion later on the DeviceMoverTransform
which is added if any of these transforms are present.
dls.show_batch()
Any operations such as show_batch
are run on the CPU and the fastai_xla_extensions
module does not modify any base fastai functionality.
Create the Learner
learner = cnn_learner(dls, resnet18, metrics=accuracy, concat_pool=False)
learner.dls.train.after_batch.fs
Again, this shows the list of transforms as the one previously listed, with the addition of Normalize
transform, which is added by the cnn_learner (default parameter normalize=True
) which adds the normalization parameters from where the model was pre-trained on.
learner.to_xla(xm.xla_device());
Calling the to_xla()
method sets the Learner
up for TPU model training.
This includes adding the XLAOptCallback
which enables the XLA method xm.optimizer_step
to be called instead of the opt.step
method as well as move the model to the TPU.
It also either adds a DeviceMoverTransform
to the batch transforms pipeline (which will move the input batch sample to TPU after performing any batch transforms that are AffineCoordTfms
-based (such as zoom, warp, rotate and resize) or a RandomResizedCropGPU
transform) or set the DataLoaders device to the TPU if no batch transforms which trigger a lowering are present in the data loading pipeline.
learner.dls.device is None
learner.dls.train.after_batch.fs
learner.dls.valid.after_batch.fs
The learner
object should have an xla_opt
attribute which confirms that XLAOptCallback
has been added to the list of callbacks for this learner.
learner.xla_opt
learner.xla_opt.barrier
learner.dls.device is None
learner.opt is None
one_param(learner.model).device
has_affinecoord_tfm(learner.dls)
has_devicemover_tfm(learner.dls.train)
learner.show_training_loop()
learner.opt is None
learner.dls.device
class CheckXLADeviceCallback(Callback):
def before_fit(self):
if self.dls.device is not None:
print(f'dls device: {self.dls.device} model device: {one_param(self.learn.model).device}')
else:
print(f'dls device: None model device: {one_param(self.learn.model).device}')
if self.learn.opt is not None:
param = first(self.learn.opt.all_params())[0]
print(f'opt param device: {param.device}')
def before_epoch(self):
if self.dls.device is not None:
print(f'dls device: {self.dls.device} model device: {one_param(self.learn.model).device}')
else:
print(f'dls device: None model device: {one_param(self.learn.model).device}')
This CheckXLADeviceCallback
will check what device is used by the dataloader and model during training.
Run fit
to train the model.
learner.fine_tune(6,freeze_epochs=4, cbs=CheckXLADeviceCallback())
learner.dls.train.after_batch.fs
learner.detach_xla()
learner.dls.device is None
learner.opt is None
one_param(learner.model).device
learner.save('stage-1')
learner = cnn_learner(dls,resnet18, metrics=accuracy, concat_pool=False)
learner.load('stage-1')
learner.to_xla()
learner.dls.device is None
learner.fit_flat_cos(1,cbs=CheckXLADeviceCallback())
one_param(learner.model).device
learner.save('stage-2')
learner.load('stage-2')
learner.fine_tune(6, freeze_epochs=4, cbs=CheckXLADeviceCallback())
learner.dls.train.after_batch.fs
learner.save('stage-4')
learner.lr_find()
learner.fit_one_cycle(5, lr_max=slice(2e-2))
learner.save('stage-5')
Gradient Accum callback (which calls CancelBatchException) should still work.
An alternative design for the XLA Opt Callback which raises the CancelBatchException in the after_backward
method (after executing xm.optimizer_step
and opt.zero_grad
) would interfere with the Gradient Accum callback (which raises CancelBatchException
in the after_backward
method to skip the gradient updates in order to accumulate the gradients).
The current design (add/remove XLAOptimProxy
during before_fit
and after_fit
callback lifecycle methods) is less disruptive and more compatible with other callbacks.
learner.fit_one_cycle(4,cbs=[GradientAccumulation(n_acc=2),])
Valid loss has kind of plateaued so this look ok.
learner.recorder.plot_loss()
Plot moms and lr across batches/epochs
learner.recorder.plot_sched()
Get Classification Interpretation for more details on model performance
interp = ClassificationInterpretation.from_learner(learner)
Plot confusion matrix
interp.plot_confusion_matrix()
Samples where model was most confused
interp.plot_top_losses(12)
End of Notebook