AutoRig/Dl/Puppeteer/PartFieldModel.cs

A neural-network feature extractor for points, implementing NVIDIA PartField behavior. It voxelizes a point cloud, runs 3D convolutions, scatters to three planes, processes with a 3d-aware UNet and a triplane transformer, and samples per-query 448-d feature vectors.

Native InteropNetworkingFile Access
using AutoRig.Dl.Nn;

namespace AutoRig.Dl.Puppeteer;

using AutoRig.Dl;
using Vector3 = System.Numerics.Vector3;

/// <summary>
/// NVIDIA PartField's per-point feature extractor, bound from the weights
/// EMBEDDED in Puppeteer's skin checkpoint (point_embed.model.*). The
/// executable spec is dev/puppeteer-reference/gen_golden_partfield.py — every
/// quirk (double point scaling, plane orders, seam-crossing rolled convs) is
/// replicated from there, not "fixed". Pipeline: PVConv stage 0 (voxelize 32³
/// → Conv3d×2 → trilinear devoxelize + point MLP) → scatter-mean onto three
/// 128² planes → 3d-aware 2D UNet → triplane transformer (residual) → split
/// [64 sdf | 448 part] → bilinear plane sampling.
/// </summary>
public sealed class PartFieldModel
{
    public const int VoxelRes = 32;
    public const int PlaneRes = 128;
    public const int PlaneChannels = 256;
    public const int OutChannels = 448;

    // ---- PVConv stage 0 ----
    public required Tensor VoxConv0W, VoxConv0B, VoxConv3W, VoxConv3B;
    public required Tensor PointMlpW, PointMlpB;

    // ---- 3d-aware UNet ----
    public sealed class UnetBlock
    {
        public required Tensor Norm1G, Norm1B, Conv1W, Conv1B;
        public required Tensor NormMidG, NormMidB;
        public required Tensor[] AwareW, AwareB;   // 3 plane convs (1×1)
        public required Tensor Norm2G, Norm2B, Conv2W, Conv2B;
        public Tensor ShortcutW, ShortcutB;        // null when channels match
        public required int CIn, COut;
    }

    public required UnetBlock[] DownBlocks;        // 3 (downsample after 0,1)
    public required UnetBlock[] UpBlocks;          // 2
    public required Tensor[] UpNorm1G, UpNorm1B;   // pre-block GN after concat
    public required Tensor UnetNormOutG, UnetNormOutB, UnetFinalW, UnetFinalB;

    // ---- triplane transformer ----
    public required Tensor Down0W, Down0B, Down3W, Down3B;
    public required Tensor PosEmbed;               // [3072, 1024]
    public sealed class TriLayer
    {
        public required Tensor Norm1G, Norm1B, InProj, OutProj;
        public required Tensor Norm2G, Norm2B, Mlp0W, Mlp0B, Mlp3W, Mlp3B;
    }
    public required TriLayer[] TriLayers;
    public required Tensor TriNormG, TriNormB;
    public required Tensor UpsamplerW, UpsamplerB; // ConvTranspose 1024→512 k4 s4
    public required Tensor Mlp0W, Mlp0B, Mlp2W, Mlp2B;   // 256→512→512

    public Action<string> Progress { get; set; }

