Loader for the RigAnything model checkpoint. It reads named tensors from a provided dictionary and assembles a RigAnythingModel, parsing transformer blocks and diffusion residual blocks and mapping weights/biases into model structures.
namespace AutoRig.Dl.RigAnything;
using AutoRig.Dl;
/// <summary>Binds riganything_ckpt.pt (keys verified in
/// dev/riganything-reference/ckpt_keys.json, 123 f32 tensors under model.*).
/// All transformer LayerNorms are γ-only (bias=False) — betas stay null.</summary>
public static class RigAnythingLoader
{
public static RigAnythingModel LoadModel( IReadOnlyDictionary<string, Tensor> tensors )
{
ArgumentNullException.ThrowIfNull( tensors );
var p = tensors.Keys.Any( k => k.StartsWith( "model.", StringComparison.Ordinal ) )
? "model."
: "";
var blocks = new List<RigAnythingBlock>();
while ( tensors.ContainsKey( $"{p}transformer.{blocks.Count}.attn.to_qkv.weight" ) )
{
var b = $"{p}transformer.{blocks.Count}.";
blocks.Add( new RigAnythingBlock
{
Norm1Gamma = Get( tensors, $"{b}norm1.weight" ),
Norm1Beta = null,
QkvWeight = Get( tensors, $"{b}attn.to_qkv.weight" ),
FcWeight = Get( tensors, $"{b}attn.fc.weight" ),
Norm2Gamma = Get( tensors, $"{b}norm2.weight" ),
Norm2Beta = null,
Mlp0Weight = Get( tensors, $"{b}mlp.mlp.0.weight" ),
Mlp2Weight = Get( tensors, $"{b}mlp.mlp.2.weight" ),
} );
}
if ( blocks.Count == 0 )
throw new FormatException( "RigAnything checkpoint: no transformer blocks found." );
var resBlocks = new List<RigAnythingDiffusion.DiffResBlock>();
while ( tensors.ContainsKey( $"{p}diffloss.net.res_blocks.{resBlocks.Count}.in_ln.weight" ) )
{
var b = $"{p}diffloss.net.res_blocks.{resBlocks.Count}.";
resBlocks.Add( new RigAnythingDiffusion.DiffResBlock
{
InLnGamma = Get( tensors, $"{b}in_ln.weight" ),
InLnBeta = Get( tensors, $"{b}in_ln.bias" ),
Mlp0Weight = Get( tensors, $"{b}mlp.0.weight" ),
Mlp0Bias = Get( tensors, $"{b}mlp.0.bias" ),
Mlp2Weight = Get( tensors, $"{b}mlp.2.weight" ),
Mlp2Bias = Get( tensors, $"{b}mlp.2.bias" ),
AdaWeight = Get( tensors, $"{b}adaLN_modulation.1.weight" ),
AdaBias = Get( tensors, $"{b}adaLN_modulation.1.bias" ),
} );
}
if ( resBlocks.Count == 0 )
throw new FormatException( "RigAnything checkpoint: no diffusion res blocks found." );
var start = Get( tensors, $"{p}start_token" );
return new RigAnythingModel
{
PcTok0 = Get( tensors, $"{p}pc_tokenizer.0.weight" ),
PcTok2 = Get( tensors, $"{p}pc_tokenizer.2.weight" ),
JointTok0 = Get( tensors, $"{p}joint_tokenizer.0.weight" ),
JointTok2 = Get( tensors, $"{p}joint_tokenizer.2.weight" ),
JointMlp0 = Get( tensors, $"{p}joint_mlp.0.weight" ),
JointMlp2 = Get( tensors, $"{p}joint_mlp.2.weight" ),
FuseMlp0 = Get( tensors, $"{p}joint_fuse_mlp.0.weight" ),
FuseMlp2 = Get( tensors, $"{p}joint_fuse_mlp.2.weight" ),
SkinMlp0 = Get( tensors, $"{p}skinning_mlp.0.weight" ),
SkinMlp2 = Get( tensors, $"{p}skinning_mlp.2.weight" ),
ParentLnGamma = Get( tensors, $"{p}parents_decoder.0.weight" ),
ParentLnBeta = null,
Parent1 = Get( tensors, $"{p}parents_decoder.1.weight" ),
Parent3 = Get( tensors, $"{p}parents_decoder.3.weight" ),
StartToken = Tensor.From( start.Data, 1, RigAnythingModel.D ),
InputLnGamma = Get( tensors, $"{p}transformer_input_layernorm.weight" ),
InputLnBeta = null,
Blocks = blocks.ToArray(),
Diffusion = new RigAnythingDiffusion
{
InputProjWeight = Get( tensors, $"{p}diffloss.net.input_proj.weight" ),
InputProjBias = Get( tensors, $"{p}diffloss.net.input_proj.bias" ),
Time0Weight = Get( tensors, $"{p}diffloss.net.time_embed.mlp.0.weight" ),
Time0Bias = Get( tensors, $"{p}diffloss.net.time_embed.mlp.0.bias" ),
Time2Weight = Get( tensors, $"{p}diffloss.net.time_embed.mlp.2.weight" ),
Time2Bias = Get( tensors, $"{p}diffloss.net.time_embed.mlp.2.bias" ),
CondWeight = Get( tensors, $"{p}diffloss.net.cond_embed.weight" ),
CondBias = Get( tensors, $"{p}diffloss.net.cond_embed.bias" ),
ResBlocks = resBlocks.ToArray(),
FinalAdaWeight = Get( tensors, $"{p}diffloss.net.final_layer.adaLN_modulation.1.weight" ),
FinalAdaBias = Get( tensors, $"{p}diffloss.net.final_layer.adaLN_modulation.1.bias" ),
FinalLinWeight = Get( tensors, $"{p}diffloss.net.final_layer.linear.weight" ),
FinalLinBias = Get( tensors, $"{p}diffloss.net.final_layer.linear.bias" ),
},
};
}
static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
=> tensors.TryGetValue( key, out var tensor )
? tensor
: throw new FormatException( $"RigAnything checkpoint is missing tensor '{key}'." );
}