Code/AutoRig/Dl/Anymate/AnymateLoader.cs

Loader that binds Anymate model checkpoint tensors into a runtime AnymateModel structure. It inspects provided dictionary keys for expected encoder/decoder blocks, assembles PointBert and Perceiver block arrays, and throws FormatException on missing tensors or unexpected layout.

using AutoRig.Dl.UniRig;

namespace AutoRig.Dl.Anymate;

using AutoRig.Dl;

/// <summary>Binds the three Anymate checkpoints (state_dict.* keys verified in
/// dev/anymate-reference/keys_*.json).</summary>
public static class AnymateLoader
{
    public static AnymateModel Load(
        IReadOnlyDictionary<string, Tensor> joint,
        IReadOnlyDictionary<string, Tensor> conn,
        IReadOnlyDictionary<string, Tensor> skin )
    {
        var connCross = new List<PerceiverCrossBlock>();
        while ( conn.ContainsKey( $"decoder.cross_attn.{connCross.Count}.attn.c_q.weight" ) )
            connCross.Add( Cross( conn, $"decoder.cross_attn.{connCross.Count}." ) );
        var connSelf = new List<PerceiverSelfBlock>();
        while ( conn.ContainsKey( $"decoder.self_attn.{connSelf.Count}.attn.c_qkv.weight" ) )
            connSelf.Add( Self( conn, $"decoder.self_attn.{connSelf.Count}." ) );
        var skinCross = new List<PerceiverCrossBlock>();
        while ( skin.ContainsKey( $"decoder.cross_attn.{skinCross.Count}.attn.c_q.weight" ) )
            skinCross.Add( Cross( skin, $"decoder.cross_attn.{skinCross.Count}." ) );
        var skinCrossJoint = new List<PerceiverCrossBlock>();
        while ( skin.ContainsKey( $"decoder.cross_attn_joint.{skinCrossJoint.Count}.attn.c_q.weight" ) )
            skinCrossJoint.Add( Cross( skin, $"decoder.cross_attn_joint.{skinCrossJoint.Count}." ) );

        if ( connCross.Count == 0 || connSelf.Count != connCross.Count * 2
            || skinCross.Count == 0 || skinCrossJoint.Count != skinCross.Count )
            throw new FormatException( "Anymate checkpoints: unexpected decoder layout." );

        return new AnymateModel
        {
            JointEncoder = Bert( joint ),
            ConnEncoder = Bert( conn ),
            SkinEncoder = Bert( skin ),
            JointQuery = Get( joint, "decoder.query" ),
            JointCross = Cross( joint, "decoder.cross_attn_decoder." ),
            JointLnPostG = Get( joint, "decoder.ln_post.weight" ),
            JointLnPostB = Get( joint, "decoder.ln_post.bias" ),
            JointOutW = Get( joint, "decoder.output_proj.weight" ),
            JointOutB = Get( joint, "decoder.output_proj.bias" ),
            ConnCoProjW = Get( conn, "decoder.co_proj.weight" ),
            ConnCoProjB = Get( conn, "decoder.co_proj.bias" ),
            ConnCross = connCross.ToArray(),
            ConnSelf = connSelf.ToArray(),
            ConnKProjW = Get( conn, "decoder.k_proj.weight" ),
            ConnKProjB = Get( conn, "decoder.k_proj.bias" ),
            ConnQProjW = Get( conn, "decoder.q_proj.weight" ),
            ConnQProjB = Get( conn, "decoder.q_proj.bias" ),
            ConnLn1G = Get( conn, "decoder.ln_1.weight" ),
            ConnLn1B = Get( conn, "decoder.ln_1.bias" ),
            ConnLn2G = Get( conn, "decoder.ln_2.weight" ),
            ConnLn2B = Get( conn, "decoder.ln_2.bias" ),
            SkinCoProjW = Get( skin, "decoder.co_proj.weight" ),
            SkinCoProjB = Get( skin, "decoder.co_proj.bias" ),
            SkinBoneProjW = Get( skin, "decoder.bone_proj.weight" ),
            SkinBoneProjB = Get( skin, "decoder.bone_proj.bias" ),
            SkinNormalProjW = Get( skin, "decoder.normal_proj.weight" ),
            SkinNormalProjB = Get( skin, "decoder.normal_proj.bias" ),
            SkinCross = skinCross.ToArray(),
            SkinCrossJoint = skinCrossJoint.ToArray(),
            SkinKProjW = Get( skin, "decoder.k_proj.weight" ),
            SkinKProjB = Get( skin, "decoder.k_proj.bias" ),
            SkinQProjW = Get( skin, "decoder.q_proj.weight" ),
            SkinQProjB = Get( skin, "decoder.q_proj.bias" ),
            SkinLn1G = Get( skin, "decoder.ln_1.weight" ),
            SkinLn1B = Get( skin, "decoder.ln_1.bias" ),
            SkinLn2G = Get( skin, "decoder.ln_2.weight" ),
            SkinLn2B = Get( skin, "decoder.ln_2.bias" ),
        };
    }

