Open In Colab

Base module for Multi TPU Core implementation

Multi-core TPU implementation is enabled by importing this module.

from fastai_xla_extensions.multi_core.base import *

Utility Conversion Functions

Functions for converting tensors and computing batches across ranks

revert_tensor[source]

revert_tensor(o)

Remove tensor subclass and revert to torch.Tensor

recast2tensor[source]

recast2tensor(o)

Recast fastai.torch_core.TensorBase subclassed tensors to torch.Tensors

round_to_multiple[source]

round_to_multiple(number, multiple)

round up batch samples to fill number of cores

class TPUDistributedDL[source]

TPUDistributedDL(dl, rank, world_size, seed=42) :: TfmdDL

A TfmdDL which splits a batch into equal size pieces for each TPU core It also recasts the output of a batch from a TensorBase subclass to a regular tensor since the XLA Parallel loader doesn't seem to be compatible to it. Code implementation was based on @tmabraham's TPUDistributedDL implementation here: https://github.com/tmabraham/fastai_tpu/blob/master/fastai_v2/tpu_distributed_dl.py

from fastai.torch_core import TensorBase, TensorImage, TensorCategory
from fastai.data.core import TfmdDL

n_batches = 10
bs = 6
world_size = 8
# setup a dataloader as base dl for tpu 
items = [(TensorImage(torch.tensor(i).float()), TensorCategory(i)) for i in range(n_batches * bs * world_size)]
dl = TfmdDL(items, bs=bs, shuffle=True)
assert len(dl) == n_batches * world_size
b0 = next(iter(dl))
assert isinstance(b0[0], TensorImage)
assert isinstance(b0[1],TensorCategory)
tpu_dl = TPUDistributedDL(dl, rank=0, world_size=world_size)
# the batches for dl for each rank is divided across all ranks
assert len(tpu_dl) == n_batches
tpu_b0 = next(iter(tpu_dl))
# the types of each batch (x,y) have been reverted to torch tensors
# and are no longer Tensor subclasses (e.g. TensorBase)
assert isinstance(tpu_b0[0], torch.Tensor)
assert isinstance(tpu_b0[1], torch.Tensor)
assert not isinstance(tpu_b0[0], TensorBase)
assert not isinstance(tpu_b0[1], TensorBase)
# add tests to make sure all items are retrieved per epoch
# create dl for each rank across all ranks
tpu_dls = [TPUDistributedDL(dl, rank=rank, world_size=world_size) for rank in range(world_size)]
rank_batches = [list(tpu_dl) for tpu_dl in tpu_dls]
# TODO: check that each rank dont contain common items
# TODO: check that all items in dl are accounted for in the tpu_dls across all ranks

Torch Dataloader Patches

These patches to the torch dataloader make torch dataloaders more compatible with fastai dataloaders (enabling them to run inside the fastai Learners and use fastai training calls)

DataLoader.__setattr__[source]

DataLoader.__setattr__(attr, val)

remove sampler,batch_sampler from list of attrs which should not be set after init

DataLoader.after_batch[source]

return empty pipeline when fastai learner looks for after_batch

DataLoader.bs[source]

return fastai synonym for torch batch size

DataLoader.device[source]

return null device

DataLoader.to[source]

DataLoader.to(device)

add impl for to(device)

DataLoader.set_distributed_sampler[source]

DataLoader.set_distributed_sampler(rank, world_size)

replace sampler with torch distributed sampler

Component Functions for Multi Core TPU Training

build_distributed_dataloaders[source]

build_distributed_dataloaders(dls, rank, world_size, sync_valid=False)

Wrap dataloaders with distributed TPU aware dataloader

make_fastai_dataloaders[source]

make_fastai_dataloaders(datablock, source, rank, world_size, device=None, path='.', sync_valid=False, verbose=False)

create fastai-based dataloaders from a datablock and wrap a tpu distributed dataloader around them

wrap_parallel_loader[source]

wrap_parallel_loader(loader, device)

wraps a tpu distributed loader or a torch dataloader (with distributed sampler) with xla parallel loader

