Open In Colab

Multi Core XLA Extensions for inference

Multi-core TPU implementation for inference is enabled by importing this module.

from fastai_xla_extensions.multi_core.inference import *
WARNING:root:Waiting for TPU to be start up with version pytorch-1.7...
WARNING:root:Waiting for TPU to be start up with version pytorch-1.7...
WARNING:root:TPU has started up successfully with version pytorch-1.7

Implement Multi Core TPU Inference

Learner.inner_get_preds[source]

Learner.inner_get_preds(ds_idx=1, dl=None, with_input=False, with_decoded=False, with_loss=False, act=None, inner=False, reorder=True, cbs=None, **kwargs)

setup_inference_args[source]

setup_inference_args(rank, inference_args)

save_pred_results[source]

save_pred_results(rank, results)

xla_run_inference[source]

xla_run_inference(rank, learner_args, add_args, inference_args, ctrl_args)

reload_pred_results[source]

reload_pred_results(num_files, n_samples)

Learner.pre_xla_inference[source]

Learner.pre_xla_inference()

Learner.post_xla_inference[source]

Learner.post_xla_inference(ctrl_args)

prep_inference_args[source]

prep_inference_args(**kwargs)

Learner.xla_get_preds[source]

Learner.xla_get_preds(ds_idx=1, dl=None, with_input=False, with_decoded=False, with_loss=False, act=None, inner=False, reorder=True, cbs=None, num_cores=8, start_method='fork', master_cbs=None, save_preds=None, save_targs=None, concat_dim=0)

Testout the code

