Code/AutoRig/Dl/Puppeteer/PartFieldLoader.cs

Loader for a PartFieldModel that reads tensor weights from a Puppeteer skin checkpoint dictionary and builds a PartFieldModel instance with UNet blocks, triplane transformer layers, and other weight tensors.

File Access
namespace AutoRig.Dl.Puppeteer;

using AutoRig.Dl;

/// <summary>Binds PartField from the weights embedded in Puppeteer's skin
/// checkpoint (point_embed.model.*, verified in skin_keys.json).</summary>
public static class PartFieldLoader
{
    public static PartFieldModel Load( IReadOnlyDictionary<string, Tensor> tensors, string prefix )
    {
        var pf = $"{prefix}point_embed.model.";
        var enc = $"{pf}pvcnn.pc_encoder.encoder.0.";
        var unet = $"{pf}pvcnn.unet_encoder.";
        var tri = $"{pf}triplane_transformer.";

        PartFieldModel.UnetBlock Block( string b, int cin, int cout ) => new()
        {
            Norm1G = Get( tensors, $"{b}norm1.weight" ),
            Norm1B = Get( tensors, $"{b}norm1.bias" ),
            Conv1W = Get( tensors, $"{b}conv1.weight" ),
            Conv1B = Get( tensors, $"{b}conv1.bias" ),
            NormMidG = Get( tensors, $"{b}norm_mid.weight" ),
            NormMidB = Get( tensors, $"{b}norm_mid.bias" ),
            AwareW = new[]
            {
                Get( tensors, $"{b}conv_3daware.plane_convs.0.weight" ),
                Get( tensors, $"{b}conv_3daware.plane_convs.1.weight" ),
                Get( tensors, $"{b}conv_3daware.plane_convs.2.weight" ),
            },
            AwareB = new[]
            {
                Get( tensors, $"{b}conv_3daware.plane_convs.0.bias" ),
                Get( tensors, $"{b}conv_3daware.plane_convs.1.bias" ),
                Get( tensors, $"{b}conv_3daware.plane_convs.2.bias" ),
            },
            Norm2G = Get( tensors, $"{b}norm2.weight" ),
            Norm2B = Get( tensors, $"{b}norm2.bias" ),
            Conv2W = Get( tensors, $"{b}conv2.weight" ),
            Conv2B = Get( tensors, $"{b}conv2.bias" ),
            ShortcutW = tensors.TryGetValue( $"{b}nin_shortcut.weight", out var sw ) ? sw : null,
            ShortcutB = tensors.TryGetValue( $"{b}nin_shortcut.bias", out var sb ) ? sb : null,
            CIn = cin,
            COut = cout,
        };

        var layers = new List<PartFieldModel.TriLayer>();
        while ( tensors.ContainsKey( $"{tri}transformer.layers.{layers.Count}.norm1.weight" ) )
        {
            var l = $"{tri}transformer.layers.{layers.Count}.";
            layers.Add( new PartFieldModel.TriLayer
            {
                Norm1G = Get( tensors, $"{l}norm1.weight" ),
                Norm1B = Get( tensors, $"{l}norm1.bias" ),
                InProj = Get( tensors, $"{l}self_attn.in_proj_weight" ),
                OutProj = Get( tensors, $"{l}self_attn.out_proj.weight" ),
                Norm2G = Get( tensors, $"{l}norm2.weight" ),
                Norm2B = Get( tensors, $"{l}norm2.bias" ),
                Mlp0W = Get( tensors, $"{l}mlp.0.weight" ),
                Mlp0B = Get( tensors, $"{l}mlp.0.bias" ),
                Mlp3W = Get( tensors, $"{l}mlp.3.weight" ),
                Mlp3B = Get( tensors, $"{l}mlp.3.bias" ),
            } );
        }
        if ( layers.Count == 0 )
            throw new FormatException( "PartField: no triplane transformer layers found." );

        var posEmbed = Get( tensors, $"{tri}pos_embed" );
        return new PartFieldModel
        {
            VoxConv0W = Get( tensors, $"{enc}voxel_layers.0.weight" ),
            VoxConv0B = Get( tensors, $"{enc}voxel_layers.0.bias" ),
            VoxConv3W = Get( tensors, $"{enc}voxel_layers.3.weight" ),
            VoxConv3B = Get( tensors, $"{enc}voxel_layers.3.bias" ),
            PointMlpW = Get( tensors, $"{enc}point_features.layers.0.weight" ),
            PointMlpB = Get( tensors, $"{enc}point_features.layers.0.bias" ),
            DownBlocks = new[]
            {
                Block( $"{unet}down_convs.0.block.", 256, 32 ),
                Block( $"{unet}down_convs.1.block.", 32, 64 ),
                Block( $"{unet}down_convs.2.block.", 64, 128 ),
            },
            UpBlocks = new[]
            {
                Block( $"{unet}up_convs.0.block.", 128 + 64, 64 ),
                Block( $"{unet}up_convs.1.block.", 64 + 32, 32 ),
            },
            UpNorm1G = new[]
            {
                Get( tensors, $"{unet}up_convs.0.norm1.weight" ),
                Get( tensors, $"{unet}up_convs.1.norm1.weight" ),
            },
            UpNorm1B = new[]
            {
                Get( tensors, $"{unet}up_convs.0.norm1.bias" ),
                Get( tensors, $"{unet}up_convs.1.norm1.bias" ),
            },
            UnetNormOutG = Get( tensors, $"{unet}norm_out.weight" ),
            UnetNormOutB = Get( tensors, $"{unet}norm_out.bias" ),
            UnetFinalW = Get( tensors, $"{unet}conv_final.weight" ),
            UnetFinalB = Get( tensors, $"{unet}conv_final.bias" ),
            Down0W = Get( tensors, $"{tri}downsampler.0.weight" ),
            Down0B = Get( tensors, $"{tri}downsampler.0.bias" ),
            Down3W = Get( tensors, $"{tri}downsampler.3.weight" ),
            Down3B = Get( tensors, $"{tri}downsampler.3.bias" ),
            PosEmbed = Tensor.From( posEmbed.Data, posEmbed.Data.Length / 1024, 1024 ),
            TriLayers = layers.ToArray(),
            TriNormG = Get( tensors, $"{tri}transformer.norm.weight" ),
            TriNormB = Get( tensors, $"{tri}transformer.norm.bias" ),
            UpsamplerW = Get( tensors, $"{tri}upsampler.weight" ),
            UpsamplerB = Get( tensors, $"{tri}upsampler.bias" ),
            Mlp0W = Get( tensors, $"{tri}mlp.0.weight" ),
            Mlp0B = Get( tensors, $"{tri}mlp.0.bias" ),
            Mlp2W = Get( tensors, $"{tri}mlp.2.weight" ),
            Mlp2B = Get( tensors, $"{tri}mlp.2.bias" ),
        };
    }

    internal static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
        => tensors.TryGetValue( key, out var tensor )
            ? tensor
            : throw new FormatException( $"Puppeteer skin checkpoint is missing tensor '{key}'." );
}