class XLATrainingCallback[source]

XLATrainingCallback(device, rank=0, sync_valid=False) :: Callback

A callback for training as a spawned process on multi-core TPUs

The XLATrainingCallback is responsible for the following functions:

  • sets the epoch on either the torch dataloader sampler or the TPU distributed DL before each epoch. This ensures that for each epoch, samples in each batch are the same across all ranks, but each rank will pick the subset of batches for each rank.

    The TPUDistributedDL (and the torch distributed sampler) ensures that all the samples (with some duplication if the samples are not exactly divisible by the number of ranks) are seen by one of the dataloaders across the ranks least once per epoch.

  • wraps the dataloader (either training or validation) with the XLA Parallel Loader (torch_xla.distributed.parallel_loader.ParallelLoader) before each training or validation run.
  • sidesteps the call to opt.step and instead calls xm.optimizer_step(opt) to sync the model gradients across all the ranks.

Helper Functions for SyncRecorderCallback

pack_metric[source]

pack_metric(metrics)

extract counts and totals from avg metrics and avg losses into a list

make_tensor[source]

make_tensor(o, device)

convert a scalar or tensor into a float tensor and move them to device

pack_metrics[source]

pack_metrics(all_metrics, device)

pack train and valid metrics into a list of float tensors and move them to device

restore_metrics[source]

restore_metrics(reduced_metrics, all_metrics)

restore list of float tensors (count and values) back into train and valid metrics

class SyncedAvgSmoothLoss[source]

SyncedAvgSmoothLoss(beta=0.98) :: AvgSmoothLoss

Smooth average of the losses (exponentially weighted with beta) synced across all ranks

class SyncRecorderCallback[source]

SyncRecorderCallback(after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None) :: Callback

A Callback to sync the metrics from each rank and update statistics accordingly so it will display correctly in the progress callback

Learner Patches and Helper Functions for Multi Core TPU Extensions

xm_save[source]

xm_save(data, file_or_path, master_only=True, global_master=False, rendezvous=True)

Saves the input data into a file.

The saved data is transferred to PyTorch CPU device before being saved, so a following torch.load() will load CPU data. Care must be taken when working with views. Instead of saving views it's recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).

Args: data: The input data to be saved. Any nested combination of Python objects (list, tuples, sets, dicts, ...). file_or_path: The destination for the data saving operation. Either a file path or a Python file object. If master_only is False the path or file objects must point to different destinations as otherwise all the writes from the same host will override each other. master_only (bool, optional): Whether only the master device should save the data. If False, the file_or_path argument should be a different file or path for each of the ordinals taking part to the replication, otherwise all the replicas on the same host will be writing to the same location. Default: True global_master (bool, optional): When master_only is True this flag controls whether every host's master (if global_master is False) saves the content, or only the global master (ordinal 0). Default: False

SaveModelCallback._save[source]

SaveModelCallback._save(name)

save best model using rendezvous=False

Learner.save[source]

Learner.save(file, with_opt=True, pickle_protocol=2)

The Learner.save has been patched to use the torch xla method xm.save which will save the model weights for the model on the TPU device. Moreover, xm.save only saves the weights on the master ordinal rank process by default, ensuring that only one copy of the model is written to a file. _Which is fine, since the xm.optimizer_step done on each training batch synchronizes the weights across all ranks anyway._

Learner.to_multi_xla[source]

Learner.to_multi_xla(device, rank, sync_valid=False)

Sets up the learner on the spawned process for multi core TPU training

do_one_loop[source]

do_one_loop(dl, rank, world_size, device, wrap_parallel=True)

test one loop for a tpu distributed dataloader

Test out the code

