AutoRig/Dl/MagicArticulate/MagicArticulateLoader.cs

Loader and model definition for MagicArticulate, a ShapeVAE-derived model. MagicArticulateModel holds network component tensors and provides EncodePrefix to build a decoder prefix from point cloud encoder output. MagicArticulateLoader constructs the model from a dictionary of named Tensor objects, loading encoder, decoder blocks, OPT decoder layers and projection tables.

File AccessNetworking
using AutoRig.Dl.Nn;
using AutoRig.Dl.UniRig;

namespace AutoRig.Dl.MagicArticulate;

using AutoRig.Dl;
using Vector3 = System.Numerics.Vector3;

/// <summary>
/// MagicArticulate (Seed3D, Apache-2.0) networks bound from the hier checkpoint:
/// the ORIGINAL Michelangelo ShapeVAE-256 (257 LEARNED queries over all 8192
/// points, heads 12, 8 encoder blocks, KL mean + 16-block decode stack), cond
/// projections into a 257×1024 prefix, and a POST-LN opt-350m decoder with
/// bone-position and cond-flag embeddings. Vocab 131: bos 0, eos 1, pad 2,
/// coordinate bins (id − 3) in 128 steps over (−0.5, 0.5).
/// </summary>
public sealed class MagicArticulateModel
{
    public const int NumDiscrete = 128;
    public const int TokenBos = 0, TokenEos = 1, TokenPad = 2;
    public const int CondLength = 257;
    public const int BonePerToken = 6;

    public required PerceiverEncoder Encoder;      // learned-query variant, heads 12
    public required Tensor Query;                  // [257, 768]
    public required Tensor PreKlWeight, PreKlBias;   // [128, 768] — mean = first 64
    public required Tensor PostKlWeight, PostKlBias; // [768, 64]
    public required PerceiverSelfBlock[] DecodeBlocks; // 16 (VAE decoder stack)
    public required Tensor CondHeadProjWeight, CondHeadProjBias; // 768 → 1024
    public required Tensor CondProjWeight, CondProjBias;
    public required OptDecoder Decoder;            // PostLn, no final norm
    public required Tensor BoneTable;              // [9, 1024] token_embed_positions
    public required Tensor CondTable;              // [2, 1024] cond_embed

    /// <summary>undiscretize: t/128 · (hi−lo) + lo — NO bin centering (reference).</summary>
    public static float Undiscretize( int bin ) => bin / (float)NumDiscrete - 0.5f;

    /// <summary>Bone-position embedding index for the k-th generated token
    /// (1-based, BOS is k=1): specials keep their id; coords cycle (k−2)%N+3
    /// (N = 6 here; Puppeteer's joint-token mode uses 4).</summary>
    public static int BonePositionId( int k, int tokenId, int bonePerToken = BonePerToken )
        => tokenId is TokenBos or TokenEos or TokenPad ? tokenId : (k - 2) % bonePerToken + 3;

    /// <summary>Point cloud → the 257×1024 decoder prefix (cond_embed[0] NOT yet
    /// added — the solver adds it, mirroring ShapeOPTDecoder's inputs_embeds path).</summary>
    public Tensor EncodePrefix( Vector3[] points, Vector3[] normals )
    {
        var data = Encoder.Project( points, normals );
        var x = Query;
        x = Encoder.Cross.Forward( x, data, Encoder.Heads );
        foreach ( var block in Encoder.SelfBlocks )
            x = block.Forward( x, Encoder.Heads );
        x = TransformerOps.LayerNorm( x, Encoder.LnPostGamma, Encoder.LnPostBeta );

        var shapeEmbed = x.Slice( 0, 0, 1 );
        var latents = x.Slice( 0, 1, x.Shape[0] - 1 );

        var moments = latents.MatMul( PreKlWeight.Transposed ).Add( PreKlBias );
        var zq = moments.Slice( 1, 0, moments.Shape[1] / 2 );   // DiagonalGaussian mean
        var d = zq.MatMul( PostKlWeight.Transposed ).Add( PostKlBias );
        foreach ( var block in DecodeBlocks )
            d = block.Forward( d, Encoder.Heads );

        var head = shapeEmbed.MatMul( CondHeadProjWeight.Transposed ).Add( CondHeadProjBias );
        var rest = d.MatMul( CondProjWeight.Transposed ).Add( CondProjBias );
        return head.Concat( rest, 0 );
    }
}

