Loader for a PartFieldModel that reads tensor weights from a Puppeteer skin checkpoint dictionary and builds a PartFieldModel instance with UNet blocks, triplane transformer layers, and other weight tensors.
namespace AutoRig.Dl.Puppeteer;
using AutoRig.Dl;
/// <summary>Binds PartField from the weights embedded in Puppeteer's skin
/// checkpoint (point_embed.model.*, verified in skin_keys.json).</summary>
public static class PartFieldLoader
{
public static PartFieldModel Load( IReadOnlyDictionary<string, Tensor> tensors, string prefix )
{
var pf = $"{prefix}point_embed.model.";
var enc = $"{pf}pvcnn.pc_encoder.encoder.0.";
var unet = $"{pf}pvcnn.unet_encoder.";
var tri = $"{pf}triplane_transformer.";
PartFieldModel.UnetBlock Block( string b, int cin, int cout ) => new()
{
Norm1G = Get( tensors, $"{b}norm1.weight" ),
Norm1B = Get( tensors, $"{b}norm1.bias" ),
Conv1W = Get( tensors, $"{b}conv1.weight" ),
Conv1B = Get( tensors, $"{b}conv1.bias" ),
NormMidG = Get( tensors, $"{b}norm_mid.weight" ),
NormMidB = Get( tensors, $"{b}norm_mid.bias" ),
AwareW = new[]
{
Get( tensors, $"{b}conv_3daware.plane_convs.0.weight" ),
Get( tensors, $"{b}conv_3daware.plane_convs.1.weight" ),
Get( tensors, $"{b}conv_3daware.plane_convs.2.weight" ),
},
AwareB = new[]
{
Get( tensors, $"{b}conv_3daware.plane_convs.0.bias" ),
Get( tensors, $"{b}conv_3daware.plane_convs.1.bias" ),
Get( tensors, $"{b}conv_3daware.plane_convs.2.bias" ),
},
Norm2G = Get( tensors, $"{b}norm2.weight" ),
Norm2B = Get( tensors, $"{b}norm2.bias" ),
Conv2W = Get( tensors, $"{b}conv2.weight" ),
Conv2B = Get( tensors, $"{b}conv2.bias" ),
ShortcutW = tensors.TryGetValue( $"{b}nin_shortcut.weight", out var sw ) ? sw : null,
ShortcutB = tensors.TryGetValue( $"{b}nin_shortcut.bias", out var sb ) ? sb : null,
CIn = cin,
COut = cout,
};
var layers = new List<PartFieldModel.TriLayer>();
while ( tensors.ContainsKey( $"{tri}transformer.layers.{layers.Count}.norm1.weight" ) )
{
var l = $"{tri}transformer.layers.{layers.Count}.";
layers.Add( new PartFieldModel.TriLayer
{
Norm1G = Get( tensors, $"{l}norm1.weight" ),
Norm1B = Get( tensors, $"{l}norm1.bias" ),
InProj = Get( tensors, $"{l}self_attn.in_proj_weight" ),
OutProj = Get( tensors, $"{l}self_attn.out_proj.weight" ),
Norm2G = Get( tensors, $"{l}norm2.weight" ),
Norm2B = Get( tensors, $"{l}norm2.bias" ),
Mlp0W = Get( tensors, $"{l}mlp.0.weight" ),
Mlp0B = Get( tensors, $"{l}mlp.0.bias" ),
Mlp3W = Get( tensors, $"{l}mlp.3.weight" ),
Mlp3B = Get( tensors, $"{l}mlp.3.bias" ),
} );
}
if ( layers.Count == 0 )
throw new FormatException( "PartField: no triplane transformer layers found." );
var posEmbed = Get( tensors, $"{tri}pos_embed" );
return new PartFieldModel
{
VoxConv0W = Get( tensors, $"{enc}voxel_layers.0.weight" ),
VoxConv0B = Get( tensors, $"{enc}voxel_layers.0.bias" ),
VoxConv3W = Get( tensors, $"{enc}voxel_layers.3.weight" ),
VoxConv3B = Get( tensors, $"{enc}voxel_layers.3.bias" ),
PointMlpW = Get( tensors, $"{enc}point_features.layers.0.weight" ),
PointMlpB = Get( tensors, $"{enc}point_features.layers.0.bias" ),
DownBlocks = new[]
{
Block( $"{unet}down_convs.0.block.", 256, 32 ),
Block( $"{unet}down_convs.1.block.", 32, 64 ),
Block( $"{unet}down_convs.2.block.", 64, 128 ),
},
UpBlocks = new[]
{
Block( $"{unet}up_convs.0.block.", 128 + 64, 64 ),
Block( $"{unet}up_convs.1.block.", 64 + 32, 32 ),
},
UpNorm1G = new[]
{
Get( tensors, $"{unet}up_convs.0.norm1.weight" ),
Get( tensors, $"{unet}up_convs.1.norm1.weight" ),
},
UpNorm1B = new[]
{
Get( tensors, $"{unet}up_convs.0.norm1.bias" ),
Get( tensors, $"{unet}up_convs.1.norm1.bias" ),
},
UnetNormOutG = Get( tensors, $"{unet}norm_out.weight" ),
UnetNormOutB = Get( tensors, $"{unet}norm_out.bias" ),
UnetFinalW = Get( tensors, $"{unet}conv_final.weight" ),
UnetFinalB = Get( tensors, $"{unet}conv_final.bias" ),
Down0W = Get( tensors, $"{tri}downsampler.0.weight" ),
Down0B = Get( tensors, $"{tri}downsampler.0.bias" ),
Down3W = Get( tensors, $"{tri}downsampler.3.weight" ),
Down3B = Get( tensors, $"{tri}downsampler.3.bias" ),
PosEmbed = Tensor.From( posEmbed.Data, posEmbed.Data.Length / 1024, 1024 ),
TriLayers = layers.ToArray(),
TriNormG = Get( tensors, $"{tri}transformer.norm.weight" ),
TriNormB = Get( tensors, $"{tri}transformer.norm.bias" ),
UpsamplerW = Get( tensors, $"{tri}upsampler.weight" ),
UpsamplerB = Get( tensors, $"{tri}upsampler.bias" ),
Mlp0W = Get( tensors, $"{tri}mlp.0.weight" ),
Mlp0B = Get( tensors, $"{tri}mlp.0.bias" ),
Mlp2W = Get( tensors, $"{tri}mlp.2.weight" ),
Mlp2B = Get( tensors, $"{tri}mlp.2.bias" ),
};
}
internal static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
=> tensors.TryGetValue( key, out var tensor )
? tensor
: throw new FormatException( $"Puppeteer skin checkpoint is missing tensor '{key}'." );
}