Open In Colab

Learner method patches to invoke multi-core fit and other operations prefixed by xla_.

These provide an alternate way to run multi core operations with minimal changes to existing fastai notebooks.

Add master_cbs property to Learner

Master callbacks are callbacks that will be executed on the master ordinal (rank 0 thread) only.

This means existing fastai notebooks must be checked if any additional callbacks used can cause conflicts if run on different threads at the same time.

Note that for default callbacks (TrainEvalCallback, Recorder, ProgressCallback) only ProgressCallback causes this problem.

However, the fastai_xla_extensions.multi_core.base module already handles this so that if used (which it is, by default), the ProgressCallback is attached only on the master ordinal thread.

Moreover, the Recorder callback is also handled such that validation losses and metrics are collated correctly by the fastai_xla_extensions.multi_core.base.SyncRecorderCallback so that the validation metrics and losses are reported correctly at the end of each epoch.

None[source]

Learner.master_cbs[source]

list all cbs to be run on the master ordinal thread

Learner.add_master_cbs[source]

Learner.add_master_cbs(cbs)

add master callbacks

Learner.add_master_cb[source]

Learner.add_master_cb(cb)

add a master callback

Learner.add_master_cbs[source]

Learner.add_master_cbs(cbs)

add master callbacks

Learner.remove_master_cb[source]

Learner.remove_master_cb(cb)

remove a cb from master callbacks

Learner.remove_master_cbs[source]

Learner.remove_master_cbs(cbs)

remove callbacks from master callbacks

Learner.grab_master_cbs[source]

Learner.grab_master_cbs(cb_cls)

find instance of cb_cls in master_cbs

Learner.remove_master_cbs[source]

Learner.remove_master_cbs(cbs)

remove callbacks from master callbacks

Learner.remove_master_cb[source]

Learner.remove_master_cb(cb)

remove a cb from master callbacks

Utility methods to implement XLA fit methods

make_xla_child_learner[source]

make_xla_child_learner(rank, sync_valid, learner_args, add_args, ctrl_args)

create a learner using passed parameters

setup_fit_cbs[source]

setup_fit_cbs(rank, fit_args)

add master cbs to cbs fit args if rank 0

xla_run_method[source]

xla_run_method(rank, fit_method, learner_args, add_args, fit_args, ctrl_args)

run fit method on spawned process

Learner.pack_learner_args[source]

Learner.pack_learner_args()

pack learner args into dict to pass to spawned process

Learner.reload_child_model[source]

Learner.reload_child_model()

reload model built by spawned processes

Learner.delete_tmp_files[source]

Learner.delete_tmp_files()

remove files created by spawned process prior to potentially recreating them

Learner.pre_xla_fit[source]

Learner.pre_xla_fit(ctrl_args={})

prepare learner for running spawned processes

Learner.post_xla_fit[source]

Learner.post_xla_fit(ctrl_args)

clean up learner after running spawned processes

prep_fit_args[source]

prep_fit_args(n_epoch, master_cbs, **kwargs)

prepare fit method args for running spawned processes

XLA fit methods

Learner.xla_fit[source]

Learner.xla_fit(n_epoch, num_cores=8, start_method='fork', master_cbs=None, lr=None, wd=None, reset_opt=False)

call fit in a multicore tpu environment

Learner.xla_fit_one_cycle[source]

Learner.xla_fit_one_cycle(n_epoch, num_cores=8, start_method='fork', master_cbs=None, lr_max=None, div=25.0, div_final=100000.0, pct_start=0.25, wd=None, moms=None, reset_opt=False)

call fit_one_cycle in a multicore tpu environment

Learner.xla_fit_flat_cos[source]

Learner.xla_fit_flat_cos(n_epoch, num_cores=8, start_method='fork', master_cbs=None, lr=None, div_final=100000.0, pct_start=0.75, wd=None, reset_opt=False)

call fit_flat_cos in a multicore tpu environment

prep_fit_sgdr_args[source]

prep_fit_sgdr_args(n_cycles, cycle_len, master_cbs, **kwargs)

prepare fit_sgdr method args for running spawned processes

Learner.xla_fit_sgdr[source]

Learner.xla_fit_sgdr(n_cycles, cycle_len, num_cores=8, start_method='fork', master_cbs=None, lr_max=None, cycle_mult=2, reset_opt=False, wd=None)

call fit_sgdr in multicore tpu environment

prep_finetune_args[source]

prep_finetune_args(epochs, master_cbs, **kwargs)

prepare finetune method args for running spawned processes

Learner.xla_fine_tune[source]

Learner.xla_fine_tune(epochs, num_cores=8, start_method='fork', master_cbs=None, base_lr=0.002, freeze_epochs=1, lr_mult=100, pct_start=0.3, div=5.0, lr_max=None, div_final=100000.0, wd=None, moms=None, reset_opt=False)

call fine_tune in multicore tpu environment

Example: Train MNIST

