Learner method patches to invoke multi-core
fit
and other operations prefixed byxla_
.These provide an alternate way to run multi core operations with minimal changes to existing fastai notebooks.
Add master_cbs property to Learner
Master callbacks are callbacks that will be executed on the master ordinal (rank 0 thread) only.
This means existing fastai notebooks must be checked if any additional callbacks used can cause conflicts if run on different threads at the same time.
Note that for default callbacks (TrainEvalCallback
, Recorder
, ProgressCallback
) only ProgressCallback
causes this problem.
However, the fastai_xla_extensions.multi_core.base
module already handles
this so that if used (which it is, by default), the ProgressCallback
is attached only on the master ordinal thread.
Moreover, the Recorder
callback is also handled such that validation losses and metrics are collated correctly by the fastai_xla_extensions.multi_core.base.SyncRecorderCallback
so that the validation metrics and losses are reported correctly at the end of each epoch.
path = untar_data(URLs.MNIST_TINY)
data = DataBlock(
blocks=(ImageBlock,CategoryBlock),
get_items=get_image_files,
get_y=parent_label,
splitter=GrandparentSplitter(),
item_tfms=Resize(28),
batch_tfms=[]
)
dls = data.dataloaders(path, bs=16)
# concat_pool must be false due to a TPU bug that is triggered if using fastai AdaptivePool
from fastai.vision.learner import cnn_learner
from torchvision.models.resnet import resnet18
learner = cnn_learner(dls, resnet18, metrics=accuracy, concat_pool=False)
learner.add_master_cbs([SaveModelCallback(fname='best_model')])
class PrintValuesCallback(Callback):
order = 56 # after recorder, sync recorder, before save model callback
def after_epoch(self):
print(f'final record: {self.learn.final_record}')
vlen = len(self.recorder.values)
print(f'values len: {vlen}')
if vlen > 0:
last_idx = self.recorder.values[-1]
len_last_idx = len(last_idx)
print(f'values last idx len: {len_last_idx}')
print(f'last idx: {last_idx}')
if 'save_model' in L(self.cbs).attrgot('name'):
save_model_idx = self.save_model.idx
print(f'save_model idx: {save_model_idx}')
if save_model_idx < len_last_idx:
val = self.recorder.values[-1][self.save_model.idx]
print(f'best_value: {val}')
if 'sync_recorder' in L(self.cbs).attrgot('name'):
sync_log = self.sync_recorder.sync_log
len_sync_log = len(sync_log)
print(f'sync rec sync_log len: {len_sync_log}')
print(f'sync rec sync_log: {sync_log}')
if len_sync_log > 0:
print(f'sync rec sync_log[1:]: {sync_log[1:]}')
# cbs = [PrintValuesCallback(), SaveModelCallback(fname='best_model')]
cbs = [PrintValuesCallback()]
learner.xla_fit_one_cycle(5,lr_max=slice(2e-3))
res = learner.get_preds()
print(accuracy(*res))
learner.load('best_model')
res = learner.get_preds()
print(accuracy(*res))
learner.unfreeze()
one_param(learner.model).device
learner.xla_fit(n_epoch=5, lr=2e-3)
learner.validate()
from pathlib import Path
FLAGS = {}
FLAGS['batch_size'] = 64
FLAGS['num_workers'] = 4
FLAGS['data_dir'] = Path('/content/data/cifar')
from torchvision import datasets, transforms
def get_dataset():
norm = transforms.Normalize(
mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
norm,
])
transform_test = transforms.Compose([
transforms.ToTensor(),
norm,
])
train_dataset = datasets.CIFAR10(
root=FLAGS['data_dir'],
train=True,
download=True,
transform=transform_train)
test_dataset = datasets.CIFAR10(
root=FLAGS['data_dir'],
train=False,
download=True,
transform=transform_test)
return train_dataset, test_dataset
train_dataset, test_dataset = get_dataset()
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=FLAGS['batch_size'],
# sampler=train_sampler,
shuffle=True,
num_workers=FLAGS['num_workers'],
drop_last=True)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=FLAGS['batch_size'],
shuffle=False,
num_workers=FLAGS['num_workers'],
drop_last=True)
# fastai dls using torch dataloaders
dls = DataLoaders(train_loader, test_loader)
learner = cnn_learner(dls, resnet18, metrics=accuracy,
n_out=10,
loss_func=nn.CrossEntropyLoss(),
concat_pool=False
)
learner.xla_fit(5,lr=2e-2)