public static class MagicArticulateLoader
{
    public static MagicArticulateModel LoadModel( IReadOnlyDictionary<string, Tensor> tensors )
    {
        ArgumentNullException.ThrowIfNull( tensors );
        var prefix = tensors.Keys.Any( k => k.StartsWith( "model.", StringComparison.Ordinal ) )
            ? "model."
            : "";
        var shape = $"{prefix}point_encoder.model.shape_model.";

        var decodeBlocks = new List<PerceiverSelfBlock>();
        while ( tensors.ContainsKey( $"{shape}transformer.resblocks.{decodeBlocks.Count}.ln_1.weight" ) )
            decodeBlocks.Add( SelfBlock( tensors, $"{shape}transformer.resblocks.{decodeBlocks.Count}." ) );
        if ( decodeBlocks.Count == 0 )
            throw new FormatException( "MagicArticulate checkpoint: no VAE decode blocks found." );

        var encoder = UniRigLoader.LoadPerceiver( tensors, $"{shape}encoder.", heads: 12 );

        return new MagicArticulateModel
        {
            Encoder = encoder,
            Query = Get( tensors, $"{shape}encoder.query" ),
            PreKlWeight = Get( tensors, $"{shape}pre_kl.weight" ),
            PreKlBias = Get( tensors, $"{shape}pre_kl.bias" ),
            PostKlWeight = Get( tensors, $"{shape}post_kl.weight" ),
            PostKlBias = Get( tensors, $"{shape}post_kl.bias" ),
            DecodeBlocks = decodeBlocks.ToArray(),
            CondHeadProjWeight = Get( tensors, $"{prefix}cond_head_proj.weight" ),
            CondHeadProjBias = Get( tensors, $"{prefix}cond_head_proj.bias" ),
            CondProjWeight = Get( tensors, $"{prefix}cond_proj.weight" ),
            CondProjBias = Get( tensors, $"{prefix}cond_proj.bias" ),
            Decoder = LoadOpt( tensors, $"{prefix}transformer." ),
            BoneTable = Get( tensors, $"{prefix}transformer.model.decoder.token_embed_positions.weight" ),
            CondTable = Get( tensors, $"{prefix}transformer.model.decoder.cond_embed.weight" ),
        };
    }

    static OptDecoder LoadOpt( IReadOnlyDictionary<string, Tensor> tensors, string prefix )
    {
        var d = $"{prefix}model.decoder.";
        var layers = new List<OptLayer>();
        while ( tensors.ContainsKey( $"{d}layers.{layers.Count}.self_attn.q_proj.weight" ) )
        {
            var l = $"{d}layers.{layers.Count}.";
            layers.Add( new OptLayer
            {
                QWeight = Get( tensors, $"{l}self_attn.q_proj.weight" ),
                QBias = Get( tensors, $"{l}self_attn.q_proj.bias" ),
                KWeight = Get( tensors, $"{l}self_attn.k_proj.weight" ),
                KBias = Get( tensors, $"{l}self_attn.k_proj.bias" ),
                VWeight = Get( tensors, $"{l}self_attn.v_proj.weight" ),
                VBias = Get( tensors, $"{l}self_attn.v_proj.bias" ),
                OutWeight = Get( tensors, $"{l}self_attn.out_proj.weight" ),
                OutBias = Get( tensors, $"{l}self_attn.out_proj.bias" ),
                AttnLnGamma = Get( tensors, $"{l}self_attn_layer_norm.weight" ),
                AttnLnBeta = Get( tensors, $"{l}self_attn_layer_norm.bias" ),
                FfnLnGamma = Get( tensors, $"{l}final_layer_norm.weight" ),
                FfnLnBeta = Get( tensors, $"{l}final_layer_norm.bias" ),
                Fc1Weight = Get( tensors, $"{l}fc1.weight" ),
                Fc1Bias = Get( tensors, $"{l}fc1.bias" ),
                Fc2Weight = Get( tensors, $"{l}fc2.weight" ),
                Fc2Bias = Get( tensors, $"{l}fc2.bias" ),
            } );
        }
        if ( layers.Count == 0 )
            throw new FormatException( "MagicArticulate checkpoint: no OPT layers found." );

        return new OptDecoder
        {
            TokenEmbedding = Get( tensors, $"{d}embed_tokens.weight" ),
            PositionEmbedding = Get( tensors, $"{d}embed_positions.weight" ),
            Layers = layers.ToArray(),
            LmHeadWeight = Get( tensors, $"{prefix}lm_head.weight" ),
            Heads = 16,
            PostLn = true,   // real opt-350m: do_layer_norm_before=False, no final norm
        };
    }

    static PerceiverSelfBlock SelfBlock( IReadOnlyDictionary<string, Tensor> tensors, string b )
        => new()
        {
            QkvWeight = Get( tensors, $"{b}attn.c_qkv.weight" ),
            ProjWeight = Get( tensors, $"{b}attn.c_proj.weight" ),
            ProjBias = Get( tensors, $"{b}attn.c_proj.bias" ),
            Ln1Gamma = Get( tensors, $"{b}ln_1.weight" ),
            Ln1Beta = Get( tensors, $"{b}ln_1.bias" ),
            Ln2Gamma = Get( tensors, $"{b}ln_2.weight" ),
            Ln2Beta = Get( tensors, $"{b}ln_2.bias" ),
            FcWeight = Get( tensors, $"{b}mlp.c_fc.weight" ),
            FcBias = Get( tensors, $"{b}mlp.c_fc.bias" ),
            Fc2Weight = Get( tensors, $"{b}mlp.c_proj.weight" ),
            Fc2Bias = Get( tensors, $"{b}mlp.c_proj.bias" ),
        };

    static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
        => tensors.TryGetValue( key, out var tensor )
            ? tensor
            : throw new FormatException( $"MagicArticulate checkpoint is missing tensor '{key}'." );
}