Torch Dataset and Dataloader compatible classes and functions for multi-core TPU training

Open In Colab

class TfmdTorchDS[source]

TfmdTorchDS(*args, **kwds) :: Dataset

A torch dataset compatible holder for items with x and y transforms

from fastcore.test import test_eq
def neg_tfm(o): return -o
def double_tfm(o): return 2*o
items = list(range(10))
ds1 = TfmdTorchDS(items, x_tfm=neg_tfm, y_tfm=double_tfm)
test_eq(ds1[5],(-5,10))

to_list[source]

to_list(o)

return item o as a list (unchanged if o is already a list and empty list if o is None)

has_setup[source]

has_setup(tfms)

returns last index if at least 1 tfm in tfms has a method setup else return -1

run_setups[source]

run_setups(tfms, items)

run tfm setups including tfm for all items

class TorchDatasetBuilder[source]

TorchDatasetBuilder(source, get_items, splitter, x_tfms, y_tfms, x_type_tfms=None, x_train_tfms=None, x_test_tfms=None, do_setup=False)

build torch compatible train and test datasets with transforms

class VocabularyMapper[source]

VocabularyMapper(vocab=None)

A simplified version of the fastai Categorize Transform

import torchvision as thv

pil2tensor = thv.transforms.ToTensor()
resize28 = thv.transforms.Resize(28)
norm = thv.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))

from fastai.vision.core import PILImage
from fastai.data.transforms import get_image_files, GrandparentSplitter, parent_label
from fastai.data.external import untar_data, URLs

path = untar_data(URLs.MNIST_TINY)
mnist_dset_builder =  TorchDatasetBuilder(
                source=path, 
                get_items=get_image_files, 
                splitter=GrandparentSplitter(),
                x_tfms=[resize28,pil2tensor,norm,], 
                y_tfms=[parent_label,VocabularyMapper(),],
                x_type_tfms=PILImage.create)

from fastcore.test import test_eq

train_ds, test_ds = mnist_dset_builder.get_datasets(do_setup=True)

test_eq(len(train_ds),709)
test_eq(len(test_ds),699)
test_eq(mnist_dset_builder.y_tfms[1].vocab, ('3','7'))
test_eq(mnist_dset_builder.y_tfms[1].c, 2)
test_eq(train_ds[0][1],mnist_dset_builder.y_tfms[1](parent_label(train_ds.items[0])))
test_eq(train_ds[0][0],norm(pil2tensor(resize28(PILImage.create(train_ds.items[0])))))
# import torch.utils.data as th_data
# from fastcore.basics import patch_to
# @patch_to(th_data.DataLoader)
# def to(self, device):
#     "move torch dataloader to device (for compatibility with fastai dataloader)"
#     self.device = device

make_torch_dataloaders[source]

make_torch_dataloaders(train_dataset, test_dataset, rank, world_size, bs, num_workers=4, distrib=True, sync_valid=False)

make torch-based distributed dataloaders from torch compatible datasets

class FileNamePatternLabeller[source]

FileNamePatternLabeller(pat_str, match=False)

Delayed action version of fastai RegexLabeller with file name selection

Test Model Training using Torch Dataloaders

from fastai.vision.all import *
# from fastai_xla_extensions.multi_core.base import *
# from fastai_xla_extensions.misc_utils import * # patch _BaseOptimizer.__get_state__ and __setstate__
from my_timesaver_utils.profiling import *
from my_timesaver_utils.profiling_callback import *
from fastai.learner import Learner
from fastai.metrics import accuracy

def train_torch_model(rank):
    torch.manual_seed(1)
    xm.rendezvous('start_train_torch_model')
    # Scale learning rate to num cores
    learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()
    IS_PROFILING = FLAGS['is_profiling']
    SYNC_VALID = FLAGS['sync_valid']

    # Get loss function, optimizer, and model
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    bs = FLAGS['batch_size']
    world_size = xm.xrt_world_size()
    moms =(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum'])
    wd = FLAGS['weight_decay']
    num_workers = FLAGS['num_workers']

    if IS_PROFILING:
        rec_name = 'rank' + str(rank) + '_dset_build'
        print(f'start {rec_name}')
        start_record(rec_name)
    dsets = DSET_BUILDER.get_datasets()
    if IS_PROFILING:
        end_record(rec_name)
        print_prof_data(rec_name)
        print(f'finished {rec_name}')

    if IS_PROFILING:
        rec_name2 = 'rank' + str(rank) + '_dataloader_build'
        print(f'start {rec_name2}')
        start_record(rec_name2)
    dls = make_torch_dataloaders(*dsets, 
                                  rank=rank, 
                                  world_size=world_size, 
                                  bs=bs,
                                  num_workers=num_workers,
                                  sync_valid=SYNC_VALID,
                                 )

    if IS_PROFILING:
        end_record(rec_name2)
        print_prof_data(rec_name2)
        print(f'finished {rec_name2}')

    xm.master_print('build learner')
    learner = Learner(dls, model, 
                      loss_func=LOSS_FUNC, 
                      opt_func=OPT_FUNC, 
                      metrics=accuracy, 
                      wd=wd,
                      moms=moms
                      )
                      
    learner.to_multi_xla(device, rank=xm.get_ordinal(), sync_valid=SYNC_VALID)
    if rank == 0 and IS_PROFILING:
        learner.to_my_profile()
                               
    epochs = FLAGS['num_epochs']
    xm.master_print('start running fit')
    learner.unfreeze()

    if IS_PROFILING:
        rec_name3 = 'rank' + str(rank) + '_run_fit'
        print(f'start {rec_name3}')
        start_record(rec_name3)
    learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10))

    if IS_PROFILING:
        end_record(rec_name3)
        print_prof_data(rec_name3)
        print(f'finished {rec_name3}')

    learner.save('stage-1')
    if rank == 0 and IS_PROFILING:
        learner.my_profile.print_stats()
    xm.mark_step() 
