Loader for RigNet model checkpoints, it maps tensors from a checkpoint dictionary into in-memory network objects (MeanshiftBundle, RootNet, BoneNet, SkinNet, etc.), folding BatchNorm into Linear where needed and throwing on missing keys.
namespace AutoRig.Dl.RigNet;
/// <summary>
/// Builds runnable networks from a RigNet checkpoint's tensors (see
/// docs/superpowers/rignet-port-notes.md for the verified name/shape layout).
/// BatchNorms are folded at load; missing keys throw with the exact name.
/// </summary>
public static class RigNetLoader
{
/// <summary>The gcn_meanshift checkpoint: jointnet + masknet + bandwidth.</summary>
public sealed class MeanshiftBundle
{
public required JointPredNet JointNet;
public required JointPredNet MaskNet;
public required float Bandwidth;
}
public static MeanshiftBundle LoadMeanshift( IReadOnlyDictionary<string, Tensor> tensors )
{
ArgumentNullException.ThrowIfNull( tensors );
var prefix = tensors.Keys.Any( k => k.StartsWith( "state_dict.", StringComparison.Ordinal ) )
? "state_dict."
: "";
var bandwidth = Get( tensors, $"{prefix}bandwidth" ).Data[0];
return new MeanshiftBundle
{
JointNet = LoadJointPredNet( tensors, $"{prefix}jointnet.", outputChannels: 3, tanh: true ),
MaskNet = LoadJointPredNet( tensors, $"{prefix}masknet.", outputChannels: 1, tanh: false ),
Bandwidth = bandwidth,
};
}
/// <summary>JointPredNet from `{prefix}gcu_1..3 / mlp_glb / mlp_tramsform` keys
/// (the transform name typo is the original's).</summary>
public static JointPredNet LoadJointPredNet(
IReadOnlyDictionary<string, Tensor> tensors, string prefix, int outputChannels, bool tanh )
{
return new JointPredNet
{
Gcu1 = LoadGcu( tensors, $"{prefix}gcu_1." ),
Gcu2 = LoadGcu( tensors, $"{prefix}gcu_2." ),
Gcu3 = LoadGcu( tensors, $"{prefix}gcu_3." ),
GlobalMlp = new Mlp { Layers = new[] { LoadBlock( tensors, $"{prefix}mlp_glb.0." ) } },
TransformMlp = new Mlp
{
Layers = new[]
{
LoadBlock( tensors, $"{prefix}mlp_tramsform.0.0." ),
LoadBlock( tensors, $"{prefix}mlp_tramsform.0.1." ),
},
},
TransformOut = new Linear
{
Weight = Get( tensors, $"{prefix}mlp_tramsform.2.weight" ),
Bias = Get( tensors, $"{prefix}mlp_tramsform.2.bias" ),
},
TanhOutput = tanh,
};
}
public static Gcu LoadGcu( IReadOnlyDictionary<string, Tensor> tensors, string prefix ) => new()
{
Topology = new EdgeConv
{
Mlp = new Mlp
{
Layers = new[]
{
LoadBlock( tensors, $"{prefix}edge_conv_tpl.nn.0." ),
LoadBlock( tensors, $"{prefix}edge_conv_tpl.nn.1." ),
},
},
},
Geodesic = new EdgeConv
{
Mlp = new Mlp
{
Layers = new[]
{
LoadBlock( tensors, $"{prefix}edge_conv_geo.nn.0." ),
LoadBlock( tensors, $"{prefix}edge_conv_geo.nn.1." ),
},
},
},
Fuse = new Mlp { Layers = new[] { LoadBlock( tensors, $"{prefix}mlp.0." ) } },
};
/// <summary>ROOTNET from the rootnet checkpoint (keys under `state_dict.`).</summary>
public static RootNet LoadRootNet( IReadOnlyDictionary<string, Tensor> tensors )
{
var prefix = DetectPrefix( tensors );
return new RootNet
{
Shape = LoadShapeEncoder( tensors, $"{prefix}shape_encoder." ),
Sa1 = new SaModule
{
Ratio = 0.999f, Radius = 0.4f,
LocalNn = LoadMlp( tensors, $"{prefix}joint_encoder.sa1_joint.conv.local_nn." ),
},
Sa2 = new SaModule
{
Ratio = 0.33f, Radius = 0.6f,
LocalNn = LoadMlp( tensors, $"{prefix}joint_encoder.sa2_joint.conv.local_nn." ),
},
Sa3 = new GlobalSaModule { Nn = LoadMlp( tensors, $"{prefix}joint_encoder.sa3_joint.nn." ) },
Fp3 = new FpModule { K = 1, Nn = LoadMlp( tensors, $"{prefix}joint_encoder.fp3_joint.nn." ) },
Fp2 = new FpModule { K = 3, Nn = LoadMlp( tensors, $"{prefix}joint_encoder.fp2_joint.nn." ) },
Fp1 = new FpModule { K = 3, Nn = LoadMlp( tensors, $"{prefix}joint_encoder.fp1_joint.nn." ) },
BackMlp = LoadMlp( tensors, $"{prefix}back_layers.0." ),
BackOut = new Linear
{
Weight = Get( tensors, $"{prefix}back_layers.1.weight" ),
Bias = Get( tensors, $"{prefix}back_layers.1.bias" ),
},
};
}
/// <summary>BONENET (PairCls) from the bonenet checkpoint.</summary>
public static BoneNet LoadBoneNet( IReadOnlyDictionary<string, Tensor> tensors )
{
var prefix = DetectPrefix( tensors );
return new BoneNet
{
Shape = LoadShapeEncoder( tensors, $"{prefix}shape_encoder." ),
Sa1 = new SaModule
{
Ratio = 0.999f, Radius = 0.4f,
LocalNn = LoadMlp( tensors, $"{prefix}joint_encoder.sa1_module_joints.conv.local_nn." ),
},
Sa2 = new SaModule
{
Ratio = 0.33f, Radius = 0.6f,
LocalNn = LoadMlp( tensors, $"{prefix}joint_encoder.sa2_module_joints.conv.local_nn." ),
},
Sa3 = new GlobalSaModule
{
Nn = LoadMlp( tensors, $"{prefix}joint_encoder.sa3_module_joints.nn." ),
},
ExpandPair = LoadMlp( tensors, $"{prefix}expand_joint_feature.0." ),
MixMlp = LoadMlp( tensors, $"{prefix}mix_transform.0." ),
MixOut = new Linear
{
Weight = Get( tensors, $"{prefix}mix_transform.2.weight" ),
Bias = Get( tensors, $"{prefix}mix_transform.2.bias" ),
},
};
}
/// <summary>SKINNET from the skinnet checkpoint (cls_branch uses flat 0/2/3/5/6 keys).</summary>
public static SkinNet LoadSkinNet( IReadOnlyDictionary<string, Tensor> tensors )
{
var prefix = DetectPrefix( tensors );
return new SkinNet
{
Transform1 = LoadMlp( tensors, $"{prefix}multi_layer_tranform1." ),
Gcu1 = LoadGcu( tensors, $"{prefix}gcu1." ),
Transform2 = LoadMlp( tensors, $"{prefix}multi_layer_tranform2." ),
Gcu2 = LoadGcu( tensors, $"{prefix}gcu2." ),
Gcu3 = LoadGcu( tensors, $"{prefix}gcu3." ),
Cls0 = LoadFlatBlock( tensors, $"{prefix}cls_branch.", 0, 2 ),
Cls1 = LoadFlatBlock( tensors, $"{prefix}cls_branch.", 3, 5 ),
ClsOut = new Linear
{
Weight = Get( tensors, $"{prefix}cls_branch.6.weight" ),
Bias = Get( tensors, $"{prefix}cls_branch.6.bias" ),
},
};
}
/// <summary>Linear+BN folded from flat Sequential indices (Lin at i, BN at j).</summary>
static LinearReluBn LoadFlatBlock(
IReadOnlyDictionary<string, Tensor> tensors, string prefix, int linear, int bn )
=> LinearReluBn.Fold(
weight: Get( tensors, $"{prefix}{linear}.weight" ),
bias: Get( tensors, $"{prefix}{linear}.bias" ),
gamma: Get( tensors, $"{prefix}{bn}.weight" ),
beta: Get( tensors, $"{prefix}{bn}.bias" ),
runningMean: Get( tensors, $"{prefix}{bn}.running_mean" ),
runningVar: Get( tensors, $"{prefix}{bn}.running_var" ) );
public static ShapeEncoder LoadShapeEncoder(
IReadOnlyDictionary<string, Tensor> tensors, string prefix ) => new()
{
Gcu1 = LoadGcu( tensors, $"{prefix}gcu_1." ),
Gcu2 = LoadGcu( tensors, $"{prefix}gcu_2." ),
Gcu3 = LoadGcu( tensors, $"{prefix}gcu_3." ),
GlobalMlp = LoadMlp( tensors, $"{prefix}mlp_glb." ),
};
/// <summary>Loads consecutive `{prefix}{i}.` blocks until the next index is missing.</summary>
public static Mlp LoadMlp( IReadOnlyDictionary<string, Tensor> tensors, string prefix )
{
var layers = new List<LinearReluBn>();
while ( tensors.ContainsKey( $"{prefix}{layers.Count}.0.weight" ) )
layers.Add( LoadBlock( tensors, $"{prefix}{layers.Count}." ) );
if ( layers.Count == 0 )
throw new FormatException( $"RigNet checkpoint has no MLP blocks under '{prefix}'." );
return new Mlp { Layers = layers.ToArray() };
}
static string DetectPrefix( IReadOnlyDictionary<string, Tensor> tensors )
=> tensors.Keys.Any( k => k.StartsWith( "state_dict.", StringComparison.Ordinal ) )
? "state_dict."
: "";
/// <summary>One Linear(+folded BN) block: `{prefix}0.*` Linear, `{prefix}2.*` BatchNorm.</summary>
static LinearReluBn LoadBlock( IReadOnlyDictionary<string, Tensor> tensors, string prefix )
=> LinearReluBn.Fold(
weight: Get( tensors, $"{prefix}0.weight" ),
bias: Get( tensors, $"{prefix}0.bias" ),
gamma: Get( tensors, $"{prefix}2.weight" ),
beta: Get( tensors, $"{prefix}2.bias" ),
runningMean: Get( tensors, $"{prefix}2.running_mean" ),
runningVar: Get( tensors, $"{prefix}2.running_var" ) );
static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
=> tensors.TryGetValue( key, out var tensor )
? tensor
: throw new FormatException( $"RigNet checkpoint is missing tensor '{key}'." );
}