Code/AutoRig/Dl/RigAnything/RigAnythingLoader.cs

Loader for the RigAnything model checkpoint. It reads named tensors from a provided dictionary and assembles a RigAnythingModel, parsing transformer blocks and diffusion residual blocks and mapping weights/biases into model structures.

File Access
namespace AutoRig.Dl.RigAnything;

using AutoRig.Dl;

/// <summary>Binds riganything_ckpt.pt (keys verified in
/// dev/riganything-reference/ckpt_keys.json, 123 f32 tensors under model.*).
/// All transformer LayerNorms are γ-only (bias=False) — betas stay null.</summary>
public static class RigAnythingLoader
{
    public static RigAnythingModel LoadModel( IReadOnlyDictionary<string, Tensor> tensors )
    {
        ArgumentNullException.ThrowIfNull( tensors );
        var p = tensors.Keys.Any( k => k.StartsWith( "model.", StringComparison.Ordinal ) )
            ? "model."
            : "";

        var blocks = new List<RigAnythingBlock>();
        while ( tensors.ContainsKey( $"{p}transformer.{blocks.Count}.attn.to_qkv.weight" ) )
        {
            var b = $"{p}transformer.{blocks.Count}.";
            blocks.Add( new RigAnythingBlock
            {
                Norm1Gamma = Get( tensors, $"{b}norm1.weight" ),
                Norm1Beta = null,
                QkvWeight = Get( tensors, $"{b}attn.to_qkv.weight" ),
                FcWeight = Get( tensors, $"{b}attn.fc.weight" ),
                Norm2Gamma = Get( tensors, $"{b}norm2.weight" ),
                Norm2Beta = null,
                Mlp0Weight = Get( tensors, $"{b}mlp.mlp.0.weight" ),
                Mlp2Weight = Get( tensors, $"{b}mlp.mlp.2.weight" ),
            } );
        }
        if ( blocks.Count == 0 )
            throw new FormatException( "RigAnything checkpoint: no transformer blocks found." );

        var resBlocks = new List<RigAnythingDiffusion.DiffResBlock>();
        while ( tensors.ContainsKey( $"{p}diffloss.net.res_blocks.{resBlocks.Count}.in_ln.weight" ) )
        {
            var b = $"{p}diffloss.net.res_blocks.{resBlocks.Count}.";
            resBlocks.Add( new RigAnythingDiffusion.DiffResBlock
            {
                InLnGamma = Get( tensors, $"{b}in_ln.weight" ),
                InLnBeta = Get( tensors, $"{b}in_ln.bias" ),
                Mlp0Weight = Get( tensors, $"{b}mlp.0.weight" ),
                Mlp0Bias = Get( tensors, $"{b}mlp.0.bias" ),
                Mlp2Weight = Get( tensors, $"{b}mlp.2.weight" ),
                Mlp2Bias = Get( tensors, $"{b}mlp.2.bias" ),
                AdaWeight = Get( tensors, $"{b}adaLN_modulation.1.weight" ),
                AdaBias = Get( tensors, $"{b}adaLN_modulation.1.bias" ),
            } );
        }
        if ( resBlocks.Count == 0 )
            throw new FormatException( "RigAnything checkpoint: no diffusion res blocks found." );

        var start = Get( tensors, $"{p}start_token" );
        return new RigAnythingModel
        {
            PcTok0 = Get( tensors, $"{p}pc_tokenizer.0.weight" ),
            PcTok2 = Get( tensors, $"{p}pc_tokenizer.2.weight" ),
            JointTok0 = Get( tensors, $"{p}joint_tokenizer.0.weight" ),
            JointTok2 = Get( tensors, $"{p}joint_tokenizer.2.weight" ),
            JointMlp0 = Get( tensors, $"{p}joint_mlp.0.weight" ),
            JointMlp2 = Get( tensors, $"{p}joint_mlp.2.weight" ),
            FuseMlp0 = Get( tensors, $"{p}joint_fuse_mlp.0.weight" ),
            FuseMlp2 = Get( tensors, $"{p}joint_fuse_mlp.2.weight" ),
            SkinMlp0 = Get( tensors, $"{p}skinning_mlp.0.weight" ),
            SkinMlp2 = Get( tensors, $"{p}skinning_mlp.2.weight" ),
            ParentLnGamma = Get( tensors, $"{p}parents_decoder.0.weight" ),
            ParentLnBeta = null,
            Parent1 = Get( tensors, $"{p}parents_decoder.1.weight" ),
            Parent3 = Get( tensors, $"{p}parents_decoder.3.weight" ),
            StartToken = Tensor.From( start.Data, 1, RigAnythingModel.D ),
            InputLnGamma = Get( tensors, $"{p}transformer_input_layernorm.weight" ),
            InputLnBeta = null,
            Blocks = blocks.ToArray(),
            Diffusion = new RigAnythingDiffusion
            {
                InputProjWeight = Get( tensors, $"{p}diffloss.net.input_proj.weight" ),
                InputProjBias = Get( tensors, $"{p}diffloss.net.input_proj.bias" ),
                Time0Weight = Get( tensors, $"{p}diffloss.net.time_embed.mlp.0.weight" ),
                Time0Bias = Get( tensors, $"{p}diffloss.net.time_embed.mlp.0.bias" ),
                Time2Weight = Get( tensors, $"{p}diffloss.net.time_embed.mlp.2.weight" ),
                Time2Bias = Get( tensors, $"{p}diffloss.net.time_embed.mlp.2.bias" ),
                CondWeight = Get( tensors, $"{p}diffloss.net.cond_embed.weight" ),
                CondBias = Get( tensors, $"{p}diffloss.net.cond_embed.bias" ),
                ResBlocks = resBlocks.ToArray(),
                FinalAdaWeight = Get( tensors, $"{p}diffloss.net.final_layer.adaLN_modulation.1.weight" ),
                FinalAdaBias = Get( tensors, $"{p}diffloss.net.final_layer.adaLN_modulation.1.bias" ),
                FinalLinWeight = Get( tensors, $"{p}diffloss.net.final_layer.linear.weight" ),
                FinalLinBias = Get( tensors, $"{p}diffloss.net.final_layer.linear.bias" ),
            },
        };
    }

    static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
        => tensors.TryGetValue( key, out var tensor )
            ? tensor
            : throw new FormatException( $"RigAnything checkpoint is missing tensor '{key}'." );
}