From 5fdc0b4e93dabf970f24667befe50207e9f37681 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 7 Apr 2024 11:24:45 -0700 Subject: [PATCH 1/5] Exploring vit features_only using get_intermediate_layers() as per #2131 --- timm/models/_builder.py | 6 ++- timm/models/_features.py | 85 ++++++++++++++++++++++++++----- timm/models/_features_fx.py | 16 ++++-- timm/models/vision_transformer.py | 8 +-- 4 files changed, 94 insertions(+), 21 deletions(-) diff --git a/timm/models/_builder.py b/timm/models/_builder.py index e6150b9a81..c1ad5c2df5 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -7,7 +7,7 @@ from torch import nn as nn from torch.hub import load_state_dict_from_url -from timm.models._features import FeatureListNet, FeatureHookNet +from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet from timm.models._features_fx import FeatureGraphNet from timm.models._helpers import load_state_dict from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf @@ -428,8 +428,12 @@ def build_model_with_cfg( feature_cls = feature_cls.lower() if 'hook' in feature_cls: feature_cls = FeatureHookNet + elif feature_cls == 'dict': + feature_cls = FeatureDictNet elif feature_cls == 'fx': feature_cls = FeatureGraphNet + elif feature_cls == 'getter': + feature_cls = FeatureGetterNet else: assert False, f'Unknown feature class {feature_cls}' model = feature_cls(model, **feature_cfg) diff --git a/timm/models/_features.py b/timm/models/_features.py index 7ef51809bc..cc4068d447 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -11,7 +11,7 @@ from collections import OrderedDict, defaultdict from copy import deepcopy from functools import partial -from typing import Dict, List, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -23,9 +23,24 @@ __all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] +def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: + if isinstance(x, int): + # if indices is an int, take last N features + return tuple(range(-x, 0)) + return tuple(x) + + +OutIndicesT = Union[int, Tuple[int, ...]] + + class FeatureInfo: - def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): + def __init__( + self, + feature_info: List[Dict], + out_indices: OutIndicesT, + ): + out_indices = _out_indices_as_tuple(out_indices) prev_reduction = 1 for i, fi in enumerate(feature_info): # sanity check the mandatory fields, there may be additional fields depending on the model @@ -37,14 +52,15 @@ def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): self.out_indices = out_indices self.info = feature_info - def from_other(self, out_indices: Tuple[int]): + def from_other(self, out_indices: OutIndicesT): + out_indices = _out_indices_as_tuple(out_indices) return FeatureInfo(deepcopy(self.info), out_indices) - def get(self, key, idx=None): + def get(self, key: str, idx: Optional[Union[int, List[int]]] = None): """ Get value by key at specified index (indices) if idx == None, returns value for key at each output index if idx is an integer, return value for that feature module index (ignoring output indices) - if idx is a list/tupple, return value for each module index (ignoring output indices) + if idx is a list/tuple, return value for each module index (ignoring output indices) """ if idx is None: return [self.info[i][key] for i in self.out_indices] @@ -53,7 +69,7 @@ def get(self, key, idx=None): else: return self.info[idx][key] - def get_dicts(self, keys=None, idx=None): + def get_dicts(self, keys: Optional[List[str]] = None, idx: Optional[Union[int, List[int]]] = None): """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) """ if idx is None: @@ -66,17 +82,17 @@ def get_dicts(self, keys=None, idx=None): else: return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} - def channels(self, idx=None): + def channels(self, idx: Optional[Union[int, List[int]]] = None): """ feature channels accessor """ return self.get('num_chs', idx) - def reduction(self, idx=None): + def reduction(self, idx: Optional[Union[int, List[int]]] = None): """ feature reduction (output stride) accessor """ return self.get('reduction', idx) - def module_name(self, idx=None): + def module_name(self, idx: Optional[Union[int, List[int]]] = None): """ feature module name accessor """ return self.get('module', idx) @@ -146,7 +162,7 @@ def _module_list(module, flatten_sequential=False): return ml -def _get_feature_info(net, out_indices): +def _get_feature_info(net, out_indices: OutIndicesT): feature_info = getattr(net, 'feature_info') if isinstance(feature_info, FeatureInfo): return feature_info.from_other(out_indices) @@ -182,7 +198,7 @@ class FeatureDictNet(nn.ModuleDict): def __init__( self, model: nn.Module, - out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + out_indices: OutIndicesT = (0, 1, 2, 3, 4), out_map: Sequence[Union[int, str]] = None, output_fmt: str = 'NCHW', feature_concat: bool = False, @@ -257,7 +273,7 @@ class FeatureListNet(FeatureDictNet): def __init__( self, model: nn.Module, - out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + out_indices: OutIndicesT = (0, 1, 2, 3, 4), output_fmt: str = 'NCHW', feature_concat: bool = False, flatten_sequential: bool = False, @@ -298,8 +314,8 @@ class FeatureHookNet(nn.ModuleDict): def __init__( self, model: nn.Module, - out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), - out_map: Sequence[Union[int, str]] = None, + out_indices: OutIndicesT = (0, 1, 2, 3, 4), + out_map: Optional[Sequence[Union[int, str]]] = None, return_dict: bool = False, output_fmt: str = 'NCHW', no_rewrite: bool = False, @@ -366,3 +382,44 @@ def forward(self, x): x = module(x) out = self.hooks.get_output(x.device) return out if self.return_dict else list(out.values()) + + +class FeatureGetterNet(nn.ModuleDict): + """ FeatureGetterNet + + Wrap models with a feature getter method, like 'get_intermediate_layers' + + """ + def __init__( + self, + model: nn.Module, + out_indices: OutIndicesT = 4, + out_map: Optional[Sequence[Union[int, str]]] = None, + return_dict: bool = False, + output_fmt: str = 'NCHW', + ): + super().__init__() + self.model = model + self.feature_info = _get_feature_info(model, out_indices) + self.out_indices = out_indices + self.out_map = out_map + self.return_dict = return_dict + self.output_fmt = output_fmt + + def forward(self, *args, **kwargs): + """ + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + """ + out = self.model.get_intermediate_layers( + *args, + n=self.out_indices, + reshape=True, + **kwargs, + ) + return out diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index c48c13b7fc..e67d1f257f 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -1,7 +1,7 @@ """ PyTorch FX Based Feature Extraction Helpers Using https://pytorch.org/vision/stable/feature_extraction.html """ -from typing import Callable, List, Dict, Union, Type +from typing import Callable, Dict, List, Optional, Union, Tuple, Type import torch from torch import nn @@ -103,7 +103,12 @@ def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str class FeatureGraphNet(nn.Module): """ A FX Graph based feature extractor that works with the model feature_info metadata """ - def __init__(self, model, out_indices, out_map=None): + def __init__( + self, + model: nn.Module, + out_indices: Tuple[int, ...], + out_map: Optional[Dict] = None, + ): super().__init__() assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' self.feature_info = _get_feature_info(model, out_indices) @@ -129,7 +134,12 @@ class GraphExtractNet(nn.Module): return_nodes: node names to return features from (dict or list) squeeze_out: if only one output, and output in list format, flatten to single tensor """ - def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True): + def __init__( + self, + model: nn.Module, + return_nodes: Union[Dict[str, str], List[str]], + squeeze_out: bool = True, + ): super().__init__() self.squeeze_out = squeeze_out self.graph_module = create_feature_extractor(model, return_nodes) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ce65ee4ae2..b57104ac2e 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -473,6 +473,7 @@ def __init__( self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False + self.feature_info = [] embed_args = {} if dynamic_img_size: @@ -520,6 +521,8 @@ def __init__( mlp_layer=mlp_layer, ) for i in range(depth)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)] self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head @@ -1770,9 +1773,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformer: - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - + out_indices = kwargs.pop('out_indices', 3) if 'flexi' in variant: # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. @@ -1791,6 +1792,7 @@ def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) pretrained, pretrained_filter_fn=_filter_fn, pretrained_strict=strict, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) From 679daef76a22415ac1cb666d970c1543699ecf51 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 9 Apr 2024 21:29:16 -0700 Subject: [PATCH 2/5] More forward_intermediates() & features_only work * forward_intermediates() added to beit, deit, eva, mvitv2, twins, vit, vit_sam * add features_only to forward intermediates to allow just intermediate features * fix #2060 * fix #1374 * fix #657 --- timm/models/_features.py | 77 +++++++++++---- timm/models/beit.py | 102 ++++++++++++++++++-- timm/models/deit.py | 4 +- timm/models/eva.py | 93 +++++++++++++++++-- timm/models/mvitv2.py | 78 +++++++++++++--- timm/models/twins.py | 62 ++++++++++++- timm/models/vision_transformer.py | 129 ++++++++++++++++++-------- timm/models/vision_transformer_sam.py | 68 +++++++++++--- 8 files changed, 512 insertions(+), 101 deletions(-) diff --git a/timm/models/_features.py b/timm/models/_features.py index cc4068d447..fa108798ad 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -11,7 +11,7 @@ from collections import OrderedDict, defaultdict from copy import deepcopy from functools import partial -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Set, Tuple, Union import torch import torch.nn as nn @@ -20,7 +20,39 @@ from timm.layers import Format -__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] +__all__ = [ + 'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet', + 'feature_take_indices' +] + + +def _take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[Set[int], int]: + if isinstance(n, int): + assert n >= 0 + take_indices = {x for x in range(num_blocks - n, num_blocks)} + else: + take_indices = {num_blocks + idx if idx < 0 else idx for idx in n} + return take_indices, max(take_indices) + + +def _take_indices_jit(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]: + if isinstance(n, int): + assert n >= 0 + take_indices = [num_blocks - n + i for i in range(n)] + elif isinstance(n, tuple): + # splitting this up is silly, but needed for torchscript type resolution of n + take_indices = [num_blocks + idx if idx < 0 else idx for idx in n] + else: + take_indices = [num_blocks + idx if idx < 0 else idx for idx in n] + return take_indices, max(take_indices) + + +def feature_take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]: + if torch.jit.is_scripting(): + return _take_indices_jit(n, num_blocks) + else: + # NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno + return _take_indices(n, num_blocks) def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: @@ -397,29 +429,38 @@ def __init__( out_map: Optional[Sequence[Union[int, str]]] = None, return_dict: bool = False, output_fmt: str = 'NCHW', + norm: bool = False, + prune: bool = True, ): + """ + + Args: + model: Model to wrap. + out_indices: Indices of features to extract. + out_map: Remap feature names for dict output (WIP, not supported). + return_dict: Return features as dictionary instead of list (WIP, not supported). + norm: Apply final model norm to all output features (if possible). + """ super().__init__() - self.model = model + if prune and hasattr(model, 'prune_intermediate_layers'): + model.prune_intermediate_layers( + out_indices, + prune_norm=not norm, + ) self.feature_info = _get_feature_info(model, out_indices) + self.model = model self.out_indices = out_indices self.out_map = out_map self.return_dict = return_dict self.output_fmt = output_fmt + self.norm = norm - def forward(self, *args, **kwargs): - """ - def get_intermediate_layers( - self, - x: torch.Tensor, - n: Union[int, Sequence] = 1, - reshape: bool = False, - return_prefix_tokens: bool = False, - norm: bool = False, - """ - out = self.model.get_intermediate_layers( - *args, + def forward(self, x): + features = self.model.forward_intermediates( + x, n=self.out_indices, - reshape=True, - **kwargs, + norm=self.norm, + output_fmt=self.output_fmt, + features_only=True, ) - return out + return features diff --git a/timm/models/beit.py b/timm/models/beit.py index 0167099ce7..46d460d3d2 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -39,7 +39,7 @@ # --------------------------------------------------------' import math -from typing import Callable, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -52,8 +52,8 @@ from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._registry import generate_default_cfgs, register_model -from .vision_transformer import checkpoint_filter_fn __all__ = ['Beit'] @@ -333,6 +333,8 @@ def __init__( window_size=self.patch_embed.grid_size if use_rel_pos_bias else None, ) for i in range(depth)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)] use_fc_norm = self.global_pool == 'avg' self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) @@ -398,6 +400,93 @@ def reset_classifier(self, num_classes, global_pool=None): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + n: Optional[Union[int, List[int], Tuple[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + features_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + n: Take last n blocks if n is an int, if in is a sequence, select by matching indices + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + features_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + num_blocks = len(self.blocks) + if n is None: + n = num_blocks + take_indices, max_index = feature_take_indices(n, num_blocks) + + # forward pass + B, _, height, width = x.shape + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x, shared_rel_pos_bias=rel_pos_bias) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + if reshape: + # reshape == True => BCHW output format + patch_size = self.patch_embed.patch_size + H = int(math.ceil(height / patch_size[0])) + W = int(math.ceil(width / patch_size[1])) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + if not torch.jit.is_scripting() and return_prefix_tokens: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if features_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + n: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(n, len(self.blocks)) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.fc_norm = nn.Identity() + self.head = nn.Identity() + def forward_features(self, x): x = self.patch_embed(x) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) @@ -547,14 +636,13 @@ def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antia def _create_beit(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for BEiT models.') - + out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( Beit, variant, pretrained, - # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes pretrained_filter_fn=_beit_checkpoint_filter_fn, - **kwargs) + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) return model diff --git a/timm/models/deit.py b/timm/models/deit.py index f80087e80d..9400549dcc 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -119,14 +119,14 @@ def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor: def _create_deit(variant, pretrained=False, distilled=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') + out_indices = kwargs.pop('out_indices', 3) model_cls = VisionTransformerDistilled if distilled else VisionTransformer model = build_model_with_cfg( model_cls, variant, pretrained, pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True), + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) return model diff --git a/timm/models/eva.py b/timm/models/eva.py index 82fff28acf..fe121b0032 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -24,9 +24,8 @@ """ # EVA models Copyright (c) 2022 BAAI-Vision # EVA02 models Copyright (c) 2023 BAAI-Vision - import math -from typing import Callable, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -39,6 +38,7 @@ to_2tuple, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._registry import generate_default_cfgs, register_model __all__ = ['Eva'] @@ -469,6 +469,8 @@ def __init__( init_values=init_values, ) for i in range(depth)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)] use_fc_norm = self.global_pool == 'avg' self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) @@ -559,6 +561,85 @@ def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices) return x, rot_pos_embed + def forward_intermediates( + self, + x: torch.Tensor, + n: Optional[Union[int, List[int], Tuple[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + features_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + Args: + x: Input image tensor + n: Take last n blocks if n is an int, if in is a sequence, select by matching indices + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + features_only: Only return intermediate features + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format for EVA-ViT features must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + num_blocks = len(self.blocks) + if n is None: + n = num_blocks + take_indices, max_index = feature_take_indices(n, num_blocks) + + # forward pass + B, _, height, width = x.shape + x = self.patch_embed(x) + x, rot_pos_embed = self._pos_embed(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x, rope=rot_pos_embed) + if i in take_indices: + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + if reshape: + # reshape == True => BCHW output format + patch_size = self.patch_embed.patch_size + H = int(math.ceil(height / patch_size[0])) + W = int(math.ceil(width / patch_size[1])) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + if not torch.jit.is_scripting() and return_prefix_tokens: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if features_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + n: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(n, len(self.blocks)) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.fc_norm = nn.Identity() + self.head = nn.Identity() + def forward_features(self, x): x = self.patch_embed(x) x, rot_pos_embed = self._pos_embed(x) @@ -663,13 +744,13 @@ def checkpoint_filter_fn( def _create_eva(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Eva models.') - + out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( Eva, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, - **kwargs) + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) return model diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 9d035fd65a..579aa87e62 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -26,6 +26,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._registry import register_model, register_model_deprecations, generate_default_cfgs @@ -747,8 +748,10 @@ def __init__( num_stages = len(cfg.embed_dim) feat_size = patch_dims + curr_stride = max(cfg.patch_stride) dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] self.stages = nn.ModuleList() + self.feature_info = [] for i in range(num_stages): if cfg.expand_attn: dim_out = cfg.embed_dim[i] @@ -775,6 +778,8 @@ def __init__( norm_layer=norm_layer, drop_path=dpr[i], ) + curr_stride *= max(cfg.stride_q[i]) + self.feature_info += [dict(module=f'block.{i}', num_chs=dim_out, reduction=curr_stride)] embed_dim = dim_out feat_size = stage.feat_size self.stages.append(stage) @@ -829,6 +834,51 @@ def reset_classifier(self, num_classes, global_pool=None): ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) ])) + def forward_intermediates( + self, + x: torch.Tensor, + n: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + features_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + assert output_fmt in ('NCHW', 'NLC'), 'Output shape for MViT-V2 must be NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + num_stages = len(self.stages) # block list is two-tiered, first tier == stage + if n is None: + n = num_stages + take_indices, max_index = feature_take_indices(n, num_stages) + + # FIXME slice block/pos_block if < max + + # forward pass + x, feat_size = self.patch_embed(x) + B = x.shape[0] + if self.cls_token is not None: + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + for i, stage in enumerate(self.stages): + x, feat_size = stage(x, feat_size) + if i in take_indices: + if norm and i == (len(self.stages) - 1): + x_inter = self.norm(x) # applying final norm last intermediate + else: + x_inter = x + if reshape: + x_inter = x_inter.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2) + intermediates.append(x_inter) + + if features_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + def forward_features(self, x): x, feat_size = self.patch_embed(x) B, N, C = x.shape @@ -862,6 +912,18 @@ def forward(self, x): def checkpoint_filter_fn(state_dict, model): if 'stages.0.blocks.0.norm1.weight' in state_dict: + # native checkpoint, look for rel_pos interpolations + for k in state_dict.keys(): + if 'rel_pos' in k: + rel_pos = state_dict[k] + dest_rel_pos_shape = model.state_dict()[k].shape + if rel_pos.shape[0] != dest_rel_pos_shape[0]: + rel_pos_resized = torch.nn.functional.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=dest_rel_pos_shape[0], + mode="linear", + ) + state_dict[k] = rel_pos_resized.reshape(-1, dest_rel_pos_shape[0]).permute(1, 0) return state_dict import re @@ -892,16 +954,6 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('head.projection', 'head.fc') out_dict[k] = v - # for k, v in state_dict.items(): - # if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: - # # To resize pos embedding when using model at different size from pretrained weights - # v = resize_pos_embed( - # v, - # model.pos_embed, - # 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), - # model.patch_embed.grid_size - # ) - return out_dict @@ -948,16 +1000,14 @@ def checkpoint_filter_fn(state_dict, model): def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Multiscale Vision Transformer models.') - + out_indices = kwargs.pop('out_indices', 4) return build_model_with_cfg( MultiScaleVit, variant, pretrained, model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], pretrained_filter_fn=checkpoint_filter_fn, - feature_cfg=dict(flatten_sequential=True), + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) diff --git a/timm/models/twins.py b/timm/models/twins.py index 3cd25fb433..feba8e37b1 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -13,7 +13,7 @@ # -------------------------------------------------------- import math from functools import partial -from typing import Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -22,6 +22,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_module from ._registry import register_model, generate_default_cfgs from .vision_transformer import Attention @@ -324,6 +325,7 @@ def __init__( patch_size = 2 self.blocks = nn.ModuleList() + self.feature_info = [] dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule cur = 0 for k in range(len(depths)): @@ -339,6 +341,7 @@ def __init__( ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])], ) self.blocks.append(_block) + self.feature_info += [dict(module=f'block.{k}', num_chs=embed_dims[k], reduction=2**(2+k))] cur += depths[k] self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims]) @@ -401,6 +404,53 @@ def _init_weights(self, m): if m.bias is not None: m.bias.data.zero_() + def forward_intermediates( + self, + x: torch.Tensor, + n: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + features_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + assert output_fmt == 'NCHW', 'Output shape for Twins must be NCHW.' + intermediates = [] + num_stages = len(self.blocks) # block list is two-tiered, first tier == stage + if n is None: + n = num_stages + take_indices, max_index = feature_take_indices(n, num_stages) + + # FIXME slice block/pos_block if < max + + # forward pass + B, _, height, width = x.shape + for i, (embed, drop, blocks, pos_blk) in enumerate(zip( + self.patch_embeds, self.pos_drops, self.blocks, self.pos_block) + ): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + + if i < len(self.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + if i in take_indices: + intermediates.append(x) + else: + if i in take_indices: + # only last feature can be normed + x_feat = self.norm(x) if norm else x + intermediates.append(x_feat.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()) + + if features_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + def forward_features(self, x): B = x.shape[0] for i, (embed, drop, blocks, pos_blk) in enumerate( @@ -429,10 +479,12 @@ def forward(self, x): def _create_twins(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - - model = build_model_with_cfg(Twins, variant, pretrained, **kwargs) + out_indices = kwargs.pop('out_indices', 4) + model = build_model_with_cfg( + Twins, variant, pretrained, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b57104ac2e..24225206e0 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -45,6 +45,7 @@ trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ get_act_layer, get_norm_layer, LayerType from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -473,7 +474,6 @@ def __init__( self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False - self.feature_info = [] embed_args = {} if dynamic_img_size: @@ -631,58 +631,111 @@ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: return self.pos_drop(x) - def _intermediate_layers( + def forward_intermediates( self, x: torch.Tensor, - n: Union[int, Sequence] = 1, - ) -> List[torch.Tensor]: - outputs, num_blocks = [], len(self.blocks) - take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) - last_index_to_take = max(take_indices) + n: Optional[Union[int, List[int], Tuple[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + features_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + n: Take last n blocks if n is an int, if in is a sequence, select by matching indices + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + features_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + num_blocks = len(self.blocks) + if n is None: + n = num_blocks + take_indices, max_index = feature_take_indices(n, num_blocks) # forward pass + B, _, height, width = x.shape x = self.patch_embed(x) x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) - for i, blk in enumerate(self.blocks[: last_index_to_take + 1]): + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): x = blk(x) if i in take_indices: - outputs.append(x) + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + if reshape: + # reshape == True => BCHW output format + patch_size = self.patch_embed.patch_size + H = int(math.ceil(height / patch_size[0])) + W = int(math.ceil(width / patch_size[1])) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + if not torch.jit.is_scripting() and return_prefix_tokens: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) - return outputs + if features_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + n: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(n, len(self.blocks)) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + if self.attn_pool is not None: + self.attn_pool = None + self.fc_norm = nn.Identity() + self.head = nn.Identity() def get_intermediate_layers( self, x: torch.Tensor, - n: Union[int, Sequence] = 1, + n: Union[int, List[int], Tuple[int]] = 1, reshape: bool = False, return_prefix_tokens: bool = False, norm: bool = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: - """ Intermediate layer accessor (NOTE: This is a WIP experiment). - Inspired by DINO / DINOv2 interface + ) -> List[torch.Tensor]: + """ Intermediate layer accessor inspired by DINO / DINOv2 interface. + NOTE: This API is for backwards compat, favour using forward_intermediates() directly. """ - # take last n blocks if n is an int, if in is a sequence, select by matching indices - outputs = self._intermediate_layers(x, n) - if norm: - outputs = [self.norm(out) for out in outputs] - prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] - outputs = [out[:, self.num_prefix_tokens:] for out in outputs] - - if reshape: - patch_size = self.patch_embed.patch_size - batch, _, height, width = x.size() - outputs = [ - out.reshape(batch, int(math.ceil(height / patch_size[0])), int(math.ceil(width / patch_size[1])), -1) - .permute(0, 3, 1, 2) - .contiguous() - for out in outputs - ] - - if return_prefix_tokens: - return tuple(zip(outputs, prefix_tokens)) - return tuple(outputs) + return self.forward_intermediates( + x, n, + return_prefix_tokens=return_prefix_tokens, + norm=norm, + output_fmt='NCHW' if reshape else 'NLC', + features_only=True, + ) def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) @@ -2485,7 +2538,7 @@ def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransfo def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-S/14 for DINOv2 """ - model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518) + model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5) model = _create_vision_transformer( 'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2495,7 +2548,7 @@ def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/14 for DINOv2 """ - model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518) + model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5) model = _create_vision_transformer( 'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2505,7 +2558,7 @@ def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransfo def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-L/14 for DINOv2 """ - model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518) + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5) model = _create_vision_transformer( 'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2521,7 +2574,7 @@ def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 model_args = dict( patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, - mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU + mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU ) model = _create_vision_transformer( 'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 59b354fb3d..1171c8b939 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -11,21 +11,22 @@ """ import logging from functools import partial -from typing import Callable, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from torch.jit import Final - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead,\ +from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \ Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn +from torch.jit import Final + from ._builder import build_model_with_cfg +from ._features import feature_take_indices +from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model -from ._features_fx import register_notrace_function # model_registry will add each entrypoint fn to this __all__ = ['VisionTransformerSAM'] @@ -343,8 +344,7 @@ def __init__( attn_drop_rate: float = 0., drop_path_rate: float = 0., weight_init: str = '', - embed_layer: Callable = partial( - PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False), + embed_layer: Callable = partial(PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False), norm_layer: Optional[Callable] = nn.LayerNorm, act_layer: Optional[Callable] = nn.GELU, block_fn: Callable = Block, @@ -469,6 +469,8 @@ def __init__( rope=self.rope_window if i not in global_attn_indexes else self.rope_global, ) for i in range(depth)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)] if neck_chans: self.neck = nn.Sequential( @@ -536,6 +538,52 @@ def get_classifier(self): def reset_classifier(self, num_classes=0, global_pool=None): self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + n: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + features_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + assert output_fmt == 'NCHW', 'Output shape for ViT-SAM must be NCHW.' + intermediates = [] + num_blocks = len(self.blocks) + if n is None: + n = num_blocks + take_indices, max_index = feature_take_indices(n, num_blocks) + + # forward pass, collect intermediates + x = self.patch_embed(x) + if self.pos_embed is not None: + # dynamically resize abs pos embedding if needed + x = x + resample_abs_pos_embed_nhwc(self.pos_embed, x.shape[1:3]) + x = self.pos_drop(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + # make output BCHW + if norm: + # norm is intertwined with neck convs so apply both, changes the dim + # FIXME only apply to final? Need experiments + intermediates.append(self.neck(x.permute(0, 3, 1, 2))) + else: + intermediates.append(x.permute(0, 3, 1, 2)) + + if features_only: + return intermediates + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x, intermediates + def forward_features(self, x): x = self.patch_embed(x) if self.pos_embed is not None: @@ -618,15 +666,13 @@ def _cfg(url='', **kwargs): def _create_vision_transformer(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError( - 'features_only not implemented for Vision Transformer models.') - + out_indices = kwargs.pop('out_indices', 3) return build_model_with_cfg( VisionTransformerSAM, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) From ef9c6fb84633d82664372e5435c758300b748c16 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 9 Apr 2024 21:54:59 -0700 Subject: [PATCH 3/5] forward_head(), consistent pre_logits handling to reduce likelihood of people manually replacing .head module having issues --- timm/models/byobnet.py | 2 +- timm/models/cspnet.py | 2 +- timm/models/focalnet.py | 2 +- timm/models/gcvit.py | 2 +- timm/models/maxxvit.py | 2 +- timm/models/regnet.py | 2 +- timm/models/resnetv2.py | 2 +- timm/models/vovnet.py | 2 +- timm/models/xception_aligned.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index a504b7262b..a2ff00958c 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -1251,7 +1251,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 21b2cd344a..d02acfb06d 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -710,7 +710,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index b8d90db0eb..07410da4ff 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -464,7 +464,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index 29536a7dd2..653bc370aa 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -501,7 +501,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: return x def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 6283443ce5..cdddf61917 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -1258,7 +1258,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 4ece9f4c01..bc73f54067 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -525,7 +525,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 63fb203326..017c32964b 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -469,7 +469,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index f4a06065e8..8e9d1679d2 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -259,7 +259,7 @@ def forward_features(self, x): return self.stages(x) def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index e4b284255c..1656e72bf1 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -286,7 +286,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) From 4b2565e4cb6068b3f38ee2a4464e55b2e6f47946 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 10 Apr 2024 15:10:09 -0700 Subject: [PATCH 4/5] More forward_intermediates() / FeatureGetterNet work * include relpos vit * refactor reduction / size calcs so hybrid vits work and dynamic_img_size works * fix -ve feature indices when pruning * fix mvitv2 w/ class token * refine naming * add tests --- tests/test_models.py | 71 +++++++++++++++++ timm/layers/patch_embed.py | 20 ++++- timm/models/_features.py | 29 +++++-- timm/models/beit.py | 27 +++---- timm/models/eva.py | 27 +++---- timm/models/mvitv2.py | 26 +++++-- timm/models/twins.py | 22 ++++-- timm/models/vision_transformer.py | 31 ++++---- timm/models/vision_transformer_hybrid.py | 30 +++++++- timm/models/vision_transformer_relpos.py | 98 ++++++++++++++++++++++-- timm/models/vision_transformer_sam.py | 44 +++++++++-- 11 files changed, 339 insertions(+), 86 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index a6411a7856..c9d7f20cd7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -47,6 +47,11 @@ torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(False) +# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper +FEAT_INTER_FILTERS = [ + 'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*' +] + # transformer models don't support many of the spatial / feature based model functionalities NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', @@ -380,6 +385,72 @@ def test_model_forward_features(model_name, batch_size): assert not torch.isnan(o).any() +@pytest.mark.features +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, include_tags=True)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward_intermediates_features(model_name, batch_size): + """Run a single forward pass with each model in feature extraction mode""" + model = create_model(model_name, pretrained=False, features_only=True) + model.eval() + print(model.feature_info.out_indices) + expected_channels = model.feature_info.channels() + expected_reduction = model.feature_info.reduction() + + input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE) + if max(input_size) > MAX_FFEAT_SIZE: + pytest.skip("Fixed input size model > limit.") + output_fmt = getattr(model, 'output_fmt', 'NCHW') + feat_axis = get_channel_dim(output_fmt) + spatial_axis = get_spatial_dim(output_fmt) + import math + + outputs = model(torch.randn((batch_size, *input_size))) + assert len(expected_channels) == len(outputs) + spatial_size = input_size[-2:] + for e, r, o in zip(expected_channels, expected_reduction, outputs): + print(o.shape) + assert e == o.shape[feat_axis] + assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1 + assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1 + assert o.shape[0] == batch_size + assert not torch.isnan(o).any() + + +@pytest.mark.features +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, include_tags=True)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward_intermediates(model_name, batch_size): + """Run a single forward pass with each model in feature extraction mode""" + model = create_model(model_name, pretrained=False) + model.eval() + feature_info = timm.models.FeatureInfo(model.feature_info, len(model.feature_info)) + expected_channels = feature_info.channels() + expected_reduction = feature_info.reduction() + assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 + + input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE) + if max(input_size) > MAX_FFEAT_SIZE: + pytest.skip("Fixed input size model > limit.") + output_fmt = getattr(model, 'output_fmt', 'NCHW') + feat_axis = get_channel_dim(output_fmt) + spatial_axis = get_spatial_dim(output_fmt) + import math + + output, intermediates = model.forward_intermediates( + torch.randn((batch_size, *input_size)), + ) + assert len(expected_channels) == len(intermediates) + spatial_size = input_size[-2:] + for e, r, o in zip(expected_channels, expected_reduction, intermediates): + assert e == o.shape[feat_axis] + assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1 + assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1 + assert o.shape[0] == batch_size + assert not torch.isnan(o).any() + + def _create_fx_model(model, train=False): # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index 5970828528..3f148944e2 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -9,6 +9,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ import logging +import math from typing import Callable, List, Optional, Tuple, Union import torch @@ -65,6 +66,21 @@ def __init__( self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: + if as_scalar: + return max(self.patch_size) + else: + return self.patch_size + + def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: + """ Get grid (feature) size for given image size taking account of dynamic padding. + NOTE: must be torchscript compatible so using fixed tuple indexing + """ + if self.dynamic_img_pad: + return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1]) + else: + return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] + def forward(self, x): B, C, H, W = x.shape if self.img_size is not None: @@ -127,13 +143,13 @@ def forward(self, x) -> Tuple[torch.Tensor, List[int]]: _assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).") x = self.proj(x) - grid_size = x.shape[-2:] + feat_size = x.shape[-2:] if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC elif self.output_fmt != Format.NCHW: x = nchw_to(x, self.output_fmt) x = self.norm(x) - return x, grid_size + return x, feat_size def resample_patch_embed( diff --git a/timm/models/_features.py b/timm/models/_features.py index fa108798ad..565f1dd8ac 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -26,7 +26,10 @@ ] -def _take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[Set[int], int]: +def _take_indices( + num_blocks: int, + n: Optional[Union[int, List[int], Tuple[int]]], +) -> Tuple[Set[int], int]: if isinstance(n, int): assert n >= 0 take_indices = {x for x in range(num_blocks - n, num_blocks)} @@ -35,7 +38,10 @@ def _take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tupl return take_indices, max(take_indices) -def _take_indices_jit(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]: +def _take_indices_jit( + num_blocks: int, + n: Union[int, List[int], Tuple[int]], +) -> Tuple[List[int], int]: if isinstance(n, int): assert n >= 0 take_indices = [num_blocks - n + i for i in range(n)] @@ -47,12 +53,17 @@ def _take_indices_jit(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> return take_indices, max(take_indices) -def feature_take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]: +def feature_take_indices( + num_blocks: int, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, +) -> Tuple[List[int], int]: + if indices is None: + indices = num_blocks # all blocks if None if torch.jit.is_scripting(): - return _take_indices_jit(n, num_blocks) + return _take_indices_jit(num_blocks, indices) else: # NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno - return _take_indices(n, num_blocks) + return _take_indices(num_blocks, indices) def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: @@ -443,10 +454,12 @@ def __init__( """ super().__init__() if prune and hasattr(model, 'prune_intermediate_layers'): - model.prune_intermediate_layers( + # replace out_indices after they've been normalized, -ve indices will be invalid after prune + out_indices = model.prune_intermediate_layers( out_indices, prune_norm=not norm, ) + out_indices = list(out_indices) self.feature_info = _get_feature_info(model, out_indices) self.model = model self.out_indices = out_indices @@ -458,9 +471,9 @@ def __init__( def forward(self, x): features = self.model.forward_intermediates( x, - n=self.out_indices, + indices=self.out_indices, norm=self.norm, output_fmt=self.output_fmt, - features_only=True, + intermediates_only=True, ) return features diff --git a/timm/models/beit.py b/timm/models/beit.py index 46d460d3d2..19bf2c587a 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -302,6 +302,7 @@ def __init__( embed_dim=embed_dim, ) num_patches = self.patch_embed.num_patches + r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) @@ -334,7 +335,7 @@ def __init__( ) for i in range(depth)]) self.feature_info = [ - dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)] + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] use_fc_norm = self.global_pool == 'avg' self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) @@ -403,33 +404,30 @@ def reset_classifier(self, num_classes, global_pool=None): def forward_intermediates( self, x: torch.Tensor, - n: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, stop_early: bool = True, output_fmt: str = 'NCHW', - features_only: bool = False, + intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. Args: x: Input image tensor - n: Take last n blocks if n is an int, if in is a sequence, select by matching indices + indices: Take last n blocks if an int, if is a sequence, select by matching indices return_prefix_tokens: Return both prefix and spatial intermediate tokens norm: Apply norm layer to all intermediates stop_early: Stop iterating over blocks when last desired intermediate hit output_fmt: Shape of intermediate feature outputs - features_only: Only return intermediate features + intermediates_only: Only return intermediate features Returns: """ assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' intermediates = [] - num_blocks = len(self.blocks) - if n is None: - n = num_blocks - take_indices, max_index = feature_take_indices(n, num_blocks) + take_indices, max_index = feature_take_indices(len(self.blocks), indices) # forward pass B, _, height, width = x.shape @@ -455,16 +453,14 @@ def forward_intermediates( prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] if reshape: - # reshape == True => BCHW output format - patch_size = self.patch_embed.patch_size - H = int(math.ceil(height / patch_size[0])) - W = int(math.ceil(width / patch_size[1])) + # reshape to BCHW output format + H, W = self.patch_embed.dynamic_feat_size((height, width)) intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] if not torch.jit.is_scripting() and return_prefix_tokens: # return_prefix not support in torchscript due to poor type handling intermediates = list(zip(intermediates, prefix_tokens)) - if features_only: + if intermediates_only: return intermediates x = self.norm(x) @@ -479,13 +475,14 @@ def prune_intermediate_layers( ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(n, len(self.blocks)) + take_indices, max_index = feature_take_indices(len(self.blocks), n) self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_norm: self.norm = nn.Identity() if prune_head: self.fc_norm = nn.Identity() self.head = nn.Identity() + return take_indices def forward_features(self, x): x = self.patch_embed(x) diff --git a/timm/models/eva.py b/timm/models/eva.py index fe121b0032..416bc9511d 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -424,6 +424,7 @@ def __init__( **embed_args, ) num_patches = self.patch_embed.num_patches + r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None @@ -470,7 +471,7 @@ def __init__( ) for i in range(depth)]) self.feature_info = [ - dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)] + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] use_fc_norm = self.global_pool == 'avg' self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) @@ -564,30 +565,27 @@ def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: def forward_intermediates( self, x: torch.Tensor, - n: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, stop_early: bool = True, output_fmt: str = 'NCHW', - features_only: bool = False, + intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. Args: x: Input image tensor - n: Take last n blocks if n is an int, if in is a sequence, select by matching indices + indices: Take last n blocks if an int, if is a sequence, select by matching indices return_prefix_tokens: Return both prefix and spatial intermediate tokens norm: Apply norm layer to all intermediates stop_early: Stop iterating over blocks when last desired intermediate hit output_fmt: Shape of intermediate feature outputs - features_only: Only return intermediate features + intermediates_only: Only return intermediate features """ assert output_fmt in ('NCHW', 'NLC'), 'Output format for EVA-ViT features must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' intermediates = [] - num_blocks = len(self.blocks) - if n is None: - n = num_blocks - take_indices, max_index = feature_take_indices(n, num_blocks) + take_indices, max_index = feature_take_indices(len(self.blocks), indices) # forward pass B, _, height, width = x.shape @@ -608,16 +606,14 @@ def forward_intermediates( prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] if reshape: - # reshape == True => BCHW output format - patch_size = self.patch_embed.patch_size - H = int(math.ceil(height / patch_size[0])) - W = int(math.ceil(width / patch_size[1])) + # reshape to BCHW output format + H, W = self.patch_embed.dynamic_feat_size((height, width)) intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] if not torch.jit.is_scripting() and return_prefix_tokens: # return_prefix not support in torchscript due to poor type handling intermediates = list(zip(intermediates, prefix_tokens)) - if features_only: + if intermediates_only: return intermediates x = self.norm(x) @@ -632,13 +628,14 @@ def prune_intermediate_layers( ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(n, len(self.blocks)) + take_indices, max_index = feature_take_indices(len(self.blocks), n) self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_norm: self.norm = nn.Identity() if prune_head: self.fc_norm = nn.Identity() self.head = nn.Identity() + return take_indices def forward_features(self, x): x = self.patch_embed(x) diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 579aa87e62..5d14646888 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -837,19 +837,28 @@ def reset_classifier(self, num_classes, global_pool=None): def forward_intermediates( self, x: torch.Tensor, - n: Union[int, List[int], Tuple[int]] = None, + indices: Union[int, List[int], Tuple[int]] = None, norm: bool = False, stop_early: bool = True, output_fmt: str = 'NCHW', - features_only: bool = False, + intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ assert output_fmt in ('NCHW', 'NLC'), 'Output shape for MViT-V2 must be NCHW or NLC.' reshape = output_fmt == 'NCHW' intermediates = [] - num_stages = len(self.stages) # block list is two-tiered, first tier == stage - if n is None: - n = num_stages - take_indices, max_index = feature_take_indices(n, num_stages) + take_indices, max_index = feature_take_indices(len(self.stages), indices) # FIXME slice block/pos_block if < max @@ -869,10 +878,13 @@ def forward_intermediates( else: x_inter = x if reshape: + if self.cls_token is not None: + # possible to allow return of class tokens, TBD + x_inter = x_inter[:, 1:] x_inter = x_inter.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2) intermediates.append(x_inter) - if features_only: + if intermediates_only: return intermediates x = self.norm(x) diff --git a/timm/models/twins.py b/timm/models/twins.py index feba8e37b1..24f7b80102 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -407,18 +407,26 @@ def _init_weights(self, m): def forward_intermediates( self, x: torch.Tensor, - n: Union[int, List[int], Tuple[int]] = None, + indices: Union[int, List[int], Tuple[int]] = None, norm: bool = False, stop_early: bool = True, output_fmt: str = 'NCHW', - features_only: bool = False, + intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ assert output_fmt == 'NCHW', 'Output shape for Twins must be NCHW.' intermediates = [] - num_stages = len(self.blocks) # block list is two-tiered, first tier == stage - if n is None: - n = num_stages - take_indices, max_index = feature_take_indices(n, num_stages) + take_indices, max_index = feature_take_indices(len(self.blocks), indices) # FIXME slice block/pos_block if < max @@ -444,7 +452,7 @@ def forward_intermediates( x_feat = self.norm(x) if norm else x intermediates.append(x_feat.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()) - if features_only: + if intermediates_only: return intermediates x = self.norm(x) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 24225206e0..c20564baa2 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -27,7 +27,7 @@ import math from collections import OrderedDict from functools import partial -from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List +from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List try: from typing import Literal except ImportError: @@ -489,6 +489,7 @@ def __init__( **embed_args, ) num_patches = self.patch_embed.num_patches + r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None @@ -522,7 +523,7 @@ def __init__( ) for i in range(depth)]) self.feature_info = [ - dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)] + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head @@ -634,33 +635,30 @@ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: def forward_intermediates( self, x: torch.Tensor, - n: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, stop_early: bool = True, output_fmt: str = 'NCHW', - features_only: bool = False, + intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. Args: x: Input image tensor - n: Take last n blocks if n is an int, if in is a sequence, select by matching indices + indices: Take last n blocks if int, all if None, select matching indices if sequence return_prefix_tokens: Return both prefix and spatial intermediate tokens norm: Apply norm layer to all intermediates stop_early: Stop iterating over blocks when last desired intermediate hit output_fmt: Shape of intermediate feature outputs - features_only: Only return intermediate features + intermediates_only: Only return intermediate features Returns: """ assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' intermediates = [] - num_blocks = len(self.blocks) - if n is None: - n = num_blocks - take_indices, max_index = feature_take_indices(n, num_blocks) + take_indices, max_index = feature_take_indices(len(self.blocks), indices) # forward pass B, _, height, width = x.shape @@ -684,16 +682,14 @@ def forward_intermediates( prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] if reshape: - # reshape == True => BCHW output format - patch_size = self.patch_embed.patch_size - H = int(math.ceil(height / patch_size[0])) - W = int(math.ceil(width / patch_size[1])) + # reshape to BCHW output format + H, W = self.patch_embed.dynamic_feat_size((height, width)) intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] if not torch.jit.is_scripting() and return_prefix_tokens: # return_prefix not support in torchscript due to poor type handling intermediates = list(zip(intermediates, prefix_tokens)) - if features_only: + if intermediates_only: return intermediates x = self.norm(x) @@ -708,7 +704,7 @@ def prune_intermediate_layers( ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(n, len(self.blocks)) + take_indices, max_index = feature_take_indices(len(self.blocks), n) self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_norm: self.norm = nn.Identity() @@ -717,6 +713,7 @@ def prune_intermediate_layers( self.attn_pool = None self.fc_norm = nn.Identity() self.head = nn.Identity() + return take_indices def get_intermediate_layers( self, @@ -734,7 +731,7 @@ def get_intermediate_layers( return_prefix_tokens=return_prefix_tokens, norm=norm, output_fmt='NCHW' if reshape else 'NLC', - features_only=True, + intermediates_only=True, ) def forward_features(self, x: torch.Tensor) -> torch.Tensor: diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 16e3d1b7c3..25dd9c275a 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -13,8 +13,9 @@ Hacked together by / Copyright 2020, Ross Wightman """ +import math from functools import partial -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -41,6 +42,7 @@ def __init__( img_size=224, patch_size=1, feature_size=None, + feature_ratio=None, in_chans=3, embed_dim=768, bias=True, @@ -68,15 +70,20 @@ def __init__( feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) + feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)]) else: + feature_size = to_2tuple(feature_size) + feature_ratio = to_2tuple(feature_ratio or 16) if hasattr(self.backbone, 'feature_info'): feature_dim = self.backbone.feature_info.channels()[-1] else: feature_dim = self.backbone.num_features if not dynamic_img_pad: assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 - self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) + self.feature_size = feature_size + self.feature_ratio = feature_ratio + self.grid_size = tuple([f // p for f, p in zip(self.feature_size, self.patch_size)]) self.num_patches = self.grid_size[0] * self.grid_size[1] if output_fmt is not None: self.flatten = False @@ -90,6 +97,25 @@ def __init__( self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: + total_reduction = ( + self.feature_ratio[0] * self.patch_size[0], + self.feature_ratio[1] * self.patch_size[1] + ) + if as_scalar: + return max(total_reduction) + else: + return total_reduction + + def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: + """ Get feature grid size taking account dynamic padding and backbone network feat reduction + """ + feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1]) + if self.dynamic_img_pad: + return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1]) + else: + return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1] + def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index ea8cf0ea1d..8461ada39a 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -7,7 +7,7 @@ import logging import math from functools import partial -from typing import Optional, Tuple, Type, Union +from typing import List, Optional, Tuple, Type, Union try: from typing import Literal @@ -22,6 +22,7 @@ from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import named_apply from ._registry import generate_default_cfgs, register_model from .vision_transformer import get_init_weights_vit @@ -297,6 +298,7 @@ def __init__( embed_dim=embed_dim, ) feat_size = self.patch_embed.grid_size + r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens) if rel_pos_type.startswith('mlp'): @@ -332,6 +334,8 @@ def __init__( act_layer=act_layer, ) for i in range(depth)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity() # Classifier Head @@ -384,6 +388,88 @@ def reset_classifier(self, num_classes: int, global_pool=None): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + # forward pass + B, _, height, width = x.shape + x = self.patch_embed(x) + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + shared_rel_pos = self.shared_rel_pos.get_bias() if self.shared_rel_pos is not None else None + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x, shared_rel_pos=shared_rel_pos) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + if reshape: + # reshape to BCHW output format + H, W = self.patch_embed.dynamic_feat_size((height, width)) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + if not torch.jit.is_scripting() and return_prefix_tokens: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + n: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.blocks), n) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.fc_norm = nn.Identity() + self.head = nn.Identity() + return take_indices + def forward_features(self, x): x = self.patch_embed(x) if self.cls_token is not None: @@ -412,10 +498,12 @@ def forward(self, x): def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - - model = build_model_with_cfg(VisionTransformerRelPos, variant, pretrained, **kwargs) + out_indices = kwargs.pop('out_indices', 3) + model = build_model_with_cfg( + VisionTransformerRelPos, variant, pretrained, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) return model diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 1171c8b939..d4d974a0d5 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -408,6 +408,8 @@ def __init__( bias=not pre_norm, # disable bias if pre-norm is used ) grid_size = self.patch_embed.grid_size + r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size + if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter(torch.zeros(1, grid_size[0], grid_size[1], embed_dim)) @@ -470,7 +472,7 @@ def __init__( ) for i in range(depth)]) self.feature_info = [ - dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)] + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] if neck_chans: self.neck = nn.Sequential( @@ -541,18 +543,27 @@ def reset_classifier(self, num_classes=0, global_pool=None): def forward_intermediates( self, x: torch.Tensor, - n: Union[int, List[int], Tuple[int]] = None, + indices: Union[int, List[int], Tuple[int]] = None, norm: bool = False, stop_early: bool = True, output_fmt: str = 'NCHW', - features_only: bool = False, + intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ assert output_fmt == 'NCHW', 'Output shape for ViT-SAM must be NCHW.' intermediates = [] - num_blocks = len(self.blocks) - if n is None: - n = num_blocks - take_indices, max_index = feature_take_indices(n, num_blocks) + take_indices, max_index = feature_take_indices(len(self.blocks), indices) # forward pass, collect intermediates x = self.patch_embed(x) @@ -577,13 +588,30 @@ def forward_intermediates( else: intermediates.append(x.permute(0, 3, 1, 2)) - if features_only: + if intermediates_only: return intermediates x = self.neck(x.permute(0, 3, 1, 2)) return x, intermediates + def prune_intermediate_layers( + self, + n: Union[int, List[int], Tuple[int]] = None, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.blocks), n) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + # neck is being treated as equivalent to final norm here + self.neck = nn.Identity() + if prune_head: + self.head = nn.Identity() + return take_indices + def forward_features(self, x): x = self.patch_embed(x) if self.pos_embed is not None: From fe3cf542faad532b8690421a4c374338bd95c7f9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 10 Apr 2024 21:14:02 -0700 Subject: [PATCH 5/5] Fix / improve tests for features --- tests/test_models.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index c9d7f20cd7..21f37a76a3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -55,8 +55,8 @@ # transformer models don't support many of the spatial / feature based model functionalities NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', - 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', - 'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', + 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', + 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*' ] NUM_NON_STD = len(NON_STD_FILTERS) @@ -356,7 +356,7 @@ def test_model_forward_torchscript(model_name, batch_size): @pytest.mark.features @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True)) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_features(model_name, batch_size): """Run a single forward pass with each model in feature extraction mode""" @@ -364,7 +364,7 @@ def test_model_forward_features(model_name, batch_size): model.eval() expected_channels = model.feature_info.channels() expected_reduction = model.feature_info.reduction() - assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 + assert len(expected_channels) >= 3 # all models here should have at least 3 default feat levels input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE) if max(input_size) > MAX_FFEAT_SIZE: @@ -387,7 +387,7 @@ def test_model_forward_features(model_name, batch_size): @pytest.mark.features @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, include_tags=True)) +@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_intermediates_features(model_name, batch_size): """Run a single forward pass with each model in feature extraction mode""" @@ -419,7 +419,7 @@ def test_model_forward_intermediates_features(model_name, batch_size): @pytest.mark.features @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, include_tags=True)) +@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_intermediates(model_name, batch_size): """Run a single forward pass with each model in feature extraction mode"""