"""
Copyright 2020 Magdalena Rybicka
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
from jsonargparse import ArgumentParser, ActionParser
from .spinenet import *
spinenet_dict = {
"spinenet49": SpineNet49,
"spinenet49s": SpineNet49S,
"spinenet96": SpineNet96,
"spinenet143": SpineNet143,
"spinenet190": SpineNet190,
"lspinenet49": LSpineNet49,
"lspinenet49_subpixel": LSpineNet49_subpixel,
"lspinenet49_bilinear": LSpineNet49_bilinear,
"lspinenet49_5": LSpineNet49_5,
"lspine2net49": LSpine2Net49,
"selspine2net49": SELSpine2Net49,
"tselspine2net49": TSELSpine2Net49,
"spine2net49": Spine2Net49,
"sespine2net49": SESpine2Net49,
"tsespine2net49": TSESpine2Net49,
"spine2net49s": Spine2Net49S,
"sespine2net49s": SESpine2Net49S,
"tsespine2net49s": TSESpine2Net49S,
"lr0_sp53": LR0_SP53,
"r0_sp53": R0_SP53,
"spinenet49_concat_time": SpineNet49_concat_time,
}
[docs]class SpineNetFactory(object):
[docs] @staticmethod
def create(
spinenet_type,
in_channels,
output_levels=[3, 4, 5, 6, 7],
endpoints_num_filters=256,
resample_alpha=0.5,
block_repeats=1,
filter_size_scale=1.0,
conv_channels=64,
base_channels=64,
out_units=0,
hid_act={"name": "relu6", "inplace": True},
out_act=None,
in_kernel_size=7,
in_stride=2,
zero_init_residual=False,
groups=1,
dropout_rate=0,
norm_layer=None,
norm_before=True,
do_maxpool=True,
in_norm=True,
se_r=16,
in_feats=None,
res2net_scale=4,
res2net_width_factor=1,
):
try:
spinenet_class = spinenet_dict[spinenet_type]
except:
raise Exception("%s is not valid SpineNet network" % (spinenet_type))
spinenet = spinenet_class(
in_channels,
output_levels=output_levels,
endpoints_num_filters=endpoints_num_filters,
resample_alpha=resample_alpha,
block_repeats=block_repeats,
filter_size_scale=filter_size_scale,
conv_channels=conv_channels,
base_channels=base_channels,
out_units=out_units,
hid_act=hid_act,
out_act=out_act,
in_kernel_size=in_kernel_size,
in_stride=in_stride,
zero_init_residual=zero_init_residual,
groups=groups,
dropout_rate=dropout_rate,
norm_layer=norm_layer,
norm_before=norm_before,
do_maxpool=do_maxpool,
in_norm=in_norm,
se_r=se_r,
in_feats=in_feats,
res2net_scale=res2net_scale,
res2net_width_factor=res2net_width_factor,
)
return spinenet
[docs] def filter_args(**kwargs):
if "norm_after" in kwargs:
kwargs["norm_before"] = not kwargs["norm_after"]
del kwargs["norm_after"]
if "no_maxpool" in kwargs:
kwargs["do_maxpool"] = not kwargs["no_maxpool"]
del kwargs["no_maxpool"]
valid_args = (
"spinenet_type",
"in_channels",
"ouput_levels",
"endpoints_num_filters",
"resample_alpha",
"block_repeats",
"filter_size_scale",
"conv_channels",
"base_channels",
"out_units",
"hid_act",
"out_act",
"in_kernel_size",
"in_stride",
"zero_init_residual",
"groups",
"dropout_rate",
"in_norm",
"norm_layer",
"norm_before",
"do_maxpool",
"se_r",
"res2net_scale",
"res2net_width_factor",
"in_feats",
)
args = dict((k, kwargs[k]) for k in valid_args if k in kwargs)
return args
[docs] @staticmethod
def add_class_args(parser, prefix=None):
if prefix is not None:
outer_parser = parser
parser = ArgumentParser(prog="")
spinenet_types = spinenet_dict.keys()
parser.add_argument(
"--spinenet-type",
type=str.lower,
default="spinenet49",
choices=spinenet_types,
help=("SpineNet type"),
)
parser.add_argument(
"--in-channels", default=1, type=int, help=("number of input channels")
)
parser.add_argument(
"--conv-channels",
default=64,
type=int,
help=("number of output channels in input convolution "),
)
parser.add_argument(
"--base-channels",
default=64,
type=int,
help=("base channels of first SpineNet block"),
)
parser.add_argument(
"--in-kernel-size",
default=7,
type=int,
help=("kernel size of first convolution"),
)
parser.add_argument(
"--in-stride", default=2, type=int, help=("stride of first convolution")
)
parser.add_argument(
"--groups",
default=1,
type=int,
help=("number of groups in residual blocks convolutions"),
)
try:
parser.add_argument(
"--norm-layer",
default=None,
choices=[
"batch-norm",
"group-norm",
"instance-norm",
"instance-norm-affine",
"layer-norm",
],
help="type of normalization layer",
)
except:
pass
parser.add_argument(
"--in-norm",
default=False,
action="store_true",
help="batch normalization at the input",
)
parser.add_argument(
"--no-maxpool",
default=False,
action="store_true",
help="don't do max pooling after first convolution",
)
parser.add_argument(
"--zero-init-residual",
default=False,
action="store_true",
help="Zero-initialize the last BN in each residual branch",
)
parser.add_argument(
"--se-r",
default=16,
type=int,
help=("squeeze ratio in squeeze-excitation blocks"),
)
parser.add_argument(
"--res2net-scale", default=4, type=int, help=("scale parameter for res2net")
)
parser.add_argument(
"--res2net-width-factor",
default=1,
type=float,
help=("multiplicative factor for the internal width of res2net"),
)
try:
parser.add_argument("--hid-act", default="relu6", help="hidden activation")
except:
pass
try:
parser.add_argument(
"--norm-after",
default=False,
action="store_true",
help="batch normalizaton after activation",
)
except:
pass
try:
parser.add_argument("--dropout-rate", default=0, type=float, help="dropout")
except:
pass
if prefix is not None:
outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
add_argparse_args = add_class_args