def run_dataloader_loop(rank):
    torch.manual_seed(1)
    print(f'xla {rank} start run_dataloader_loop')
    xm.rendezvous('start_run_dataloader_loop')
    # Scale learning rate to num cores
    learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()
    SYNC_VALID = FLAGS['sync_valid']
    IS_PROFILING = FLAGS['is_profiling']
    # Get loss function, optimizer, and model
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    bs = FLAGS['batch_size']
    world_size = xm.xrt_world_size()
    if IS_PROFILING:
        rec_name = 'rank' + str(rank) + '_dataloader_build'
        print(f'start {rec_name}')
        start_record(rec_name)

    # dls = make_fastai_dataloaders(
    #                         DATA, 
    #                         PATH, 
    #                         rank=rank, 
    #                         world_size=world_size, 
    #                         sync_valid=SYNC_VALID,
    #                         bs=bs,)
    dls = DATA.dataloaders(PATH, bs=bs)
    # distrib_dls = build_distributed_dataloaders(dls, rank, world_size, 
    #                                            sync_valid=True)
    dl = dls.train
    tpu_dl = TPUDistributedDL(dl,rank=rank,world_size=world_size)
    print(f'xla: {rank} fake_l.num_workers {tpu_dl.fake_l.num_workers}')
    do_one_loop(tpu_dl, rank, world_size, device, wrap_parallel=False)
    if IS_PROFILING:
        end_record(rec_name)
        print_prof_data(rec_name)
        print(f'finished {rec_name}')

    xm.mark_step()
    print(f'xla {rank} completed run_dataloader_loop')
    # print_prof_data()
def train_model(rank):
    torch.manual_seed(1)
    xm.rendezvous('start_train_model')
    print(f'xla {rank} start train model')

    
    SYNC_VALID = FLAGS['sync_valid']
    IS_PROFILING = FLAGS['is_profiling']
    # Get loss function, optimizer, and model
    device = xm.xla_device()

    bs = FLAGS['batch_size']
    world_size = xm.xrt_world_size()
    if IS_PROFILING:
        rec_name = 'rank' + str(rank) + '_dataloader_build'
        print(f'start {rec_name}')
        start_record(rec_name)

    dls = make_fastai_dataloaders(
                            DATA, 
                            PATH, 
                            rank=rank, 
                            world_size=world_size, 
                            sync_valid=SYNC_VALID,
                            bs=bs,)
    if IS_PROFILING:
        end_record(rec_name)
        print_prof_data(rec_name)
        print(f'finished {rec_name}')
    model = WRAPPED_MODEL.to(device)
    moms =(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum'])
    wd = FLAGS['weight_decay']

    xm.master_print('build learner')
    learner = Learner(dls, model, 
                      loss_func=LOSS_FUNC, 
                      opt_func=OPT_FUNC, 
                      metrics=accuracy, 
                      wd=wd,
                      moms=moms)
                      
    learner.to_multi_xla(device, rank=xm.get_ordinal(), sync_valid=SYNC_VALID)
    if IS_PROFILING and rank == 0:
        learner.to_my_profile()

    # Scale learning rate to num cores
    learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()
                               
    epochs = FLAGS['num_epochs']
    xm.master_print('start running fit')
    learner.unfreeze()
    if IS_PROFILING:
        rec_name3 = 'rank' + str(rank) + '_run_fit'
        print(f'start {rec_name3}')
        start_record(rec_name3)

    learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10))
    if IS_PROFILING:
        end_record(rec_name3)
        print_prof_data(rec_name3)
        print(f'finished {rec_name3}')
    xm.rendezvous('end_train_model')
    learner.save('stage-1', rendezvous=False)
    if rank == 0 and IS_PROFILING :
        learner.my_profile.print_stats()

    
