Loader for the SkinTokens model that binds tensors from a checkpoint into typed C# model objects. It constructs a SkinTokensModel by loading a Perceiver encoder, Qwen3 decoder layers, projection weights and a VAE from a dictionary of named Tensor objects.
namespace AutoRig.Dl.UniRig;
using AutoRig.Dl;
/// <summary>SkinTokens (TokenRig) networks bound from the grpo checkpoint.</summary>
public sealed class SkinTokensModel
{
public required PerceiverEncoder MeshEncoder; // 8-block Michelangelo variant
public required Tensor OutputProjWeight, OutputProjBias; // Linear 512 → 896
public required Tensor OutputNormGamma; // RMSNorm(896)
public required Qwen3Decoder Decoder;
public required SkinVae Vae; // FSQ skin decoder (vae.model.*)
/// <summary>Vocab: [0, SkeletonVocab) skeleton grammar (whose EOS = 258 doubles
/// as the skeleton→skin PHASE SWITCH), then the FSQ skin codebook, and the
/// UNIFIED EOS as the last id (verified: switch_token_id = tokenizer.eos,
/// self.eos = vocab_size − 1 in tokenrig.py).</summary>
public int SkeletonVocab => UniRigTokenizer.VocabSize;
public int SwitchToken => UniRigTokenizer.TokenEos;
public int UnifiedEos => Decoder.TokenEmbedding.Shape[0] - 1;
public int TokensPerSkin { get; init; } = 4; // "token_4" experiment config
}
/// <summary>
/// Binds the SkinTokens grpo checkpoint (bf16 Lightning zip; keys verified in
/// dev/skintokens-reference/grpo_keys.json) onto the C# modules: 28-layer Qwen3
/// decoder + the 8-block perceiver + the RMS-normed latent projection.
/// (The in-checkpoint skin FSQ-VAE binds in a later step.)
/// </summary>
public static class SkinTokensLoader
{
public static SkinTokensModel LoadModel( IReadOnlyDictionary<string, Tensor> tensors )
{
ArgumentNullException.ThrowIfNull( tensors );
var prefix = tensors.Keys.Any( k => k.StartsWith( "state_dict.", StringComparison.Ordinal ) )
? "state_dict."
: "";
return new SkinTokensModel
{
MeshEncoder = UniRigLoader.LoadPerceiver( tensors, $"{prefix}mesh_encoder.encoder." ),
OutputProjWeight = Get( tensors, $"{prefix}output_proj.0.weight" ),
OutputProjBias = Get( tensors, $"{prefix}output_proj.0.bias" ),
OutputNormGamma = Get( tensors, $"{prefix}output_proj.1.weight" ),
Decoder = LoadQwen3( tensors, $"{prefix}transformer." ),
Vae = SkinVaeLoader.Load( tensors ),
};
}
static Qwen3Decoder LoadQwen3( IReadOnlyDictionary<string, Tensor> tensors, string prefix )
{
var model = $"{prefix}model.";
var layers = new List<Qwen3Layer>();
while ( tensors.ContainsKey( $"{model}layers.{layers.Count}.self_attn.q_proj.weight" ) )
{
var l = $"{model}layers.{layers.Count}.";
layers.Add( new Qwen3Layer
{
QWeight = Get( tensors, $"{l}self_attn.q_proj.weight" ),
KWeight = Get( tensors, $"{l}self_attn.k_proj.weight" ),
VWeight = Get( tensors, $"{l}self_attn.v_proj.weight" ),
OWeight = Get( tensors, $"{l}self_attn.o_proj.weight" ),
QNormGamma = Get( tensors, $"{l}self_attn.q_norm.weight" ),
KNormGamma = Get( tensors, $"{l}self_attn.k_norm.weight" ),
InputLnGamma = Get( tensors, $"{l}input_layernorm.weight" ),
PostAttnLnGamma = Get( tensors, $"{l}post_attention_layernorm.weight" ),
GateWeight = Get( tensors, $"{l}mlp.gate_proj.weight" ),
UpWeight = Get( tensors, $"{l}mlp.up_proj.weight" ),
DownWeight = Get( tensors, $"{l}mlp.down_proj.weight" ),
} );
}
if ( layers.Count == 0 )
throw new FormatException( "SkinTokens checkpoint: no Qwen3 layers found." );
// Dims from the weights themselves: head_dim = q_norm length.
var headDim = layers[0].QNormGamma.Shape[0];
return new Qwen3Decoder
{
TokenEmbedding = Get( tensors, $"{model}embed_tokens.weight" ),
Layers = layers.ToArray(),
FinalNormGamma = Get( tensors, $"{model}norm.weight" ),
LmHeadWeight = Get( tensors, $"{prefix}lm_head.weight" ),
QHeads = layers[0].QWeight.Shape[0] / headDim,
KvHeads = layers[0].KWeight.Shape[0] / headDim,
};
}
static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
=> tensors.TryGetValue( key, out var tensor )
? tensor
: throw new FormatException( $"SkinTokens checkpoint is missing tensor '{key}'." );
}