# Author: lukemelas (github username)
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
# With adjustments and added comments by workingcoder (github username).
import collections
import math
import re
from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import model_zoo
from backbone import MammothBackbone
url_map = {
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
}
[docs]
def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, verbose=True):
"""Loads pretrained weights from weights path or download using url.
Args:
model (Module): The whole model of efficientnet.
model_name (str): Model name of efficientnet.
weights_path (None or str):
str: path to pretrained weights file on the local disk.
None: use pretrained weights downloaded from the Internet.
load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
advprop (bool): Whether to load pretrained weights
trained with advprop (valid when weights_path is None).
"""
if isinstance(weights_path, str):
state_dict = torch.load(weights_path)
else:
# AutoAugment or Advprop (different preprocessing)
url_map_ = url_map
state_dict = model_zoo.load_url(url_map_[model_name])
if load_fc:
ret = model.load_state_dict(state_dict, strict=False)
assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
else:
state_dict.pop('_fc.weight')
state_dict.pop('_fc.bias')
ret = model.load_state_dict(state_dict, strict=False) # TODO fix _fc is now classifier
assert set(ret.missing_keys) == set(
['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
if verbose:
print('Loaded pretrained weights for {}'.format(model_name))
_DEFAULT_BLOCKS_ARGS = [
'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
'r1_k3_s11_e6_i192_o320_se0.25',
]
[docs]
def get_width_and_height_from_size(x):
"""Obtain height and width from x.
Args:
x (int, tuple or list): Data size.
Returns:
size: A tuple or list (H,W).
"""
if isinstance(x, int):
return x, x
if isinstance(x, (list, tuple)):
return x
else:
raise TypeError()
[docs]
def calculate_output_image_size(input_image_size, stride):
"""Calculates the output image size when using Conv2dSamePadding with a stride.
Necessary for static padding. Thanks to mannatsingh for pointing this out.
Args:
input_image_size (int, tuple or list): Size of input image.
stride (int, tuple or list): Conv2d operation's stride.
Returns:
output_image_size: A list [H,W].
"""
if input_image_size is None:
return None
image_height, image_width = get_width_and_height_from_size(input_image_size)
stride = stride if isinstance(stride, int) else stride[0]
image_height = int(math.ceil(image_height / stride))
image_width = int(math.ceil(image_width / stride))
return [image_height, image_width]
[docs]
def drop_connect(inputs, p, training):
"""Drop connect.
Args:
input (tensor: BCWH): Input of this structure.
p (float: 0.0~1.0): Probability of drop connection.
training (bool): The running mode.
Returns:
output: Output after drop connection.
"""
assert 0 <= p <= 1, 'p must be in range of [0,1]'
if not training:
return inputs
batch_size = inputs.shape[0]
keep_prob = 1 - p
# generate binary_tensor mask according to probability (p for 0, 1-p for 1)
random_tensor = keep_prob
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
binary_tensor = torch.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
return output
[docs]
def round_repeats(repeats, global_params):
"""Calculate module's repeat number of a block based on depth multiplier.
Use depth_coefficient of global_params.
Args:
repeats (int): num_repeat to be calculated.
global_params (namedtuple): Global params of the model.
Returns:
new repeat: New repeat number after calculating.
"""
multiplier = global_params.depth_coefficient
if not multiplier:
return repeats
# follow the formula transferred from official TensorFlow implementation
return int(math.ceil(multiplier * repeats))
[docs]
def round_filters(filters, global_params):
"""Calculate and round number of filters based on width multiplier.
Use width_coefficient, depth_divisor and min_depth of global_params.
Args:
filters (int): Filters number to be calculated.
global_params (namedtuple): Global params of the model.
Returns:
new_filters: New filters number after calculating.
"""
multiplier = global_params.width_coefficient
if not multiplier:
return filters
divisor = global_params.depth_divisor
min_depth = global_params.min_depth
filters *= multiplier
min_depth = min_depth or divisor # pay attention to this line when using min_depth
# follow the formula transferred from official TensorFlow implementation
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
new_filters += divisor
return int(new_filters)
[docs]
class Conv2dDynamicSamePadding(nn.Conv2d):
"""2D Convolutions like TensorFlow, for a dynamic image size.
The padding is operated in forward function by calculating dynamically.
"""
# Tips for 'SAME' mode padding.
# Given the following:
# i: width or height
# s: stride
# k: kernel size
# d: dilation
# p: padding
# Output after Conv2d:
# o = floor((i+p-((k-1)*d+1))/s+1)
# If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
# => p = (i-1)*s+((k-1)*d+1)-i
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
[docs]
def forward(self, x):
ih, iw = x.size()[-2:]
kh, kw = self.weight.size()[-2:]
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs]
class Conv2dStaticSamePadding(nn.Conv2d):
"""2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
The padding mudule is calculated in construction function, then used in forward.
"""
# With the same calculation as Conv2dDynamicSamePadding
def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
# Calculate padding based on image size and save it
assert image_size is not None
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
kh, kw = self.weight.size()[-2:]
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2,
pad_h // 2, pad_h - pad_h // 2))
else:
self.static_padding = nn.Identity()
[docs]
def forward(self, x):
x = self.static_padding(x)
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
[docs]
def get_same_padding_conv2d(image_size=None):
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
Static padding is necessary for ONNX exporting of models.
Args:
image_size (int or tuple): Size of the image.
Returns:
Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
"""
if image_size is None:
return Conv2dDynamicSamePadding
else:
return partial(Conv2dStaticSamePadding, image_size=image_size)
GlobalParams = collections.namedtuple('GlobalParams', [
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'data_format',
'num_classes', 'width_coefficient', 'depth_coefficient', 'depth_divisor',
'min_depth', 'survival_prob', 'relu_fn', 'batch_norm', 'use_se',
'local_pooling', 'condconv_num_experts', 'clip_projection_output',
'blocks_args', 'image_size', 'drop_connect_rate', 'include_top'
])
BlockArgs = collections.namedtuple('BlockArgs', [
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
'expand_ratio', 'id_skip', 'strides', 'se_ratio', 'conv_type', 'fused_conv',
'super_pixel', 'condconv', 'stride'
])
# Set GlobalParams and BlockArgs's defaults
GlobalParams.__new__.__defaults__ = (None,) * (len(GlobalParams._fields) - 1) + (True,) # include top
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
# Swish activation function
if hasattr(nn, 'SiLU'):
Swish = nn.SiLU
else:
# For compatibility with old PyTorch versions
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
# A memory-efficient implementation of Swish function
[docs]
class SwishImplementation(torch.autograd.Function):
[docs]
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result
[docs]
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_tensors[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
[docs]
class MemoryEfficientSwish(nn.Module):
[docs]
def forward(self, x):
return SwishImplementation.apply(x)
[docs]
class BlockDecoder(object):
"""Block Decoder for readability,
straight from the official TensorFlow repository.
"""
@staticmethod
def _decode_block_string(block_string):
"""Get a block through a string notation of arguments.
Args:
block_string (str): A string notation of arguments.
Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
Returns:
BlockArgs: The namedtuple defined at the top of this file.
"""
assert isinstance(block_string, str)
ops = block_string.split('_')
options = {}
for op in ops:
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# Check stride
assert (('s' in options and len(options['s']) == 1) or
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
return BlockArgs(
num_repeat=int(options['r']),
kernel_size=int(options['k']),
stride=[int(options['s'][0])],
expand_ratio=int(options['e']),
input_filters=int(options['i']),
output_filters=int(options['o']),
se_ratio=float(options['se']) if 'se' in options else None,
id_skip=('noskip' not in block_string))
@staticmethod
def _encode_block_string(block):
"""Encode a block to a string.
Args:
block (namedtuple): A BlockArgs type argument.
Returns:
block_string: A String form of BlockArgs.
"""
args = [
'r%d' % block.num_repeat,
'k%d' % block.kernel_size,
's%d%d' % (block.strides[0], block.strides[1]),
'e%s' % block.expand_ratio,
'i%d' % block.input_filters,
'o%d' % block.output_filters
]
if 0 < block.se_ratio <= 1:
args.append('se%s' % block.se_ratio)
if block.id_skip is False:
args.append('noskip')
return '_'.join(args)
[docs]
@staticmethod
def decode(string_list):
"""Decode a list of string notations to specify blocks inside the network.
Args:
string_list (list[str]): A list of strings, each string is a notation of block.
Returns:
blocks_args: A list of BlockArgs namedtuples of block args.
"""
assert isinstance(string_list, list)
blocks_args = []
for block_string in string_list:
blocks_args.append(BlockDecoder._decode_block_string(block_string))
return blocks_args
[docs]
@staticmethod
def encode(blocks_args):
"""Encode a list of BlockArgs to a list of strings.
Args:
blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
Returns:
block_strings: A list of strings, each string is a notation of block.
"""
block_strings = []
for block in blocks_args:
block_strings.append(BlockDecoder._encode_block_string(block))
return block_strings
def efficientnet_params(model_name):
"""Map EfficientNet model name to parameter coefficients.
Args:
model_name (str): Model name to be queried.
Returns:
params_dict[model_name]: A (width,depth,res,dropout) tuple.
"""
params_dict = {
# Coefficients: width,depth,res,dropout
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
}
return params_dict[model_name]
[docs]
def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True):
"""Create BlockArgs and GlobalParams for efficientnet model.
Args:
width_coefficient (float)
depth_coefficient (float)
image_size (int)
dropout_rate (float)
drop_connect_rate (float)
num_classes (int)
Meaning as the name suggests.
Returns:
blocks_args, global_params.
"""
# Blocks args for the whole model(efficientnet-b0 by default)
# It will be modified in the construction of EfficientNet Class according to model
blocks_args = [
'r1_k3_s11_e1_i32_o16_se0.25',
'r2_k3_s22_e6_i16_o24_se0.25',
'r2_k5_s22_e6_i24_o40_se0.25',
'r3_k3_s22_e6_i40_o80_se0.25',
'r3_k5_s11_e6_i80_o112_se0.25',
'r4_k5_s22_e6_i112_o192_se0.25',
'r1_k3_s11_e6_i192_o320_se0.25',
]
blocks_args = BlockDecoder.decode(blocks_args)
global_params = GlobalParams(
width_coefficient=width_coefficient,
depth_coefficient=depth_coefficient,
image_size=image_size,
dropout_rate=dropout_rate,
num_classes=num_classes,
batch_norm_momentum=0.99,
batch_norm_epsilon=1e-3,
drop_connect_rate=drop_connect_rate,
depth_divisor=8,
min_depth=None,
include_top=include_top,
)
return blocks_args, global_params
[docs]
def efficientnet_tf(width_coefficient=None,
depth_coefficient=None,
dropout_rate=0.2,
survival_prob=0.8):
"""Creates a efficientnet model."""
global_params = GlobalParams(
blocks_args=_DEFAULT_BLOCKS_ARGS,
batch_norm_momentum=0.99,
batch_norm_epsilon=1e-3,
dropout_rate=dropout_rate,
survival_prob=survival_prob,
data_format='channels_last',
num_classes=1000,
width_coefficient=width_coefficient,
depth_coefficient=depth_coefficient,
depth_divisor=8,
min_depth=None,
use_se=True,
clip_projection_output=False)
return global_params
[docs]
def get_model_params_tf(model_name, override_params):
"""Get the block args and global params for a given model."""
if model_name.startswith('efficientnet'):
width_coefficient, depth_coefficient, _, dropout_rate = (
efficientnet_params(model_name))
global_params = efficientnet(
width_coefficient, depth_coefficient, dropout_rate)
else:
raise NotImplementedError('model name is not pre-defined: %s' % model_name)
if override_params:
# ValueError will be raised here if override_params has fields not included
# in global_params.
global_params = global_params._replace(**override_params)
decoder = BlockDecoder()
blocks_args = decoder.decode(global_params.blocks_args)
print('EFFNET LOGGING: global_params= %s', global_params)
return blocks_args, global_params
[docs]
def get_model_params(model_name, override_params):
"""Get the block args and global params for a given model name.
Args:
model_name (str): Model's name.
override_params (dict): A dict to modify global_params.
Returns:
blocks_args, global_params
"""
if model_name.startswith('efficientnet'):
w, d, s, p = efficientnet_params(model_name)
# note: all models have drop connect rate = 0.2
blocks_args, global_params = efficientnet(
width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
else:
raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
if override_params:
# ValueError will be raised here if override_params has fields not included in global_params.
global_params = global_params._replace(**override_params)
return blocks_args, global_params
[docs]
def efficientnet_params(model_name):
"""Get efficientnet params based on model name."""
params_dict = {
# (width_coefficient, depth_coefficient, resolution, dropout_rate)
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
}
return params_dict[model_name]
VALID_MODELS = (
'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
'efficientnet-b8',
)
[docs]
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck Block.
Args:
block_args (namedtuple): BlockArgs, defined in utils.py.
global_params (namedtuple): GlobalParam, defined in utils.py.
image_size (tuple or list): [image_height, image_width].
References:
[1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
[2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
"""
def __init__(self, block_args, global_params, image_size=None):
super().__init__()
self._block_args = block_args
self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
self._bn_eps = global_params.batch_norm_epsilon
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
# Expansion phase (Inverted Bottleneck)
inp = self._block_args.input_filters # number of input channels
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
if self._block_args.expand_ratio != 1:
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
# Depthwise convolution phase
k = self._block_args.kernel_size
s = self._block_args.stride
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._depthwise_conv = Conv2d(
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
kernel_size=k, stride=s, bias=False)
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
image_size = calculate_output_image_size(image_size, s)
# Squeeze and Excitation layer, if desired
if self.has_se:
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
# Pointwise convolution phase
final_oup = self._block_args.output_filters
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
self._swish = MemoryEfficientSwish()
[docs]
def forward(self, inputs, drop_connect_rate=None):
"""MBConvBlock's forward function.
Args:
inputs (tensor): Input tensor.
drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
Returns:
Output of this block after processing.
"""
# Expansion and Depthwise Convolution
x = inputs
if self._block_args.expand_ratio != 1:
x = self._expand_conv(inputs)
x = self._bn0(x)
x = self._swish(x)
x = self._depthwise_conv(x)
x = self._bn1(x)
x = self._swish(x)
# Squeeze and Excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self._se_reduce(x_squeezed)
x_squeezed = self._swish(x_squeezed)
x_squeezed = self._se_expand(x_squeezed)
x = torch.sigmoid(x_squeezed) * x
# Pointwise Convolution
x = self._project_conv(x)
x = self._bn2(x)
# Skip connection and drop connect
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
# The combination of skip connection and drop connect brings about stochastic depth.
if drop_connect_rate:
x = drop_connect(x, p=drop_connect_rate, training=self.training)
x = x + inputs # skip connection
return x
[docs]
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export).
Args:
memory_efficient (bool): Whether to use memory-efficient version of swish.
"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
[docs]
class EfficientNet(MammothBackbone):
"""EfficientNet model.
Most easily loaded with the .from_name or .from_pretrained methods.
Args:
blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
global_params (namedtuple): A set of GlobalParams shared between blocks.
References:
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)
Example:
>>> import torch
>>> from efficientnet.model import EfficientNet
>>> inputs = torch.rand(1, 3, 224, 224)
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
>>> model.eval()
>>> outputs = model(inputs)
"""
def __init__(self, blocks_args=None, global_params=None, hookme=False):
super().__init__()
assert isinstance(blocks_args, list), 'blocks_args should be a list'
assert len(blocks_args) > 0, 'block args must be greater than 0'
self._global_params = global_params
self.hookme = hookme
self._blocks_args = blocks_args
# Batch norm parameters
bn_mom = 1 - self._global_params.batch_norm_momentum
bn_eps = self._global_params.batch_norm_epsilon
# Get stem static or dynamic convolution depending on image size
image_size = global_params.image_size
Conv2d = get_same_padding_conv2d(image_size=image_size)
# Stem
in_channels = 3 # rgb
out_channels = round_filters(32, self._global_params) # number of output channels
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
image_size = calculate_output_image_size(image_size, 2)
# Build blocks
self._blocks = nn.ModuleList([])
for block_args in self._blocks_args:
# Update block input and output filters based on depth multiplier.
block_args = block_args._replace(
input_filters=round_filters(block_args.input_filters, self._global_params),
output_filters=round_filters(block_args.output_filters, self._global_params),
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
)
# The first block needs to take care of stride and filter size increase.
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
image_size = calculate_output_image_size(image_size, block_args.stride)
if block_args.num_repeat > 1: # modify block_args to keep same output size
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
for _ in range(block_args.num_repeat - 1):
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
# Head
in_channels = block_args.output_filters # output of final block
out_channels = round_filters(1280, self._global_params)
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
# Final linear layer
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
if self._global_params.include_top:
self._dropout = nn.Dropout(self._global_params.dropout_rate)
self.classifier = nn.Linear(out_channels, self._global_params.num_classes)
# set activation to memory efficient swish by default
self._swish = MemoryEfficientSwish()
[docs]
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export).
Args:
memory_efficient (bool): Whether to use memory-efficient version of swish.
"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
for block in self._blocks:
block.set_swish(memory_efficient)
[docs]
def activations_hook(self, grad):
self.gradients = grad
[docs]
def forward(self, inputs, returnt='out'):
"""EfficientNet's forward function.
Calls extract_features to extract features, applies final linear layer, and returns logits.
Args:
inputs (tensor): Input tensor.
Returns:
Output of this model after processing.
"""
# Convolution layers
x = self.extract_features(inputs)
# Pooling and final linear layer
feats = self._avg_pooling(x).flatten(start_dim=1)
if returnt == 'features':
return feats
if self._global_params.include_top:
x = self._dropout(feats)
x = self.classifier(x)
if returnt == 'out':
return x
elif returnt == 'full':
return (x, feats)
raise NotImplementedError("Unknown return type")
[docs]
@classmethod
def from_name(cls, model_name, in_channels=3, **override_params):
"""Create an efficientnet model according to name.
Args:
model_name (str): Name for efficientnet.
in_channels (int): Input data's channel number.
override_params (other key word params):
Params to override model's global_params.
Optional key:
'width_coefficient', 'depth_coefficient',
'image_size', 'dropout_rate',
'num_classes', 'batch_norm_momentum',
'batch_norm_epsilon', 'drop_connect_rate',
'depth_divisor', 'min_depth'
Returns:
An efficientnet model.
"""
cls._check_model_name_is_valid(model_name)
blocks_args, global_params = get_model_params(model_name, override_params)
model = cls(blocks_args, global_params)
model._change_in_channels(in_channels)
return model
[docs]
@classmethod
def from_pretrained(cls, model_name, weights_path=None, advprop=False,
in_channels=3, num_classes=1000, **override_params):
"""Create an efficientnet model according to name.
Args:
model_name (str): Name for efficientnet.
weights_path (None or str):
str: path to pretrained weights file on the local disk.
None: use pretrained weights downloaded from the Internet.
advprop (bool):
Whether to load pretrained weights
trained with advprop (valid when weights_path is None).
in_channels (int): Input data's channel number.
num_classes (int):
Number of categories for classification.
It controls the output size for final linear layer.
override_params (other key word params):
Params to override model's global_params.
Optional key:
'width_coefficient', 'depth_coefficient',
'image_size', 'dropout_rate',
'batch_norm_momentum',
'batch_norm_epsilon', 'drop_connect_rate',
'depth_divisor', 'min_depth'
Returns:
A pretrained efficientnet model.
"""
model = cls.from_name(model_name, num_classes=num_classes, **override_params)
load_pretrained_weights(model, model_name, weights_path=weights_path,
load_fc=(num_classes == 1000), advprop=advprop)
model._change_in_channels(in_channels)
return model
[docs]
@classmethod
def get_image_size(cls, model_name):
"""Get the input image size for a given efficientnet model.
Args:
model_name (str): Name for efficientnet.
Returns:
Input image size (resolution).
"""
cls._check_model_name_is_valid(model_name)
_, _, res, _ = efficientnet_params(model_name)
return res
@classmethod
def _check_model_name_is_valid(cls, model_name):
"""Validates model name.
Args:
model_name (str): Name for efficientnet.
Returns:
bool: Is a valid name or not.
"""
if model_name not in VALID_MODELS:
raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
def _change_in_channels(self, in_channels):
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
Args:
in_channels (int): Input data's channel number.
"""
if in_channels != 3:
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
out_channels = round_filters(32, self._global_params)
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
[docs]
def mammoth_efficientnet(nclasses: int, model_name: str, pretrained=False):
"""
Instantiates a ResNet18 network.
Args:
nclasses: number of output classes
nf: number of filters
Returns:
ResNet network
"""
print(model_name)
if not pretrained:
return EfficientNet.from_name(model_name=model_name, num_classes=nclasses)
else:
return EfficientNet.from_pretrained(model_name=model_name, num_classes=nclasses)