    static PointBert Bert( IReadOnlyDictionary<string, Tensor> t )
    {
        var blocks = new List<PointBert.VitBlock>();
        string Prefix( int i ) => t.ContainsKey( $"encoder.blocks.blocks.{i}.norm1.weight" )
            ? $"encoder.blocks.blocks.{i}."
            : $"encoder.blocks.{i}.";
        while ( t.ContainsKey( $"{Prefix( blocks.Count )}norm1.weight" ) )
        {
            var b = Prefix( blocks.Count );
            blocks.Add( new PointBert.VitBlock
            {
                Norm1G = Get( t, $"{b}norm1.weight" ),
                Norm1B = Get( t, $"{b}norm1.bias" ),
                QkvW = Get( t, $"{b}attn.qkv.weight" ),
                ProjW = Get( t, $"{b}attn.proj.weight" ),
                ProjB = Get( t, $"{b}attn.proj.bias" ),
                Norm2G = Get( t, $"{b}norm2.weight" ),
                Norm2B = Get( t, $"{b}norm2.bias" ),
                Fc1W = Get( t, $"{b}mlp.fc1.weight" ),
                Fc1B = Get( t, $"{b}mlp.fc1.bias" ),
                Fc2W = Get( t, $"{b}mlp.fc2.weight" ),
                Fc2B = Get( t, $"{b}mlp.fc2.bias" ),
            } );
        }
        if ( blocks.Count == 0 )
            throw new FormatException( "Anymate checkpoint: no ViT blocks found." );

        return new PointBert
        {
            FirstConv0W = Get( t, "encoder.encoder.first_conv.0.weight" ),
            FirstConv0B = Get( t, "encoder.encoder.first_conv.0.bias" ),
            FirstBnMean = Get( t, "encoder.encoder.first_conv.1.running_mean" ),
            FirstBnVar = Get( t, "encoder.encoder.first_conv.1.running_var" ),
            FirstBnG = Get( t, "encoder.encoder.first_conv.1.weight" ),
            FirstBnB = Get( t, "encoder.encoder.first_conv.1.bias" ),
            FirstConv3W = Get( t, "encoder.encoder.first_conv.3.weight" ),
            FirstConv3B = Get( t, "encoder.encoder.first_conv.3.bias" ),
            SecondConv0W = Get( t, "encoder.encoder.second_conv.0.weight" ),
            SecondConv0B = Get( t, "encoder.encoder.second_conv.0.bias" ),
            SecondBnMean = Get( t, "encoder.encoder.second_conv.1.running_mean" ),
            SecondBnVar = Get( t, "encoder.encoder.second_conv.1.running_var" ),
            SecondBnG = Get( t, "encoder.encoder.second_conv.1.weight" ),
            SecondBnB = Get( t, "encoder.encoder.second_conv.1.bias" ),
            SecondConv3W = Get( t, "encoder.encoder.second_conv.3.weight" ),
            SecondConv3B = Get( t, "encoder.encoder.second_conv.3.bias" ),
            ReduceDimW = Get( t, "encoder.reduce_dim.weight" ),
            ReduceDimB = Get( t, "encoder.reduce_dim.bias" ),
            ClsToken = Get( t, "encoder.cls_token" ),
            ClsPos = Get( t, "encoder.cls_pos" ),
            PosEmbed0W = Get( t, "encoder.pos_embed.0.weight" ),
            PosEmbed0B = Get( t, "encoder.pos_embed.0.bias" ),
            PosEmbed2W = Get( t, "encoder.pos_embed.2.weight" ),
            PosEmbed2B = Get( t, "encoder.pos_embed.2.bias" ),
            Blocks = blocks.ToArray(),
            NormG = Get( t, "encoder.norm.weight" ),
            NormB = Get( t, "encoder.norm.bias" ),
            Proj0W = Get( t, "point_proj.0.weight" ),
            Proj0B = Get( t, "point_proj.0.bias" ),
            Proj2W = Get( t, "point_proj.2.weight" ),
            Proj2B = Get( t, "point_proj.2.bias" ),
            Proj4W = Get( t, "point_proj.4.weight" ),
            Proj4B = Get( t, "point_proj.4.bias" ),
        };
    }

