Torch Dataset and Dataloader compatible classes and functions for multi-core TPU training
from fastcore.test import test_eq
def neg_tfm(o): return -o
def double_tfm(o): return 2*o
items = list(range(10))
ds1 = TfmdTorchDS(items, x_tfm=neg_tfm, y_tfm=double_tfm)
test_eq(ds1[5],(-5,10))
import torchvision as thv
pil2tensor = thv.transforms.ToTensor()
resize28 = thv.transforms.Resize(28)
norm = thv.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
from fastai.vision.core import PILImage
from fastai.data.transforms import get_image_files, GrandparentSplitter, parent_label
from fastai.data.external import untar_data, URLs
path = untar_data(URLs.MNIST_TINY)
mnist_dset_builder = TorchDatasetBuilder(
source=path,
get_items=get_image_files,
splitter=GrandparentSplitter(),
x_tfms=[resize28,pil2tensor,norm,],
y_tfms=[parent_label,VocabularyMapper(),],
x_type_tfms=PILImage.create)
from fastcore.test import test_eq
train_ds, test_ds = mnist_dset_builder.get_datasets(do_setup=True)
test_eq(len(train_ds),709)
test_eq(len(test_ds),699)
test_eq(mnist_dset_builder.y_tfms[1].vocab, ('3','7'))
test_eq(mnist_dset_builder.y_tfms[1].c, 2)
test_eq(train_ds[0][1],mnist_dset_builder.y_tfms[1](parent_label(train_ds.items[0])))
test_eq(train_ds[0][0],norm(pil2tensor(resize28(PILImage.create(train_ds.items[0])))))
# import torch.utils.data as th_data
# from fastcore.basics import patch_to
# @patch_to(th_data.DataLoader)
# def to(self, device):
# "move torch dataloader to device (for compatibility with fastai dataloader)"
# self.device = device
from fastai.vision.all import *
# from fastai_xla_extensions.multi_core.base import *
# from fastai_xla_extensions.misc_utils import * # patch _BaseOptimizer.__get_state__ and __setstate__
from my_timesaver_utils.profiling import *
from my_timesaver_utils.profiling_callback import *
from fastai.learner import Learner
from fastai.metrics import accuracy
def train_torch_model(rank):
torch.manual_seed(1)
xm.rendezvous('start_train_torch_model')
# Scale learning rate to num cores
learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()
IS_PROFILING = FLAGS['is_profiling']
SYNC_VALID = FLAGS['sync_valid']
# Get loss function, optimizer, and model
device = xm.xla_device()
model = WRAPPED_MODEL.to(device)
bs = FLAGS['batch_size']
world_size = xm.xrt_world_size()
moms =(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum'])
wd = FLAGS['weight_decay']
num_workers = FLAGS['num_workers']
if IS_PROFILING:
rec_name = 'rank' + str(rank) + '_dset_build'
print(f'start {rec_name}')
start_record(rec_name)
dsets = DSET_BUILDER.get_datasets()
if IS_PROFILING:
end_record(rec_name)
print_prof_data(rec_name)
print(f'finished {rec_name}')
if IS_PROFILING:
rec_name2 = 'rank' + str(rank) + '_dataloader_build'
print(f'start {rec_name2}')
start_record(rec_name2)
dls = make_torch_dataloaders(*dsets,
rank=rank,
world_size=world_size,
bs=bs,
num_workers=num_workers,
sync_valid=SYNC_VALID,
)
if IS_PROFILING:
end_record(rec_name2)
print_prof_data(rec_name2)
print(f'finished {rec_name2}')
xm.master_print('build learner')
learner = Learner(dls, model,
loss_func=LOSS_FUNC,
opt_func=OPT_FUNC,
metrics=accuracy,
wd=wd,
moms=moms
)
learner.to_multi_xla(device, rank=xm.get_ordinal(), sync_valid=SYNC_VALID)
if rank == 0 and IS_PROFILING:
learner.to_my_profile()
epochs = FLAGS['num_epochs']
xm.master_print('start running fit')
learner.unfreeze()
if IS_PROFILING:
rec_name3 = 'rank' + str(rank) + '_run_fit'
print(f'start {rec_name3}')
start_record(rec_name3)
learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10))
if IS_PROFILING:
end_record(rec_name3)
print_prof_data(rec_name3)
print(f'finished {rec_name3}')
learner.save('stage-1')
if rank == 0 and IS_PROFILING:
learner.my_profile.print_stats()
xm.mark_step()
# Start training processes
def _mp_fn2(rank, flags):
global FLAGS
FLAGS = flags
train_torch_model(rank)
import torch
from fastcore.transform import DisplayedTransform, Transform
from fastcore.basics import store_attr
from fastai.vision.core import PILImage, PILBase, image2tensor
from fastai.data.block import TransformBlock
from fastai.data.transforms import get_c
# from fastai.vision.all import *
from fastai.data.block import DataBlock, CategoryBlock
from fastai.vision.data import ImageBlock
from fastai.data.transforms import get_image_files, parent_label, GrandparentSplitter
from fastai.vision.augment import Resize, aug_transforms
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import Normalize
from fastai.vision.core import imagenet_stats
import torch.nn as nn
LOSS_FUNC = nn.CrossEntropyLoss()
from fastai.optimizer import Adam
OPT_FUNC = Adam
from fastai.data.transforms import RandomSplitter
from fastai.vision.learner import create_cnn_model
from fastai.vision.models import resnet34
import os
# Define Parameters
FLAGS = {}
# FLAGS['batch_size'] = 1024
FLAGS['sync_valid'] = True
FLAGS['is_profiling'] = True
FLAGS['batch_size'] = 64
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 1e-3
FLAGS['image_size'] = 224
FLAGS['momentum'] = 0.85
FLAGS['weight_decay'] = 2e-3
FLAGS['num_epochs'] = 5
FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1
# FLAGS['num_cores'] = 1
ARCH = resnet34
USE_DBLOCK = False
from pathlib import Path
from fastcore.xtras import *
PATH = untar_data(URLs.PETS)/'images'
# PATH = untar_data(URLs.MNIST)
# PATH = untar_data(URLs.MNIST_TINY)
imagenet_norm = thv.transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
cifar_norm = thv.transforms.Normalize(
mean=(0.4914, 0.4822, 0.4465),
std=(0.2023, 0.1994, 0.2010))
image_size = FLAGS['image_size']
splitter = RandomSplitter(seed=42)
pat = r'(.+)_\d+.jpg$'
fname_labeller = FileNamePatternLabeller(pat)
DSET_BUILDER = TorchDatasetBuilder(
PATH,
get_items=get_image_files,
splitter=splitter,
x_tfms=[thv.transforms.Resize((image_size,image_size)), thv.transforms.ToTensor(), imagenet_norm],
y_tfms=[fname_labeller, VocabularyMapper(),],
x_type_tfms=PILImage.create,
)
start_record('master_vocab_setup')
DSET_BUILDER.setup(get_image_files(PATH),do_setup=True)
end_record('master_vocab_setup')
print_prof_data('master_vocab_setup')
clear_prof_data()
N_OUT = DSET_BUILDER.y_tfms[1].c
assert N_OUT is not None and N_OUT > 0,f'N_OUT {N_OUT} should be > 0'
custom_model = create_cnn_model(ARCH, N_OUT,
pretrained=True,
concat_pool=False)
# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(custom_model)
%%time
FLAGS['is_profiling'] = False
# !rm -f /content/models/stage-1.pth
xmp.spawn(_mp_fn2, args=(FLAGS,), nprocs=FLAGS['num_cores'],
start_method='fork')
mdsets = DSET_BUILDER.get_datasets()
mdls = make_torch_dataloaders(*mdsets,
rank=0,
world_size=1,
bs=FLAGS['batch_size'],
num_workers=FLAGS['num_workers']
)
mlearner = Learner(mdls, custom_model,
loss_func=LOSS_FUNC,
opt_func=OPT_FUNC,
metrics=accuracy,
wd=FLAGS['weight_decay'],
moms=(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum']))
mlearner.load('stage-1');
mlearner.dls.device
from fastai.torch_core import one_param
one_param(mlearner.model).device
%%time
valid_metrics = mlearner.validate();print(valid_metrics)