* adding a measure of uncertainty to predictions
from fastai.test_utils import synth_dbunch, synth_learner
try:
from contextlib import nullcontext # python 3.7 only
except ImportError as e:
from contextlib import suppress as nullcontext # supported in 3.6 below
dls = synth_dbunch()
dls.vocab = [1,]
learner = synth_learner(data=dls)
learner.no_bar = nullcontext
bears_dl = dls.train
pets_dl = dls.valid
N_SAMPLE = 2
CATEGORIES = 1
BS = 160
from fastai.learner import load_learner
from fastai.data.transforms import get_image_files
from fastai.data.external import Config
from fastai.vision.core import PILImage
import random
# setup objects using local paths
cfg = Config()
learner = load_learner(cfg.model_path/'bears_classifier'/'export.pkl')
bear_path = cfg.data_path/'bears'
pet_path = cfg.data_path/'pets'
bear_img_files = get_image_files(bear_path)
pet_img_files = get_image_files(pet_path)
random.seed(69420) # fix images retrieved
pet_img = PILImage.create(pet_img_files.shuffle()[0])
bear_img = PILImage.create(bear_img_files.shuffle()[0])
pet_items = pet_img_files.shuffle()[:20]
bear_items = bear_img_files.shuffle()[:20]
pet_dset = pet_items.map(lambda o: PILImage.create(o))
bear_dset = bear_items.map(lambda o: PILImage.create(o))
pets_dl = learner.dls.test_dl(pet_dset,num_workers=0)
bears_dl = learner.dls.test_dl(bear_dset,num_workers=0)
# xb.shape = torch.size([20,3,224,224])
N_SAMPLE = 2
CATEGORIES = 3
BS = 20
from fastcore.test import *
bear_res = learner.bayes_get_preds(dl=bears_dl, n_sample=N_SAMPLE)
pet_res = learner.bayes_get_preds(dl=pets_dl, n_sample=N_SAMPLE)
test_eq(len(bear_res),6)
# ci 6
# local 6
test_eq(bear_res[0].shape, [N_SAMPLE,BS,CATEGORIES])
#ci torch.Size([2, 160, 1])
#local torch.Size([5, 20, 3])
test_eq(bear_res[1].shape, [BS, CATEGORIES])
#ci torch.Size([160, 1])
#local torch.Size([20, 3])
test_eq(bear_res[2].shape,[BS])
#ci torch.Size([160])
#local torch.Size([20])
test_eq(bear_res[3].shape,[BS])
# ci torch.Size([160])
# local torch.Size([20])
test_eq(bear_res[4].shape,[BS])
#ci torch.Size([160])
#local torch.Size([20])
test_eq(len(bear_res[5]),BS)
# ci 160
# local 20