Code/AutoRig/Dl/RigNet/RigNetLoader.cs

Loader for RigNet model checkpoints, it maps tensors from a checkpoint dictionary into in-memory network objects (MeanshiftBundle, RootNet, BoneNet, SkinNet, etc.), folding BatchNorm into Linear where needed and throwing on missing keys.

File Access
namespace AutoRig.Dl.RigNet;

/// <summary>
/// Builds runnable networks from a RigNet checkpoint's tensors (see
/// docs/superpowers/rignet-port-notes.md for the verified name/shape layout).
/// BatchNorms are folded at load; missing keys throw with the exact name.
/// </summary>
public static class RigNetLoader
{
    /// <summary>The gcn_meanshift checkpoint: jointnet + masknet + bandwidth.</summary>
    public sealed class MeanshiftBundle
    {
        public required JointPredNet JointNet;
        public required JointPredNet MaskNet;
        public required float Bandwidth;
    }

    public static MeanshiftBundle LoadMeanshift( IReadOnlyDictionary<string, Tensor> tensors )
    {
        ArgumentNullException.ThrowIfNull( tensors );
        var prefix = tensors.Keys.Any( k => k.StartsWith( "state_dict.", StringComparison.Ordinal ) )
            ? "state_dict."
            : "";

        var bandwidth = Get( tensors, $"{prefix}bandwidth" ).Data[0];
        return new MeanshiftBundle
        {
            JointNet = LoadJointPredNet( tensors, $"{prefix}jointnet.", outputChannels: 3, tanh: true ),
            MaskNet = LoadJointPredNet( tensors, $"{prefix}masknet.", outputChannels: 1, tanh: false ),
            Bandwidth = bandwidth,
        };
    }

    /// <summary>JointPredNet from `{prefix}gcu_1..3 / mlp_glb / mlp_tramsform` keys
    /// (the transform name typo is the original's).</summary>
    public static JointPredNet LoadJointPredNet(
        IReadOnlyDictionary<string, Tensor> tensors, string prefix, int outputChannels, bool tanh )
    {
        return new JointPredNet
        {
            Gcu1 = LoadGcu( tensors, $"{prefix}gcu_1." ),
            Gcu2 = LoadGcu( tensors, $"{prefix}gcu_2." ),
            Gcu3 = LoadGcu( tensors, $"{prefix}gcu_3." ),
            GlobalMlp = new Mlp { Layers = new[] { LoadBlock( tensors, $"{prefix}mlp_glb.0." ) } },
            TransformMlp = new Mlp
            {
                Layers = new[]
                {
                    LoadBlock( tensors, $"{prefix}mlp_tramsform.0.0." ),
                    LoadBlock( tensors, $"{prefix}mlp_tramsform.0.1." ),
                },
            },
            TransformOut = new Linear
            {
                Weight = Get( tensors, $"{prefix}mlp_tramsform.2.weight" ),
                Bias = Get( tensors, $"{prefix}mlp_tramsform.2.bias" ),
            },
            TanhOutput = tanh,
        };
    }

    public static Gcu LoadGcu( IReadOnlyDictionary<string, Tensor> tensors, string prefix ) => new()
    {
        Topology = new EdgeConv
        {
            Mlp = new Mlp
            {
                Layers = new[]
                {
                    LoadBlock( tensors, $"{prefix}edge_conv_tpl.nn.0." ),
                    LoadBlock( tensors, $"{prefix}edge_conv_tpl.nn.1." ),
                },
            },
        },
        Geodesic = new EdgeConv
        {
            Mlp = new Mlp
            {
                Layers = new[]
                {
                    LoadBlock( tensors, $"{prefix}edge_conv_geo.nn.0." ),
                    LoadBlock( tensors, $"{prefix}edge_conv_geo.nn.1." ),
                },
            },
        },
        Fuse = new Mlp { Layers = new[] { LoadBlock( tensors, $"{prefix}mlp.0." ) } },
    };

    /// <summary>ROOTNET from the rootnet checkpoint (keys under `state_dict.`).</summary>
    public static RootNet LoadRootNet( IReadOnlyDictionary<string, Tensor> tensors )
    {
        var prefix = DetectPrefix( tensors );
        return new RootNet
        {
            Shape = LoadShapeEncoder( tensors, $"{prefix}shape_encoder." ),
            Sa1 = new SaModule
            {
                Ratio = 0.999f, Radius = 0.4f,
                LocalNn = LoadMlp( tensors, $"{prefix}joint_encoder.sa1_joint.conv.local_nn." ),
            },
            Sa2 = new SaModule
            {
                Ratio = 0.33f, Radius = 0.6f,
                LocalNn = LoadMlp( tensors, $"{prefix}joint_encoder.sa2_joint.conv.local_nn." ),
            },
            Sa3 = new GlobalSaModule { Nn = LoadMlp( tensors, $"{prefix}joint_encoder.sa3_joint.nn." ) },
            Fp3 = new FpModule { K = 1, Nn = LoadMlp( tensors, $"{prefix}joint_encoder.fp3_joint.nn." ) },
            Fp2 = new FpModule { K = 3, Nn = LoadMlp( tensors, $"{prefix}joint_encoder.fp2_joint.nn." ) },
            Fp1 = new FpModule { K = 3, Nn = LoadMlp( tensors, $"{prefix}joint_encoder.fp1_joint.nn." ) },
            BackMlp = LoadMlp( tensors, $"{prefix}back_layers.0." ),
            BackOut = new Linear
            {
                Weight = Get( tensors, $"{prefix}back_layers.1.weight" ),
                Bias = Get( tensors, $"{prefix}back_layers.1.bias" ),
            },
        };
    }