path = untar_data(URLs.MNIST_TINY)
data = DataBlock(
    blocks=(ImageBlock,CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(),
    item_tfms=Resize(28),
    batch_tfms=[]
)
dls = data.dataloaders(path, bs=16)
# concat_pool must be false due to a TPU bug that is triggered if using fastai AdaptivePool
from fastai.vision.learner import cnn_learner
from torchvision.models.resnet import resnet18
learner = cnn_learner(dls, resnet18, metrics=accuracy, concat_pool=False)
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth

learner.add_master_cbs([SaveModelCallback(fname='best_model')])
class PrintValuesCallback(Callback):
    order = 56 # after recorder, sync recorder, before save model callback  
    def after_epoch(self):
        print(f'final record: {self.learn.final_record}')
        vlen = len(self.recorder.values)
        print(f'values len: {vlen}')
        if vlen > 0:   
            last_idx = self.recorder.values[-1]  
            len_last_idx = len(last_idx)
            print(f'values last idx len: {len_last_idx}')
            print(f'last idx: {last_idx}')
            if 'save_model' in L(self.cbs).attrgot('name'):
                save_model_idx = self.save_model.idx
                print(f'save_model idx: {save_model_idx}')     
                if save_model_idx < len_last_idx:
                    val = self.recorder.values[-1][self.save_model.idx]
                    print(f'best_value: {val}')
        if 'sync_recorder' in L(self.cbs).attrgot('name'):
            sync_log = self.sync_recorder.sync_log
            len_sync_log = len(sync_log)
            print(f'sync rec sync_log len: {len_sync_log}')
            print(f'sync rec sync_log: {sync_log}')

            if len_sync_log > 0:
                print(f'sync rec sync_log[1:]: {sync_log[1:]}')
# cbs = [PrintValuesCallback(), SaveModelCallback(fname='best_model')]
cbs = [PrintValuesCallback()]
learner.xla_fit_one_cycle(5,lr_max=slice(2e-3))
start fit
epoch train_loss valid_loss accuracy time
0 0.230687 0.693944 0.602273 00:21
1 0.201443 0.468439 0.839489 00:03
2 0.215106 0.533784 0.754261 00:03
3 0.234269 0.617997 0.671875 00:04
4 0.256085 0.771770 0.555398 00:04
Better model found at epoch 0 with valid_loss value: 0.6939442753791809.
Better model found at epoch 1 with valid_loss value: 0.4684391915798187.
res = learner.get_preds()
print(accuracy(*res))
TensorBase(0.8426)
learner.load('best_model')
/usr/local/lib/python3.7/dist-packages/fastai/learner.py:56: UserWarning: Saved filed doesn't contain an optimizer state.
  elif with_opt: warn("Saved filed doesn't contain an optimizer state.")
<fastai.learner.Learner at 0x7f5e8fa13810>
res = learner.get_preds()
print(accuracy(*res))
TensorBase(0.8426)
learner.unfreeze()
one_param(learner.model).device
device(type='cpu')
learner.xla_fit(n_epoch=5, lr=2e-3)
start fit
epoch train_loss valid_loss accuracy time
0 0.079039 0.057682 0.981534 00:20
1 0.070004 0.602258 0.867898 00:03
2 0.078288 1.280675 0.713068 00:04
3 0.080101 1.281230 0.857955 00:04
4 0.081996 0.056090 0.984375 00:04
Better model found at epoch 0 with valid_loss value: 0.05768230929970741.
Better model found at epoch 4 with valid_loss value: 0.056089673191308975.
learner.validate()
(#2) [0.06356370449066162,0.9856938719749451]

Train using torch datasets and dataloaders

from pathlib import Path
FLAGS = {}
FLAGS['batch_size']  = 64
FLAGS['num_workers'] = 4
FLAGS['data_dir'] = Path('/content/data/cifar')
from torchvision import datasets, transforms
def get_dataset():
    norm = transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        norm,
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        norm,
    ])
    train_dataset = datasets.CIFAR10(
        root=FLAGS['data_dir'],
        train=True,
        download=True,
        transform=transform_train)
    test_dataset = datasets.CIFAR10(
        root=FLAGS['data_dir'],
        train=False,
        download=True,
        transform=transform_test)
    
    return train_dataset, test_dataset
train_dataset, test_dataset = get_dataset()
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/data/cifar/cifar-10-python.tar.gz
Extracting /content/data/cifar/cifar-10-python.tar.gz to /content/data/cifar
Files already downloaded and verified
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=FLAGS['batch_size'],
#   sampler=train_sampler,
    shuffle=True,
    num_workers=FLAGS['num_workers'],
    drop_last=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=FLAGS['batch_size'],
    shuffle=False,
    num_workers=FLAGS['num_workers'],
    drop_last=True)
# fastai dls using torch dataloaders
dls = DataLoaders(train_loader, test_loader)
learner = cnn_learner(dls, resnet18, metrics=accuracy, 
                      n_out=10, 
                      loss_func=nn.CrossEntropyLoss(),
                      concat_pool=False 
                      )

learner.xla_fit(5,lr=2e-2)
start fit
epoch train_loss valid_loss accuracy time
0 1.280729 1.212466 0.571477 01:28
1 1.139668 1.096565 0.610852 01:16
2 1.053340 1.212897 0.596054 01:17
3 1.004425 1.009118 0.655374 01:17
4 0.979024 0.906682 0.682767 01:13