AutoRig/Dl/RigNet/RigNetLoader.cs

Loader for RigNet model checkpoints, it reads tensors from a dictionary and builds runnable network objects (MeanshiftBundle, RootNet, BoneNet, SkinNet, etc). It maps expected tensor keys to model components, folds BatchNorm into Linear layers, and throws if required keys are missing.

File AccessNative Interop
namespace AutoRig.Dl.RigNet;

/// <summary>
/// Builds runnable networks from a RigNet checkpoint's tensors (see
/// docs/superpowers/rignet-port-notes.md for the verified name/shape layout).
/// BatchNorms are folded at load; missing keys throw with the exact name.
/// </summary>
public static class RigNetLoader
{
    /// <summary>The gcn_meanshift checkpoint: jointnet + masknet + bandwidth.</summary>
    public sealed class MeanshiftBundle
    {
        public required JointPredNet JointNet;
        public required JointPredNet MaskNet;
        public required float Bandwidth;
    }

    public static MeanshiftBundle LoadMeanshift( IReadOnlyDictionary<string, Tensor> tensors )
    {
        ArgumentNullException.ThrowIfNull( tensors );
        var prefix = tensors.Keys.Any( k => k.StartsWith( "state_dict.", StringComparison.Ordinal ) )
            ? "state_dict."
            : "";

        var bandwidth = Get( tensors, $"{prefix}bandwidth" ).Data[0];
        return new MeanshiftBundle
        {
            JointNet = LoadJointPredNet( tensors, $"{prefix}jointnet.", outputChannels: 3, tanh: true ),
            MaskNet = LoadJointPredNet( tensors, $"{prefix}masknet.", outputChannels: 1, tanh: false ),
            Bandwidth = bandwidth,
        };
    }

    /// <summary>JointPredNet from `{prefix}gcu_1..3 / mlp_glb / mlp_tramsform` keys
    /// (the transform name typo is the original's).</summary>
    public static JointPredNet LoadJointPredNet(
        IReadOnlyDictionary<string, Tensor> tensors, string prefix, int outputChannels, bool tanh )
    {
        return new JointPredNet
        {
            Gcu1 = LoadGcu( tensors, $"{prefix}gcu_1." ),
            Gcu2 = LoadGcu( tensors, $"{prefix}gcu_2." ),
            Gcu3 = LoadGcu( tensors, $"{prefix}gcu_3." ),
            GlobalMlp = new Mlp { Layers = new[] { LoadBlock( tensors, $"{prefix}mlp_glb.0." ) } },
            TransformMlp = new Mlp
            {
                Layers = new[]
                {
                    LoadBlock( tensors, $"{prefix}mlp_tramsform.0.0." ),
                    LoadBlock( tensors, $"{prefix}mlp_tramsform.0.1." ),
                },
            },
            TransformOut = new Linear
            {
                Weight = Get( tensors, $"{prefix}mlp_tramsform.2.weight" ),
                Bias = Get( tensors, $"{prefix}mlp_tramsform.2.bias" ),
            },
            TanhOutput = tanh,
        };
    }

    public static Gcu LoadGcu( IReadOnlyDictionary<string, Tensor> tensors, string prefix ) => new()
    {
        Topology = new EdgeConv
        {
            Mlp = new Mlp
            {
                Layers = new[]
                {
                    LoadBlock( tensors, $"{prefix}edge_conv_tpl.nn.0." ),
                    LoadBlock( tensors, $"{prefix}edge_conv_tpl.nn.1." ),
                },
            },
        },
        Geodesic = new EdgeConv
        {
            Mlp = new Mlp
            {
                Layers = new[]
                {
                    LoadBlock( tensors, $"{prefix}edge_conv_geo.nn.0." ),
                    LoadBlock( tensors, $"{prefix}edge_conv_geo.nn.1." ),
                },
            },
        },
        Fuse = new Mlp { Layers = new[] { LoadBlock( tensors, $"{prefix}mlp.0." ) } },
    };

    /// <summary>ROOTNET from the rootnet checkpoint (keys under `state_dict.`).</summary>
    public static RootNet LoadRootNet( IReadOnlyDictionary<string, Tensor> tensors )
    {
        var prefix = DetectPrefix( tensors );
        return new RootNet
        {
            Shape = LoadShapeEncoder( tensors, $"{prefix}shape_encoder." ),
            Sa1 = new SaModule
            {
                Ratio = 0.999f, Radius = 0.4f,
                LocalNn = LoadMlp( tensors, $"{prefix}joint_encoder.sa1_joint.conv.local_nn." ),
            },
            Sa2 = new SaModule
            {
                Ratio = 0.33f, Radius = 0.6f,
                LocalNn = LoadMlp( tensors, $"{prefix}joint_encoder.sa2_joint.conv.local_nn." ),
            },
            Sa3 = new GlobalSaModule { Nn = LoadMlp( tensors, $"{prefix}joint_encoder.sa3_joint.nn." ) },
            Fp3 = new FpModule { K = 1, Nn = LoadMlp( tensors, $"{prefix}joint_encoder.fp3_joint.nn." ) },
            Fp2 = new FpModule { K = 3, Nn = LoadMlp( tensors, $"{prefix}joint_encoder.fp2_joint.nn." ) },
            Fp1 = new FpModule { K = 3, Nn = LoadMlp( tensors, $"{prefix}joint_encoder.fp1_joint.nn." ) },
            BackMlp = LoadMlp( tensors, $"{prefix}back_layers.0." ),
            BackOut = new Linear
            {
                Weight = Get( tensors, $"{prefix}back_layers.1.weight" ),
                Bias = Get( tensors, $"{prefix}back_layers.1.bias" ),
            },
        };
    }

    /// <summary>BONENET (PairCls) from the bonenet checkpoint.</summary>
    public static BoneNet LoadBoneNet( IReadOnlyDictionary<string, Tensor> tensors )
    {
        var prefix = DetectPrefix( tensors );
        return new BoneNet
        {
            Shape = LoadShapeEncoder( tensors, $"{prefix}shape_encoder." ),
            Sa1 = new SaModule
            {
                Ratio = 0.999f, Radius = 0.4f,
                LocalNn = LoadMlp( tensors, $"{prefix}joint_encoder.sa1_module_joints.conv.local_nn." ),
            },
            Sa2 = new SaModule
            {
                Ratio = 0.33f, Radius = 0.6f,
                LocalNn = LoadMlp( tensors, $"{prefix}joint_encoder.sa2_module_joints.conv.local_nn." ),
            },
            Sa3 = new GlobalSaModule
            {
                Nn = LoadMlp( tensors, $"{prefix}joint_encoder.sa3_module_joints.nn." ),
            },
            ExpandPair = LoadMlp( tensors, $"{prefix}expand_joint_feature.0." ),
            MixMlp = LoadMlp( tensors, $"{prefix}mix_transform.0." ),
            MixOut = new Linear
            {
                Weight = Get( tensors, $"{prefix}mix_transform.2.weight" ),
                Bias = Get( tensors, $"{prefix}mix_transform.2.bias" ),
            },
        };
    }

    /// <summary>SKINNET from the skinnet checkpoint (cls_branch uses flat 0/2/3/5/6 keys).</summary>
    public static SkinNet LoadSkinNet( IReadOnlyDictionary<string, Tensor> tensors )
    {
        var prefix = DetectPrefix( tensors );
        return new SkinNet
        {
            Transform1 = LoadMlp( tensors, $"{prefix}multi_layer_tranform1." ),
            Gcu1 = LoadGcu( tensors, $"{prefix}gcu1." ),
            Transform2 = LoadMlp( tensors, $"{prefix}multi_layer_tranform2." ),
            Gcu2 = LoadGcu( tensors, $"{prefix}gcu2." ),
            Gcu3 = LoadGcu( tensors, $"{prefix}gcu3." ),
            Cls0 = LoadFlatBlock( tensors, $"{prefix}cls_branch.", 0, 2 ),
            Cls1 = LoadFlatBlock( tensors, $"{prefix}cls_branch.", 3, 5 ),
            ClsOut = new Linear
            {
                Weight = Get( tensors, $"{prefix}cls_branch.6.weight" ),
                Bias = Get( tensors, $"{prefix}cls_branch.6.bias" ),
            },
        };
    }

    /// <summary>Linear+BN folded from flat Sequential indices (Lin at i, BN at j).</summary>
    static LinearReluBn LoadFlatBlock(
        IReadOnlyDictionary<string, Tensor> tensors, string prefix, int linear, int bn )
        => LinearReluBn.Fold(
            weight: Get( tensors, $"{prefix}{linear}.weight" ),
            bias: Get( tensors, $"{prefix}{linear}.bias" ),
            gamma: Get( tensors, $"{prefix}{bn}.weight" ),
            beta: Get( tensors, $"{prefix}{bn}.bias" ),
            runningMean: Get( tensors, $"{prefix}{bn}.running_mean" ),
            runningVar: Get( tensors, $"{prefix}{bn}.running_var" ) );

    public static ShapeEncoder LoadShapeEncoder(
        IReadOnlyDictionary<string, Tensor> tensors, string prefix ) => new()
    {
        Gcu1 = LoadGcu( tensors, $"{prefix}gcu_1." ),
        Gcu2 = LoadGcu( tensors, $"{prefix}gcu_2." ),
        Gcu3 = LoadGcu( tensors, $"{prefix}gcu_3." ),
        GlobalMlp = LoadMlp( tensors, $"{prefix}mlp_glb." ),
    };

    /// <summary>Loads consecutive `{prefix}{i}.` blocks until the next index is missing.</summary>
    public static Mlp LoadMlp( IReadOnlyDictionary<string, Tensor> tensors, string prefix )
    {
        var layers = new List<LinearReluBn>();
        while ( tensors.ContainsKey( $"{prefix}{layers.Count}.0.weight" ) )
            layers.Add( LoadBlock( tensors, $"{prefix}{layers.Count}." ) );
        if ( layers.Count == 0 )
            throw new FormatException( $"RigNet checkpoint has no MLP blocks under '{prefix}'." );
        return new Mlp { Layers = layers.ToArray() };
    }

    static string DetectPrefix( IReadOnlyDictionary<string, Tensor> tensors )
        => tensors.Keys.Any( k => k.StartsWith( "state_dict.", StringComparison.Ordinal ) )
            ? "state_dict."
            : "";

    /// <summary>One Linear(+folded BN) block: `{prefix}0.*` Linear, `{prefix}2.*` BatchNorm.</summary>
    static LinearReluBn LoadBlock( IReadOnlyDictionary<string, Tensor> tensors, string prefix )
        => LinearReluBn.Fold(
            weight: Get( tensors, $"{prefix}0.weight" ),
            bias: Get( tensors, $"{prefix}0.bias" ),
            gamma: Get( tensors, $"{prefix}2.weight" ),
            beta: Get( tensors, $"{prefix}2.bias" ),
            runningMean: Get( tensors, $"{prefix}2.running_mean" ),
            runningVar: Get( tensors, $"{prefix}2.running_var" ) );

    static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
        => tensors.TryGetValue( key, out var tensor )
            ? tensor
            : throw new FormatException( $"RigNet checkpoint is missing tensor '{key}'." );
}