    /// <summary>BONENET (PairCls) from the bonenet checkpoint.</summary>
    public static BoneNet LoadBoneNet( IReadOnlyDictionary<string, Tensor> tensors )
    {
        var prefix = DetectPrefix( tensors );
        return new BoneNet
        {
            Shape = LoadShapeEncoder( tensors, $"{prefix}shape_encoder." ),
            Sa1 = new SaModule
            {
                Ratio = 0.999f, Radius = 0.4f,
                LocalNn = LoadMlp( tensors, $"{prefix}joint_encoder.sa1_module_joints.conv.local_nn." ),
            },
            Sa2 = new SaModule
            {
                Ratio = 0.33f, Radius = 0.6f,
                LocalNn = LoadMlp( tensors, $"{prefix}joint_encoder.sa2_module_joints.conv.local_nn." ),
            },
            Sa3 = new GlobalSaModule
            {
                Nn = LoadMlp( tensors, $"{prefix}joint_encoder.sa3_module_joints.nn." ),
            },
            ExpandPair = LoadMlp( tensors, $"{prefix}expand_joint_feature.0." ),
            MixMlp = LoadMlp( tensors, $"{prefix}mix_transform.0." ),
            MixOut = new Linear
            {
                Weight = Get( tensors, $"{prefix}mix_transform.2.weight" ),
                Bias = Get( tensors, $"{prefix}mix_transform.2.bias" ),
            },
        };
    }

    /// <summary>SKINNET from the skinnet checkpoint (cls_branch uses flat 0/2/3/5/6 keys).</summary>
    public static SkinNet LoadSkinNet( IReadOnlyDictionary<string, Tensor> tensors )
    {
        var prefix = DetectPrefix( tensors );
        return new SkinNet
        {
            Transform1 = LoadMlp( tensors, $"{prefix}multi_layer_tranform1." ),
            Gcu1 = LoadGcu( tensors, $"{prefix}gcu1." ),
            Transform2 = LoadMlp( tensors, $"{prefix}multi_layer_tranform2." ),
            Gcu2 = LoadGcu( tensors, $"{prefix}gcu2." ),
            Gcu3 = LoadGcu( tensors, $"{prefix}gcu3." ),
            Cls0 = LoadFlatBlock( tensors, $"{prefix}cls_branch.", 0, 2 ),
            Cls1 = LoadFlatBlock( tensors, $"{prefix}cls_branch.", 3, 5 ),
            ClsOut = new Linear
            {
                Weight = Get( tensors, $"{prefix}cls_branch.6.weight" ),
                Bias = Get( tensors, $"{prefix}cls_branch.6.bias" ),
            },
        };
    }

    /// <summary>Linear+BN folded from flat Sequential indices (Lin at i, BN at j).</summary>
    static LinearReluBn LoadFlatBlock(
        IReadOnlyDictionary<string, Tensor> tensors, string prefix, int linear, int bn )
        => LinearReluBn.Fold(
            weight: Get( tensors, $"{prefix}{linear}.weight" ),
            bias: Get( tensors, $"{prefix}{linear}.bias" ),
            gamma: Get( tensors, $"{prefix}{bn}.weight" ),
            beta: Get( tensors, $"{prefix}{bn}.bias" ),
            runningMean: Get( tensors, $"{prefix}{bn}.running_mean" ),
            runningVar: Get( tensors, $"{prefix}{bn}.running_var" ) );

    public static ShapeEncoder LoadShapeEncoder(
        IReadOnlyDictionary<string, Tensor> tensors, string prefix ) => new()
    {
        Gcu1 = LoadGcu( tensors, $"{prefix}gcu_1." ),
        Gcu2 = LoadGcu( tensors, $"{prefix}gcu_2." ),
        Gcu3 = LoadGcu( tensors, $"{prefix}gcu_3." ),
        GlobalMlp = LoadMlp( tensors, $"{prefix}mlp_glb." ),
    };

    /// <summary>Loads consecutive `{prefix}{i}.` blocks until the next index is missing.</summary>
    public static Mlp LoadMlp( IReadOnlyDictionary<string, Tensor> tensors, string prefix )
    {
        var layers = new List<LinearReluBn>();
        while ( tensors.ContainsKey( $"{prefix}{layers.Count}.0.weight" ) )
            layers.Add( LoadBlock( tensors, $"{prefix}{layers.Count}." ) );
        if ( layers.Count == 0 )
            throw new FormatException( $"RigNet checkpoint has no MLP blocks under '{prefix}'." );
        return new Mlp { Layers = layers.ToArray() };
    }

    static string DetectPrefix( IReadOnlyDictionary<string, Tensor> tensors )
        => tensors.Keys.Any( k => k.StartsWith( "state_dict.", StringComparison.Ordinal ) )
            ? "state_dict."
            : "";

    /// <summary>One Linear(+folded BN) block: `{prefix}0.*` Linear, `{prefix}2.*` BatchNorm.</summary>
    static LinearReluBn LoadBlock( IReadOnlyDictionary<string, Tensor> tensors, string prefix )
        => LinearReluBn.Fold(
            weight: Get( tensors, $"{prefix}0.weight" ),
            bias: Get( tensors, $"{prefix}0.bias" ),
            gamma: Get( tensors, $"{prefix}2.weight" ),
            beta: Get( tensors, $"{prefix}2.bias" ),
            runningMean: Get( tensors, $"{prefix}2.running_mean" ),
            runningVar: Get( tensors, $"{prefix}2.running_var" ) );

    static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
        => tensors.TryGetValue( key, out var tensor )
            ? tensor
            : throw new FormatException( $"RigNet checkpoint is missing tensor '{key}'." );
}