    /// <summary>Cloud + query points in the reference's normalized space
    /// (≈[-0.5, 0.5]). Returns [queries.Length × 448].</summary>
    public float[] Features( Vector3[] cloud, Vector3[] queries )
    {
        var n = cloud.Length;

        // ---- PVConv stage 0 (spec: everything ×2 twice, 2 zero channels) ----
        var feats = new float[8 * n];
        var coords = new float[3 * n];
        for ( var i = 0; i < n; i++ )
        {
            var p = cloud[i];
            feats[0 * n + i] = p.X * 2;
            feats[1 * n + i] = p.Y * 2;
            feats[2 * n + i] = p.Z * 2;
            feats[3 * n + i] = p.X * 4;
            feats[4 * n + i] = p.Y * 4;
            feats[5 * n + i] = p.Z * 4;
            coords[0 * n + i] = p.X * 2;
            coords[1 * n + i] = p.Y * 2;
            coords[2 * n + i] = p.Z * 2;
        }

        const int r = VoxelRes;
        var normCoords = new float[3 * n];
        var voxelIndex = new int[n];
        for ( var i = 0; i < n; i++ )
        {
            var nx = Math.Clamp( (coords[i] + 1) / 2f * r, 0, r - 1 );
            var ny = Math.Clamp( (coords[n + i] + 1) / 2f * r, 0, r - 1 );
            var nz = Math.Clamp( (coords[2 * n + i] + 1) / 2f * r, 0, r - 1 );
            normCoords[i] = nx;
            normCoords[n + i] = ny;
            normCoords[2 * n + i] = nz;
            voxelIndex[i] = (int)MathF.Round( nx ) * r * r
                + (int)MathF.Round( ny ) * r + (int)MathF.Round( nz );
        }

        var voxel = new float[8 * r * r * r];
        var counts = new float[r * r * r];
        for ( var i = 0; i < n; i++ )
            counts[voxelIndex[i]] += 1f;
        for ( var ch = 0; ch < 8; ch++ )
            for ( var i = 0; i < n; i++ )
                voxel[ch * r * r * r + voxelIndex[i]] += feats[ch * n + i];
        for ( var ch = 0; ch < 8; ch++ )
            for ( var v = 0; v < r * r * r; v++ )
                voxel[ch * r * r * r + v] /= counts[v] > 0 ? counts[v] : 1e-5f;

        Progress?.Invoke( "partfield: voxel convs" );
        var h = ConvOps.Conv3d( voxel, 8, r, r, r, VoxConv0W, VoxConv0B, 1 );
        ConvOps.InstanceNorm( h, PlaneChannels, r * r * r, 1e-4f );
        ConvOps.LeakyRelu( h );
        h = ConvOps.Conv3d( h, PlaneChannels, r, r, r, VoxConv3W, VoxConv3B, 1 );
        ConvOps.InstanceNorm( h, PlaneChannels, r * r * r, 1e-4f );
        ConvOps.LeakyRelu( h );

        // Devoxelize: grid (x,y,z) → torch maps x→W(=z-major axis); spec order.
        var gx = new float[n];
        var gy = new float[n];
        var gz = new float[n];
        for ( var i = 0; i < n; i++ )
        {
            gx[i] = (normCoords[i] * 2 + 1f) / r - 1f;
            gy[i] = (normCoords[n + i] * 2 + 1f) / r - 1f;
            gz[i] = (normCoords[2 * n + i] * 2 + 1f) / r - 1f;
        }
        var devox = ConvOps.GridSample3d( h, PlaneChannels, r, r, r, gx, gy, gz ); // [n,256]

        // Point MLP: conv1d 1×1 = per-point linear over the 8 features.
        var fused = new float[PlaneChannels * n];
        var pointOut = new float[PlaneChannels * n];
        for ( var oc = 0; oc < PlaneChannels; oc++ )
            for ( var i = 0; i < n; i++ )
            {
                var sum = PointMlpB.Data[oc];
                for ( var ic = 0; ic < 8; ic++ )
                    sum += PointMlpW.Data[oc * 8 + ic] * feats[ic * n + i];
                pointOut[oc * n + i] = sum;
            }
        ConvOps.InstanceNorm( pointOut, PlaneChannels, n, 1e-5f );
        for ( var i = 0; i < pointOut.Length; i++ )
            if ( pointOut[i] < 0f )
                pointOut[i] = 0f;   // ReLU
        for ( var oc = 0; oc < PlaneChannels; oc++ )
            for ( var i = 0; i < n; i++ )
                fused[oc * n + i] = devox[i * PlaneChannels + oc] + pointOut[oc * n + i];

        // ---- scatter-mean onto three planes (index = u + res·v) ----
        Progress?.Invoke( "partfield: scatter planes" );
        float[] Plane( int selA, int selB )
        {
            const int reso = PlaneRes;
            var plane = new float[PlaneChannels * reso * reso];
            var count = new float[reso * reso];
            var cell = new int[n];
            for ( var i = 0; i < n; i++ )
            {
                float Axis( int a ) => a == 0 ? cloud[i].X : a == 1 ? cloud[i].Y : cloud[i].Z;
                var u = Math.Clamp( Axis( selA ) / (1f + 1e-5f) + 0.5f, 0f, 1f - 1e-6f );
                var v = Math.Clamp( Axis( selB ) / (1f + 1e-5f) + 0.5f, 0f, 1f - 1e-6f );
                cell[i] = (int)(u * reso) + reso * (int)(v * reso);
                count[cell[i]] += 1f;
            }
            for ( var ch = 0; ch < PlaneChannels; ch++ )
                for ( var i = 0; i < n; i++ )
                    plane[ch * reso * reso + cell[i]] += fused[ch * n + i];
            for ( var ch = 0; ch < PlaneChannels; ch++ )
                for ( var q = 0; q < reso * reso; q++ )
                    if ( count[q] > 0 )
                        plane[ch * reso * reso + q] /= count[q];
            return plane;
        }
        var planes = new[] { Plane( 0, 1 ), Plane( 1, 2 ), Plane( 0, 2 ) };   // xy, yz, xz

        // ---- 3d-aware UNet over the rolled (3H × W) image ----
        Progress?.Invoke( "partfield: unet" );
        var rolled = Roll( planes, PlaneChannels, PlaneRes );
        var height = 3 * PlaneRes;
        var width = PlaneRes;
        var channels = PlaneChannels;
        var skips = new List<(float[] Data, int C, int H)>();
        for ( var b = 0; b < DownBlocks.Length; b++ )
        {
            rolled = ResBlock( DownBlocks[b], rolled, channels, height, width );
            channels = DownBlocks[b].COut;
            skips.Add( (rolled, channels, height) );
            if ( b < DownBlocks.Length - 1 )
            {
                rolled = PerPlane( rolled, channels, height, width, 1, 2,
                    ( p, c2, h2, w2 ) => ConvOps.AvgPool2( p, c2, h2, w2 ) );
                height /= 2;
                width /= 2;
            }
        }
        for ( var b = 0; b < UpBlocks.Length; b++ )
        {
            rolled = PerPlane( rolled, channels, height, width, 2, 1,
                ( p, c2, h2, w2 ) => ConvOps.UpsampleNearest2( p, c2, h2, w2 ) );
            height *= 2;
            width *= 2;
            var skip = skips[skips.Count - 2 - b];
            var merged = new float[(channels + skip.C) * height * width];
            Array.Copy( rolled, 0, merged, 0, rolled.Length );
            Array.Copy( skip.Data, 0, merged, rolled.Length, skip.Data.Length );
            var mergedChannels = channels + skip.C;
            ConvOps.GroupNorm( merged, mergedChannels, height * width,
                Math.Min( 32, mergedChannels ), UpNorm1G[b], UpNorm1B[b] );
            rolled = ResBlock( UpBlocks[b], merged, mergedChannels, height, width );
            channels = UpBlocks[b].COut;
        }
        ConvOps.GroupNorm( rolled, channels, height * width,
            Math.Min( 32, channels ), UnetNormOutG, UnetNormOutB );
        ConvOps.Swish( rolled );
        rolled = ConvOps.Conv2d( rolled, channels, height, width, UnetFinalW, UnetFinalB, 0 );
        var unetPlanes = Unroll( rolled, PlaneChannels, PlaneRes );

        // ---- triplane transformer ----
        Progress?.Invoke( "partfield: triplane transformer" );
        var final = TriplaneTransform( unetPlanes );

        // ---- sample the part planes (channels 64..512) at query·2 ----
        var qn = queries.Length;
        var u0 = new float[qn]; var v0 = new float[qn];
        var u1 = new float[qn]; var v1 = new float[qn];
        var u2 = new float[qn]; var v2 = new float[qn];
        for ( var i = 0; i < qn; i++ )
        {
            var p = queries[i] * 2f;
            u0[i] = p.X; v0[i] = p.Y;   // plane xy at (x,y)
            u1[i] = p.Y; v1[i] = p.Z;   // plane yz at (y,z)
            u2[i] = p.X; v2[i] = p.Z;   // plane xz at (x,z)
        }
        const int fullC = 512;
        var s0 = ConvOps.GridSample2d( final[0], fullC, PlaneRes, PlaneRes, u0, v0, alignCorners: true );
        var s1 = ConvOps.GridSample2d( final[1], fullC, PlaneRes, PlaneRes, u1, v1, alignCorners: true );
        var s2 = ConvOps.GridSample2d( final[2], fullC, PlaneRes, PlaneRes, u2, v2, alignCorners: true );
        var output = new float[qn * OutChannels];
        for ( var i = 0; i < qn; i++ )
            for ( var ch = 0; ch < OutChannels; ch++ )
                output[i * OutChannels + ch] =
                    s0[i * fullC + 64 + ch] + s1[i * fullC + 64 + ch] + s2[i * fullC + 64 + ch];
        return output;
    }

