Multi Core XLA Extensions for inference
Multi-core TPU implementation for inference is enabled by importing this module.
from fastai_xla_extensions.multi_core.inference import *
path = untar_data(URLs.MNIST)
# path = untar_data(URLs.PETS)/'images'
data = DataBlock(
blocks=(ImageBlock,CategoryBlock),
get_items=get_image_files,
get_y=parent_label,
splitter=GrandparentSplitter(train_name='training', valid_name='testing'),
item_tfms=Resize(28),
batch_tfms=[Normalize.from_stats(*imagenet_stats)]
)
# pat = r'(.+)_\d+.jpg$'
# data = DataBlock(
# blocks=(ImageBlock,CategoryBlock),
# get_items=get_image_files,
# get_y=using_attr(RegexLabeller(pat),'name'),
# splitter=RandomSplitter(seed=42),
# item_tfms=Resize(224),
# batch_tfms=[Normalize.from_stats(*imagenet_stats)]
# )
dls = data.dataloaders(path, bs=64)
# loss_func=nn.CrossEntropyLoss()
loss_func=CrossEntropyLossFlat()
learner = cnn_learner(dls, resnet18, metrics=accuracy, loss_func=loss_func, concat_pool=False)
# learner = cnn_learner(dls, resnet34, metrics=accuracy, loss_func=loss_func, concat_pool=False)
learner.xla_fit_one_cycle(3, lr_max=slice(3e-2))
learner.unfreeze()
learner.xla_fit_one_cycle(5,lr_max=slice(1e-6,2e-4))
# learner.validate()
%%time
res = learner.get_preds()
print(len(res))
print(res[0].shape, res[1].shape)
print(accuracy(*res))
res[1][:10]
%%time
res2 = learner.get_preds(reorder=False)
print(len(res2))
print(res2[0].shape, res2[1].shape)
res2[1][:10]
print(accuracy(*res2))
%%time
xla_res = learner.xla_get_preds(reorder=False)
print(len(xla_res))
(xla_res[0].shape, xla_res[1].shape)
xla_res[1][:10]
xla_res[0][:10]
print(accuracy(*xla_res))
%%time
xla_res2 = learner.xla_get_preds(reorder=True)
print(len(xla_res2))
(xla_res2[0].shape, xla_res2[1].shape)
xla_res2[1][:10]
xla_res2[0][:10]
print(accuracy(*xla_res2))
%cd /content
#test save preds and save targs to files per iter
xla_res3 = learner.xla_get_preds(save_preds='my_preds', save_targs='my_targs')
xla_res3
!ls -ald /content/my_preds*
!ls -ald /content/my_targs*
!rm -rf /content/my_preds*
!rm -rf /content/my_targs*
!ls -ald *