def train_mnist_model(rank):
    torch.manual_seed(1)
    xm.rendezvous('start_train_mnist_model')
    print(f'xla {rank} start train mnist model')    
    SYNC_VALID = FLAGS2['sync_valid']
    device = xm.xla_device()

    bs = FLAGS2['batch_size']
    world_size = xm.xrt_world_size()

    dls = make_fastai_dataloaders(
                            DATA2, 
                            PATH2, 
                            rank=rank, 
                            world_size=world_size, 
                            sync_valid=SYNC_VALID,
                            bs=bs,)
    model = WRAPPED_MODEL2.to(device)
    moms =(FLAGS2['momentum'],FLAGS2['momentum'],FLAGS2['momentum'])
    wd = FLAGS2['weight_decay']

    xm.master_print('build learner')
    learner = Learner(dls, model, 
                      loss_func=LOSS_FUNC, 
                      opt_func=OPT_FUNC, 
                      metrics=accuracy, 
                      wd=wd,
                      moms=moms)
                      
    learner.to_multi_xla(device, rank=xm.get_ordinal(), sync_valid=SYNC_VALID)
    # Scale learning rate to num cores
    learning_rate = FLAGS2['learning_rate'] * xm.xrt_world_size()
                               
    epochs = FLAGS2['num_epochs']
    xm.master_print('start running fit')
    learner.unfreeze()

    learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10))
    xm.rendezvous('end_train_mnist_model')
    learner.save('mnist-stage-1', rendezvous=False)
    xm.mark_step()  
    

This is the main method that runs the training.

It includes some profiling code to measure the building of the dataloaders and running of the fit methods.

At the end of the spawned processes, the master ordinal process saves the model to a temporary file. (see Learner.save patch above)

The saved model will then be loaded by the main process so that it will now contain the trained weights updated by the spawned training processes.

# Start training processes
def _mp_fn(rank, flags):
    global FLAGS
    FLAGS = flags
    train_model(rank)
# Start dataloader processes
def _mp_fn2(rank, flags):
    global FLAGS
    FLAGS = flags
    run_dataloader_loop(rank)
# Start training processes
def _mp_fn3(rank, flags):
    global FLAGS2
    FLAGS2 = flags
    train_mnist_model(rank)
import torch
from fastcore.transform import DisplayedTransform, Transform
from fastcore.basics import store_attr
from fastai.vision.core import PILImage, PILBase, image2tensor
from fastai.data.block import TransformBlock
from fastai.data.transforms import get_c
# from fastai.vision.all import *
from fastai.data.block import DataBlock, CategoryBlock
from fastai.vision.data import ImageBlock
from fastai.data.transforms import get_image_files, parent_label, GrandparentSplitter
from fastai.vision.augment import Resize, aug_transforms
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import Normalize
from fastai.vision.core import imagenet_stats
from fastcore.basics import using_attr
from fastai.data.transforms import RegexLabeller, CategoryMap
import torch.nn as nn
LOSS_FUNC = nn.CrossEntropyLoss()
from fastai.optimizer import Adam
OPT_FUNC = Adam
from fastai.data.transforms import RandomSplitter
from fastai.vision.learner import create_cnn_model
from fastai.vision.models import resnet34, resnet18
import os
# Define Parameters
FLAGS = {}
# FLAGS['batch_size'] = 1024
FLAGS['sync_valid'] = True
FLAGS['is_profiling'] = True
FLAGS['batch_size'] = 64
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 1e-3
FLAGS['image_size'] = 224
FLAGS['momentum'] = 0.85
FLAGS['weight_decay'] = 2e-3
FLAGS['num_epochs'] = 5
FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1
# FLAGS['num_cores'] = 1 
ARCH = resnet34
import os
# Define Parameters
FLAGS2 = {}
FLAGS2['batch_size'] = 1024
FLAGS2['sync_valid'] = True
# FLAGS2['batch_size'] = 64
FLAGS2['num_workers'] = 4
FLAGS2['learning_rate'] = 1e-3
FLAGS2['image_size'] = 28
FLAGS2['momentum'] = 0.85
FLAGS2['weight_decay'] = 2e-3
FLAGS2['num_epochs'] = 5
FLAGS2['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1
# FLAGS['num_cores'] = 1 
ARCH2 = resnet18
from pathlib import Path
from fastcore.xtras import *
import torch_xla.distributed.xla_multiprocessing as xmp
PATH = untar_data(URLs.PETS)/'images'
PATH2 = untar_data(URLs.MNIST)
# PATH = untar_data(URLs.MNIST_TINY)
pat = r'(.+)_\d+.jpg$'
fname_labeller = using_attr(RegexLabeller(pat),'name') 
splitter=RandomSplitter(seed=42)
DATA = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    get_y=fname_labeller,
    splitter=splitter,
    item_tfms=[Resize(FLAGS['image_size']),],
    batch_tfms=[Normalize.from_stats(*imagenet_stats)]
)
vocab = CategoryMap(get_image_files(PATH).map(fname_labeller))
N_OUT = len(vocab)
DATA2 = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(train_name='training',valid_name='testing'),
    item_tfms=[Resize(FLAGS2['image_size']),],
    batch_tfms=[Normalize.from_stats(*imagenet_stats)]
)
vocab2 = CategoryMap(get_image_files(PATH2).map(parent_label))
N_OUT2 = len(vocab2)
assert N_OUT is not None and N_OUT > 0,f'N_OUT {N_OUT} should be > 0'
assert N_OUT2 is not None and N_OUT2 > 0,f'N_OUT2 {N_OUT2} should be > 0'

