Source code for backbone.EfficientNet

# Author: lukemelas (github username)
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
# With adjustments and added comments by workingcoder (github username).

import collections
import logging
import math
import re
from functools import partial

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import model_zoo

from backbone import MammothBackbone

url_map = {
    'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
    'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
    'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
    'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
    'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
    'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
    'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
    'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
}


[docs] def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, verbose=True): """Loads pretrained weights from weights path or download using url. Args: model (Module): The whole model of efficientnet. model_name (str): Model name of efficientnet. weights_path (None or str): str: path to pretrained weights file on the local disk. None: use pretrained weights downloaded from the Internet. load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. advprop (bool): Whether to load pretrained weights trained with advprop (valid when weights_path is None). """ if isinstance(weights_path, str): state_dict = torch.load(weights_path, weights_only=True) else: # AutoAugment or Advprop (different preprocessing) url_map_ = url_map state_dict = model_zoo.load_url(url_map_[model_name]) if load_fc: ret = model.load_state_dict(state_dict, strict=False) assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) else: state_dict.pop('_fc.weight') state_dict.pop('_fc.bias') ret = model.load_state_dict(state_dict, strict=False) assert set(ret.missing_keys) == set( ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) if verbose: logging.info('Loaded pretrained weights for {}'.format(model_name))
_DEFAULT_BLOCKS_ARGS = [ 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 'r1_k3_s11_e6_i192_o320_se0.25', ]
[docs] def get_width_and_height_from_size(x): """Obtain height and width from x. Args: x (int, tuple or list): Data size. Returns: size: A tuple or list (H,W). """ if isinstance(x, int): return x, x if isinstance(x, (list, tuple)): return x else: raise TypeError()
[docs] def calculate_output_image_size(input_image_size, stride): """Calculates the output image size when using Conv2dSamePadding with a stride. Necessary for static padding. Thanks to mannatsingh for pointing this out. Args: input_image_size (int, tuple or list): Size of input image. stride (int, tuple or list): Conv2d operation's stride. Returns: output_image_size: A list [H,W]. """ if input_image_size is None: return None image_height, image_width = get_width_and_height_from_size(input_image_size) stride = stride if isinstance(stride, int) else stride[0] image_height = int(math.ceil(image_height / stride)) image_width = int(math.ceil(image_width / stride)) return [image_height, image_width]
[docs] def drop_connect(inputs, p, training): """Drop connect. Args: input (tensor: BCWH): Input of this structure. p (float: 0.0~1.0): Probability of drop connection. training (bool): The running mode. Returns: output: Output after drop connection. """ assert 0 <= p <= 1, 'p must be in range of [0,1]' if not training: return inputs batch_size = inputs.shape[0] keep_prob = 1 - p # generate binary_tensor mask according to probability (p for 0, 1-p for 1) random_tensor = keep_prob random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) binary_tensor = torch.floor(random_tensor) output = inputs / keep_prob * binary_tensor return output
[docs] def round_repeats(repeats, global_params): """Calculate module's repeat number of a block based on depth multiplier. Use depth_coefficient of global_params. Args: repeats (int): num_repeat to be calculated. global_params (namedtuple): Global params of the model. Returns: new repeat: New repeat number after calculating. """ multiplier = global_params.depth_coefficient if not multiplier: return repeats # follow the formula transferred from official TensorFlow implementation return int(math.ceil(multiplier * repeats))
[docs] def round_filters(filters, global_params): """Calculate and round number of filters based on width multiplier. Use width_coefficient, depth_divisor and min_depth of global_params. Args: filters (int): Filters number to be calculated. global_params (namedtuple): Global params of the model. Returns: new_filters: New filters number after calculating. """ multiplier = global_params.width_coefficient if not multiplier: return filters divisor = global_params.depth_divisor min_depth = global_params.min_depth filters *= multiplier min_depth = min_depth or divisor # pay attention to this line when using min_depth # follow the formula transferred from official TensorFlow implementation new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) if new_filters < 0.9 * filters: # prevent rounding by more than 10% new_filters += divisor return int(new_filters)
[docs] class Conv2dDynamicSamePadding(nn.Conv2d): """2D Convolutions like TensorFlow, for a dynamic image size. The padding is operated in forward function by calculating dynamically. """ # Tips for 'SAME' mode padding. # Given the following: # i: width or height # s: stride # k: kernel size # d: dilation # p: padding # Output after Conv2d: # o = floor((i+p-((k-1)*d+1))/s+1) # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1), # => p = (i-1)*s+((k-1)*d+1)-i def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
[docs] def forward(self, x): ih, iw = x.size()[-2:] kh, kw = self.weight.size()[-2:] sh, sw = self.stride oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! ! pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs] class Conv2dStaticSamePadding(nn.Conv2d): """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. The padding mudule is calculated in construction function, then used in forward. """ # With the same calculation as Conv2dDynamicSamePadding def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs): super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs) self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 # Calculate padding based on image size and save it assert image_size is not None ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size kh, kw = self.weight.size()[-2:] sh, sw = self.stride oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) else: self.static_padding = nn.Identity()
[docs] def forward(self, x): x = self.static_padding(x) x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x
[docs] def get_same_padding_conv2d(image_size=None): """Chooses static padding if you have specified an image size, and dynamic padding otherwise. Static padding is necessary for ONNX exporting of models. Args: image_size (int or tuple): Size of the image. Returns: Conv2dDynamicSamePadding or Conv2dStaticSamePadding. """ if image_size is None: return Conv2dDynamicSamePadding else: return partial(Conv2dStaticSamePadding, image_size=image_size)
GlobalParams = collections.namedtuple('GlobalParams', [ 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'data_format', 'num_classes', 'width_coefficient', 'depth_coefficient', 'depth_divisor', 'min_depth', 'survival_prob', 'relu_fn', 'batch_norm', 'use_se', 'local_pooling', 'condconv_num_experts', 'clip_projection_output', 'blocks_args', 'image_size', 'drop_connect_rate', 'include_top' ]) BlockArgs = collections.namedtuple('BlockArgs', [ 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 'expand_ratio', 'id_skip', 'strides', 'se_ratio', 'conv_type', 'fused_conv', 'super_pixel', 'condconv', 'stride' ]) # Set GlobalParams and BlockArgs's defaults GlobalParams.__new__.__defaults__ = (None,) * (len(GlobalParams._fields) - 1) + (True,) # include top BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) # Swish activation function if hasattr(nn, 'SiLU'): Swish = nn.SiLU else: # For compatibility with old PyTorch versions class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) # A memory-efficient implementation of Swish function
[docs] class SwishImplementation(torch.autograd.Function):
[docs] @staticmethod def forward(ctx, i): result = i * torch.sigmoid(i) ctx.save_for_backward(i) return result
[docs] @staticmethod def backward(ctx, grad_output): i = ctx.saved_tensors[0] sigmoid_i = torch.sigmoid(i) return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
[docs] class MemoryEfficientSwish(nn.Module):
[docs] def forward(self, x): return SwishImplementation.apply(x)
[docs] class BlockDecoder(object): """Block Decoder for readability, straight from the official TensorFlow repository. """ @staticmethod def _decode_block_string(block_string): """Get a block through a string notation of arguments. Args: block_string (str): A string notation of arguments. Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. Returns: BlockArgs: The namedtuple defined at the top of this file. """ assert isinstance(block_string, str) ops = block_string.split('_') options = {} for op in ops: splits = re.split(r'(\d.*)', op) if len(splits) >= 2: key, value = splits[:2] options[key] = value # Check stride assert (('s' in options and len(options['s']) == 1) or (len(options['s']) == 2 and options['s'][0] == options['s'][1])) return BlockArgs( num_repeat=int(options['r']), kernel_size=int(options['k']), stride=[int(options['s'][0])], expand_ratio=int(options['e']), input_filters=int(options['i']), output_filters=int(options['o']), se_ratio=float(options['se']) if 'se' in options else None, id_skip=('noskip' not in block_string)) @staticmethod def _encode_block_string(block): """Encode a block to a string. Args: block (namedtuple): A BlockArgs type argument. Returns: block_string: A String form of BlockArgs. """ args = [ 'r%d' % block.num_repeat, 'k%d' % block.kernel_size, 's%d%d' % (block.strides[0], block.strides[1]), 'e%s' % block.expand_ratio, 'i%d' % block.input_filters, 'o%d' % block.output_filters ] if 0 < block.se_ratio <= 1: args.append('se%s' % block.se_ratio) if block.id_skip is False: args.append('noskip') return '_'.join(args)
[docs] @staticmethod def decode(string_list): """Decode a list of string notations to specify blocks inside the network. Args: string_list (list[str]): A list of strings, each string is a notation of block. Returns: blocks_args: A list of BlockArgs namedtuples of block args. """ assert isinstance(string_list, list) blocks_args = [] for block_string in string_list: blocks_args.append(BlockDecoder._decode_block_string(block_string)) return blocks_args
[docs] @staticmethod def encode(blocks_args): """Encode a list of BlockArgs to a list of strings. Args: blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. Returns: block_strings: A list of strings, each string is a notation of block. """ block_strings = [] for block in blocks_args: block_strings.append(BlockDecoder._encode_block_string(block)) return block_strings
def efficientnet_params(model_name): """Map EfficientNet model name to parameter coefficients. Args: model_name (str): Model name to be queried. Returns: params_dict[model_name]: A (width,depth,res,dropout) tuple. """ params_dict = { # Coefficients: width,depth,res,dropout 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 'efficientnet-b8': (2.2, 3.6, 672, 0.5), 'efficientnet-l2': (4.3, 5.3, 800, 0.5), } return params_dict[model_name]
[docs] def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None, dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True): """Create BlockArgs and GlobalParams for efficientnet model. Args: width_coefficient (float) depth_coefficient (float) image_size (int) dropout_rate (float) drop_connect_rate (float) num_classes (int) Meaning as the name suggests. Returns: blocks_args, global_params. """ # Blocks args for the whole model(efficientnet-b0 by default) # It will be modified in the construction of EfficientNet Class according to model blocks_args = [ 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 'r1_k3_s11_e6_i192_o320_se0.25', ] blocks_args = BlockDecoder.decode(blocks_args) global_params = GlobalParams( width_coefficient=width_coefficient, depth_coefficient=depth_coefficient, image_size=image_size, dropout_rate=dropout_rate, num_classes=num_classes, batch_norm_momentum=0.99, batch_norm_epsilon=1e-3, drop_connect_rate=drop_connect_rate, depth_divisor=8, min_depth=None, include_top=include_top, ) return blocks_args, global_params
[docs] def efficientnet_tf(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, survival_prob=0.8): """Creates a efficientnet model.""" global_params = GlobalParams( blocks_args=_DEFAULT_BLOCKS_ARGS, batch_norm_momentum=0.99, batch_norm_epsilon=1e-3, dropout_rate=dropout_rate, survival_prob=survival_prob, data_format='channels_last', num_classes=1000, width_coefficient=width_coefficient, depth_coefficient=depth_coefficient, depth_divisor=8, min_depth=None, use_se=True, clip_projection_output=False) return global_params
[docs] def get_model_params_tf(model_name, override_params): """Get the block args and global params for a given model.""" if model_name.startswith('efficientnet'): width_coefficient, depth_coefficient, _, dropout_rate = ( efficientnet_params(model_name)) global_params = efficientnet( width_coefficient, depth_coefficient, dropout_rate) else: raise NotImplementedError('model name is not pre-defined: %s' % model_name) if override_params: # ValueError will be raised here if override_params has fields not included # in global_params. global_params = global_params._replace(**override_params) decoder = BlockDecoder() blocks_args = decoder.decode(global_params.blocks_args) logging.info('EFFNET LOGGING: global_params= %s', global_params) return blocks_args, global_params
[docs] def get_model_params(model_name, override_params): """Get the block args and global params for a given model name. Args: model_name (str): Model's name. override_params (dict): A dict to modify global_params. Returns: blocks_args, global_params """ if model_name.startswith('efficientnet'): w, d, s, p = efficientnet_params(model_name) # note: all models have drop connect rate = 0.2 blocks_args, global_params = efficientnet( width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) else: raise NotImplementedError('model name is not pre-defined: {}'.format(model_name)) if override_params: # ValueError will be raised here if override_params has fields not included in global_params. global_params = global_params._replace(**override_params) return blocks_args, global_params
[docs] def efficientnet_params(model_name): """Get efficientnet params based on model name.""" params_dict = { # (width_coefficient, depth_coefficient, resolution, dropout_rate) 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 'efficientnet-b8': (2.2, 3.6, 672, 0.5), 'efficientnet-l2': (4.3, 5.3, 800, 0.5), } return params_dict[model_name]
VALID_MODELS = ( 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7', 'efficientnet-b8', )
[docs] class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck Block. Args: block_args (namedtuple): BlockArgs, defined in utils.py. global_params (namedtuple): GlobalParam, defined in utils.py. image_size (tuple or list): [image_height, image_width]. References: [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) """ def __init__(self, block_args, global_params, image_size=None): super().__init__() self._block_args = block_args self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow self._bn_eps = global_params.batch_norm_epsilon self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) self.id_skip = block_args.id_skip # whether to use skip connection and drop connect # Expansion phase (Inverted Bottleneck) inp = self._block_args.input_filters # number of input channels oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels if self._block_args.expand_ratio != 1: Conv2d = get_same_padding_conv2d(image_size=image_size) self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size # Depthwise convolution phase k = self._block_args.kernel_size s = self._block_args.stride Conv2d = get_same_padding_conv2d(image_size=image_size) self._depthwise_conv = Conv2d( in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise kernel_size=k, stride=s, bias=False) self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) image_size = calculate_output_image_size(image_size, s) # Squeeze and Excitation layer, if desired if self.has_se: Conv2d = get_same_padding_conv2d(image_size=(1, 1)) num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) # Pointwise convolution phase final_oup = self._block_args.output_filters Conv2d = get_same_padding_conv2d(image_size=image_size) self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) self._swish = MemoryEfficientSwish()
[docs] def forward(self, inputs, drop_connect_rate=None): """MBConvBlock's forward function. Args: inputs (tensor): Input tensor. drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). Returns: Output of this block after processing. """ # Expansion and Depthwise Convolution x = inputs if self._block_args.expand_ratio != 1: x = self._expand_conv(inputs) x = self._bn0(x) x = self._swish(x) x = self._depthwise_conv(x) x = self._bn1(x) x = self._swish(x) # Squeeze and Excitation if self.has_se: x_squeezed = F.adaptive_avg_pool2d(x, 1) x_squeezed = self._se_reduce(x_squeezed) x_squeezed = self._swish(x_squeezed) x_squeezed = self._se_expand(x_squeezed) x = torch.sigmoid(x_squeezed) * x # Pointwise Convolution x = self._project_conv(x) x = self._bn2(x) # Skip connection and drop connect input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: # The combination of skip connection and drop connect brings about stochastic depth. if drop_connect_rate: x = drop_connect(x, p=drop_connect_rate, training=self.training) x = x + inputs # skip connection return x
[docs] def set_swish(self, memory_efficient=True): """Sets swish function as memory efficient (for training) or standard (for export). Args: memory_efficient (bool): Whether to use memory-efficient version of swish. """ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
[docs] class EfficientNet(MammothBackbone): """EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods. Args: blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. global_params (namedtuple): A set of GlobalParams shared between blocks. References: [1] https://arxiv.org/abs/1905.11946 (EfficientNet) Example: >>> import torch >>> from efficientnet.model import EfficientNet >>> inputs = torch.rand(1, 3, 224, 224) >>> model = EfficientNet.from_pretrained('efficientnet-b0') >>> model.eval() >>> outputs = model(inputs) """ def __init__(self, blocks_args=None, global_params=None, hookme=False): super().__init__() assert isinstance(blocks_args, list), 'blocks_args should be a list' assert len(blocks_args) > 0, 'block args must be greater than 0' self._global_params = global_params self.hookme = hookme self._blocks_args = blocks_args # Batch norm parameters bn_mom = 1 - self._global_params.batch_norm_momentum bn_eps = self._global_params.batch_norm_epsilon # Get stem static or dynamic convolution depending on image size image_size = global_params.image_size Conv2d = get_same_padding_conv2d(image_size=image_size) # Stem in_channels = 3 # rgb out_channels = round_filters(32, self._global_params) # number of output channels self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) image_size = calculate_output_image_size(image_size, 2) # Build blocks self._blocks = nn.ModuleList([]) for block_args in self._blocks_args: # Update block input and output filters based on depth multiplier. block_args = block_args._replace( input_filters=round_filters(block_args.input_filters, self._global_params), output_filters=round_filters(block_args.output_filters, self._global_params), num_repeat=round_repeats(block_args.num_repeat, self._global_params) ) # The first block needs to take care of stride and filter size increase. self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) image_size = calculate_output_image_size(image_size, block_args.stride) if block_args.num_repeat > 1: # modify block_args to keep same output size block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) for _ in range(block_args.num_repeat - 1): self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 # Head in_channels = block_args.output_filters # output of final block out_channels = round_filters(1280, self._global_params) Conv2d = get_same_padding_conv2d(image_size=image_size) self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) # Final linear layer self._avg_pooling = nn.AdaptiveAvgPool2d(1) if self._global_params.include_top: self._dropout = nn.Dropout(self._global_params.dropout_rate) self.classifier = nn.Linear(out_channels, self._global_params.num_classes) # set activation to memory efficient swish by default self._swish = MemoryEfficientSwish()
[docs] def set_swish(self, memory_efficient=True): """Sets swish function as memory efficient (for training) or standard (for export). Args: memory_efficient (bool): Whether to use memory-efficient version of swish. """ self._swish = MemoryEfficientSwish() if memory_efficient else Swish() for block in self._blocks: block.set_swish(memory_efficient)
[docs] def extract_endpoints(self, inputs): """Use convolution layer to extract features from reduction levels i in [1, 2, 3, 4, 5]. Args: inputs (tensor): Input tensor. Returns: Dictionary of last intermediate features with reduction levels i in [1, 2, 3, 4, 5]. Example: >>> import torch >>> from efficientnet.model import EfficientNet >>> inputs = torch.rand(1, 3, 224, 224) >>> model = EfficientNet.from_pretrained('efficientnet-b0') >>> endpoints = model.extract_endpoints(inputs) >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7]) >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7]) """ endpoints = {} # Stem x = self._swish(self._bn0(self._conv_stem(inputs))) prev_x = x # Blocks for idx, block in enumerate(self._blocks): drop_connect_rate = self._global_params.drop_connect_rate if drop_connect_rate: drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate x = block(x, drop_connect_rate=drop_connect_rate) if prev_x.size(2) > x.size(2): endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x elif idx == len(self._blocks) - 1: endpoints['reduction_{}'.format(len(endpoints) + 1)] = x prev_x = x # Head x = self._swish(self._bn1(self._conv_head(x))) endpoints['reduction_{}'.format(len(endpoints) + 1)] = x return endpoints
[docs] def extract_features(self, inputs): """use convolution layer to extract feature . Args: inputs (tensor): Input tensor. Returns: Output of the final convolution layer in the efficientnet model. """ # Stem x = self._swish(self._bn0(self._conv_stem(inputs))) # Blocks for idx, block in enumerate(self._blocks): drop_connect_rate = self._global_params.drop_connect_rate if drop_connect_rate: drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate x = block(x, drop_connect_rate=drop_connect_rate) # Head x = self._swish(self._bn1(self._conv_head(x))) return x
[docs] def activations_hook(self, grad): self.gradients = grad
[docs] def forward(self, inputs, returnt='out'): """EfficientNet's forward function. Calls extract_features to extract features, applies final linear layer, and returns logits. Args: inputs (tensor): Input tensor. Returns: Output of this model after processing. """ # Convolution layers x = self.extract_features(inputs) # Pooling and final linear layer feats = self._avg_pooling(x).flatten(start_dim=1) if returnt == 'features': return feats if self._global_params.include_top: x = self._dropout(feats) x = self.classifier(x) if returnt == 'out': return x elif returnt == 'full': return (x, feats) raise NotImplementedError("Unknown return type")
[docs] @classmethod def from_name(cls, model_name, in_channels=3, **override_params): """Create an efficientnet model according to name. Args: model_name (str): Name for efficientnet. in_channels (int): Input data's channel number. override_params (other key word params): Params to override model's global_params. Optional key: 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate', 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon', 'drop_connect_rate', 'depth_divisor', 'min_depth' Returns: An efficientnet model. """ cls._check_model_name_is_valid(model_name) blocks_args, global_params = get_model_params(model_name, override_params) model = cls(blocks_args, global_params) model._change_in_channels(in_channels) return model
[docs] @classmethod def from_pretrained(cls, model_name, weights_path=None, advprop=False, in_channels=3, num_classes=1000, **override_params): """Create an efficientnet model according to name. Args: model_name (str): Name for efficientnet. weights_path (None or str): str: path to pretrained weights file on the local disk. None: use pretrained weights downloaded from the Internet. advprop (bool): Whether to load pretrained weights trained with advprop (valid when weights_path is None). in_channels (int): Input data's channel number. num_classes (int): Number of categories for classification. It controls the output size for final linear layer. override_params (other key word params): Params to override model's global_params. Optional key: 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate', 'batch_norm_momentum', 'batch_norm_epsilon', 'drop_connect_rate', 'depth_divisor', 'min_depth' Returns: A pretrained efficientnet model. """ model = cls.from_name(model_name, num_classes=num_classes, **override_params) load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop) model._change_in_channels(in_channels) return model
[docs] @classmethod def get_image_size(cls, model_name): """Get the input image size for a given efficientnet model. Args: model_name (str): Name for efficientnet. Returns: Input image size (resolution). """ cls._check_model_name_is_valid(model_name) _, _, res, _ = efficientnet_params(model_name) return res
@classmethod def _check_model_name_is_valid(cls, model_name): """Validates model name. Args: model_name (str): Name for efficientnet. Returns: bool: Is a valid name or not. """ if model_name not in VALID_MODELS: raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS)) def _change_in_channels(self, in_channels): """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. Args: in_channels (int): Input data's channel number. """ if in_channels != 3: Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) out_channels = round_filters(32, self._global_params) self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
[docs] def mammoth_efficientnet(nclasses: int, model_name: str, pretrained=False): """ Instantiates a ResNet18 network. Args: nclasses: number of output classes nf: number of filters Returns: ResNet network """ if not pretrained: return EfficientNet.from_name(model_name=model_name, num_classes=nclasses) else: return EfficientNet.from_pretrained(model_name=model_name, num_classes=nclasses)