path = untar_data(URLs.MNIST)
# path = untar_data(URLs.PETS)/'images'
data = DataBlock(
    blocks=(ImageBlock,CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(train_name='training', valid_name='testing'),
    item_tfms=Resize(28),
    batch_tfms=[Normalize.from_stats(*imagenet_stats)]
)
# pat = r'(.+)_\d+.jpg$'
# data = DataBlock(
#     blocks=(ImageBlock,CategoryBlock),
#     get_items=get_image_files,
#     get_y=using_attr(RegexLabeller(pat),'name'),
#     splitter=RandomSplitter(seed=42),
#     item_tfms=Resize(224),
#     batch_tfms=[Normalize.from_stats(*imagenet_stats)]
# )
dls = data.dataloaders(path, bs=64)
# loss_func=nn.CrossEntropyLoss()
loss_func=CrossEntropyLossFlat()
learner = cnn_learner(dls, resnet18, metrics=accuracy, loss_func=loss_func, concat_pool=False)
# learner = cnn_learner(dls, resnet34, metrics=accuracy, loss_func=loss_func, concat_pool=False)
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth

learner.xla_fit_one_cycle(3, lr_max=slice(3e-2))
start fit
epoch train_loss valid_loss accuracy time
0 0.561700 0.177223 0.944400 02:33
1 0.224954 0.091147 0.972200 02:23
2 0.144529 0.077342 0.976500 02:35
learner.unfreeze()
 
 
learner.xla_fit_one_cycle(5,lr_max=slice(1e-6,2e-4))
start fit
epoch train_loss valid_loss accuracy time
0 0.091381 0.066814 0.978300 02:45
1 0.055593 0.067667 0.978500 02:38
2 0.083875 0.060253 0.981000 02:37
3 0.083661 0.059680 0.979900 02:41
4 0.078803 0.058656 0.980800 02:35
 
# learner.validate()
%%time
res = learner.get_preds()
CPU times: user 35.7 s, sys: 1.77 s, total: 37.5 s
Wall time: 39.2 s
print(len(res))
print(res[0].shape, res[1].shape)
2
torch.Size([10000, 10]) torch.Size([10000])
print(accuracy(*res))
TensorBase(0.9816)
res[1][:10]
TensorCategory([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
%%time
res2 = learner.get_preds(reorder=False)
CPU times: user 35.9 s, sys: 2.02 s, total: 37.9 s
Wall time: 39.4 s
print(len(res2))
print(res2[0].shape, res2[1].shape)
2
torch.Size([10000, 10]) torch.Size([10000])
res2[1][:10]
TensorCategory([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
print(accuracy(*res2))
TensorBase(0.9816)
%%time
xla_res = learner.xla_get_preds(reorder=False)
start fit
CPU times: user 48.8 ms, sys: 117 ms, total: 166 ms
Wall time: 28.6 s
print(len(xla_res))
2
(xla_res[0].shape, xla_res[1].shape)
(torch.Size([10000, 10]), torch.Size([10000]))
xla_res[1][:10]
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
xla_res[0][:10]
tensor([[9.9997e-01, 1.2037e-06, 2.2819e-07, 9.6082e-08, 5.1155e-08, 2.5031e-07,
         2.1356e-05, 7.4675e-07, 1.8351e-07, 4.7405e-06],
        [9.9949e-01, 9.3674e-07, 9.9061e-07, 2.3419e-07, 4.0992e-07, 6.5995e-07,
         5.0583e-04, 8.3670e-08, 7.9095e-07, 2.8556e-06],
        [8.7443e-01, 1.1864e-03, 1.4634e-05, 5.5702e-05, 3.1255e-03, 3.9448e-02,
         8.0691e-02, 1.3605e-04, 2.0612e-04, 7.0657e-04],
        [9.9999e-01, 9.1822e-07, 7.4076e-08, 1.3242e-07, 3.8235e-08, 9.4999e-08,
         5.3028e-06, 8.0097e-07, 3.6222e-08, 9.3102e-07],
        [9.9996e-01, 6.0875e-07, 9.9508e-08, 1.6203e-06, 2.5617e-07, 1.3026e-06,
         2.4092e-05, 1.4711e-06, 3.8553e-06, 3.5397e-06],
        [9.9967e-01, 2.5601e-06, 7.9427e-06, 7.0779e-07, 2.9737e-05, 6.6313e-07,
         2.7755e-04, 8.1741e-07, 1.5646e-06, 5.0701e-06],
        [9.9988e-01, 4.0136e-06, 4.6879e-06, 1.0828e-06, 1.6936e-06, 2.4147e-06,
         7.2721e-05, 3.8035e-06, 1.4383e-06, 3.1155e-05],
        [9.9994e-01, 5.1368e-07, 1.0097e-06, 1.6207e-07, 4.7854e-07, 1.5919e-07,
         1.1031e-05, 8.6520e-06, 3.0814e-07, 3.5142e-05],
        [9.9932e-01, 9.1196e-06, 2.2332e-06, 5.8100e-06, 1.6907e-05, 4.3462e-05,
         5.6675e-04, 1.5917e-05, 5.3773e-06, 1.0780e-05],
        [9.9997e-01, 4.1387e-06, 1.7671e-07, 6.1776e-08, 6.7920e-08, 9.2069e-08,
         1.2195e-05, 3.7874e-07, 1.7996e-07, 8.4062e-06]])
print(accuracy(*xla_res))
TensorBase(0.9812)
%%time
xla_res2 = learner.xla_get_preds(reorder=True)
start fit
CPU times: user 70.6 ms, sys: 118 ms, total: 189 ms
Wall time: 46.3 s
print(len(xla_res2))
2
(xla_res2[0].shape, xla_res2[1].shape)
(torch.Size([10000, 10]), torch.Size([10000]))
xla_res2[1][:10]
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
xla_res2[0][:10]
tensor([[9.9997e-01, 1.2037e-06, 2.2819e-07, 9.6082e-08, 5.1155e-08, 2.5031e-07,
         2.1356e-05, 7.4675e-07, 1.8351e-07, 4.7405e-06],
        [9.9949e-01, 9.3674e-07, 9.9061e-07, 2.3419e-07, 4.0992e-07, 6.5995e-07,
         5.0583e-04, 8.3670e-08, 7.9095e-07, 2.8556e-06],
        [8.7443e-01, 1.1864e-03, 1.4634e-05, 5.5702e-05, 3.1255e-03, 3.9448e-02,
         8.0691e-02, 1.3605e-04, 2.0612e-04, 7.0657e-04],
        [9.9999e-01, 9.1822e-07, 7.4076e-08, 1.3242e-07, 3.8235e-08, 9.4999e-08,
         5.3028e-06, 8.0097e-07, 3.6222e-08, 9.3102e-07],
        [9.9996e-01, 6.0875e-07, 9.9508e-08, 1.6203e-06, 2.5617e-07, 1.3026e-06,
         2.4092e-05, 1.4711e-06, 3.8553e-06, 3.5397e-06],
        [9.9967e-01, 2.5601e-06, 7.9427e-06, 7.0779e-07, 2.9737e-05, 6.6313e-07,
         2.7755e-04, 8.1741e-07, 1.5646e-06, 5.0701e-06],
        [9.9988e-01, 4.0136e-06, 4.6879e-06, 1.0828e-06, 1.6936e-06, 2.4147e-06,
         7.2721e-05, 3.8035e-06, 1.4383e-06, 3.1155e-05],
        [9.9994e-01, 5.1368e-07, 1.0097e-06, 1.6207e-07, 4.7854e-07, 1.5919e-07,
         1.1031e-05, 8.6520e-06, 3.0814e-07, 3.5142e-05],
        [9.9932e-01, 9.1196e-06, 2.2332e-06, 5.8100e-06, 1.6907e-05, 4.3462e-05,
         5.6675e-04, 1.5917e-05, 5.3773e-06, 1.0780e-05],
        [9.9997e-01, 4.1387e-06, 1.7671e-07, 6.1776e-08, 6.7920e-08, 9.2069e-08,
         1.2195e-05, 3.7874e-07, 1.7996e-07, 8.4062e-06]])
print(accuracy(*xla_res2))
TensorBase(0.9812)
%cd /content
#test save preds and save targs to files per iter
xla_res3 = learner.xla_get_preds(save_preds='my_preds', save_targs='my_targs')
/content
start fit
xla_res3
[(#8) [None,None,None,None,None,None,None,None],
 (#8) [None,None,None,None,None,None,None,None]]
!ls -ald /content/my_preds*
drwxr-xr-x 2 root root 4096 Mar 10 18:53 /content/my_preds0
drwxr-xr-x 2 root root 4096 Mar 10 18:54 /content/my_preds1
drwxr-xr-x 2 root root 4096 Mar 10 18:54 /content/my_preds2
drwxr-xr-x 2 root root 4096 Mar 10 18:54 /content/my_preds3
drwxr-xr-x 2 root root 4096 Mar 10 18:54 /content/my_preds4
drwxr-xr-x 2 root root 4096 Mar 10 18:53 /content/my_preds5
drwxr-xr-x 2 root root 4096 Mar 10 18:54 /content/my_preds6
drwxr-xr-x 2 root root 4096 Mar 10 18:53 /content/my_preds7
!ls -ald /content/my_targs*
drwxr-xr-x 2 root root 4096 Mar 10 18:53 /content/my_targs0
drwxr-xr-x 2 root root 4096 Mar 10 18:54 /content/my_targs1
drwxr-xr-x 2 root root 4096 Mar 10 18:54 /content/my_targs2
drwxr-xr-x 2 root root 4096 Mar 10 18:54 /content/my_targs3
drwxr-xr-x 2 root root 4096 Mar 10 18:54 /content/my_targs4
drwxr-xr-x 2 root root 4096 Mar 10 18:53 /content/my_targs5
drwxr-xr-x 2 root root 4096 Mar 10 18:54 /content/my_targs6
drwxr-xr-x 2 root root 4096 Mar 10 18:53 /content/my_targs7
!rm -rf /content/my_preds*
!rm -rf /content/my_targs*
!ls -ald *
lrwxrwxrwx 1 root root   18 Mar 10 17:35 data -> /root/.fastai/data
drwx------ 5 root root 4096 Mar 10 17:32 drive
lrwxrwxrwx 1 root root   44 Mar 10 17:35 fastai_xla_extensions -> /content/drive/MyDrive/fastai_xla_extensions
lrwxrwxrwx 1 root root   19 Mar 10 17:35 models -> /root/.torch/models