The model is created by the main process and wrapped by the xmp.MpModelWrapper. This is to reduce the memory usage by not having multiple copies of the model in the spawned processes.

custom_model = create_cnn_model(ARCH, N_OUT, 
                                pretrained=True,
                                concat_pool=False)
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth

custom_model2 = create_cnn_model(ARCH2, N_OUT2, 
                                pretrained=True,
                                concat_pool=False)
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth

# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(custom_model)
WRAPPED_MODEL2 = xmp.MpModelWrapper(custom_model2)
#colab
%%time
xmp.spawn(_mp_fn2, args=(FLAGS,), nprocs=FLAGS['num_cores'],
        start_method='fork')
%%time
FLAGS['is_profiling'] = False
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],
        start_method='fork')
xla 1 start train model
xla 6 start train model
xla 0 start train model
xla 4 start train model
xla 2 start train model
xla 7 start train model
xla 3 start train model
xla 5 start train model
build learner
start running fit
start fit
epoch train_loss valid_loss accuracy time
0 0.762775 1.240811 0.677703 01:49
1 0.637322 1.133922 0.727027 01:26
2 0.584464 0.496379 0.867568 01:24
3 0.511225 0.345086 0.900000 01:32
4 0.437666 0.297154 0.924324 01:28
CPU times: user 128 ms, sys: 135 ms, total: 263 ms
Wall time: 8min 10s
# %%time
xmp.spawn(_mp_fn3, args=(FLAGS2,), nprocs=FLAGS2['num_cores'],
        start_method='fork')