# Start training processes
def _mp_fn2(rank, flags):
    global FLAGS
    FLAGS = flags
    train_torch_model(rank)
import torch
from fastcore.transform import DisplayedTransform, Transform
from fastcore.basics import store_attr
from fastai.vision.core import PILImage, PILBase, image2tensor
from fastai.data.block import TransformBlock
from fastai.data.transforms import get_c
# from fastai.vision.all import *
from fastai.data.block import DataBlock, CategoryBlock
from fastai.vision.data import ImageBlock
from fastai.data.transforms import get_image_files, parent_label, GrandparentSplitter
from fastai.vision.augment import Resize, aug_transforms
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import Normalize
from fastai.vision.core import imagenet_stats
import torch.nn as nn
LOSS_FUNC = nn.CrossEntropyLoss()
from fastai.optimizer import Adam
OPT_FUNC = Adam
from fastai.data.transforms import RandomSplitter
from fastai.vision.learner import create_cnn_model
from fastai.vision.models import resnet34
import os
# Define Parameters
FLAGS = {}
# FLAGS['batch_size'] = 1024
FLAGS['sync_valid'] = True
FLAGS['is_profiling'] = True
FLAGS['batch_size'] = 64
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 1e-3
FLAGS['image_size'] = 224
FLAGS['momentum'] = 0.85
FLAGS['weight_decay'] = 2e-3
FLAGS['num_epochs'] = 5
FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1

# FLAGS['num_cores'] = 1 
ARCH = resnet34
USE_DBLOCK = False
from pathlib import Path
from fastcore.xtras import *
PATH = untar_data(URLs.PETS)/'images'
# PATH = untar_data(URLs.MNIST)
# PATH = untar_data(URLs.MNIST_TINY)
imagenet_norm = thv.transforms.Normalize(
    mean=(0.485, 0.456, 0.406), 
    std=(0.229, 0.224, 0.225))

cifar_norm = thv.transforms.Normalize(
    mean=(0.4914, 0.4822, 0.4465), 
    std=(0.2023, 0.1994, 0.2010))

image_size = FLAGS['image_size']
splitter = RandomSplitter(seed=42)
pat = r'(.+)_\d+.jpg$'
fname_labeller = FileNamePatternLabeller(pat)

DSET_BUILDER = TorchDatasetBuilder(
    PATH, 
    get_items=get_image_files,
    splitter=splitter,
    x_tfms=[thv.transforms.Resize((image_size,image_size)), thv.transforms.ToTensor(), imagenet_norm],
    y_tfms=[fname_labeller, VocabularyMapper(),],
    x_type_tfms=PILImage.create,
) 
start_record('master_vocab_setup')
DSET_BUILDER.setup(get_image_files(PATH),do_setup=True)
end_record('master_vocab_setup')
print_prof_data('master_vocab_setup')
clear_prof_data()
N_OUT = DSET_BUILDER.y_tfms[1].c     
Function master_vocab_setup called 1 times.
Execution time max: 0.060, average: 0.060
assert N_OUT is not None and N_OUT > 0,f'N_OUT {N_OUT} should be > 0'
custom_model = create_cnn_model(ARCH, N_OUT, 
                                pretrained=True,
                                concat_pool=False)
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth

# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(custom_model)
%%time
FLAGS['is_profiling'] = False
# !rm -f /content/models/stage-1.pth
xmp.spawn(_mp_fn2, args=(FLAGS,), nprocs=FLAGS['num_cores'],
        start_method='fork')
build learner
start running fit
start fit
epoch train_loss valid_loss accuracy time
0 0.805420 2.441340 0.434570 01:43
1 0.683048 1.726209 0.626953 01:21
2 0.603717 0.506599 0.843750 01:16
3 0.516238 0.362987 0.888672 01:14
4 0.431123 0.281401 0.904297 01:15
CPU times: user 103 ms, sys: 117 ms, total: 220 ms
Wall time: 7min 27s
mdsets = DSET_BUILDER.get_datasets()
mdls = make_torch_dataloaders(*mdsets,
                                rank=0,
                                world_size=1,
                                bs=FLAGS['batch_size'],
                                num_workers=FLAGS['num_workers']
                                )
mlearner = Learner(mdls, custom_model, 
                    loss_func=LOSS_FUNC, 
                    opt_func=OPT_FUNC, 
                    metrics=accuracy, 
                    wd=FLAGS['weight_decay'],
                    moms=(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum']))
mlearner.load('stage-1');
mlearner.dls.device
from fastai.torch_core import one_param
one_param(mlearner.model).device
device(type='cpu')
%%time
valid_metrics = mlearner.validate();print(valid_metrics)
[0.27262669801712036, 0.91236412525177]
CPU times: user 3min 26s, sys: 2.92 s, total: 3min 28s
Wall time: 3min 32s