Loader for Anymate model checkpoints. It reads tensors from provided dictionaries for joint, conn, and skin checkpoints, validates decoder layout, and constructs an AnymateModel with nested Bert, Perceiver cross/self blocks, and various weight/bias tensors.
using AutoRig.Dl.UniRig;
namespace AutoRig.Dl.Anymate;
using AutoRig.Dl;
/// <summary>Binds the three Anymate checkpoints (state_dict.* keys verified in
/// dev/anymate-reference/keys_*.json).</summary>
public static class AnymateLoader
{
public static AnymateModel Load(
IReadOnlyDictionary<string, Tensor> joint,
IReadOnlyDictionary<string, Tensor> conn,
IReadOnlyDictionary<string, Tensor> skin )
{
var connCross = new List<PerceiverCrossBlock>();
while ( conn.ContainsKey( $"decoder.cross_attn.{connCross.Count}.attn.c_q.weight" ) )
connCross.Add( Cross( conn, $"decoder.cross_attn.{connCross.Count}." ) );
var connSelf = new List<PerceiverSelfBlock>();
while ( conn.ContainsKey( $"decoder.self_attn.{connSelf.Count}.attn.c_qkv.weight" ) )
connSelf.Add( Self( conn, $"decoder.self_attn.{connSelf.Count}." ) );
var skinCross = new List<PerceiverCrossBlock>();
while ( skin.ContainsKey( $"decoder.cross_attn.{skinCross.Count}.attn.c_q.weight" ) )
skinCross.Add( Cross( skin, $"decoder.cross_attn.{skinCross.Count}." ) );
var skinCrossJoint = new List<PerceiverCrossBlock>();
while ( skin.ContainsKey( $"decoder.cross_attn_joint.{skinCrossJoint.Count}.attn.c_q.weight" ) )
skinCrossJoint.Add( Cross( skin, $"decoder.cross_attn_joint.{skinCrossJoint.Count}." ) );
if ( connCross.Count == 0 || connSelf.Count != connCross.Count * 2
|| skinCross.Count == 0 || skinCrossJoint.Count != skinCross.Count )
throw new FormatException( "Anymate checkpoints: unexpected decoder layout." );
return new AnymateModel
{
JointEncoder = Bert( joint ),
ConnEncoder = Bert( conn ),
SkinEncoder = Bert( skin ),
JointQuery = Get( joint, "decoder.query" ),
JointCross = Cross( joint, "decoder.cross_attn_decoder." ),
JointLnPostG = Get( joint, "decoder.ln_post.weight" ),
JointLnPostB = Get( joint, "decoder.ln_post.bias" ),
JointOutW = Get( joint, "decoder.output_proj.weight" ),
JointOutB = Get( joint, "decoder.output_proj.bias" ),
ConnCoProjW = Get( conn, "decoder.co_proj.weight" ),
ConnCoProjB = Get( conn, "decoder.co_proj.bias" ),
ConnCross = connCross.ToArray(),
ConnSelf = connSelf.ToArray(),
ConnKProjW = Get( conn, "decoder.k_proj.weight" ),
ConnKProjB = Get( conn, "decoder.k_proj.bias" ),
ConnQProjW = Get( conn, "decoder.q_proj.weight" ),
ConnQProjB = Get( conn, "decoder.q_proj.bias" ),
ConnLn1G = Get( conn, "decoder.ln_1.weight" ),
ConnLn1B = Get( conn, "decoder.ln_1.bias" ),
ConnLn2G = Get( conn, "decoder.ln_2.weight" ),
ConnLn2B = Get( conn, "decoder.ln_2.bias" ),
SkinCoProjW = Get( skin, "decoder.co_proj.weight" ),
SkinCoProjB = Get( skin, "decoder.co_proj.bias" ),
SkinBoneProjW = Get( skin, "decoder.bone_proj.weight" ),
SkinBoneProjB = Get( skin, "decoder.bone_proj.bias" ),
SkinNormalProjW = Get( skin, "decoder.normal_proj.weight" ),
SkinNormalProjB = Get( skin, "decoder.normal_proj.bias" ),
SkinCross = skinCross.ToArray(),
SkinCrossJoint = skinCrossJoint.ToArray(),
SkinKProjW = Get( skin, "decoder.k_proj.weight" ),
SkinKProjB = Get( skin, "decoder.k_proj.bias" ),
SkinQProjW = Get( skin, "decoder.q_proj.weight" ),
SkinQProjB = Get( skin, "decoder.q_proj.bias" ),
SkinLn1G = Get( skin, "decoder.ln_1.weight" ),
SkinLn1B = Get( skin, "decoder.ln_1.bias" ),
SkinLn2G = Get( skin, "decoder.ln_2.weight" ),
SkinLn2B = Get( skin, "decoder.ln_2.bias" ),
};
}
static PointBert Bert( IReadOnlyDictionary<string, Tensor> t )
{
var blocks = new List<PointBert.VitBlock>();
string Prefix( int i ) => t.ContainsKey( $"encoder.blocks.blocks.{i}.norm1.weight" )
? $"encoder.blocks.blocks.{i}."
: $"encoder.blocks.{i}.";
while ( t.ContainsKey( $"{Prefix( blocks.Count )}norm1.weight" ) )
{
var b = Prefix( blocks.Count );
blocks.Add( new PointBert.VitBlock
{
Norm1G = Get( t, $"{b}norm1.weight" ),
Norm1B = Get( t, $"{b}norm1.bias" ),
QkvW = Get( t, $"{b}attn.qkv.weight" ),
ProjW = Get( t, $"{b}attn.proj.weight" ),
ProjB = Get( t, $"{b}attn.proj.bias" ),
Norm2G = Get( t, $"{b}norm2.weight" ),
Norm2B = Get( t, $"{b}norm2.bias" ),
Fc1W = Get( t, $"{b}mlp.fc1.weight" ),
Fc1B = Get( t, $"{b}mlp.fc1.bias" ),
Fc2W = Get( t, $"{b}mlp.fc2.weight" ),
Fc2B = Get( t, $"{b}mlp.fc2.bias" ),
} );
}
if ( blocks.Count == 0 )
throw new FormatException( "Anymate checkpoint: no ViT blocks found." );
return new PointBert
{
FirstConv0W = Get( t, "encoder.encoder.first_conv.0.weight" ),
FirstConv0B = Get( t, "encoder.encoder.first_conv.0.bias" ),
FirstBnMean = Get( t, "encoder.encoder.first_conv.1.running_mean" ),
FirstBnVar = Get( t, "encoder.encoder.first_conv.1.running_var" ),
FirstBnG = Get( t, "encoder.encoder.first_conv.1.weight" ),
FirstBnB = Get( t, "encoder.encoder.first_conv.1.bias" ),
FirstConv3W = Get( t, "encoder.encoder.first_conv.3.weight" ),
FirstConv3B = Get( t, "encoder.encoder.first_conv.3.bias" ),
SecondConv0W = Get( t, "encoder.encoder.second_conv.0.weight" ),
SecondConv0B = Get( t, "encoder.encoder.second_conv.0.bias" ),
SecondBnMean = Get( t, "encoder.encoder.second_conv.1.running_mean" ),
SecondBnVar = Get( t, "encoder.encoder.second_conv.1.running_var" ),
SecondBnG = Get( t, "encoder.encoder.second_conv.1.weight" ),
SecondBnB = Get( t, "encoder.encoder.second_conv.1.bias" ),
SecondConv3W = Get( t, "encoder.encoder.second_conv.3.weight" ),
SecondConv3B = Get( t, "encoder.encoder.second_conv.3.bias" ),
ReduceDimW = Get( t, "encoder.reduce_dim.weight" ),
ReduceDimB = Get( t, "encoder.reduce_dim.bias" ),
ClsToken = Get( t, "encoder.cls_token" ),
ClsPos = Get( t, "encoder.cls_pos" ),
PosEmbed0W = Get( t, "encoder.pos_embed.0.weight" ),
PosEmbed0B = Get( t, "encoder.pos_embed.0.bias" ),
PosEmbed2W = Get( t, "encoder.pos_embed.2.weight" ),
PosEmbed2B = Get( t, "encoder.pos_embed.2.bias" ),
Blocks = blocks.ToArray(),
NormG = Get( t, "encoder.norm.weight" ),
NormB = Get( t, "encoder.norm.bias" ),
Proj0W = Get( t, "point_proj.0.weight" ),
Proj0B = Get( t, "point_proj.0.bias" ),
Proj2W = Get( t, "point_proj.2.weight" ),
Proj2B = Get( t, "point_proj.2.bias" ),
Proj4W = Get( t, "point_proj.4.weight" ),
Proj4B = Get( t, "point_proj.4.bias" ),
};
}
static PerceiverCrossBlock Cross( IReadOnlyDictionary<string, Tensor> t, string b ) => new()
{
QWeight = Get( t, $"{b}attn.c_q.weight" ),
KvWeight = Get( t, $"{b}attn.c_kv.weight" ),
ProjWeight = Get( t, $"{b}attn.c_proj.weight" ),
ProjBias = Get( t, $"{b}attn.c_proj.bias" ),
Ln1Gamma = Get( t, $"{b}ln_1.weight" ),
Ln1Beta = Get( t, $"{b}ln_1.bias" ),
Ln2Gamma = Get( t, $"{b}ln_2.weight" ),
Ln2Beta = Get( t, $"{b}ln_2.bias" ),
Ln3Gamma = Get( t, $"{b}ln_3.weight" ),
Ln3Beta = Get( t, $"{b}ln_3.bias" ),
FcWeight = Get( t, $"{b}mlp.c_fc.weight" ),
FcBias = Get( t, $"{b}mlp.c_fc.bias" ),
Fc2Weight = Get( t, $"{b}mlp.c_proj.weight" ),
Fc2Bias = Get( t, $"{b}mlp.c_proj.bias" ),
};
static PerceiverSelfBlock Self( IReadOnlyDictionary<string, Tensor> t, string b ) => new()
{
QkvWeight = Get( t, $"{b}attn.c_qkv.weight" ),
ProjWeight = Get( t, $"{b}attn.c_proj.weight" ),
ProjBias = Get( t, $"{b}attn.c_proj.bias" ),
Ln1Gamma = Get( t, $"{b}ln_1.weight" ),
Ln1Beta = Get( t, $"{b}ln_1.bias" ),
Ln2Gamma = Get( t, $"{b}ln_2.weight" ),
Ln2Beta = Get( t, $"{b}ln_2.bias" ),
FcWeight = Get( t, $"{b}mlp.c_fc.weight" ),
FcBias = Get( t, $"{b}mlp.c_fc.bias" ),
Fc2Weight = Get( t, $"{b}mlp.c_proj.weight" ),
Fc2Bias = Get( t, $"{b}mlp.c_proj.bias" ),
};
static Tensor Get( IReadOnlyDictionary<string, Tensor> t, string key )
=> t.TryGetValue( key, out var tensor )
? tensor
: throw new FormatException( $"Anymate checkpoint is missing tensor '{key}'." );
}