    // ---------------------------------------------------------------- pieces

    static float[] Roll( float[][] planes, int c, int res )
    {
        var rolled = new float[c * 3 * res * res];
        for ( var ch = 0; ch < c; ch++ )
            for ( var t = 0; t < 3; t++ )
                Array.Copy( planes[t], ch * res * res,
                    rolled, ch * 3 * res * res + t * res * res, res * res );
        return rolled;
    }

    static float[][] Unroll( float[] rolled, int c, int res )
    {
        var planes = new float[3][];
        for ( var t = 0; t < 3; t++ )
        {
            planes[t] = new float[c * res * res];
            for ( var ch = 0; ch < c; ch++ )
                Array.Copy( rolled, ch * 3 * res * res + t * res * res,
                    planes[t], ch * res * res, res * res );
        }
        return planes;
    }

    /// <summary>Applies a per-plane 2× resize op to a rolled tensor (avoids
    /// pooling across triplane seams, like the reference's (c tri) rearrange).
    /// <paramref name="scaleNumerator"/>/<paramref name="scaleDenominator"/> =
    /// 1/2 for pooling, 2/1 for upsampling.</summary>
    static float[] PerPlane( float[] rolled, int c, int height, int width,
        int scaleNumerator, int scaleDenominator,
        Func<float[], int, int, int, float[]> op )
    {
        var planeH = height / 3;
        var oh = planeH * scaleNumerator / scaleDenominator;
        var ow = width * scaleNumerator / scaleDenominator;
        var merged = new float[c * 3 * oh * ow];
        for ( var t = 0; t < 3; t++ )
        {
            var slice = new float[c * planeH * width];
            for ( var ch = 0; ch < c; ch++ )
                Array.Copy( rolled, ch * height * width + t * planeH * width,
                    slice, ch * planeH * width, planeH * width );
            var resized = op( slice, c, planeH, width );
            for ( var ch = 0; ch < c; ch++ )
                Array.Copy( resized, ch * oh * ow,
                    merged, ch * 3 * oh * ow + t * oh * ow, oh * ow );
        }
        return merged;
    }

