Code/AutoRig/Dl/UniRig/SkinTokensLoader.cs

Loader for the SkinTokens model. It maps tensors from a checkpoint dictionary into C# model objects: a Perceiver mesh encoder, a Qwen3 decoder (built from repeated layer tensors), output projection tensors, and an embedded VAE loader.

File AccessNetworking
namespace AutoRig.Dl.UniRig;

using AutoRig.Dl;

/// <summary>SkinTokens (TokenRig) networks bound from the grpo checkpoint.</summary>
public sealed class SkinTokensModel
{
    public required PerceiverEncoder MeshEncoder;      // 8-block Michelangelo variant
    public required Tensor OutputProjWeight, OutputProjBias;   // Linear 512 → 896
    public required Tensor OutputNormGamma;                    // RMSNorm(896)
    public required Qwen3Decoder Decoder;
    public required SkinVae Vae;                       // FSQ skin decoder (vae.model.*)

    /// <summary>Vocab: [0, SkeletonVocab) skeleton grammar (whose EOS = 258 doubles
    /// as the skeleton→skin PHASE SWITCH), then the FSQ skin codebook, and the
    /// UNIFIED EOS as the last id (verified: switch_token_id = tokenizer.eos,
    /// self.eos = vocab_size − 1 in tokenrig.py).</summary>
    public int SkeletonVocab => UniRigTokenizer.VocabSize;
    public int SwitchToken => UniRigTokenizer.TokenEos;
    public int UnifiedEos => Decoder.TokenEmbedding.Shape[0] - 1;
    public int TokensPerSkin { get; init; } = 4;   // "token_4" experiment config
}

/// <summary>
/// Binds the SkinTokens grpo checkpoint (bf16 Lightning zip; keys verified in
/// dev/skintokens-reference/grpo_keys.json) onto the C# modules: 28-layer Qwen3
/// decoder + the 8-block perceiver + the RMS-normed latent projection.
/// (The in-checkpoint skin FSQ-VAE binds in a later step.)
/// </summary>
public static class SkinTokensLoader
{
    public static SkinTokensModel LoadModel( IReadOnlyDictionary<string, Tensor> tensors )
    {
        ArgumentNullException.ThrowIfNull( tensors );
        var prefix = tensors.Keys.Any( k => k.StartsWith( "state_dict.", StringComparison.Ordinal ) )
            ? "state_dict."
            : "";

        return new SkinTokensModel
        {
            MeshEncoder = UniRigLoader.LoadPerceiver( tensors, $"{prefix}mesh_encoder.encoder." ),
            OutputProjWeight = Get( tensors, $"{prefix}output_proj.0.weight" ),
            OutputProjBias = Get( tensors, $"{prefix}output_proj.0.bias" ),
            OutputNormGamma = Get( tensors, $"{prefix}output_proj.1.weight" ),
            Decoder = LoadQwen3( tensors, $"{prefix}transformer." ),
            Vae = SkinVaeLoader.Load( tensors ),
        };
    }

    static Qwen3Decoder LoadQwen3( IReadOnlyDictionary<string, Tensor> tensors, string prefix )
    {
        var model = $"{prefix}model.";
        var layers = new List<Qwen3Layer>();
        while ( tensors.ContainsKey( $"{model}layers.{layers.Count}.self_attn.q_proj.weight" ) )
        {
            var l = $"{model}layers.{layers.Count}.";
            layers.Add( new Qwen3Layer
            {
                QWeight = Get( tensors, $"{l}self_attn.q_proj.weight" ),
                KWeight = Get( tensors, $"{l}self_attn.k_proj.weight" ),
                VWeight = Get( tensors, $"{l}self_attn.v_proj.weight" ),
                OWeight = Get( tensors, $"{l}self_attn.o_proj.weight" ),
                QNormGamma = Get( tensors, $"{l}self_attn.q_norm.weight" ),
                KNormGamma = Get( tensors, $"{l}self_attn.k_norm.weight" ),
                InputLnGamma = Get( tensors, $"{l}input_layernorm.weight" ),
                PostAttnLnGamma = Get( tensors, $"{l}post_attention_layernorm.weight" ),
                GateWeight = Get( tensors, $"{l}mlp.gate_proj.weight" ),
                UpWeight = Get( tensors, $"{l}mlp.up_proj.weight" ),
                DownWeight = Get( tensors, $"{l}mlp.down_proj.weight" ),
            } );
        }
        if ( layers.Count == 0 )
            throw new FormatException( "SkinTokens checkpoint: no Qwen3 layers found." );

        // Dims from the weights themselves: head_dim = q_norm length.
        var headDim = layers[0].QNormGamma.Shape[0];
        return new Qwen3Decoder
        {
            TokenEmbedding = Get( tensors, $"{model}embed_tokens.weight" ),
            Layers = layers.ToArray(),
            FinalNormGamma = Get( tensors, $"{model}norm.weight" ),
            LmHeadWeight = Get( tensors, $"{prefix}lm_head.weight" ),
            QHeads = layers[0].QWeight.Shape[0] / headDim,
            KvHeads = layers[0].KWeight.Shape[0] / headDim,
        };
    }

    static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
        => tensors.TryGetValue( key, out var tensor )
            ? tensor
            : throw new FormatException( $"SkinTokens checkpoint is missing tensor '{key}'." );
}