from torchgeo.datamodules import EuroSATDataModule
from torchgeo.datasets import EuroSAT100resnet
torchgeo resnet model adapter utils
make_resnet_model
make_resnet_model (model:torch.nn.modules.module.Module, n_out:int)
Creates a ResNet model by cutting the fully connected (fc) layer of a pretrained ResNet model and replacing it with a new head.
The new head is created by concatenating adaptive pooling layers and a linear layer followed by an activation function. The new head is then appended to the cut model
Parameters
model(torch.nn.Module): A pretrained ResNet model.n_out(int): The number of output classes.
Returns
torch.nn.Module: The ResNet model with the new head.
| Type | Details | |
|---|---|---|
| model | Module | pretrained torchgeo model |
| n_out | int | number of outputs |
| Returns | Module | new model with a new head for finetuning |
resnet_split
resnet_split (m:torch.nn.modules.module.Module)
Splits the resnet model parameters into parameter groups
Used by fastai for discriminative learning rates (finetuning)
Parameters
m(nn.Module): Model
Returns
[torch.nn.Module]: A list of parameter groups
| Type | Details | |
|---|---|---|
| m | Module | A model |
| Returns | [<class ‘torch.nn.modules.module.Module’>] | A list of parameter groups |
Adapting a pretrained resnet torchgeo model for a fastai Learner
from torchgeo.models import ResNet18_Weights, resnet18
from torchgeo.datamodules import EuroSATDataModulepretrained = resnet18(ResNet18_Weights.SENTINEL2_ALL_MOCO, num_classes=10) # load pretrained weightsmodel = make_resnet_model(pretrained, n_out=10)dblock = fv.DataBlock(blocks=(GeoImageBlock(), fv.CategoryBlock()),
get_items=fv.get_image_files,
splitter=fv.RandomSplitter(valid_pct=0.1, seed=42),
get_y=fv.parent_label,
item_tfms=fv.Resize(64),
batch_tfms=[fv.Normalize.from_stats(EuroSATDataModule.mean, EuroSATDataModule.std)],
)sat_path = fv.untar_data(EuroSAT100.url)dls = dblock.dataloaders(sat_path, bs=64)model = make_resnet_model(pretrained, n_out=10)batch_size=64
num_workers = fv.defaults.cpus# datamodule = EuroSATDataModule(root=sat_path,batch_size=batch_size, num_workers=num_workers, download=True)# %%time
# datamodule.prepare_data()learn = fv.Learner(
dls,
model,
loss_func=fv.CrossEntropyLossFlat(),
metrics=[fv.error_rate,fv.accuracy],
splitter=resnet_split,
)
# freeze uses parameter groups created by `resnet_split`
# to lock parameters of pretrained model except for the model head
learn.freeze()# note: only head parameter group is trainable (except BatchNorm layers w/ch are always trainable)
learn.summary()Sequential (Input shape: 64 x 13 x 64 x 64)
============================================================================
Layer (type) Output Shape Param # Trainable
============================================================================
64 x 64 x 32 x 32
Conv2d 40768 False
BatchNorm2d 128 True
ReLU
____________________________________________________________________________
64 x 64 x 16 x 16
MaxPool2d
Conv2d 36864 False
BatchNorm2d 128 True
Identity
ReLU
Identity
Conv2d 36864 False
BatchNorm2d 128 True
ReLU
Conv2d 36864 False
BatchNorm2d 128 True
Identity
ReLU
Identity
Conv2d 36864 False
BatchNorm2d 128 True
ReLU
____________________________________________________________________________
64 x 128 x 8 x 8
Conv2d 73728 False
BatchNorm2d 256 True
Identity
ReLU
Identity
Conv2d 147456 False
BatchNorm2d 256 True
ReLU
Conv2d 8192 False
BatchNorm2d 256 True
Conv2d 147456 False
BatchNorm2d 256 True
Identity
ReLU
Identity
Conv2d 147456 False
BatchNorm2d 256 True
ReLU
____________________________________________________________________________
64 x 256 x 4 x 4
Conv2d 294912 False
BatchNorm2d 512 True
Identity
ReLU
Identity
Conv2d 589824 False
BatchNorm2d 512 True
ReLU
Conv2d 32768 False
BatchNorm2d 512 True
Conv2d 589824 False
BatchNorm2d 512 True
Identity
ReLU
Identity
Conv2d 589824 False
BatchNorm2d 512 True
ReLU
____________________________________________________________________________
64 x 512 x 2 x 2
Conv2d 1179648 False
BatchNorm2d 1024 True
Identity
ReLU
Identity
Conv2d 2359296 False
BatchNorm2d 1024 True
ReLU
Conv2d 131072 False
BatchNorm2d 1024 True
Conv2d 2359296 False
BatchNorm2d 1024 True
Identity
ReLU
Identity
Conv2d 2359296 False
BatchNorm2d 1024 True
ReLU
____________________________________________________________________________
64 x 512 x 1 x 1
AdaptiveAvgPool2d
AdaptiveMaxPool2d
____________________________________________________________________________
64 x 1024
Flatten
BatchNorm1d 2048 True
Dropout
____________________________________________________________________________
64 x 512
Linear 524288 True
ReLU
BatchNorm1d 1024 True
Dropout
____________________________________________________________________________
64 x 10
Linear 5120 True
____________________________________________________________________________
Total params: 11,740,352
Total trainable params: 542,080
Total non-trainable params: 11,198,272
Optimizer used: <function Adam>
Loss function: FlattenedLoss of CrossEntropyLoss()
Model frozen up to parameter group #2
Callbacks:
- TrainEvalCallback
- CastToTensor
- Recorder
- ProgressCallback
learn.fine_tune(2)| epoch | train_loss | valid_loss | error_rate | accuracy | time |
|---|---|---|---|---|---|
| 0 | 4.042061 | 2.311349 | 0.800000 | 0.200000 | 00:01 |
| epoch | train_loss | valid_loss | error_rate | accuracy | time |
|---|---|---|---|---|---|
| 0 | 3.859293 | 2.302754 | 1.000000 | 0.000000 | 00:02 |
| 1 | 3.652791 | 2.247182 | 0.900000 | 0.100000 | 00:02 |
# unlock all weights and make the whole model trainable
learn.unfreeze()# all parameters are now trainable
learn.summary()Sequential (Input shape: 64 x 13 x 64 x 64)
============================================================================
Layer (type) Output Shape Param # Trainable
============================================================================
64 x 64 x 32 x 32
Conv2d 40768 True
BatchNorm2d 128 True
ReLU
____________________________________________________________________________
64 x 64 x 16 x 16
MaxPool2d
Conv2d 36864 True
BatchNorm2d 128 True
Identity
ReLU
Identity
Conv2d 36864 True
BatchNorm2d 128 True
ReLU
Conv2d 36864 True
BatchNorm2d 128 True
Identity
ReLU
Identity
Conv2d 36864 True
BatchNorm2d 128 True
ReLU
____________________________________________________________________________
64 x 128 x 8 x 8
Conv2d 73728 True
BatchNorm2d 256 True
Identity
ReLU
Identity
Conv2d 147456 True
BatchNorm2d 256 True
ReLU
Conv2d 8192 True
BatchNorm2d 256 True
Conv2d 147456 True
BatchNorm2d 256 True
Identity
ReLU
Identity
Conv2d 147456 True
BatchNorm2d 256 True
ReLU
____________________________________________________________________________
64 x 256 x 4 x 4
Conv2d 294912 True
BatchNorm2d 512 True
Identity
ReLU
Identity
Conv2d 589824 True
BatchNorm2d 512 True
ReLU
Conv2d 32768 True
BatchNorm2d 512 True
Conv2d 589824 True
BatchNorm2d 512 True
Identity
ReLU
Identity
Conv2d 589824 True
BatchNorm2d 512 True
ReLU
____________________________________________________________________________
64 x 512 x 2 x 2
Conv2d 1179648 True
BatchNorm2d 1024 True
Identity
ReLU
Identity
Conv2d 2359296 True
BatchNorm2d 1024 True
ReLU
Conv2d 131072 True
BatchNorm2d 1024 True
Conv2d 2359296 True
BatchNorm2d 1024 True
Identity
ReLU
Identity
Conv2d 2359296 True
BatchNorm2d 1024 True
ReLU
____________________________________________________________________________
64 x 512 x 1 x 1
AdaptiveAvgPool2d
AdaptiveMaxPool2d
____________________________________________________________________________
64 x 1024
Flatten
BatchNorm1d 2048 True
Dropout
____________________________________________________________________________
64 x 512
Linear 524288 True
ReLU
BatchNorm1d 1024 True
Dropout
____________________________________________________________________________
64 x 10
Linear 5120 True
____________________________________________________________________________
Total params: 11,740,352
Total trainable params: 11,740,352
Total non-trainable params: 0
Optimizer used: <function Adam>
Loss function: FlattenedLoss of CrossEntropyLoss()
Model unfrozen
Callbacks:
- TrainEvalCallback
- CastToTensor
- Recorder
- ProgressCallback
# uses discriminative learning rates across parameter groups to give
# the "upper lavers" higher learning rates while keeping the "lower layers"
# to a lower learning rates, nearly freezing their weights.
learn.fit_one_cycle(5, lr_max=slice(2.e-3,8.e-6))| epoch | train_loss | valid_loss | error_rate | accuracy | time |
|---|---|---|---|---|---|
| 0 | 2.139503 | 2.269313 | 0.900000 | 0.100000 | 00:02 |
| 1 | 2.047539 | 2.380141 | 1.000000 | 0.000000 | 00:03 |
| 2 | 2.059563 | 2.353606 | 0.900000 | 0.100000 | 00:02 |
| 3 | 2.036056 | 2.318051 | 0.900000 | 0.100000 | 00:02 |
| 4 | 1.905594 | 2.308228 | 0.900000 | 0.100000 | 00:03 |