    float[] ResBlock( UnetBlock block, float[] x, int cin, int height, int width )
    {
        var spatial = height * width;
        var h = new float[x.Length];
        Array.Copy( x, h, x.Length );
        ConvOps.GroupNorm( h, cin, spatial, Math.Min( 32, cin ), block.Norm1G, block.Norm1B );
        ConvOps.Swish( h );
        h = ConvOps.Conv2d( h, cin, height, width, block.Conv1W, block.Conv1B, 1 );
        var cout = block.COut;
        ConvOps.GroupNorm( h, cout, spatial, Math.Min( 32, cout ), block.NormMidG, block.NormMidB );
        ConvOps.Swish( h );
        h = Aware3d( block, h, cout, height, width );
        ConvOps.GroupNorm( h, cout, spatial, Math.Min( 32, cout ), block.Norm2G, block.Norm2B );
        ConvOps.Swish( h );
        h = ConvOps.Conv2d( h, cout, height, width, block.Conv2W, block.Conv2B, 1 );

        float[] shortcut = x;
        if ( block.ShortcutW is not null )
            shortcut = ConvOps.Conv2d( x, cin, height, width, block.ShortcutW, block.ShortcutB, 0 );
        for ( var i = 0; i < h.Length; i++ )
            h[i] += shortcut[i];
        return h;
    }

