Code/AutoRig/Dl/Puppeteer/PuppeteerSkinLoader.cs

Loader for a Puppeteer skin model, it reads named Tensor entries from a dictionary and constructs a PuppeteerSkinModel by mapping tensor keys to model substructures, loading Perceiver blocks and attention/FF components.

File AccessNetworking
using AutoRig.Dl.UniRig;

namespace AutoRig.Dl.Puppeteer;

using AutoRig.Dl;

/// <summary>Binds the whole skinning stack from puppeteer_skin.pth (596
/// tensors: PartField + Michelangelo + the block, all in one file).</summary>
public static class PuppeteerSkinLoader
{
    public static PuppeteerSkinModel Load( IReadOnlyDictionary<string, Tensor> tensors )
    {
        ArgumentNullException.ThrowIfNull( tensors );
        var prefix = tensors.Keys.Any( k => k.StartsWith( "model.", StringComparison.Ordinal ) )
            ? "model."
            : "";
        var shape = $"{prefix}point_encoder.model.shape_model.";

        PuppeteerSkinModel.AttentionUnit Unit( string b, bool cross ) => new()
        {
            NormG = Get( tensors, $"{b}.norm.weight" ),
            NormB = Get( tensors, $"{b}.norm.bias" ),
            ContextNormG = cross ? Get( tensors, $"{b}.norm_context.weight" ) : null,
            ContextNormB = cross ? Get( tensors, $"{b}.norm_context.bias" ) : null,
            ToQ = Get( tensors, $"{b}.fn.to_q.weight" ),
            ToKv = Get( tensors, $"{b}.fn.to_kv.weight" ),
            ToOutW = Get( tensors, $"{b}.fn.to_out.weight" ),
            ToOutB = Get( tensors, $"{b}.fn.to_out.bias" ),
            FfNormG = Get( tensors, $"{FfName( b )}.norm.weight" ),
            FfNormB = Get( tensors, $"{FfName( b )}.norm.bias" ),
            FfInW = Get( tensors, $"{FfName( b )}.fn.net.0.weight" ),
            FfInB = Get( tensors, $"{FfName( b )}.fn.net.0.bias" ),
            FfOutW = Get( tensors, $"{FfName( b )}.fn.net.2.weight" ),
            FfOutB = Get( tensors, $"{FfName( b )}.fn.net.2.bias" ),
        };

        string FfName( string attnName ) => attnName
            .Replace( "self_attn_j", "ff_j" )
            .Replace( "cross_attn_ps", "ff_ps" )
            .Replace( "cross_attn_js", "ff_js" )
            .Replace( "cross_attn_jp", "ff_jp" )
            .Replace( "cross_attn_pj", "ff_pj" );

        var decodeBlocks = new List<PerceiverSelfBlock>();
        while ( tensors.ContainsKey(
            $"{shape}transformer.resblocks.{decodeBlocks.Count}.ln_1.weight" ) )
            decodeBlocks.Add( LoadSelfBlock(
                tensors, $"{shape}transformer.resblocks.{decodeBlocks.Count}." ) );

        var b0 = $"{prefix}blocks.0";
        return new PuppeteerSkinModel
        {
            PartField = PartFieldLoader.Load( tensors, prefix ),
            ShapeEncoder = UniRigLoader.LoadPerceiver( tensors, $"{shape}encoder.", heads: 12 ),
            ShapeQuery = Get( tensors, $"{shape}encoder.query" ),
            PreKlWeight = Get( tensors, $"{shape}pre_kl.weight" ),
            PreKlBias = Get( tensors, $"{shape}pre_kl.bias" ),
            PostKlWeight = Get( tensors, $"{shape}post_kl.weight" ),
            PostKlBias = Get( tensors, $"{shape}post_kl.bias" ),
            ShapeDecodeBlocks = decodeBlocks.ToArray(),
            Proj0W = Get( tensors, $"{prefix}proj.0.weight" ),
            Proj0B = Get( tensors, $"{prefix}proj.0.bias" ),
            Proj2W = Get( tensors, $"{prefix}proj.2.weight" ),
            Proj2B = Get( tensors, $"{prefix}proj.2.bias" ),
            PointPeBasis = Get( tensors, $"{prefix}point_embed_pe.basis" ),
            PointPeW = Get( tensors, $"{prefix}point_embed_pe.mlp.weight" ),
            PointPeB = Get( tensors, $"{prefix}point_embed_pe.mlp.bias" ),
            SkelBasis = Get( tensors, $"{prefix}skeleton_condition.basis" ),
            SkelW = Get( tensors, $"{prefix}skeleton_condition.mlp.weight" ),
            SkelB = Get( tensors, $"{prefix}skeleton_condition.mlp.bias" ),
            Scale = Get( tensors, $"{prefix}scale" ),
            SelfJ = Unit( $"{b0}.self_attn_j", cross: true ),   // context_dim set → norm_context exists
            CrossPs = Unit( $"{b0}.cross_attn_ps", cross: true ),
            CrossJs = Unit( $"{b0}.cross_attn_js", cross: true ),
            CrossJp = Unit( $"{b0}.cross_attn_jp", cross: true ),
            CrossPj = Unit( $"{b0}.cross_attn_pj", cross: true ),
            RelPosEmbedding = Get( tensors, $"{b0}.rel_pos_embedding.weight" ),
            RelPosProjW = Get( tensors, $"{b0}.rel_pos_proj.weight" ),
            RelPosProjB = Get( tensors, $"{b0}.rel_pos_proj.bias" ),
            RelPosScale = Get( tensors, $"{b0}.rel_pos_scale" ),
        };
    }

    static PerceiverSelfBlock LoadSelfBlock(
        IReadOnlyDictionary<string, Tensor> tensors, string b ) => new()
    {
        QkvWeight = Get( tensors, $"{b}attn.c_qkv.weight" ),
        ProjWeight = Get( tensors, $"{b}attn.c_proj.weight" ),
        ProjBias = Get( tensors, $"{b}attn.c_proj.bias" ),
        Ln1Gamma = Get( tensors, $"{b}ln_1.weight" ),
        Ln1Beta = Get( tensors, $"{b}ln_1.bias" ),
        Ln2Gamma = Get( tensors, $"{b}ln_2.weight" ),
        Ln2Beta = Get( tensors, $"{b}ln_2.bias" ),
        FcWeight = Get( tensors, $"{b}mlp.c_fc.weight" ),
        FcBias = Get( tensors, $"{b}mlp.c_fc.bias" ),
        Fc2Weight = Get( tensors, $"{b}mlp.c_proj.weight" ),
        Fc2Bias = Get( tensors, $"{b}mlp.c_proj.bias" ),
    };

    static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
        => PartFieldLoader.Get( tensors, key );
}