Zhipeng
init project
966d9af
"""
CapsNeck: efficient capsule-style neck blocks for Ultralytics YAML models.
Design intent:
- Keep capsule semantics (type/channel grouping + routing-style fusion).
- Stay lightweight and export-friendly for detection training/inference.
- Avoid expensive iterative EM/dynamic routing inside the neck path.
This neck is "capsule-style" rather than a full matrix-capsule network:
1) CapsProj : CNN feature -> packed capsules (K types * D dims)
2) CapsAlign : scale alignment between pyramid levels (no global context)
3) CapsRoute : efficient self-routing proxy across sources (softmax source gating)
4) CapsDecode: packed capsules -> standard feature map for Detect
5) CapsuleTap: optional pass-through cache hook for analysis/aux losses
Note:
- Routing here is source-level and single-step by default (iters=1), chosen for speed.
- If stronger capsule routing is needed, it should be added in the head where cost is lower.
"""
from __future__ import annotations
from typing import List, Optional, Tuple, Union
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.nn.modules import C3k2, Conv, DWConv
# -------------------------
# 1) CapsProj
# -------------------------
class CapsProj(nn.Module):
"""
Project a standard feature map into packed capsule channels using one C3k2 block.
Input: x [B, C, H, W]
Output: u [B, K*(D+1), H, W]
Args:
K: number of capsule types
D: capsule pose dimension per type
mix/mix_kernel: kept for backward YAML compatibility (unused)
"""
def __init__(self, c1: int, K: int = 4, D: int = 16):
super().__init__()
self.K = int(K)
self.D = int(D)
self.c_out = self.K * (self.D + 1)
# Use a single C3k2 block as the capsule projection operator.
self.map = C3k2(c1, self.c_out, n=1, c3k=False, e=0.5, g=1, shortcut=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.map(x)
# -------------------------
# 2) CapsAlign (no context)
# -------------------------
class CapsAlign(nn.Module):
"""
Align packed capsules across pyramid levels with YOLO-style ops.
- Upsampling uses ``nn.Upsample(scale_factor=2, mode='nearest')``.
- Downsampling uses stride-2 ``Conv`` blocks.
Args:
c1: input/output channel count.
src_level: source pyramid level in {3,4,5}.
tgt_level: target pyramid level in {3,4,5}.
down_groups: groups for downsample Conv.
Use capsule-type count K to keep each capsule block isolated.
"""
def __init__(self, c1: int, src_level: int, tgt_level: int, down_groups: int = 1):
super().__init__()
self.c1 = int(c1)
self.src_level = int(src_level)
self.tgt_level = int(tgt_level)
self.down_groups = int(down_groups)
if self.src_level not in (3, 4, 5) or self.tgt_level not in (3, 4, 5):
raise ValueError("CapsAlign levels must be in {3,4,5}.")
if self.down_groups < 1 or self.c1 % self.down_groups != 0:
raise ValueError(f"CapsAlign down_groups={self.down_groups} must divide c1={self.c1}.")
steps = abs(self.src_level - self.tgt_level)
if self.src_level == self.tgt_level:
self.mode = 'identity'
self.ops = nn.ModuleList()
elif self.src_level > self.tgt_level:
self.mode = 'up'
# YOLO-style top-down path: nearest-neighbor upsample x2 per level.
self.ops = nn.ModuleList(nn.Upsample(scale_factor=2, mode='nearest') for _ in range(steps))
else:
self.mode = 'down'
# YOLO-style bottom-up path: stride-2 grouped Conv per level.
self.ops = nn.ModuleList(Conv(self.c1, self.c1, 3, 2, g=self.down_groups) for _ in range(steps))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.mode == 'identity':
return x
for op in self.ops:
x = op(x)
return x
# -------------------------
# 3) CapsRoute (light, parser-friendly)
# -------------------------
class ConvSelfRouting(nn.Module):
"""Grouped-conv self-routing over stacked capsule sources.
Args:
K_in: input capsule type count.
P_in: input pose dimension.
K_out: output capsule type count.
P_out: output pose dimension.
kernel_size: grouped conv kernel for local capsule mixing.
"""
def __init__(self, K_in: int, P_in: int, K_out: int, P_out: int, kernel_size: int = 3):
super().__init__()
self.K_in = int(K_in)
self.P_in = int(P_in)
self.K_out = int(K_out)
self.P_out = int(P_out)
if min(self.K_in, self.P_in, self.K_out, self.P_out) <= 0:
raise ValueError('ConvSelfRouting expects positive K/P values.')
self.c_in = self.K_in * (self.P_in + 1)
self.c_out = self.K_out * (self.P_out + 1)
k = int(kernel_size)
padding = k//2
self.mix = nn.Conv2d(self.c_in, self.c_in, kernel_size=k, stride=1, padding=padding, groups=self.K_in, bias=False)
self.gate = nn.Conv2d(self.c_in, self.K_in, kernel_size=1, stride=1, padding=0, groups=self.K_in, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B,C,H,W], C = K_in*(P_in+1)
b, c, h, w = x.shape
if c != self.c_in:
raise ValueError(f'ConvSelfRouting expected C={self.c_in}, got C={c}')
mixed = self.mix(x)
logits = self.gate(mixed).reshape(b, self.K_in, h, w)
weights = logits.softmax(dim=1)
caps = mixed.reshape(b, self.K_in, self.P_in + 1, h, w)
routed = weights.unsqueeze(2) * caps
routed = routed.reshape(b, self.c_in, h, w)
return routed
class SelfRouting(nn.Module):
"""Pose-transform self-routing on packed capsule tensor.
Args:
K_in: input capsule type count.
P_in: input pose dimension.
K_out: output capsule type count.
P_out: output pose dimension.
Input:
x: [B, K_in*(P_in+1), H, W]
Output:
y: [B, K_out*(P_out+1), H, W]
"""
def __init__(self, K_in: int, P_in: int, K_out: int, P_out: int):
super().__init__()
self.K_in = int(K_in)
self.P_in = int(P_in)
self.K_out = int(K_out)
self.P_out = int(P_out)
if min(self.K_in, self.P_in, self.K_out, self.P_out) <= 0:
raise ValueError('SelfRouting expects positive K/P values.')
self.c_in = self.K_in * (self.P_in + 1)
self.c_out = self.K_out * (self.P_out + 1)
self.eps = 1e-6
self.W_pose = nn.Parameter(torch.empty(self.K_in, self.K_out, self.P_in, self.P_out))
nn.init.kaiming_uniform_(self.W_pose, a=math.sqrt(5))
self.W_gate = nn.Parameter(torch.zeros(self.K_in, self.K_out, self.P_in))
self.b_gate = nn.Parameter(torch.zeros(1, self.K_in, self.K_out, 1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, C, H, W], C = K_in*(P_in+1)
if x.ndim != 4:
raise TypeError(f'SelfRouting expects [B,C,H,W], got {tuple(x.shape)}')
b, c, h, w = x.shape
if c != self.c_in:
raise ValueError(f'SelfRouting expected C={self.c_in}, got C={c}')
# Packed capsule layout is interleaved per type: [pose(P), act(1)].
# x_caps: [B, K_in, P_in+1, H, W]
x_caps = x.reshape(b, self.K_in, self.P_in + 1, h, w)
pose = x_caps[:, :, :self.P_in] # [B, K_in, P_in, H, W]
act = x_caps[:, :, self.P_in : self.P_in + 1].sigmoid() # [B, K_in, 1, H, W]
# votes: [B, K_in, K_out, H, W, P_out]
votes = torch.einsum('bkphw,kopq->bkohwq', pose, self.W_pose)
# logits/weights: [B, K_in, K_out, H, W]
logits = torch.einsum('bkphw,kop->bkohw', pose, self.W_gate) + self.b_gate
weights = logits.softmax(dim=2)
ar = weights * act # [B, K_in, K_out, H, W]
ar_sum = ar.sum(dim=1, keepdim=True) + self.eps
coeff = ar / ar_sum
pose_out = (coeff.unsqueeze(-1) * votes).sum(dim=1) # [B, K_out, H, W, P_out]
pose_out = pose_out.permute(0, 1, 4, 2, 3) # [B, K_out, P_out, H, W]
act_out = ar_sum.squeeze(1).unsqueeze(2) # [B, K_out, 1, H, W]
# Keep interleaved packed output: [pose(P_out), act(1)] per capsule type.
out = torch.cat([pose_out, act_out], dim=2).reshape(b, self.c_out, h, w)
return out
class HybridRoute1(nn.Module):
"""Conv-heavy replacement for SelfRouting with lightweight capsule-aware gating."""
def __init__(self, K_in: int, P_in: int, K_out: int, P_out: int):
super().__init__()
self.K_in = int(K_in)
self.P_in = int(P_in)
self.K_out = int(K_out)
self.P_out = int(P_out)
self.c_in = self.K_in * (self.P_in + 1)
self.c_out = self.K_out * (self.P_out + 1)
pose_in = self.K_in * self.P_in
pose_out = self.K_out * self.P_out
vote_groups = math.gcd(self.K_in, self.K_out)
vote_groups = max(int(vote_groups), 1)
self.vote_proj = Conv(pose_in, pose_out, 1, 1, g=vote_groups)
self.gate_proj = nn.Conv2d(self.c_in, self.K_out, kernel_size=1, stride=1, padding=0, bias=True)
self.act_proj = nn.Conv2d(self.K_in, self.K_out, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.ndim != 4:
raise TypeError(f'HybridRoute1 expects [B,C,H,W], got {tuple(x.shape)}')
b, c, h, w = x.shape
if c != self.c_in:
raise ValueError(f'HybridRoute1 expected C={self.c_in}, got C={c}')
x_caps = x.reshape(b, self.K_in, self.P_in + 1, h, w)
pose = x_caps[:, :, :self.P_in].reshape(b, self.K_in * self.P_in, h, w)
act = x_caps[:, :, self.P_in].contiguous()
pose_votes = self.vote_proj(pose).reshape(b, self.K_out, self.P_out, h, w)
gate = self.gate_proj(x).sigmoid().unsqueeze(2)
pose_out = pose_votes * gate
act_out = self.act_proj(act).sigmoid().unsqueeze(2)
out = torch.cat([pose_out, act_out], dim=2).reshape(b, self.c_out, h, w)
return out
class CapsRoute(nn.Module):
"""Capsule routing fusion by direct capsule concatenation.
Args:
K_in: list of input capsule type counts per source.
P_in: list of input pose dimensions per source.
K_out: target output capsule type count.
P_out: target output pose dimension.
kernel_size: grouped-conv kernel for ``ConvSelfRouting``.
Notes:
Inputs are concatenated directly (no pre-projection).
For direct packed concat, all ``P_in`` must be identical.
"""
def __init__(
self,
K_in: Union[List[int], Tuple[int, ...]],
P_in: Union[List[int], Tuple[int, ...]],
K_out: int,
P_out: int,
kernel_size: int = 3,
pre_k: int = 3,
post_k: int = 3,
pre_groups: Optional[int] = None,
post_groups: Optional[int] = None,
):
super().__init__()
self.K_in_list = [int(v) for v in K_in]
self.P_in_list = [int(v) for v in P_in]
if len(self.K_in_list) < 2 or len(self.K_in_list) != len(self.P_in_list):
raise ValueError('CapsRoute expects K_in/P_in lists with same length >= 2.')
if min(*self.K_in_list, *self.P_in_list) <= 0:
raise ValueError('CapsRoute expects positive K_in/P_in values.')
# Direct capsule concat requires a shared pose dimension.
if len(set(self.P_in_list)) != 1:
raise ValueError('CapsRoute direct concat requires all P_in to be identical.')
self.num_sources = len(self.K_in_list)
self.P_cat = int(self.P_in_list[0])
self.K_cat = int(sum(self.K_in_list))
self.c_cat = self.K_cat * (self.P_cat + 1)
self.K_out = int(K_out)
self.P_out = int(P_out)
if min(self.K_out, self.P_out) <= 0:
raise ValueError('CapsRoute expects positive K_out/P_out values.')
self.c_out = self.K_out * (self.P_out + 1)
# self.conv_route = ConvSelfRouting(
# K_in=self.K_cat,
# P_in=self.P_cat,
# K_out=self.K_cat,
# P_out=self.P_cat,
# kernel_size=kernel_size,
# )
# Grouped Conv before routing: C = K_cat * (P_cat + 1), groups = K_cat.
self.conv_route = Conv(self.c_cat, self.c_cat, 3, 1, g=self.K_cat)
self.route1 = SelfRouting(K_in=self.K_cat, P_in=self.P_cat, K_out=self.K_out, P_out=self.P_out)
# Grouped Conv after routing: C = K_out * (P_out + 1), groups = K_out.
self.spagg = Conv(self.c_out, self.c_out, 3, 1, g=self.K_out)
# self.route2 = SelfRouting(K_in=self.K_out, P_in=self.P_out, K_out=self.K_out, P_out=self.P_out)
def forward(self, xs: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]]) -> torch.Tensor:
if not isinstance(xs, (list, tuple)):
raise TypeError(f'CapsRoute expects list/tuple inputs, got {type(xs)}')
if len(xs) != self.num_sources:
raise ValueError(f'CapsRoute expected {self.num_sources} sources, got {len(xs)}')
h, w = int(xs[0].shape[-2]), int(xs[0].shape[-1])
cat_parts = []
for i, x in enumerate(xs):
expected_c = self.K_in_list[i] * (self.P_in_list[i] + 1)
if int(x.shape[1]) != expected_c:
raise ValueError(f'CapsRoute source-{i} expected C={expected_c} from K_in/P_in, got C={int(x.shape[1])}')
if int(x.shape[-2]) != h or int(x.shape[-1]) != w:
raise ValueError('CapsRoute inputs must share H,W. Use CapsAlign before routing.')
cat_parts.append(x)
x_cat = torch.cat(cat_parts, dim=1) # [B, K_cat*(P+1), H, W]
routed = self.route1(self.conv_route(x_cat))
routed = self.spagg(routed)
return routed
class CapsRoutev2(CapsRoute):
"""CapsRoute with per-capsule pose refinement and act residual update."""
def __init__(
self,
K_in: Union[List[int], Tuple[int, ...]],
P_in: Union[List[int], Tuple[int, ...]],
K_out: int,
P_out: int,
kernel_size: int = 3,
pre_k: int = 3,
post_k: int = 3,
pre_groups: Optional[int] = None,
post_groups: Optional[int] = None,
):
super().__init__(K_in, P_in, K_out, P_out, kernel_size, pre_k, post_k, pre_groups, post_groups)
_ = (post_k, post_groups, pre_k, pre_groups) # kept for YAML/API compatibility
self.profile_route = False
self._route_profile = {
'cat_ms': 0.0,
'conv_route_ms': 0.0,
'route1_ms': 0.0,
'pose_refine_ms': 0.0,
'act_from_pose_ms': 0.0,
'pack_ms': 0.0,
'calls': 0.0,
}
deep_stage = self.K_out >= 64
pose_ch = self.K_out * self.P_out
# Match YOLO26 neck style:
# - shallow/mid stages: C3k2(n=2, c3k=True, attn=False)
# - deep stage: C3k2(n=1, c3k=True, attn=True)
pose_e = 0.5 if (self.P_out % 2 == 0) else 1.0
self.pose_refine = C3k2(
pose_ch,
pose_ch,
n=1 if deep_stage else 2,
c3k=True,
e=pose_e,
attn=deep_stage,
g=self.K_out,
shortcut=True,
)
self.act_from_pose = Conv(pose_ch, self.K_out, 1, 1, g=self.K_out)
self.act_alpha = nn.Parameter(torch.tensor(0.1))
@staticmethod
def _sync_profile() -> None:
if torch.cuda.is_available():
torch.cuda.synchronize()
def _ensure_route_profile_state(self) -> None:
if not hasattr(self, "profile_route"):
self.profile_route = False
if not hasattr(self, "_route_profile"):
self._route_profile = {
'cat_ms': 0.0,
'conv_route_ms': 0.0,
'route1_ms': 0.0,
'pose_refine_ms': 0.0,
'act_from_pose_ms': 0.0,
'pack_ms': 0.0,
'calls': 0.0,
}
def reset_route_profile(self) -> None:
self._ensure_route_profile_state()
for k in self._route_profile:
self._route_profile[k] = 0.0
def get_route_profile(self) -> dict:
self._ensure_route_profile_state()
calls = max(float(self._route_profile.get('calls', 0.0)), 1.0)
total = (
self._route_profile['cat_ms']
+ self._route_profile['conv_route_ms']
+ self._route_profile['route1_ms']
+ self._route_profile['pose_refine_ms']
+ self._route_profile['act_from_pose_ms']
+ self._route_profile['pack_ms']
)
out = dict(self._route_profile)
out['total_ms'] = total
out['cat_avg_ms'] = self._route_profile['cat_ms'] / calls
out['conv_route_avg_ms'] = self._route_profile['conv_route_ms'] / calls
out['route1_avg_ms'] = self._route_profile['route1_ms'] / calls
out['pose_refine_avg_ms'] = self._route_profile['pose_refine_ms'] / calls
out['act_from_pose_avg_ms'] = self._route_profile['act_from_pose_ms'] / calls
out['pack_avg_ms'] = self._route_profile['pack_ms'] / calls
out['total_avg_ms'] = total / calls
return out
def forward(self, xs: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]]) -> torch.Tensor:
if not isinstance(xs, (list, tuple)):
raise TypeError(f'CapsRoutev2 expects list/tuple inputs, got {type(xs)}')
if len(xs) != self.num_sources:
raise ValueError(f'CapsRoutev2 expected {self.num_sources} sources, got {len(xs)}')
h, w = int(xs[0].shape[-2]), int(xs[0].shape[-1])
cat_parts = []
for i, x in enumerate(xs):
expected_c = self.K_in_list[i] * (self.P_in_list[i] + 1)
if int(x.shape[1]) != expected_c:
raise ValueError(f'CapsRoutev2 source-{i} expected C={expected_c}, got C={int(x.shape[1])}')
if int(x.shape[-2]) != h or int(x.shape[-1]) != w:
raise ValueError('CapsRoutev2 inputs must share H,W. Use CapsAlign before routing.')
cat_parts.append(x)
self._ensure_route_profile_state()
if getattr(self, "profile_route", False):
self._route_profile['calls'] += 1.0
self._sync_profile()
t0 = time.perf_counter()
x_cat = torch.cat(cat_parts, dim=1) # [B, K_cat*(P+1), H, W]
self._sync_profile()
self._route_profile['cat_ms'] += (time.perf_counter() - t0) * 1000.0
t0 = time.perf_counter()
conv_out = self.conv_route(x_cat)
self._sync_profile()
self._route_profile['conv_route_ms'] += (time.perf_counter() - t0) * 1000.0
t0 = time.perf_counter()
routed = self.route1(conv_out) # [B, K_out*(P_out+1), H, W]
self._sync_profile()
self._route_profile['route1_ms'] += (time.perf_counter() - t0) * 1000.0
else:
x_cat = torch.cat(cat_parts, dim=1) # [B, K_cat*(P+1), H, W]
routed = self.route1(self.conv_route(x_cat)) # [B, K_out*(P_out+1), H, W]
b, _, _, _ = routed.shape
# Packed layout by type: [pose(P), act(1)] repeated K times.
caps = routed.reshape(b, self.K_out, self.P_out + 1, h, w)
pose = caps[:, :, :self.P_out].contiguous() # [B, K_out, P_out, H, W]
act = caps[:, :, self.P_out].contiguous() # [B, K_out, H, W]
# Grouped pose refinement across type blocks (equivalent to per-type grouped processing).
pose_flat = pose.reshape(b, self.K_out * self.P_out, h, w)
if getattr(self, "profile_route", False):
t0 = time.perf_counter()
pose_flat = self.pose_refine(pose_flat)
self._sync_profile()
self._route_profile['pose_refine_ms'] += (time.perf_counter() - t0) * 1000.0
t0 = time.perf_counter()
act_delta = self.act_from_pose(pose_flat)
act_final = act + act_delta
self._sync_profile()
self._route_profile['act_from_pose_ms'] += (time.perf_counter() - t0) * 1000.0
else:
pose_flat = self.pose_refine(pose_flat)
act_delta = self.act_from_pose(pose_flat)
act_final = act + act_delta
if getattr(self, "profile_route", False):
t0 = time.perf_counter()
pose_pack = pose_flat.reshape(b, self.K_out, self.P_out, h, w)
out = torch.cat([pose_pack, act_final.unsqueeze(2)], dim=2).reshape(b, self.c_out, h, w)
self._sync_profile()
self._route_profile['pack_ms'] += (time.perf_counter() - t0) * 1000.0
else:
pose_pack = pose_flat.reshape(b, self.K_out, self.P_out, h, w)
out = torch.cat([pose_pack, act_final.unsqueeze(2)], dim=2).reshape(b, self.c_out, h, w)
return out
# -------------------------
# 4) CapsDecode
# -------------------------
class CapsRoutev3(CapsRoute):
"""CapsRoute with DS-style lightweight pose refinement and act residual update."""
def __init__(
self,
K_in: Union[List[int], Tuple[int, ...]],
P_in: Union[List[int], Tuple[int, ...]],
K_out: int,
P_out: int,
kernel_size: int = 3,
pre_k: int = 3,
post_k: int = 3,
pre_groups: Optional[int] = None,
post_groups: Optional[int] = None,
):
super().__init__(K_in, P_in, K_out, P_out, kernel_size, pre_k, post_k, pre_groups, post_groups)
_ = (post_k, post_groups, pre_k, pre_groups)
self.profile_route = False
self._route_profile = {
'cat_ms': 0.0,
'conv_route_ms': 0.0,
'route1_ms': 0.0,
'pose_refine_ms': 0.0,
'act_from_pose_ms': 0.0,
'pack_ms': 0.0,
'calls': 0.0,
}
pose_ch = self.K_out * self.P_out
# Keep refinement fully type-grouped to preserve capsule semantics:
# each capsule type only mixes its own pose channels.
self.pose_refine = nn.Sequential(
Conv(pose_ch, pose_ch, 1, 1, g=self.K_out),
Conv(pose_ch, pose_ch, 3, 1, g=self.K_out),
Conv(pose_ch, pose_ch, 1, 1, g=self.K_out),
)
self.act_from_pose = Conv(pose_ch, self.K_out, 1, 1, g=self.K_out)
self.act_alpha = nn.Parameter(torch.tensor(0.1))
@staticmethod
def _sync_profile() -> None:
if torch.cuda.is_available():
torch.cuda.synchronize()
def _ensure_route_profile_state(self) -> None:
if not hasattr(self, "profile_route"):
self.profile_route = False
if not hasattr(self, "_route_profile"):
self._route_profile = {
'cat_ms': 0.0,
'conv_route_ms': 0.0,
'route1_ms': 0.0,
'pose_refine_ms': 0.0,
'act_from_pose_ms': 0.0,
'pack_ms': 0.0,
'calls': 0.0,
}
def reset_route_profile(self) -> None:
self._ensure_route_profile_state()
for k in self._route_profile:
self._route_profile[k] = 0.0
def get_route_profile(self) -> dict:
self._ensure_route_profile_state()
calls = max(float(self._route_profile.get('calls', 0.0)), 1.0)
total = (
self._route_profile['cat_ms']
+ self._route_profile['conv_route_ms']
+ self._route_profile['route1_ms']
+ self._route_profile['pose_refine_ms']
+ self._route_profile['act_from_pose_ms']
+ self._route_profile['pack_ms']
)
out = dict(self._route_profile)
out['total_ms'] = total
out['cat_avg_ms'] = self._route_profile['cat_ms'] / calls
out['conv_route_avg_ms'] = self._route_profile['conv_route_ms'] / calls
out['route1_avg_ms'] = self._route_profile['route1_ms'] / calls
out['pose_refine_avg_ms'] = self._route_profile['pose_refine_ms'] / calls
out['act_from_pose_avg_ms'] = self._route_profile['act_from_pose_ms'] / calls
out['pack_avg_ms'] = self._route_profile['pack_ms'] / calls
out['total_avg_ms'] = total / calls
return out
def forward(self, xs: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]]) -> torch.Tensor:
if not isinstance(xs, (list, tuple)):
raise TypeError(f'CapsRoutev3 expects list/tuple inputs, got {type(xs)}')
if len(xs) != self.num_sources:
raise ValueError(f'CapsRoutev3 expected {self.num_sources} sources, got {len(xs)}')
h, w = int(xs[0].shape[-2]), int(xs[0].shape[-1])
cat_parts = []
for i, x in enumerate(xs):
expected_c = self.K_in_list[i] * (self.P_in_list[i] + 1)
if int(x.shape[1]) != expected_c:
raise ValueError(f'CapsRoutev3 source-{i} expected C={expected_c}, got C={int(x.shape[1])}')
if int(x.shape[-2]) != h or int(x.shape[-1]) != w:
raise ValueError('CapsRoutev3 inputs must share H,W. Use CapsAlign before routing.')
cat_parts.append(x)
self._ensure_route_profile_state()
if getattr(self, "profile_route", False):
self._route_profile['calls'] += 1.0
self._sync_profile()
t0 = time.perf_counter()
x_cat = torch.cat(cat_parts, dim=1)
self._sync_profile()
self._route_profile['cat_ms'] += (time.perf_counter() - t0) * 1000.0
t0 = time.perf_counter()
conv_out = self.conv_route(x_cat)
self._sync_profile()
self._route_profile['conv_route_ms'] += (time.perf_counter() - t0) * 1000.0
t0 = time.perf_counter()
routed = self.route1(conv_out)
self._sync_profile()
self._route_profile['route1_ms'] += (time.perf_counter() - t0) * 1000.0
else:
x_cat = torch.cat(cat_parts, dim=1)
routed = self.route1(self.conv_route(x_cat))
b, _, _, _ = routed.shape
caps = routed.reshape(b, self.K_out, self.P_out + 1, h, w)
pose = caps[:, :, :self.P_out].contiguous()
act = caps[:, :, self.P_out].contiguous()
pose_flat = pose.reshape(b, self.K_out * self.P_out, h, w)
if getattr(self, "profile_route", False):
t0 = time.perf_counter()
pose_flat = pose_flat + self.pose_refine(pose_flat)
self._sync_profile()
self._route_profile['pose_refine_ms'] += (time.perf_counter() - t0) * 1000.0
t0 = time.perf_counter()
act_delta = self.act_from_pose(pose_flat)
act_final = act + act_delta
self._sync_profile()
self._route_profile['act_from_pose_ms'] += (time.perf_counter() - t0) * 1000.0
else:
pose_flat = pose_flat + self.pose_refine(pose_flat)
act_delta = self.act_from_pose(pose_flat)
act_final = act + act_delta
if getattr(self, "profile_route", False):
t0 = time.perf_counter()
pose_pack = pose_flat.reshape(b, self.K_out, self.P_out, h, w)
out = torch.cat([pose_pack, act_final.unsqueeze(2)], dim=2).reshape(b, self.c_out, h, w)
self._sync_profile()
self._route_profile['pack_ms'] += (time.perf_counter() - t0) * 1000.0
else:
pose_pack = pose_flat.reshape(b, self.K_out, self.P_out, h, w)
out = torch.cat([pose_pack, act_final.unsqueeze(2)], dim=2).reshape(b, self.c_out, h, w)
return out
class CapsRoutev4(CapsRoutev2):
"""CapsRoutev2 with conv-heavy HybridRoute1 to reduce routing overhead."""
def __init__(
self,
K_in: Union[List[int], Tuple[int, ...]],
P_in: Union[List[int], Tuple[int, ...]],
K_out: int,
P_out: int,
kernel_size: int = 3,
pre_k: int = 3,
post_k: int = 3,
pre_groups: Optional[int] = None,
post_groups: Optional[int] = None,
):
super().__init__(K_in, P_in, K_out, P_out, kernel_size, pre_k, post_k, pre_groups, post_groups)
self.route1 = HybridRoute1(K_in=self.K_cat, P_in=self.P_cat, K_out=self.K_out, P_out=self.P_out)
class CapsDecode(nn.Module):
"""
Decode routed capsule features to standard feature map for Detect.
Input: y [B, C_in, H, W] (often concat of weighted sources, so C_in = S*(K*D))
Output: f [B, C_out, H, W]
Args:
c2: output channels (e.g., 256/512/1024)
"""
def __init__(self, c1: int, c2: int):
super().__init__()
self.conv = nn.Conv2d(c1, c2, kernel_size=1, stride=1, padding=0, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(self.bn(self.conv(x)))
# -------------------------
# 5) CapsuleTap
# -------------------------
class CapsuleTap(nn.Module):
"""
Pass-through hook to cache feature maps for explainability/aux loss.
MUST NOT change tensor shape. Returns x unchanged.
Args:
tag: string identifier ("F3"/"F4"/"F5")
K,D: capsule hyperparams (metadata only)
cache_enabled: if True, cache during training (disabled in tracing/scripting)
"""
def __init__(self, tag: str = "F", K: int = 4, D: int = 16, cache_enabled: bool = True):
super().__init__()
self.tag = str(tag)
self.K = int(K)
self.D = int(D)
self.cache_enabled = bool(cache_enabled)
self.last_x: Optional[torch.Tensor] = None
def clear_cache(self) -> None:
self.last_x = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
if (
self.cache_enabled
and self.training
and (not torch.jit.is_scripting())
and (not torch.jit.is_tracing())
):
self.last_x = x.detach()
return x