resnet

torchgeo resnet model adapter utils

from torchgeo.datamodules import EuroSATDataModule
from torchgeo.datasets import EuroSAT100

source

make_resnet_model

 make_resnet_model (model:torch.nn.modules.module.Module, n_out:int)

Creates a ResNet model by cutting the fully connected (fc) layer of a pretrained ResNet model and replacing it with a new head.

The new head is created by concatenating adaptive pooling layers and a linear layer followed by an activation function. The new head is then appended to the cut model

Parameters

  • model (torch.nn.Module): A pretrained ResNet model.
  • n_out (int): The number of output classes.

Returns

  • torch.nn.Module: The ResNet model with the new head.
Type Details
model Module pretrained torchgeo model
n_out int number of outputs
Returns Module new model with a new head for finetuning

source

resnet_split

 resnet_split (m:torch.nn.modules.module.Module)

Splits the resnet model parameters into parameter groups

Used by fastai for discriminative learning rates (finetuning)

Parameters

  • m (nn.Module): Model

Returns

  • [torch.nn.Module] : A list of parameter groups
Type Details
m Module A model
Returns [<class ‘torch.nn.modules.module.Module’>] A list of parameter groups

Adapting a pretrained resnet torchgeo model for a fastai Learner

from torchgeo.models import ResNet18_Weights, resnet18
from torchgeo.datamodules import EuroSATDataModule
pretrained = resnet18(ResNet18_Weights.SENTINEL2_ALL_MOCO, num_classes=10) # load pretrained weights
model = make_resnet_model(pretrained, n_out=10)
dblock = fv.DataBlock(blocks=(GeoImageBlock(), fv.CategoryBlock()),
                      get_items=fv.get_image_files,
                      splitter=fv.RandomSplitter(valid_pct=0.1, seed=42),
                      get_y=fv.parent_label,
                      item_tfms=fv.Resize(64),
                      batch_tfms=[fv.Normalize.from_stats(EuroSATDataModule.mean, EuroSATDataModule.std)],
                     )
sat_path = fv.untar_data(EuroSAT100.url)
dls = dblock.dataloaders(sat_path, bs=64)
model = make_resnet_model(pretrained, n_out=10)
batch_size=64
num_workers = fv.defaults.cpus
# datamodule = EuroSATDataModule(root=sat_path,batch_size=batch_size, num_workers=num_workers, download=True)
# %%time
# datamodule.prepare_data()
learn = fv.Learner(
    dls, 
    model,
    loss_func=fv.CrossEntropyLossFlat(),
    metrics=[fv.error_rate,fv.accuracy],
    splitter=resnet_split,
)
# freeze uses parameter groups created by `resnet_split` 
# to lock parameters of pretrained model except for the model head

learn.freeze()
# note: only head parameter group is trainable (except BatchNorm layers w/ch are always trainable)
learn.summary()
Sequential (Input shape: 64 x 13 x 64 x 64)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     64 x 64 x 32 x 32   
Conv2d                                    40768      False     
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     64 x 64 x 16 x 16   
MaxPool2d                                                      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     64 x 128 x 8 x 8    
Conv2d                                    73728      False     
BatchNorm2d                               256        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    8192       False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
____________________________________________________________________________
                     64 x 256 x 4 x 4    
Conv2d                                    294912     False     
BatchNorm2d                               512        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    32768      False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
____________________________________________________________________________
                     64 x 512 x 2 x 2    
Conv2d                                    1179648    False     
BatchNorm2d                               1024       True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    131072     False     
BatchNorm2d                               1024       True      
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
____________________________________________________________________________
                     64 x 512 x 1 x 1    
AdaptiveAvgPool2d                                              
AdaptiveMaxPool2d                                              
____________________________________________________________________________
                     64 x 1024           
Flatten                                                        
BatchNorm1d                               2048       True      
Dropout                                                        
____________________________________________________________________________
                     64 x 512            
Linear                                    524288     True      
ReLU                                                           
BatchNorm1d                               1024       True      
Dropout                                                        
____________________________________________________________________________
                     64 x 10             
Linear                                    5120       True      
____________________________________________________________________________

Total params: 11,740,352
Total trainable params: 542,080
Total non-trainable params: 11,198,272

Optimizer used: <function Adam>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #2

Callbacks:
  - TrainEvalCallback
  - CastToTensor
  - Recorder
  - ProgressCallback
learn.fine_tune(2)
epoch train_loss valid_loss error_rate accuracy time
0 4.042061 2.311349 0.800000 0.200000 00:01
epoch train_loss valid_loss error_rate accuracy time
0 3.859293 2.302754 1.000000 0.000000 00:02
1 3.652791 2.247182 0.900000 0.100000 00:02
# unlock all weights and make the whole model trainable
learn.unfreeze()
# all parameters are now trainable
learn.summary()
Sequential (Input shape: 64 x 13 x 64 x 64)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     64 x 64 x 32 x 32   
Conv2d                                    40768      True      
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     64 x 64 x 16 x 16   
MaxPool2d                                                      
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     64 x 128 x 8 x 8    
Conv2d                                    73728      True      
BatchNorm2d                               256        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    147456     True      
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    8192       True      
BatchNorm2d                               256        True      
Conv2d                                    147456     True      
BatchNorm2d                               256        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    147456     True      
BatchNorm2d                               256        True      
ReLU                                                           
____________________________________________________________________________
                     64 x 256 x 4 x 4    
Conv2d                                    294912     True      
BatchNorm2d                               512        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    589824     True      
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    32768      True      
BatchNorm2d                               512        True      
Conv2d                                    589824     True      
BatchNorm2d                               512        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    589824     True      
BatchNorm2d                               512        True      
ReLU                                                           
____________________________________________________________________________
                     64 x 512 x 2 x 2    
Conv2d                                    1179648    True      
BatchNorm2d                               1024       True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    2359296    True      
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    131072     True      
BatchNorm2d                               1024       True      
Conv2d                                    2359296    True      
BatchNorm2d                               1024       True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    2359296    True      
BatchNorm2d                               1024       True      
ReLU                                                           
____________________________________________________________________________
                     64 x 512 x 1 x 1    
AdaptiveAvgPool2d                                              
AdaptiveMaxPool2d                                              
____________________________________________________________________________
                     64 x 1024           
Flatten                                                        
BatchNorm1d                               2048       True      
Dropout                                                        
____________________________________________________________________________
                     64 x 512            
Linear                                    524288     True      
ReLU                                                           
BatchNorm1d                               1024       True      
Dropout                                                        
____________________________________________________________________________
                     64 x 10             
Linear                                    5120       True      
____________________________________________________________________________

Total params: 11,740,352
Total trainable params: 11,740,352
Total non-trainable params: 0

Optimizer used: <function Adam>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model unfrozen

Callbacks:
  - TrainEvalCallback
  - CastToTensor
  - Recorder
  - ProgressCallback
# uses discriminative learning rates across parameter groups to give 
# the "upper lavers" higher learning rates while keeping the "lower layers"
# to a lower learning rates, nearly freezing their weights.
learn.fit_one_cycle(5, lr_max=slice(2.e-3,8.e-6))
epoch train_loss valid_loss error_rate accuracy time
0 2.139503 2.269313 0.900000 0.100000 00:02
1 2.047539 2.380141 1.000000 0.000000 00:03
2 2.059563 2.353606 0.900000 0.100000 00:02
3 2.036056 2.318051 0.900000 0.100000 00:02
4 1.905594 2.308228 0.900000 0.100000 00:03