Loader for the UniRig skeleton-stage model checkpoint. It maps tensors from a checkpoint dictionary into C# model structs: a Perceiver encoder, OPT decoder layers, and the output projection, validating presence of expected tensors and throwing FormatException if missing.
namespace AutoRig.Dl.UniRig;
using AutoRig.Dl;
/// <summary>UniRig skeleton-stage networks, loaded and ready to run.</summary>
public sealed class UniRigSkeletonModel
{
public required PerceiverEncoder MeshEncoder;
public required Tensor OutputProjWeight, OutputProjBias; // 512 → 1024
public required OptDecoder Decoder;
}
/// <summary>
/// Binds the UniRig skeleton checkpoint (Lightning zip-torch, keys verified in
/// dev/unirig-reference/skeleton_keys.json) onto the C# modules: 24-layer OPT
/// decoder + Michelangelo perceiver + the latent output projection.
/// </summary>
public static class UniRigLoader
{
public static UniRigSkeletonModel LoadSkeleton( IReadOnlyDictionary<string, Tensor> tensors )
{
ArgumentNullException.ThrowIfNull( tensors );
var prefix = tensors.Keys.Any( k => k.StartsWith( "state_dict.", StringComparison.Ordinal ) )
? "state_dict."
: "";
return new UniRigSkeletonModel
{
MeshEncoder = LoadPerceiver( tensors, $"{prefix}model.mesh_encoder.encoder." ),
OutputProjWeight = Get( tensors, $"{prefix}model.output_proj.weight" ),
OutputProjBias = Get( tensors, $"{prefix}model.output_proj.bias" ),
Decoder = LoadOpt( tensors, $"{prefix}model.transformer." ),
};
}
static OptDecoder LoadOpt( IReadOnlyDictionary<string, Tensor> tensors, string prefix )
{
var decoder = $"{prefix}model.decoder.";
var layers = new List<OptLayer>();
while ( tensors.ContainsKey( $"{decoder}layers.{layers.Count}.self_attn.q_proj.weight" ) )
{
var l = $"{decoder}layers.{layers.Count}.";
layers.Add( new OptLayer
{
QWeight = Get( tensors, $"{l}self_attn.q_proj.weight" ),
QBias = Get( tensors, $"{l}self_attn.q_proj.bias" ),
KWeight = Get( tensors, $"{l}self_attn.k_proj.weight" ),
KBias = Get( tensors, $"{l}self_attn.k_proj.bias" ),
VWeight = Get( tensors, $"{l}self_attn.v_proj.weight" ),
VBias = Get( tensors, $"{l}self_attn.v_proj.bias" ),
OutWeight = Get( tensors, $"{l}self_attn.out_proj.weight" ),
OutBias = Get( tensors, $"{l}self_attn.out_proj.bias" ),
AttnLnGamma = Get( tensors, $"{l}self_attn_layer_norm.weight" ),
AttnLnBeta = Get( tensors, $"{l}self_attn_layer_norm.bias" ),
FfnLnGamma = Get( tensors, $"{l}final_layer_norm.weight" ),
FfnLnBeta = Get( tensors, $"{l}final_layer_norm.bias" ),
Fc1Weight = Get( tensors, $"{l}fc1.weight" ),
Fc1Bias = Get( tensors, $"{l}fc1.bias" ),
Fc2Weight = Get( tensors, $"{l}fc2.weight" ),
Fc2Bias = Get( tensors, $"{l}fc2.bias" ),
} );
}
if ( layers.Count == 0 )
throw new FormatException( "UniRig checkpoint: no OPT decoder layers found." );
return new OptDecoder
{
TokenEmbedding = Get( tensors, $"{decoder}embed_tokens.weight" ),
PositionEmbedding = Get( tensors, $"{decoder}embed_positions.weight" ),
Layers = layers.ToArray(),
FinalLnGamma = Get( tensors, $"{decoder}final_layer_norm.weight" ),
FinalLnBeta = Get( tensors, $"{decoder}final_layer_norm.bias" ),
LmHeadWeight = Get( tensors, $"{prefix}lm_head.weight" ),
Heads = 16, // OPT-350m family
};
}
internal static PerceiverEncoder LoadPerceiver(
IReadOnlyDictionary<string, Tensor> tensors, string prefix, int heads = 8 )
{
PerceiverSelfBlock SelfBlock( string b ) => new()
{
QkvWeight = Get( tensors, $"{b}attn.c_qkv.weight" ),
ProjWeight = Get( tensors, $"{b}attn.c_proj.weight" ),
ProjBias = Get( tensors, $"{b}attn.c_proj.bias" ),
Ln1Gamma = Get( tensors, $"{b}ln_1.weight" ),
Ln1Beta = Get( tensors, $"{b}ln_1.bias" ),
Ln2Gamma = Get( tensors, $"{b}ln_2.weight" ),
Ln2Beta = Get( tensors, $"{b}ln_2.bias" ),
FcWeight = Get( tensors, $"{b}mlp.c_fc.weight" ),
FcBias = Get( tensors, $"{b}mlp.c_fc.bias" ),
Fc2Weight = Get( tensors, $"{b}mlp.c_proj.weight" ),
Fc2Bias = Get( tensors, $"{b}mlp.c_proj.bias" ),
};
var blocks = new List<PerceiverSelfBlock>();
while ( tensors.ContainsKey( $"{prefix}self_attn.resblocks.{blocks.Count}.attn.c_qkv.weight" ) )
blocks.Add( SelfBlock( $"{prefix}self_attn.resblocks.{blocks.Count}." ) );
if ( blocks.Count == 0 )
throw new FormatException( "UniRig checkpoint: no perceiver self-attention blocks found." );
var cross = $"{prefix}cross_attn.";
return new PerceiverEncoder
{
Heads = heads,
InputProjWeight = Get( tensors, $"{prefix}input_proj.weight" ),
InputProjBias = Get( tensors, $"{prefix}input_proj.bias" ),
Cross = new PerceiverCrossBlock
{
QWeight = Get( tensors, $"{cross}attn.c_q.weight" ),
KvWeight = Get( tensors, $"{cross}attn.c_kv.weight" ),
ProjWeight = Get( tensors, $"{cross}attn.c_proj.weight" ),
ProjBias = Get( tensors, $"{cross}attn.c_proj.bias" ),
Ln1Gamma = Get( tensors, $"{cross}ln_1.weight" ),
Ln1Beta = Get( tensors, $"{cross}ln_1.bias" ),
Ln2Gamma = Get( tensors, $"{cross}ln_2.weight" ),
Ln2Beta = Get( tensors, $"{cross}ln_2.bias" ),
Ln3Gamma = Get( tensors, $"{cross}ln_3.weight" ),
Ln3Beta = Get( tensors, $"{cross}ln_3.bias" ),
FcWeight = Get( tensors, $"{cross}mlp.c_fc.weight" ),
FcBias = Get( tensors, $"{cross}mlp.c_fc.bias" ),
Fc2Weight = Get( tensors, $"{cross}mlp.c_proj.weight" ),
Fc2Bias = Get( tensors, $"{cross}mlp.c_proj.bias" ),
},
SelfBlocks = blocks.ToArray(),
LnPostGamma = Get( tensors, $"{prefix}ln_post.weight" ),
LnPostBeta = Get( tensors, $"{prefix}ln_post.bias" ),
};
}
static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
=> tensors.TryGetValue( key, out var tensor )
? tensor
: throw new FormatException( $"UniRig checkpoint is missing tensor '{key}'." );
}