Base module for Multi TPU Core implementation
Multi-core TPU implementation is enabled by importing this module.
from fastai_xla_extensions.multi_core.base import *
from fastai.torch_core import TensorBase, TensorImage, TensorCategory
from fastai.data.core import TfmdDL
n_batches = 10
bs = 6
world_size = 8
# setup a dataloader as base dl for tpu
items = [(TensorImage(torch.tensor(i).float()), TensorCategory(i)) for i in range(n_batches * bs * world_size)]
dl = TfmdDL(items, bs=bs, shuffle=True)
assert len(dl) == n_batches * world_size
b0 = next(iter(dl))
assert isinstance(b0[0], TensorImage)
assert isinstance(b0[1],TensorCategory)
tpu_dl = TPUDistributedDL(dl, rank=0, world_size=world_size)
# the batches for dl for each rank is divided across all ranks
assert len(tpu_dl) == n_batches
tpu_b0 = next(iter(tpu_dl))
# the types of each batch (x,y) have been reverted to torch tensors
# and are no longer Tensor subclasses (e.g. TensorBase)
assert isinstance(tpu_b0[0], torch.Tensor)
assert isinstance(tpu_b0[1], torch.Tensor)
assert not isinstance(tpu_b0[0], TensorBase)
assert not isinstance(tpu_b0[1], TensorBase)
# add tests to make sure all items are retrieved per epoch
# create dl for each rank across all ranks
tpu_dls = [TPUDistributedDL(dl, rank=rank, world_size=world_size) for rank in range(world_size)]
rank_batches = [list(tpu_dl) for tpu_dl in tpu_dls]
# TODO: check that each rank dont contain common items
# TODO: check that all items in dl are accounted for in the tpu_dls across all ranks
The XLATrainingCallback
is responsible for the following functions:
sets the
epoch
on either the torch dataloader sampler or the TPU distributed DL before each epoch. This ensures that for each epoch, samples in each batch are the same across all ranks, but each rank will pick the subset of batches for each rank.The
TPUDistributedDL
(and the torch distributed sampler) ensures that all the samples (with some duplication if the samples are not exactly divisible by the number of ranks) are seen by one of the dataloaders across the ranks least once per epoch.- wraps the dataloader (either training or validation) with the XLA Parallel Loader (
torch_xla.distributed.parallel_loader.ParallelLoader
) before each training or validation run. - sidesteps the call to
opt.step
and instead callsxm.optimizer_step(opt)
to sync the model gradients across all the ranks.
Helper Functions for SyncRecorderCallback
The Learner.save
has been patched to use the torch xla method xm.save
which will save the model weights for the model on the TPU device. Moreover, xm.save
only saves the weights on the master ordinal rank process by default, ensuring that only one copy of the model is written to a file. _Which is fine, since the xm.optimizer_step
done on each training batch synchronizes the weights across all ranks anyway._
def run_dataloader_loop(rank):
torch.manual_seed(1)
print(f'xla {rank} start run_dataloader_loop')
xm.rendezvous('start_run_dataloader_loop')
# Scale learning rate to num cores
learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()
SYNC_VALID = FLAGS['sync_valid']
IS_PROFILING = FLAGS['is_profiling']
# Get loss function, optimizer, and model
device = xm.xla_device()
model = WRAPPED_MODEL.to(device)
bs = FLAGS['batch_size']
world_size = xm.xrt_world_size()
if IS_PROFILING:
rec_name = 'rank' + str(rank) + '_dataloader_build'
print(f'start {rec_name}')
start_record(rec_name)
# dls = make_fastai_dataloaders(
# DATA,
# PATH,
# rank=rank,
# world_size=world_size,
# sync_valid=SYNC_VALID,
# bs=bs,)
dls = DATA.dataloaders(PATH, bs=bs)
# distrib_dls = build_distributed_dataloaders(dls, rank, world_size,
# sync_valid=True)
dl = dls.train
tpu_dl = TPUDistributedDL(dl,rank=rank,world_size=world_size)
print(f'xla: {rank} fake_l.num_workers {tpu_dl.fake_l.num_workers}')
do_one_loop(tpu_dl, rank, world_size, device, wrap_parallel=False)
if IS_PROFILING:
end_record(rec_name)
print_prof_data(rec_name)
print(f'finished {rec_name}')
xm.mark_step()
print(f'xla {rank} completed run_dataloader_loop')
# print_prof_data()
def train_model(rank):
torch.manual_seed(1)
xm.rendezvous('start_train_model')
print(f'xla {rank} start train model')
SYNC_VALID = FLAGS['sync_valid']
IS_PROFILING = FLAGS['is_profiling']
# Get loss function, optimizer, and model
device = xm.xla_device()
bs = FLAGS['batch_size']
world_size = xm.xrt_world_size()
if IS_PROFILING:
rec_name = 'rank' + str(rank) + '_dataloader_build'
print(f'start {rec_name}')
start_record(rec_name)
dls = make_fastai_dataloaders(
DATA,
PATH,
rank=rank,
world_size=world_size,
sync_valid=SYNC_VALID,
bs=bs,)
if IS_PROFILING:
end_record(rec_name)
print_prof_data(rec_name)
print(f'finished {rec_name}')
model = WRAPPED_MODEL.to(device)
moms =(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum'])
wd = FLAGS['weight_decay']
xm.master_print('build learner')
learner = Learner(dls, model,
loss_func=LOSS_FUNC,
opt_func=OPT_FUNC,
metrics=accuracy,
wd=wd,
moms=moms)
learner.to_multi_xla(device, rank=xm.get_ordinal(), sync_valid=SYNC_VALID)
if IS_PROFILING and rank == 0:
learner.to_my_profile()
# Scale learning rate to num cores
learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()
epochs = FLAGS['num_epochs']
xm.master_print('start running fit')
learner.unfreeze()
if IS_PROFILING:
rec_name3 = 'rank' + str(rank) + '_run_fit'
print(f'start {rec_name3}')
start_record(rec_name3)
learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10))
if IS_PROFILING:
end_record(rec_name3)
print_prof_data(rec_name3)
print(f'finished {rec_name3}')
xm.rendezvous('end_train_model')
learner.save('stage-1', rendezvous=False)
if rank == 0 and IS_PROFILING :
learner.my_profile.print_stats()
def train_mnist_model(rank):
torch.manual_seed(1)
xm.rendezvous('start_train_mnist_model')
print(f'xla {rank} start train mnist model')
SYNC_VALID = FLAGS2['sync_valid']
device = xm.xla_device()
bs = FLAGS2['batch_size']
world_size = xm.xrt_world_size()
dls = make_fastai_dataloaders(
DATA2,
PATH2,
rank=rank,
world_size=world_size,
sync_valid=SYNC_VALID,
bs=bs,)
model = WRAPPED_MODEL2.to(device)
moms =(FLAGS2['momentum'],FLAGS2['momentum'],FLAGS2['momentum'])
wd = FLAGS2['weight_decay']
xm.master_print('build learner')
learner = Learner(dls, model,
loss_func=LOSS_FUNC,
opt_func=OPT_FUNC,
metrics=accuracy,
wd=wd,
moms=moms)
learner.to_multi_xla(device, rank=xm.get_ordinal(), sync_valid=SYNC_VALID)
# Scale learning rate to num cores
learning_rate = FLAGS2['learning_rate'] * xm.xrt_world_size()
epochs = FLAGS2['num_epochs']
xm.master_print('start running fit')
learner.unfreeze()
learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10))
xm.rendezvous('end_train_mnist_model')
learner.save('mnist-stage-1', rendezvous=False)
xm.mark_step()
This is the main method that runs the training.
It includes some profiling code to measure the building of the dataloaders
and running of the fit
methods.
At the end of the spawned processes, the master ordinal process saves the model to a temporary file. (see Learner.save
patch above)
The saved model will then be loaded by the main process so that it will now contain the trained weights updated by the spawned training processes.
# Start training processes
def _mp_fn(rank, flags):
global FLAGS
FLAGS = flags
train_model(rank)
# Start dataloader processes
def _mp_fn2(rank, flags):
global FLAGS
FLAGS = flags
run_dataloader_loop(rank)
# Start training processes
def _mp_fn3(rank, flags):
global FLAGS2
FLAGS2 = flags
train_mnist_model(rank)
import torch
from fastcore.transform import DisplayedTransform, Transform
from fastcore.basics import store_attr
from fastai.vision.core import PILImage, PILBase, image2tensor
from fastai.data.block import TransformBlock
from fastai.data.transforms import get_c
# from fastai.vision.all import *
from fastai.data.block import DataBlock, CategoryBlock
from fastai.vision.data import ImageBlock
from fastai.data.transforms import get_image_files, parent_label, GrandparentSplitter
from fastai.vision.augment import Resize, aug_transforms
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import Normalize
from fastai.vision.core import imagenet_stats
from fastcore.basics import using_attr
from fastai.data.transforms import RegexLabeller, CategoryMap
import torch.nn as nn
LOSS_FUNC = nn.CrossEntropyLoss()
from fastai.optimizer import Adam
OPT_FUNC = Adam
from fastai.data.transforms import RandomSplitter
from fastai.vision.learner import create_cnn_model
from fastai.vision.models import resnet34, resnet18
import os
# Define Parameters
FLAGS = {}
# FLAGS['batch_size'] = 1024
FLAGS['sync_valid'] = True
FLAGS['is_profiling'] = True
FLAGS['batch_size'] = 64
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 1e-3
FLAGS['image_size'] = 224
FLAGS['momentum'] = 0.85
FLAGS['weight_decay'] = 2e-3
FLAGS['num_epochs'] = 5
FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1
# FLAGS['num_cores'] = 1
ARCH = resnet34
import os
# Define Parameters
FLAGS2 = {}
FLAGS2['batch_size'] = 1024
FLAGS2['sync_valid'] = True
# FLAGS2['batch_size'] = 64
FLAGS2['num_workers'] = 4
FLAGS2['learning_rate'] = 1e-3
FLAGS2['image_size'] = 28
FLAGS2['momentum'] = 0.85
FLAGS2['weight_decay'] = 2e-3
FLAGS2['num_epochs'] = 5
FLAGS2['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1
# FLAGS['num_cores'] = 1
ARCH2 = resnet18
from pathlib import Path
from fastcore.xtras import *
import torch_xla.distributed.xla_multiprocessing as xmp
PATH = untar_data(URLs.PETS)/'images'
PATH2 = untar_data(URLs.MNIST)
# PATH = untar_data(URLs.MNIST_TINY)
pat = r'(.+)_\d+.jpg$'
fname_labeller = using_attr(RegexLabeller(pat),'name')
splitter=RandomSplitter(seed=42)
DATA = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
get_y=fname_labeller,
splitter=splitter,
item_tfms=[Resize(FLAGS['image_size']),],
batch_tfms=[Normalize.from_stats(*imagenet_stats)]
)
vocab = CategoryMap(get_image_files(PATH).map(fname_labeller))
N_OUT = len(vocab)
DATA2 = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
get_y=parent_label,
splitter=GrandparentSplitter(train_name='training',valid_name='testing'),
item_tfms=[Resize(FLAGS2['image_size']),],
batch_tfms=[Normalize.from_stats(*imagenet_stats)]
)
vocab2 = CategoryMap(get_image_files(PATH2).map(parent_label))
N_OUT2 = len(vocab2)
assert N_OUT is not None and N_OUT > 0,f'N_OUT {N_OUT} should be > 0'
assert N_OUT2 is not None and N_OUT2 > 0,f'N_OUT2 {N_OUT2} should be > 0'
The model is created by the main process and wrapped by the xmp.MpModelWrapper
. This is to reduce the memory usage by not having multiple copies of the model in the spawned processes.
custom_model = create_cnn_model(ARCH, N_OUT,
pretrained=True,
concat_pool=False)
custom_model2 = create_cnn_model(ARCH2, N_OUT2,
pretrained=True,
concat_pool=False)
# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(custom_model)
WRAPPED_MODEL2 = xmp.MpModelWrapper(custom_model2)
#colab
%%time
xmp.spawn(_mp_fn2, args=(FLAGS,), nprocs=FLAGS['num_cores'],
start_method='fork')
%%time
FLAGS['is_profiling'] = False
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],
start_method='fork')
# %%time
xmp.spawn(_mp_fn3, args=(FLAGS2,), nprocs=FLAGS2['num_cores'],
start_method='fork')
mdls = DATA.dataloaders(PATH, bs=FLAGS['batch_size'])
mlearner = Learner(mdls, custom_model,
loss_func=LOSS_FUNC,
opt_func=OPT_FUNC,
metrics=accuracy,
wd=FLAGS['weight_decay'],
moms=(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum']))
# load trained weights from multi core tpu training
if Path('models/stage-1.pth').is_file():
mlearner.load('stage-1')
mlearner.dls.device
from fastai.torch_core import one_param
one_param(mlearner.model).device
%%time
valid_metrics = mlearner.validate();print(valid_metrics)
import os
from pathlib import Path
FLAGS3 = {}
FLAGS3['batch_size'] = 64
FLAGS3['num_workers'] = 4
FLAGS3['data_dir'] = Path('/content/data/cifar')
FLAGS3['sync_valid'] = True
FLAGS3['learning_rate'] = 1e-3
FLAGS3['image_size'] = 28
FLAGS3['momentum'] = 0.85
FLAGS3['weight_decay'] = 2e-3
FLAGS3['num_epochs'] = 5
FLAGS3['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1
# FLAGS['num_cores'] = 1
ARCH3 = resnet18
from torchvision import datasets, transforms
def get_dataset():
norm = transforms.Normalize(
mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
transform_train = transforms.Compose([
transforms.RandomCrop(FLAGS3['image_size'], padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
norm,
])
transform_test = transforms.Compose([
transforms.Resize((FLAGS3['image_size'],FLAGS3['image_size'])),
transforms.ToTensor(),
norm,
])
train_dataset = datasets.CIFAR10(
root=FLAGS3['data_dir'],
train=True,
download=True,
transform=transform_train)
test_dataset = datasets.CIFAR10(
root=FLAGS3['data_dir'],
train=False,
download=True,
transform=transform_test)
return train_dataset, test_dataset
train_dataset, test_dataset = get_dataset()
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=FLAGS3['batch_size'],
# sampler=train_sampler,
shuffle=True,
num_workers=FLAGS3['num_workers'],
drop_last=True)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=FLAGS3['batch_size'],
shuffle=False,
num_workers=FLAGS3['num_workers'],
drop_last=True)
# fastai dls using torch dataloaders
CIFAR_DLS = DataLoaders(train_loader, test_loader)
N_OUT3 = 10 # cifar10 has 10 labels
custom_model3 = create_cnn_model(ARCH3, N_OUT3,
pretrained=True,
concat_pool=False)
WRAPPED_MODEL3 = xmp.MpModelWrapper(custom_model3)
class PrintDevicesCallback(Callback):
order = -6 # before XLATrainingCallback
def before_train(self):
self.print_device()
def before_validate(self):
self.print_device()
def print_device(self):
if self.learn.epoch == 0:
print(f'train: {self.learn.training} xla {self.learn.xla_rank}: dl.type: {type(self.learn.dl)} dl.device {self.learn.dl.device} model.device: {one_param(self.learn.model).device}')
def train_cifar_model(rank):
torch.manual_seed(1)
xm.rendezvous('start_train_cifar_model')
print(f'xla {rank} start train cifar model')
SYNC_VALID = FLAGS3['sync_valid']
device = xm.xla_device()
bs = FLAGS3['batch_size']
world_size = xm.xrt_world_size()
dls = build_distributed_dataloaders(CIFAR_DLS,
rank,
world_size,
sync_valid=SYNC_VALID)
model = WRAPPED_MODEL3.to(device)
moms =(FLAGS3['momentum'],FLAGS3['momentum'],FLAGS3['momentum'])
wd = FLAGS3['weight_decay']
xm.master_print('build learner')
learner = Learner(dls, model,
loss_func=LOSS_FUNC,
opt_func=OPT_FUNC,
metrics=accuracy,
wd=wd,
moms=moms)
learner.to_multi_xla(device, rank=xm.get_ordinal(), sync_valid=SYNC_VALID)
# Scale learning rate to num cores
learning_rate = FLAGS3['learning_rate'] * xm.xrt_world_size()
epochs = FLAGS3['num_epochs']
xm.master_print('start running fit')
learner.unfreeze()
cbs = [PrintDevicesCallback()]
learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10), cbs=cbs)
xm.rendezvous('end_train_cifar_model')
learner.save('cifar-stage-1', rendezvous=False)
xm.mark_step()
# Start training processes
def _mp_fn4(rank, flags):
global FLAGS3
FLAGS3 = flags
train_cifar_model(rank)
# %%time
xmp.spawn(_mp_fn4, args=(FLAGS3,), nprocs=FLAGS3['num_cores'],
start_method='fork')
class NumIterCancelCallback(Callback):
order = 20
def __init__(self, num_iters=0, on_train=True, on_valid=True):
store_attr()
def before_fit(self):
if not getattr(self.learn,'inner_xla',False):
return # skip if not spawned
self.my_iter = 0
def after_batch(self):
if not getattr(self.learn,'inner_xla',False):
return # skip if not spawned
if self.learn.training and not self.on_train:
return
if not self.learn.training and not self.on_valid:
return
if self.num_iters == 0:
return
self.my_iter += 1
if self.my_iter > self.num_iters:
self.synced_cancel.trigger_cancel_fit()
# path = untar_data(URLs.MNIST_TINY)
path = untar_data(URLs.MNIST)
data = DataBlock(
blocks=(ImageBlock,CategoryBlock),
get_items=get_image_files,
get_y=parent_label,
# splitter=GrandparentSplitter(),
splitter=GrandparentSplitter(train_name='training', valid_name='testing'),
item_tfms=Resize(28),
batch_tfms=[Normalize.from_stats(*imagenet_stats)]
)
# MDLS = data.dataloaders(path, bs=8)
MDLS = data.dataloaders(path, bs=64)
ARCH = resnet18
custom_model = create_cnn_model(ARCH, n_out=MDLS.c, concat_pool=False)
WRAPPED_MODEL = xmp.MpModelWrapper(custom_model)
# print(f'xla {rank}: start train')
# xm.rendezvous('start_train_model')
# world_size = xm.xrt_world_size()
# device = xm.xla_device()
# dls = build_distributed_dataloaders(MDLS, rank, world_size, sync_valid=True)
# model = WRAPPED_MODEL.to(device)
# learner = Learner(dls, model, loss_func=nn.CrossEntropyLoss(),
# opt_func=Adam,
# metrics=accuracy)
# learner.to_multi_xla(device, rank, sync_valid=True)
# learner.fit(5, lr=2e-3)
# learner.save('stage-1', rendezvous=False)
# xm.rendezvous('end_train_model')
# print(f'xla {rank}: end train')
# xmp.spawn(train_model, args=(), nprocs=8,
# start_method='fork')
%cd /content/fastai_xla_extensions
from nbdev.export2html import notebook2html
notebook2html()