    /// <summary>RODIN cross-plane 1×1: plane i's conv input = [own, broadcast
    /// row-mean of plane j, broadcast col-mean of plane k]; plane 2 transposed
    /// in and out (order 'xz').</summary>
    float[] Aware3d( UnetBlock block, float[] rolled, int c, int height, int width )
    {
        var planeH = height / 3;
        var tri = new float[3][];
        for ( var t = 0; t < 3; t++ )
        {
            tri[t] = new float[c * planeH * width];
            for ( var ch = 0; ch < c; ch++ )
                Array.Copy( rolled, ch * height * width + t * planeH * width,
                    tri[t], ch * planeH * width, planeH * width );
        }

        // plane 2 (xz) → zx
        tri[2] = Transpose2d( tri[2], c, planeH, width );

        var dims = new (int H, int W)[] { (planeH, width), (planeH, width), (width, planeH) };
        var outs = new float[3][];
        for ( var i = 0; i < 3; i++ )
        {
            var j = (i + 1) % 3;
            var k = (i + 2) % 3;
            var (hi, wi) = dims[i];
            var cat = new float[3 * c * hi * wi];
            Array.Copy( tri[i], 0, cat, 0, tri[i].Length );

            // jpool: mean over j's W → transposed to a row, repeated down rows.
            var (hj, wj) = dims[j];
            for ( var ch = 0; ch < c; ch++ )
                for ( var row = 0; row < hj; row++ )
                {
                    float mean = 0;
                    for ( var col = 0; col < wj; col++ )
                        mean += tri[j][ch * hj * wj + row * wj + col];
                    mean /= wj;
                    for ( var rr = 0; rr < hi; rr++ )
                        cat[(c + ch) * hi * wi + rr * wi + row] = mean;
                }
            // kpool: mean over k's H → a column, repeated across columns.
            var (hk, wk) = dims[k];
            for ( var ch = 0; ch < c; ch++ )
                for ( var col = 0; col < wk; col++ )
                {
                    float mean = 0;
                    for ( var row = 0; row < hk; row++ )
                        mean += tri[k][ch * hk * wk + row * wk + col];
                    mean /= hk;
                    for ( var cc = 0; cc < wi; cc++ )
                        cat[(2 * c + ch) * hi * wi + col * wi + cc] = mean;
                }
            outs[i] = ConvOps.Conv2d( cat, 3 * c, hi, wi, block.AwareW[i], block.AwareB[i], 0 );
        }
        outs[2] = Transpose2d( outs[2], c, width, planeH );

        var merged = new float[c * height * width];
        for ( var ch = 0; ch < c; ch++ )
            for ( var t = 0; t < 3; t++ )
                Array.Copy( outs[t], ch * planeH * width,
                    merged, ch * height * width + t * planeH * width, planeH * width );
        return merged;
    }

    static float[] Transpose2d( float[] x, int c, int h, int w )
    {
        var result = new float[c * w * h];
        for ( var ch = 0; ch < c; ch++ )
            for ( var y = 0; y < h; y++ )
                for ( var x0 = 0; x0 < w; x0++ )
                    result[ch * w * h + x0 * h + y] = x[ch * h * w + y * w + x0];
        return result;
    }

