Skip to content

Exploring vit features_only via new forward_intermediates() API, inspired by #2131 #2136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 75 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,16 @@
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(False)

# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
FEAT_INTER_FILTERS = [
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*'
]

# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*'
]
NUM_NON_STD = len(NON_STD_FILTERS)
Expand Down Expand Up @@ -351,15 +356,46 @@ def test_model_forward_torchscript(model_name, batch_size):

@pytest.mark.features
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_features(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""
model = create_model(model_name, pretrained=False, features_only=True)
model.eval()
expected_channels = model.feature_info.channels()
expected_reduction = model.feature_info.reduction()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
assert len(expected_channels) >= 3 # all models here should have at least 3 default feat levels

input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
if max(input_size) > MAX_FFEAT_SIZE:
pytest.skip("Fixed input size model > limit.")
output_fmt = getattr(model, 'output_fmt', 'NCHW')
feat_axis = get_channel_dim(output_fmt)
spatial_axis = get_spatial_dim(output_fmt)
import math

outputs = model(torch.randn((batch_size, *input_size)))
assert len(expected_channels) == len(outputs)
spatial_size = input_size[-2:]
for e, r, o in zip(expected_channels, expected_reduction, outputs):
assert e == o.shape[feat_axis]
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
assert o.shape[0] == batch_size
assert not torch.isnan(o).any()


@pytest.mark.features
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_intermediates_features(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""
model = create_model(model_name, pretrained=False, features_only=True)
model.eval()
print(model.feature_info.out_indices)
expected_channels = model.feature_info.channels()
expected_reduction = model.feature_info.reduction()

input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
if max(input_size) > MAX_FFEAT_SIZE:
Expand All @@ -373,6 +409,41 @@ def test_model_forward_features(model_name, batch_size):
assert len(expected_channels) == len(outputs)
spatial_size = input_size[-2:]
for e, r, o in zip(expected_channels, expected_reduction, outputs):
print(o.shape)
assert e == o.shape[feat_axis]
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
assert o.shape[0] == batch_size
assert not torch.isnan(o).any()


@pytest.mark.features
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_intermediates(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""
model = create_model(model_name, pretrained=False)
model.eval()
feature_info = timm.models.FeatureInfo(model.feature_info, len(model.feature_info))
expected_channels = feature_info.channels()
expected_reduction = feature_info.reduction()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6

input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
if max(input_size) > MAX_FFEAT_SIZE:
pytest.skip("Fixed input size model > limit.")
output_fmt = getattr(model, 'output_fmt', 'NCHW')
feat_axis = get_channel_dim(output_fmt)
spatial_axis = get_spatial_dim(output_fmt)
import math

output, intermediates = model.forward_intermediates(
torch.randn((batch_size, *input_size)),
)
assert len(expected_channels) == len(intermediates)
spatial_size = input_size[-2:]
for e, r, o in zip(expected_channels, expected_reduction, intermediates):
assert e == o.shape[feat_axis]
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
Expand Down
20 changes: 18 additions & 2 deletions timm/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import math
from typing import Callable, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -65,6 +66,21 @@ def __init__(
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
if as_scalar:
return max(self.patch_size)
else:
return self.patch_size

def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
""" Get grid (feature) size for given image size taking account of dynamic padding.
NOTE: must be torchscript compatible so using fixed tuple indexing
"""
if self.dynamic_img_pad:
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
else:
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]

def forward(self, x):
B, C, H, W = x.shape
if self.img_size is not None:
Expand Down Expand Up @@ -127,13 +143,13 @@ def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
_assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).")

x = self.proj(x)
grid_size = x.shape[-2:]
feat_size = x.shape[-2:]
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
x = self.norm(x)
return x, grid_size
return x, feat_size


def resample_patch_embed(
Expand Down
6 changes: 5 additions & 1 deletion timm/models/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn as nn
from torch.hub import load_state_dict_from_url

from timm.models._features import FeatureListNet, FeatureHookNet
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
Expand Down Expand Up @@ -428,8 +428,12 @@ def build_model_with_cfg(
feature_cls = feature_cls.lower()
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
elif feature_cls == 'dict':
feature_cls = FeatureDictNet
elif feature_cls == 'fx':
feature_cls = FeatureGraphNet
elif feature_cls == 'getter':
feature_cls = FeatureGetterNet
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)
Expand Down
Loading
Loading