    static PerceiverCrossBlock Cross( IReadOnlyDictionary<string, Tensor> t, string b ) => new()
    {
        QWeight = Get( t, $"{b}attn.c_q.weight" ),
        KvWeight = Get( t, $"{b}attn.c_kv.weight" ),
        ProjWeight = Get( t, $"{b}attn.c_proj.weight" ),
        ProjBias = Get( t, $"{b}attn.c_proj.bias" ),
        Ln1Gamma = Get( t, $"{b}ln_1.weight" ),
        Ln1Beta = Get( t, $"{b}ln_1.bias" ),
        Ln2Gamma = Get( t, $"{b}ln_2.weight" ),
        Ln2Beta = Get( t, $"{b}ln_2.bias" ),
        Ln3Gamma = Get( t, $"{b}ln_3.weight" ),
        Ln3Beta = Get( t, $"{b}ln_3.bias" ),
        FcWeight = Get( t, $"{b}mlp.c_fc.weight" ),
        FcBias = Get( t, $"{b}mlp.c_fc.bias" ),
        Fc2Weight = Get( t, $"{b}mlp.c_proj.weight" ),
        Fc2Bias = Get( t, $"{b}mlp.c_proj.bias" ),
    };

    static PerceiverSelfBlock Self( IReadOnlyDictionary<string, Tensor> t, string b ) => new()
    {
        QkvWeight = Get( t, $"{b}attn.c_qkv.weight" ),
        ProjWeight = Get( t, $"{b}attn.c_proj.weight" ),
        ProjBias = Get( t, $"{b}attn.c_proj.bias" ),
        Ln1Gamma = Get( t, $"{b}ln_1.weight" ),
        Ln1Beta = Get( t, $"{b}ln_1.bias" ),
        Ln2Gamma = Get( t, $"{b}ln_2.weight" ),
        Ln2Beta = Get( t, $"{b}ln_2.bias" ),
        FcWeight = Get( t, $"{b}mlp.c_fc.weight" ),
        FcBias = Get( t, $"{b}mlp.c_fc.bias" ),
        Fc2Weight = Get( t, $"{b}mlp.c_proj.weight" ),
        Fc2Bias = Get( t, $"{b}mlp.c_proj.bias" ),
    };

    static Tensor Get( IReadOnlyDictionary<string, Tensor> t, string key )
        => t.TryGetValue( key, out var tensor )
            ? tensor
            : throw new FormatException( $"Anymate checkpoint is missing tensor '{key}'." );
}