    float[][] TriplaneTransform( float[][] planes )
    {
        const int reso = PlaneRes;
        const int dim = 1024;
        const int low = 32;
        // downsample each plane 128→32 via conv+relu+maxpool ×2
        var lowPlanes = new float[3][];
        for ( var t = 0; t < 3; t++ )
        {
            var z = ConvOps.Conv2d( planes[t], PlaneChannels, reso, reso, Down0W, Down0B, 1 );
            for ( var i = 0; i < z.Length; i++ )
                if ( z[i] < 0f )
                    z[i] = 0f;
            z = ConvOps.MaxPool2( z, dim, reso, reso );
            z = ConvOps.Conv2d( z, dim, reso / 2, reso / 2, Down3W, Down3B, 1 );
            for ( var i = 0; i < z.Length; i++ )
                if ( z[i] < 0f )
                    z[i] = 0f;
            lowPlanes[t] = ConvOps.MaxPool2( z, dim, reso / 2, reso / 2 );
        }

        // tokens: (3·32·32, 1024) in nihwd order + pos embed
        var tokenCount = 3 * low * low;
        var tokens = new float[tokenCount * dim];
        for ( var t = 0; t < 3; t++ )
            for ( var y = 0; y < low; y++ )
                for ( var x0 = 0; x0 < low; x0++ )
                {
                    var token = t * low * low + y * low + x0;
                    for ( var ch = 0; ch < dim; ch++ )
                        tokens[token * dim + ch] =
                            lowPlanes[t][ch * low * low + y * low + x0]
                            + PosEmbed.Data[token * dim + ch];
                }

        var x = Tensor.From( tokens, tokenCount, dim );
        foreach ( var layer in TriLayers )
        {
            var normed = TransformerOps.LayerNorm( x, layer.Norm1G, layer.Norm1B, 1e-6f );
            var qkv = normed.MatMul( layer.InProj.Transposed );
            var q = qkv.Slice( 1, 0, dim );
            var k = qkv.Slice( 1, dim, dim );
            var v = qkv.Slice( 1, 2 * dim, dim );
            var attended = TransformerOps.Attention( q, k, v, 16, causal: false )
                .MatMul( layer.OutProj.Transposed );
            x = x.Add( attended );
            normed = TransformerOps.LayerNorm( x, layer.Norm2G, layer.Norm2B, 1e-6f );
            var mlp = TransformerOps.Gelu(
                normed.MatMul( layer.Mlp0W.Transposed ).Add( layer.Mlp0B ) );
            x = x.Add( mlp.MatMul( layer.Mlp3W.Transposed ).Add( layer.Mlp3B ) );
        }
        x = TransformerOps.LayerNorm( x, TriNormG, TriNormB, 1e-6f );

        // upsample residual per plane + parallel per-pixel MLP on the input
        var result = new float[3][];
        for ( var t = 0; t < 3; t++ )
        {
            var lowT = new float[dim * low * low];
            for ( var y = 0; y < low; y++ )
                for ( var x0 = 0; x0 < low; x0++ )
                {
                    var token = t * low * low + y * low + x0;
                    for ( var ch = 0; ch < dim; ch++ )
                        lowT[ch * low * low + y * low + x0] = x.Data[token * dim + ch];
                }
            var residual = ConvOps.ConvTranspose2dBlock( lowT, dim, low, low, UpsamplerW, UpsamplerB );

            var plane = new float[512 * reso * reso];
            Concurrency.For( 0, reso * reso, pix =>
            {
                Span<float> hbuf = stackalloc float[512];
                for ( var oc = 0; oc < 512; oc++ )
                {
                    var sum = Mlp0B.Data[oc];
                    for ( var ic = 0; ic < PlaneChannels; ic++ )
                        sum += Mlp0W.Data[oc * PlaneChannels + ic]
                            * planes[t][ic * reso * reso + pix];
                    hbuf[oc] = MathF.Max( sum, 0f );
                }
                for ( var oc = 0; oc < 512; oc++ )
                {
                    var sum = Mlp2B.Data[oc];
                    for ( var ic = 0; ic < 512; ic++ )
                        sum += Mlp2W.Data[oc * 512 + ic] * hbuf[ic];
                    plane[oc * reso * reso + pix] = sum + residual[oc * reso * reso + pix];
                }
            } );
            result[t] = plane;
        }
        return result;
    }
}