Learner method patches to invoke multi-core
fitand other operations prefixed byxla_.These provide an alternate way to run multi core operations with minimal changes to existing fastai notebooks.
Add master_cbs property to Learner
Master callbacks are callbacks that will be executed on the master ordinal (rank 0 thread) only.
This means existing fastai notebooks must be checked if any additional callbacks used can cause conflicts if run on different threads at the same time.
Note that for default callbacks (TrainEvalCallback, Recorder, ProgressCallback) only ProgressCallback causes this problem.
However, the fastai_xla_extensions.multi_core.base module already handles
this so that if used (which it is, by default), the ProgressCallback is attached only on the master ordinal thread.
Moreover, the Recorder callback is also handled such that validation losses and metrics are collated correctly by the fastai_xla_extensions.multi_core.base.SyncRecorderCallback so that the validation metrics and losses are reported correctly at the end of each epoch.
path = untar_data(URLs.MNIST_TINY)
data = DataBlock(
    blocks=(ImageBlock,CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(),
    item_tfms=Resize(28),
    batch_tfms=[]
)
dls = data.dataloaders(path, bs=16)
# concat_pool must be false due to a TPU bug that is triggered if using fastai AdaptivePool
from fastai.vision.learner import cnn_learner
from torchvision.models.resnet import resnet18
learner = cnn_learner(dls, resnet18, metrics=accuracy, concat_pool=False)
learner.add_master_cbs([SaveModelCallback(fname='best_model')])
class PrintValuesCallback(Callback):
    order = 56 # after recorder, sync recorder, before save model callback  
    def after_epoch(self):
        print(f'final record: {self.learn.final_record}')
        vlen = len(self.recorder.values)
        print(f'values len: {vlen}')
        if vlen > 0:   
            last_idx = self.recorder.values[-1]  
            len_last_idx = len(last_idx)
            print(f'values last idx len: {len_last_idx}')
            print(f'last idx: {last_idx}')
            if 'save_model' in L(self.cbs).attrgot('name'):
                save_model_idx = self.save_model.idx
                print(f'save_model idx: {save_model_idx}')     
                if save_model_idx < len_last_idx:
                    val = self.recorder.values[-1][self.save_model.idx]
                    print(f'best_value: {val}')
        if 'sync_recorder' in L(self.cbs).attrgot('name'):
            sync_log = self.sync_recorder.sync_log
            len_sync_log = len(sync_log)
            print(f'sync rec sync_log len: {len_sync_log}')
            print(f'sync rec sync_log: {sync_log}')
            if len_sync_log > 0:
                print(f'sync rec sync_log[1:]: {sync_log[1:]}')
# cbs = [PrintValuesCallback(), SaveModelCallback(fname='best_model')]
cbs = [PrintValuesCallback()]
learner.xla_fit_one_cycle(5,lr_max=slice(2e-3))
res = learner.get_preds()
print(accuracy(*res))
learner.load('best_model')
res = learner.get_preds()
print(accuracy(*res))
learner.unfreeze()
one_param(learner.model).device
learner.xla_fit(n_epoch=5, lr=2e-3)
learner.validate()
from pathlib import Path
FLAGS = {}
FLAGS['batch_size']  = 64
FLAGS['num_workers'] = 4
FLAGS['data_dir'] = Path('/content/data/cifar')
from torchvision import datasets, transforms
def get_dataset():
    norm = transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        norm,
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        norm,
    ])
    train_dataset = datasets.CIFAR10(
        root=FLAGS['data_dir'],
        train=True,
        download=True,
        transform=transform_train)
    test_dataset = datasets.CIFAR10(
        root=FLAGS['data_dir'],
        train=False,
        download=True,
        transform=transform_test)
    
    return train_dataset, test_dataset
train_dataset, test_dataset = get_dataset()
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=FLAGS['batch_size'],
#   sampler=train_sampler,
    shuffle=True,
    num_workers=FLAGS['num_workers'],
    drop_last=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=FLAGS['batch_size'],
    shuffle=False,
    num_workers=FLAGS['num_workers'],
    drop_last=True)
# fastai dls using torch dataloaders
dls = DataLoaders(train_loader, test_loader)
learner = cnn_learner(dls, resnet18, metrics=accuracy, 
                      n_out=10, 
                      loss_func=nn.CrossEntropyLoss(),
                      concat_pool=False 
                      )
learner.xla_fit(5,lr=2e-2)