Skip to content

Instantly share code, notes, and snippets.

@gau-nernst
Last active December 22, 2021 13:52
Show Gist options
  • Save gau-nernst/d5828aafaa93071f5c23e0e78b58f984 to your computer and use it in GitHub Desktop.
Save gau-nernst/d5828aafaa93071f5c23e0e78b58f984 to your computer and use it in GitHub Desktop.
Torchvision feature extractor
from torch import nn
from torchvision.models import resnet, mobilenet, efficientnet
from torchvision.models.feature_extraction import create_feature_extractor
class _Extractor(nn.Module):
def __init__(self, backbone, node_names):
super().__init__()
self.feat_extractor = create_feature_extractor(backbone, node_names)
self.feat_extractor.eval()
with torch.no_grad():
out_channels = [x.shape[1] for x in self.feat_extractor(torch.rand(1,3,224,224)).values()]
self.out_channels = tuple(out_channels)
def forward(self, x):
return self.feat_extractor(x).values()
class ResNetExtractor(_Extractor):
def __init__(self, name, pretrained=False):
backbone = resnet.__dict__[name](pretrained=pretrained, progress=False)
node_names = ["relu", "layer1", "layer2", "layer3", "layer4"]
super().__init__(backbone, node_names)
class MobileNetExtractor(_Extractor):
def __init__(self, name, pretrained=False):
backbone = mobilenet.__dict__[name](pretrained=pretrained, progress=False)
# take output at expansion 1x1 conv
stage_indices = [i for i, b in enumerate(backbone.features) if getattr(b, "_is_cn", False)]
block_name = "conv" if name == "mobilenet_v2" else "block"
node_names = [f"features.{i}.{block_name}.0" for i in stage_indices] + [f"features.{len(backbone.features)-1}"]
super().__init__(backbone, node_names)
class EfficientNetExtractor(_Extractor):
def __init__(self, name, pretrained=False):
backbone = efficientnet.__dict__[name](pretrained=pretrained, progress=False)
# take output at expansion 1x1 conv
stage_indices = [2, 3, 4, 6]
node_names = [f"features.{i}.0.block.0" for i in stage_indices] + [f"features.{len(backbone.features)-1}"]
super().__init__(backbone, node_names)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment