AutoRig/Dl/UniRig/UniRigLoader.cs

Loader for a UniRig skeleton checkpoint. It reads named Tensor entries from a dictionary and binds them into C# model objects: a Perceiver encoder, an OPT decoder (multiple layers) and the output projection tensors.

File Access
namespace AutoRig.Dl.UniRig;

using AutoRig.Dl;

/// <summary>UniRig skeleton-stage networks, loaded and ready to run.</summary>
public sealed class UniRigSkeletonModel
{
    public required PerceiverEncoder MeshEncoder;
    public required Tensor OutputProjWeight, OutputProjBias;   // 512 → 1024
    public required OptDecoder Decoder;
}

/// <summary>
/// Binds the UniRig skeleton checkpoint (Lightning zip-torch, keys verified in
/// dev/unirig-reference/skeleton_keys.json) onto the C# modules: 24-layer OPT
/// decoder + Michelangelo perceiver + the latent output projection.
/// </summary>
public static class UniRigLoader
{
    public static UniRigSkeletonModel LoadSkeleton( IReadOnlyDictionary<string, Tensor> tensors )
    {
        ArgumentNullException.ThrowIfNull( tensors );
        var prefix = tensors.Keys.Any( k => k.StartsWith( "state_dict.", StringComparison.Ordinal ) )
            ? "state_dict."
            : "";

        return new UniRigSkeletonModel
        {
            MeshEncoder = LoadPerceiver( tensors, $"{prefix}model.mesh_encoder.encoder." ),
            OutputProjWeight = Get( tensors, $"{prefix}model.output_proj.weight" ),
            OutputProjBias = Get( tensors, $"{prefix}model.output_proj.bias" ),
            Decoder = LoadOpt( tensors, $"{prefix}model.transformer." ),
        };
    }

    static OptDecoder LoadOpt( IReadOnlyDictionary<string, Tensor> tensors, string prefix )
    {
        var decoder = $"{prefix}model.decoder.";
        var layers = new List<OptLayer>();
        while ( tensors.ContainsKey( $"{decoder}layers.{layers.Count}.self_attn.q_proj.weight" ) )
        {
            var l = $"{decoder}layers.{layers.Count}.";
            layers.Add( new OptLayer
            {
                QWeight = Get( tensors, $"{l}self_attn.q_proj.weight" ),
                QBias = Get( tensors, $"{l}self_attn.q_proj.bias" ),
                KWeight = Get( tensors, $"{l}self_attn.k_proj.weight" ),
                KBias = Get( tensors, $"{l}self_attn.k_proj.bias" ),
                VWeight = Get( tensors, $"{l}self_attn.v_proj.weight" ),
                VBias = Get( tensors, $"{l}self_attn.v_proj.bias" ),
                OutWeight = Get( tensors, $"{l}self_attn.out_proj.weight" ),
                OutBias = Get( tensors, $"{l}self_attn.out_proj.bias" ),
                AttnLnGamma = Get( tensors, $"{l}self_attn_layer_norm.weight" ),
                AttnLnBeta = Get( tensors, $"{l}self_attn_layer_norm.bias" ),
                FfnLnGamma = Get( tensors, $"{l}final_layer_norm.weight" ),
                FfnLnBeta = Get( tensors, $"{l}final_layer_norm.bias" ),
                Fc1Weight = Get( tensors, $"{l}fc1.weight" ),
                Fc1Bias = Get( tensors, $"{l}fc1.bias" ),
                Fc2Weight = Get( tensors, $"{l}fc2.weight" ),
                Fc2Bias = Get( tensors, $"{l}fc2.bias" ),
            } );
        }
        if ( layers.Count == 0 )
            throw new FormatException( "UniRig checkpoint: no OPT decoder layers found." );

        return new OptDecoder
        {
            TokenEmbedding = Get( tensors, $"{decoder}embed_tokens.weight" ),
            PositionEmbedding = Get( tensors, $"{decoder}embed_positions.weight" ),
            Layers = layers.ToArray(),
            FinalLnGamma = Get( tensors, $"{decoder}final_layer_norm.weight" ),
            FinalLnBeta = Get( tensors, $"{decoder}final_layer_norm.bias" ),
            LmHeadWeight = Get( tensors, $"{prefix}lm_head.weight" ),
            Heads = 16,   // OPT-350m family
        };
    }

    internal static PerceiverEncoder LoadPerceiver(
        IReadOnlyDictionary<string, Tensor> tensors, string prefix, int heads = 8 )
    {
        PerceiverSelfBlock SelfBlock( string b ) => new()
        {
            QkvWeight = Get( tensors, $"{b}attn.c_qkv.weight" ),
            ProjWeight = Get( tensors, $"{b}attn.c_proj.weight" ),
            ProjBias = Get( tensors, $"{b}attn.c_proj.bias" ),
            Ln1Gamma = Get( tensors, $"{b}ln_1.weight" ),
            Ln1Beta = Get( tensors, $"{b}ln_1.bias" ),
            Ln2Gamma = Get( tensors, $"{b}ln_2.weight" ),
            Ln2Beta = Get( tensors, $"{b}ln_2.bias" ),
            FcWeight = Get( tensors, $"{b}mlp.c_fc.weight" ),
            FcBias = Get( tensors, $"{b}mlp.c_fc.bias" ),
            Fc2Weight = Get( tensors, $"{b}mlp.c_proj.weight" ),
            Fc2Bias = Get( tensors, $"{b}mlp.c_proj.bias" ),
        };

        var blocks = new List<PerceiverSelfBlock>();
        while ( tensors.ContainsKey( $"{prefix}self_attn.resblocks.{blocks.Count}.attn.c_qkv.weight" ) )
            blocks.Add( SelfBlock( $"{prefix}self_attn.resblocks.{blocks.Count}." ) );
        if ( blocks.Count == 0 )
            throw new FormatException( "UniRig checkpoint: no perceiver self-attention blocks found." );

        var cross = $"{prefix}cross_attn.";
        return new PerceiverEncoder
        {
            Heads = heads,
            InputProjWeight = Get( tensors, $"{prefix}input_proj.weight" ),
            InputProjBias = Get( tensors, $"{prefix}input_proj.bias" ),
            Cross = new PerceiverCrossBlock
            {
                QWeight = Get( tensors, $"{cross}attn.c_q.weight" ),
                KvWeight = Get( tensors, $"{cross}attn.c_kv.weight" ),
                ProjWeight = Get( tensors, $"{cross}attn.c_proj.weight" ),
                ProjBias = Get( tensors, $"{cross}attn.c_proj.bias" ),
                Ln1Gamma = Get( tensors, $"{cross}ln_1.weight" ),
                Ln1Beta = Get( tensors, $"{cross}ln_1.bias" ),
                Ln2Gamma = Get( tensors, $"{cross}ln_2.weight" ),
                Ln2Beta = Get( tensors, $"{cross}ln_2.bias" ),
                Ln3Gamma = Get( tensors, $"{cross}ln_3.weight" ),
                Ln3Beta = Get( tensors, $"{cross}ln_3.bias" ),
                FcWeight = Get( tensors, $"{cross}mlp.c_fc.weight" ),
                FcBias = Get( tensors, $"{cross}mlp.c_fc.bias" ),
                Fc2Weight = Get( tensors, $"{cross}mlp.c_proj.weight" ),
                Fc2Bias = Get( tensors, $"{cross}mlp.c_proj.bias" ),
            },
            SelfBlocks = blocks.ToArray(),
            LnPostGamma = Get( tensors, $"{prefix}ln_post.weight" ),
            LnPostBeta = Get( tensors, $"{prefix}ln_post.bias" ),
        };
    }

    static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
        => tensors.TryGetValue( key, out var tensor )
            ? tensor
            : throw new FormatException( $"UniRig checkpoint is missing tensor '{key}'." );
}