xla 6 start train mnist model
xla 0 start train mnist model
xla 7 start train mnist model
xla 2 start train mnist model
xla 1 start train mnist model
xla 4 start train mnist model
xla 3 start train mnist model
xla 5 start train mnist model
build learner
start running fit
start fit
epoch train_loss valid_loss accuracy time
0 0.464617 1.163002 0.603000 01:56
1 0.314348 0.510301 0.880600 01:44
2 0.272418 0.127753 0.970900 01:48
3 0.242397 0.058145 0.985600 01:53
4 0.215875 0.044948 0.987800 01:46
mdls = DATA.dataloaders(PATH, bs=FLAGS['batch_size'])
mlearner = Learner(mdls, custom_model, 
                    loss_func=LOSS_FUNC, 
                    opt_func=OPT_FUNC, 
                    metrics=accuracy, 
                    wd=FLAGS['weight_decay'],
                    moms=(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum']))
# load trained weights from multi core tpu training
if Path('models/stage-1.pth').is_file():
    mlearner.load('stage-1')
mlearner.dls.device
device(type='cpu')
from fastai.torch_core import one_param
one_param(mlearner.model).device
device(type='cpu')
%%time
valid_metrics = mlearner.validate();print(valid_metrics)
[0.3011414706707001, 0.918809175491333]
CPU times: user 3min 28s, sys: 2.9 s, total: 3min 31s
Wall time: 3min 33s
import os
from pathlib import Path
FLAGS3 = {}
FLAGS3['batch_size']  = 64
FLAGS3['num_workers'] = 4
FLAGS3['data_dir'] = Path('/content/data/cifar')
FLAGS3['sync_valid'] = True
FLAGS3['learning_rate'] = 1e-3
FLAGS3['image_size'] = 28
FLAGS3['momentum'] = 0.85
FLAGS3['weight_decay'] = 2e-3
FLAGS3['num_epochs'] = 5
FLAGS3['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1
# FLAGS['num_cores'] = 1 
ARCH3 = resnet18
from torchvision import datasets, transforms
def get_dataset():
    norm = transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
    transform_train = transforms.Compose([
        transforms.RandomCrop(FLAGS3['image_size'], padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        norm,
    ])
    transform_test = transforms.Compose([
        transforms.Resize((FLAGS3['image_size'],FLAGS3['image_size'])),                                 
        transforms.ToTensor(),
        norm,
    ])
    train_dataset = datasets.CIFAR10(
        root=FLAGS3['data_dir'],
        train=True,
        download=True,
        transform=transform_train)
    test_dataset = datasets.CIFAR10(
        root=FLAGS3['data_dir'],
        train=False,
        download=True,
        transform=transform_test)
    
    return train_dataset, test_dataset
train_dataset, test_dataset = get_dataset()
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/data/cifar/cifar-10-python.tar.gz
Extracting /content/data/cifar/cifar-10-python.tar.gz to /content/data/cifar
Files already downloaded and verified
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=FLAGS3['batch_size'],
#   sampler=train_sampler,
    shuffle=True,
    num_workers=FLAGS3['num_workers'],
    drop_last=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=FLAGS3['batch_size'],
    shuffle=False,
    num_workers=FLAGS3['num_workers'],
    drop_last=True)
# fastai dls using torch dataloaders
CIFAR_DLS = DataLoaders(train_loader, test_loader)
N_OUT3 = 10 # cifar10 has 10 labels
custom_model3 = create_cnn_model(ARCH3, N_OUT3, 
                                pretrained=True,
                                concat_pool=False)
WRAPPED_MODEL3 = xmp.MpModelWrapper(custom_model3)
class PrintDevicesCallback(Callback):
    order = -6 # before XLATrainingCallback
    def before_train(self):
        self.print_device()
    def before_validate(self):
        self.print_device()
    def print_device(self):
        if self.learn.epoch == 0:
            print(f'train: {self.learn.training} xla {self.learn.xla_rank}: dl.type: {type(self.learn.dl)} dl.device {self.learn.dl.device} model.device: {one_param(self.learn.model).device}')
def train_cifar_model(rank):
    torch.manual_seed(1)
    xm.rendezvous('start_train_cifar_model')
    print(f'xla {rank} start train cifar model')    
    SYNC_VALID = FLAGS3['sync_valid']
    device = xm.xla_device()

    bs = FLAGS3['batch_size']
    world_size = xm.xrt_world_size()

    dls = build_distributed_dataloaders(CIFAR_DLS, 
                                        rank, 
                                        world_size, 
                                        sync_valid=SYNC_VALID) 
    model = WRAPPED_MODEL3.to(device)
    moms =(FLAGS3['momentum'],FLAGS3['momentum'],FLAGS3['momentum'])
    wd = FLAGS3['weight_decay']

    xm.master_print('build learner')
    learner = Learner(dls, model, 
                      loss_func=LOSS_FUNC, 
                      opt_func=OPT_FUNC, 
                      metrics=accuracy, 
                      wd=wd,
                      moms=moms)
                      
    learner.to_multi_xla(device, rank=xm.get_ordinal(), sync_valid=SYNC_VALID)
    # Scale learning rate to num cores
    learning_rate = FLAGS3['learning_rate'] * xm.xrt_world_size()
                               
    epochs = FLAGS3['num_epochs']
    xm.master_print('start running fit')
    learner.unfreeze()
    cbs = [PrintDevicesCallback()]
    learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10), cbs=cbs)
    xm.rendezvous('end_train_cifar_model')
    learner.save('cifar-stage-1', rendezvous=False)
    xm.mark_step()  
    
# Start training processes
def _mp_fn4(rank, flags):
    global FLAGS3
    FLAGS3 = flags
    train_cifar_model(rank)
# %%time
xmp.spawn(_mp_fn4, args=(FLAGS3,), nprocs=FLAGS3['num_cores'],
        start_method='fork')
xla 3 start train cifar model
xla 1 start train cifar model
xla 5 start train cifar model
xla 4 start train cifar model
xla 6 start train cifar model
xla 2 start train cifar model
xla 0 start train cifar model
xla 7 start train cifar model
train: True xla 3: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0
train: True xla 5: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0
train: True xla 4: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0
train: True xla 1: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0
train: True xla 2: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0
build learner
start running fit
start fit
epoch train_loss valid_loss accuracy time
0 1.841237 1.288982 0.562738 01:41
1 1.276613 0.989299 0.660244 01:31
2 1.007693 0.830739 0.713980 01:40
3 0.858846 0.755246 0.740147 01:33
4 0.775464 0.729593 0.749387 01:28
train: True xla 0: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:1
train: True xla 6: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0
train: True xla 7: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0
train: False xla 0: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:1
train: False xla 1: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0
train: False xla 4: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0
train: False xla 5: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0
train: False xla 2: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0
train: False xla 3: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0
train: False xla 7: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0
train: False xla 6: dl.type: <class 'torch.utils.data.dataloader.DataLoader'> dl.device None model.device: xla:0




















 
class NumIterCancelCallback(Callback):
    order = 20
    def __init__(self, num_iters=0, on_train=True, on_valid=True):
        store_attr()

    def before_fit(self):

        if not getattr(self.learn,'inner_xla',False):
            return # skip if not spawned

        self.my_iter = 0

    def after_batch(self):

        if not getattr(self.learn,'inner_xla',False):
            return # skip if not spawned

        if self.learn.training and not self.on_train:
            return
    
        if not self.learn.training and not self.on_valid:
            return

        if self.num_iters == 0:
            return

        self.my_iter += 1
        if self.my_iter > self.num_iters:
            self.synced_cancel.trigger_cancel_fit()
 
# path = untar_data(URLs.MNIST_TINY)
path = untar_data(URLs.MNIST)
data = DataBlock(
    blocks=(ImageBlock,CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    # splitter=GrandparentSplitter(),
    splitter=GrandparentSplitter(train_name='training', valid_name='testing'),
    item_tfms=Resize(28),
    batch_tfms=[Normalize.from_stats(*imagenet_stats)]
)
# MDLS = data.dataloaders(path, bs=8)
MDLS = data.dataloaders(path, bs=64)
ARCH = resnet18
custom_model = create_cnn_model(ARCH, n_out=MDLS.c, concat_pool=False)
WRAPPED_MODEL = xmp.MpModelWrapper(custom_model)
#     print(f'xla {rank}: start train')
#     xm.rendezvous('start_train_model')
#     world_size = xm.xrt_world_size()
#     device = xm.xla_device()
#     dls = build_distributed_dataloaders(MDLS, rank, world_size, sync_valid=True)
#     model = WRAPPED_MODEL.to(device)
#     learner = Learner(dls, model, loss_func=nn.CrossEntropyLoss(), 
#                       opt_func=Adam,
#                       metrics=accuracy)
#     learner.to_multi_xla(device, rank, sync_valid=True)
#     learner.fit(5, lr=2e-3)
#     learner.save('stage-1', rendezvous=False)
#     xm.rendezvous('end_train_model')
#     print(f'xla {rank}: end train')
# xmp.spawn(train_model, args=(), nprocs=8,
#         start_method='fork')
%cd /content/fastai_xla_extensions
from nbdev.export2html import notebook2html
notebook2html()
/content/drive/MyDrive/fastai_xla_extensions
converting: /content/drive/MyDrive/fastai_xla_extensions/nbs/03_multi_core.base.ipynb