Neural network feature extractor for per-point part embeddings, implementing a PVConv pipeline, scatter-to-triplanes, a 3d-aware UNet, a triplane transformer and final plane sampling to produce 448-dimensional features per query point.
using AutoRig.Dl.Nn;
namespace AutoRig.Dl.Puppeteer;
using AutoRig.Dl;
using Vector3 = System.Numerics.Vector3;
/// <summary>
/// NVIDIA PartField's per-point feature extractor, bound from the weights
/// EMBEDDED in Puppeteer's skin checkpoint (point_embed.model.*). The
/// executable spec is dev/puppeteer-reference/gen_golden_partfield.py — every
/// quirk (double point scaling, plane orders, seam-crossing rolled convs) is
/// replicated from there, not "fixed". Pipeline: PVConv stage 0 (voxelize 32³
/// → Conv3d×2 → trilinear devoxelize + point MLP) → scatter-mean onto three
/// 128² planes → 3d-aware 2D UNet → triplane transformer (residual) → split
/// [64 sdf | 448 part] → bilinear plane sampling.
/// </summary>
public sealed class PartFieldModel
{
public const int VoxelRes = 32;
public const int PlaneRes = 128;
public const int PlaneChannels = 256;
public const int OutChannels = 448;
// ---- PVConv stage 0 ----
public required Tensor VoxConv0W, VoxConv0B, VoxConv3W, VoxConv3B;
public required Tensor PointMlpW, PointMlpB;
// ---- 3d-aware UNet ----
public sealed class UnetBlock
{
public required Tensor Norm1G, Norm1B, Conv1W, Conv1B;
public required Tensor NormMidG, NormMidB;
public required Tensor[] AwareW, AwareB; // 3 plane convs (1×1)
public required Tensor Norm2G, Norm2B, Conv2W, Conv2B;
public Tensor ShortcutW, ShortcutB; // null when channels match
public required int CIn, COut;
}
public required UnetBlock[] DownBlocks; // 3 (downsample after 0,1)
public required UnetBlock[] UpBlocks; // 2
public required Tensor[] UpNorm1G, UpNorm1B; // pre-block GN after concat
public required Tensor UnetNormOutG, UnetNormOutB, UnetFinalW, UnetFinalB;
// ---- triplane transformer ----
public required Tensor Down0W, Down0B, Down3W, Down3B;
public required Tensor PosEmbed; // [3072, 1024]
public sealed class TriLayer
{
public required Tensor Norm1G, Norm1B, InProj, OutProj;
public required Tensor Norm2G, Norm2B, Mlp0W, Mlp0B, Mlp3W, Mlp3B;
}
public required TriLayer[] TriLayers;
public required Tensor TriNormG, TriNormB;
public required Tensor UpsamplerW, UpsamplerB; // ConvTranspose 1024→512 k4 s4
public required Tensor Mlp0W, Mlp0B, Mlp2W, Mlp2B; // 256→512→512
public Action<string> Progress { get; set; }
/// <summary>Cloud + query points in the reference's normalized space
/// (≈[-0.5, 0.5]). Returns [queries.Length × 448].</summary>
public float[] Features( Vector3[] cloud, Vector3[] queries )
{
var n = cloud.Length;
// ---- PVConv stage 0 (spec: everything ×2 twice, 2 zero channels) ----
var feats = new float[8 * n];
var coords = new float[3 * n];
for ( var i = 0; i < n; i++ )
{
var p = cloud[i];
feats[0 * n + i] = p.X * 2;
feats[1 * n + i] = p.Y * 2;
feats[2 * n + i] = p.Z * 2;
feats[3 * n + i] = p.X * 4;
feats[4 * n + i] = p.Y * 4;
feats[5 * n + i] = p.Z * 4;
coords[0 * n + i] = p.X * 2;
coords[1 * n + i] = p.Y * 2;
coords[2 * n + i] = p.Z * 2;
}
const int r = VoxelRes;
var normCoords = new float[3 * n];
var voxelIndex = new int[n];
for ( var i = 0; i < n; i++ )
{
var nx = Math.Clamp( (coords[i] + 1) / 2f * r, 0, r - 1 );
var ny = Math.Clamp( (coords[n + i] + 1) / 2f * r, 0, r - 1 );
var nz = Math.Clamp( (coords[2 * n + i] + 1) / 2f * r, 0, r - 1 );
normCoords[i] = nx;
normCoords[n + i] = ny;
normCoords[2 * n + i] = nz;
voxelIndex[i] = (int)MathF.Round( nx ) * r * r
+ (int)MathF.Round( ny ) * r + (int)MathF.Round( nz );
}
var voxel = new float[8 * r * r * r];
var counts = new float[r * r * r];
for ( var i = 0; i < n; i++ )
counts[voxelIndex[i]] += 1f;
for ( var ch = 0; ch < 8; ch++ )
for ( var i = 0; i < n; i++ )
voxel[ch * r * r * r + voxelIndex[i]] += feats[ch * n + i];
for ( var ch = 0; ch < 8; ch++ )
for ( var v = 0; v < r * r * r; v++ )
voxel[ch * r * r * r + v] /= counts[v] > 0 ? counts[v] : 1e-5f;
Progress?.Invoke( "partfield: voxel convs" );
var h = ConvOps.Conv3d( voxel, 8, r, r, r, VoxConv0W, VoxConv0B, 1 );
ConvOps.InstanceNorm( h, PlaneChannels, r * r * r, 1e-4f );
ConvOps.LeakyRelu( h );
h = ConvOps.Conv3d( h, PlaneChannels, r, r, r, VoxConv3W, VoxConv3B, 1 );
ConvOps.InstanceNorm( h, PlaneChannels, r * r * r, 1e-4f );
ConvOps.LeakyRelu( h );
// Devoxelize: grid (x,y,z) → torch maps x→W(=z-major axis); spec order.
var gx = new float[n];
var gy = new float[n];
var gz = new float[n];
for ( var i = 0; i < n; i++ )
{
gx[i] = (normCoords[i] * 2 + 1f) / r - 1f;
gy[i] = (normCoords[n + i] * 2 + 1f) / r - 1f;
gz[i] = (normCoords[2 * n + i] * 2 + 1f) / r - 1f;
}
var devox = ConvOps.GridSample3d( h, PlaneChannels, r, r, r, gx, gy, gz ); // [n,256]
// Point MLP: conv1d 1×1 = per-point linear over the 8 features.
var fused = new float[PlaneChannels * n];
var pointOut = new float[PlaneChannels * n];
for ( var oc = 0; oc < PlaneChannels; oc++ )
for ( var i = 0; i < n; i++ )
{
var sum = PointMlpB.Data[oc];
for ( var ic = 0; ic < 8; ic++ )
sum += PointMlpW.Data[oc * 8 + ic] * feats[ic * n + i];
pointOut[oc * n + i] = sum;
}
ConvOps.InstanceNorm( pointOut, PlaneChannels, n, 1e-5f );
for ( var i = 0; i < pointOut.Length; i++ )
if ( pointOut[i] < 0f )
pointOut[i] = 0f; // ReLU
for ( var oc = 0; oc < PlaneChannels; oc++ )
for ( var i = 0; i < n; i++ )
fused[oc * n + i] = devox[i * PlaneChannels + oc] + pointOut[oc * n + i];
// ---- scatter-mean onto three planes (index = u + res·v) ----
Progress?.Invoke( "partfield: scatter planes" );
float[] Plane( int selA, int selB )
{
const int reso = PlaneRes;
var plane = new float[PlaneChannels * reso * reso];
var count = new float[reso * reso];
var cell = new int[n];
for ( var i = 0; i < n; i++ )
{
float Axis( int a ) => a == 0 ? cloud[i].X : a == 1 ? cloud[i].Y : cloud[i].Z;
var u = Math.Clamp( Axis( selA ) / (1f + 1e-5f) + 0.5f, 0f, 1f - 1e-6f );
var v = Math.Clamp( Axis( selB ) / (1f + 1e-5f) + 0.5f, 0f, 1f - 1e-6f );
cell[i] = (int)(u * reso) + reso * (int)(v * reso);
count[cell[i]] += 1f;
}
for ( var ch = 0; ch < PlaneChannels; ch++ )
for ( var i = 0; i < n; i++ )
plane[ch * reso * reso + cell[i]] += fused[ch * n + i];
for ( var ch = 0; ch < PlaneChannels; ch++ )
for ( var q = 0; q < reso * reso; q++ )
if ( count[q] > 0 )
plane[ch * reso * reso + q] /= count[q];
return plane;
}
var planes = new[] { Plane( 0, 1 ), Plane( 1, 2 ), Plane( 0, 2 ) }; // xy, yz, xz
// ---- 3d-aware UNet over the rolled (3H × W) image ----
Progress?.Invoke( "partfield: unet" );
var rolled = Roll( planes, PlaneChannels, PlaneRes );
var height = 3 * PlaneRes;
var width = PlaneRes;
var channels = PlaneChannels;
var skips = new List<(float[] Data, int C, int H)>();
for ( var b = 0; b < DownBlocks.Length; b++ )
{
rolled = ResBlock( DownBlocks[b], rolled, channels, height, width );
channels = DownBlocks[b].COut;
skips.Add( (rolled, channels, height) );
if ( b < DownBlocks.Length - 1 )
{
rolled = PerPlane( rolled, channels, height, width, 1, 2,
( p, c2, h2, w2 ) => ConvOps.AvgPool2( p, c2, h2, w2 ) );
height /= 2;
width /= 2;
}
}
for ( var b = 0; b < UpBlocks.Length; b++ )
{
rolled = PerPlane( rolled, channels, height, width, 2, 1,
( p, c2, h2, w2 ) => ConvOps.UpsampleNearest2( p, c2, h2, w2 ) );
height *= 2;
width *= 2;
var skip = skips[skips.Count - 2 - b];
var merged = new float[(channels + skip.C) * height * width];
Array.Copy( rolled, 0, merged, 0, rolled.Length );
Array.Copy( skip.Data, 0, merged, rolled.Length, skip.Data.Length );
var mergedChannels = channels + skip.C;
ConvOps.GroupNorm( merged, mergedChannels, height * width,
Math.Min( 32, mergedChannels ), UpNorm1G[b], UpNorm1B[b] );
rolled = ResBlock( UpBlocks[b], merged, mergedChannels, height, width );
channels = UpBlocks[b].COut;
}
ConvOps.GroupNorm( rolled, channels, height * width,
Math.Min( 32, channels ), UnetNormOutG, UnetNormOutB );
ConvOps.Swish( rolled );
rolled = ConvOps.Conv2d( rolled, channels, height, width, UnetFinalW, UnetFinalB, 0 );
var unetPlanes = Unroll( rolled, PlaneChannels, PlaneRes );
// ---- triplane transformer ----
Progress?.Invoke( "partfield: triplane transformer" );
var final = TriplaneTransform( unetPlanes );
// ---- sample the part planes (channels 64..512) at query·2 ----
var qn = queries.Length;
var u0 = new float[qn]; var v0 = new float[qn];
var u1 = new float[qn]; var v1 = new float[qn];
var u2 = new float[qn]; var v2 = new float[qn];
for ( var i = 0; i < qn; i++ )
{
var p = queries[i] * 2f;
u0[i] = p.X; v0[i] = p.Y; // plane xy at (x,y)
u1[i] = p.Y; v1[i] = p.Z; // plane yz at (y,z)
u2[i] = p.X; v2[i] = p.Z; // plane xz at (x,z)
}
const int fullC = 512;
var s0 = ConvOps.GridSample2d( final[0], fullC, PlaneRes, PlaneRes, u0, v0, alignCorners: true );
var s1 = ConvOps.GridSample2d( final[1], fullC, PlaneRes, PlaneRes, u1, v1, alignCorners: true );
var s2 = ConvOps.GridSample2d( final[2], fullC, PlaneRes, PlaneRes, u2, v2, alignCorners: true );
var output = new float[qn * OutChannels];
for ( var i = 0; i < qn; i++ )
for ( var ch = 0; ch < OutChannels; ch++ )
output[i * OutChannels + ch] =
s0[i * fullC + 64 + ch] + s1[i * fullC + 64 + ch] + s2[i * fullC + 64 + ch];
return output;
}
// ---------------------------------------------------------------- pieces
static float[] Roll( float[][] planes, int c, int res )
{
var rolled = new float[c * 3 * res * res];
for ( var ch = 0; ch < c; ch++ )
for ( var t = 0; t < 3; t++ )
Array.Copy( planes[t], ch * res * res,
rolled, ch * 3 * res * res + t * res * res, res * res );
return rolled;
}
static float[][] Unroll( float[] rolled, int c, int res )
{
var planes = new float[3][];
for ( var t = 0; t < 3; t++ )
{
planes[t] = new float[c * res * res];
for ( var ch = 0; ch < c; ch++ )
Array.Copy( rolled, ch * 3 * res * res + t * res * res,
planes[t], ch * res * res, res * res );
}
return planes;
}
/// <summary>Applies a per-plane 2× resize op to a rolled tensor (avoids
/// pooling across triplane seams, like the reference's (c tri) rearrange).
/// <paramref name="scaleNumerator"/>/<paramref name="scaleDenominator"/> =
/// 1/2 for pooling, 2/1 for upsampling.</summary>
static float[] PerPlane( float[] rolled, int c, int height, int width,
int scaleNumerator, int scaleDenominator,
Func<float[], int, int, int, float[]> op )
{
var planeH = height / 3;
var oh = planeH * scaleNumerator / scaleDenominator;
var ow = width * scaleNumerator / scaleDenominator;
var merged = new float[c * 3 * oh * ow];
for ( var t = 0; t < 3; t++ )
{
var slice = new float[c * planeH * width];
for ( var ch = 0; ch < c; ch++ )
Array.Copy( rolled, ch * height * width + t * planeH * width,
slice, ch * planeH * width, planeH * width );
var resized = op( slice, c, planeH, width );
for ( var ch = 0; ch < c; ch++ )
Array.Copy( resized, ch * oh * ow,
merged, ch * 3 * oh * ow + t * oh * ow, oh * ow );
}
return merged;
}
float[] ResBlock( UnetBlock block, float[] x, int cin, int height, int width )
{
var spatial = height * width;
var h = new float[x.Length];
Array.Copy( x, h, x.Length );
ConvOps.GroupNorm( h, cin, spatial, Math.Min( 32, cin ), block.Norm1G, block.Norm1B );
ConvOps.Swish( h );
h = ConvOps.Conv2d( h, cin, height, width, block.Conv1W, block.Conv1B, 1 );
var cout = block.COut;
ConvOps.GroupNorm( h, cout, spatial, Math.Min( 32, cout ), block.NormMidG, block.NormMidB );
ConvOps.Swish( h );
h = Aware3d( block, h, cout, height, width );
ConvOps.GroupNorm( h, cout, spatial, Math.Min( 32, cout ), block.Norm2G, block.Norm2B );
ConvOps.Swish( h );
h = ConvOps.Conv2d( h, cout, height, width, block.Conv2W, block.Conv2B, 1 );
float[] shortcut = x;
if ( block.ShortcutW is not null )
shortcut = ConvOps.Conv2d( x, cin, height, width, block.ShortcutW, block.ShortcutB, 0 );
for ( var i = 0; i < h.Length; i++ )
h[i] += shortcut[i];
return h;
}
/// <summary>RODIN cross-plane 1×1: plane i's conv input = [own, broadcast
/// row-mean of plane j, broadcast col-mean of plane k]; plane 2 transposed
/// in and out (order 'xz').</summary>
float[] Aware3d( UnetBlock block, float[] rolled, int c, int height, int width )
{
var planeH = height / 3;
var tri = new float[3][];
for ( var t = 0; t < 3; t++ )
{
tri[t] = new float[c * planeH * width];
for ( var ch = 0; ch < c; ch++ )
Array.Copy( rolled, ch * height * width + t * planeH * width,
tri[t], ch * planeH * width, planeH * width );
}
// plane 2 (xz) → zx
tri[2] = Transpose2d( tri[2], c, planeH, width );
var dims = new (int H, int W)[] { (planeH, width), (planeH, width), (width, planeH) };
var outs = new float[3][];
for ( var i = 0; i < 3; i++ )
{
var j = (i + 1) % 3;
var k = (i + 2) % 3;
var (hi, wi) = dims[i];
var cat = new float[3 * c * hi * wi];
Array.Copy( tri[i], 0, cat, 0, tri[i].Length );
// jpool: mean over j's W → transposed to a row, repeated down rows.
var (hj, wj) = dims[j];
for ( var ch = 0; ch < c; ch++ )
for ( var row = 0; row < hj; row++ )
{
float mean = 0;
for ( var col = 0; col < wj; col++ )
mean += tri[j][ch * hj * wj + row * wj + col];
mean /= wj;
for ( var rr = 0; rr < hi; rr++ )
cat[(c + ch) * hi * wi + rr * wi + row] = mean;
}
// kpool: mean over k's H → a column, repeated across columns.
var (hk, wk) = dims[k];
for ( var ch = 0; ch < c; ch++ )
for ( var col = 0; col < wk; col++ )
{
float mean = 0;
for ( var row = 0; row < hk; row++ )
mean += tri[k][ch * hk * wk + row * wk + col];
mean /= hk;
for ( var cc = 0; cc < wi; cc++ )
cat[(2 * c + ch) * hi * wi + col * wi + cc] = mean;
}
outs[i] = ConvOps.Conv2d( cat, 3 * c, hi, wi, block.AwareW[i], block.AwareB[i], 0 );
}
outs[2] = Transpose2d( outs[2], c, width, planeH );
var merged = new float[c * height * width];
for ( var ch = 0; ch < c; ch++ )
for ( var t = 0; t < 3; t++ )
Array.Copy( outs[t], ch * planeH * width,
merged, ch * height * width + t * planeH * width, planeH * width );
return merged;
}
static float[] Transpose2d( float[] x, int c, int h, int w )
{
var result = new float[c * w * h];
for ( var ch = 0; ch < c; ch++ )
for ( var y = 0; y < h; y++ )
for ( var x0 = 0; x0 < w; x0++ )
result[ch * w * h + x0 * h + y] = x[ch * h * w + y * w + x0];
return result;
}
float[][] TriplaneTransform( float[][] planes )
{
const int reso = PlaneRes;
const int dim = 1024;
const int low = 32;
// downsample each plane 128→32 via conv+relu+maxpool ×2
var lowPlanes = new float[3][];
for ( var t = 0; t < 3; t++ )
{
var z = ConvOps.Conv2d( planes[t], PlaneChannels, reso, reso, Down0W, Down0B, 1 );
for ( var i = 0; i < z.Length; i++ )
if ( z[i] < 0f )
z[i] = 0f;
z = ConvOps.MaxPool2( z, dim, reso, reso );
z = ConvOps.Conv2d( z, dim, reso / 2, reso / 2, Down3W, Down3B, 1 );
for ( var i = 0; i < z.Length; i++ )
if ( z[i] < 0f )
z[i] = 0f;
lowPlanes[t] = ConvOps.MaxPool2( z, dim, reso / 2, reso / 2 );
}
// tokens: (3·32·32, 1024) in nihwd order + pos embed
var tokenCount = 3 * low * low;
var tokens = new float[tokenCount * dim];
for ( var t = 0; t < 3; t++ )
for ( var y = 0; y < low; y++ )
for ( var x0 = 0; x0 < low; x0++ )
{
var token = t * low * low + y * low + x0;
for ( var ch = 0; ch < dim; ch++ )
tokens[token * dim + ch] =
lowPlanes[t][ch * low * low + y * low + x0]
+ PosEmbed.Data[token * dim + ch];
}
var x = Tensor.From( tokens, tokenCount, dim );
foreach ( var layer in TriLayers )
{
var normed = TransformerOps.LayerNorm( x, layer.Norm1G, layer.Norm1B, 1e-6f );
var qkv = normed.MatMul( layer.InProj.Transposed );
var q = qkv.Slice( 1, 0, dim );
var k = qkv.Slice( 1, dim, dim );
var v = qkv.Slice( 1, 2 * dim, dim );
var attended = TransformerOps.Attention( q, k, v, 16, causal: false )
.MatMul( layer.OutProj.Transposed );
x = x.Add( attended );
normed = TransformerOps.LayerNorm( x, layer.Norm2G, layer.Norm2B, 1e-6f );
var mlp = TransformerOps.Gelu(
normed.MatMul( layer.Mlp0W.Transposed ).Add( layer.Mlp0B ) );
x = x.Add( mlp.MatMul( layer.Mlp3W.Transposed ).Add( layer.Mlp3B ) );
}
x = TransformerOps.LayerNorm( x, TriNormG, TriNormB, 1e-6f );
// upsample residual per plane + parallel per-pixel MLP on the input
var result = new float[3][];
for ( var t = 0; t < 3; t++ )
{
var lowT = new float[dim * low * low];
for ( var y = 0; y < low; y++ )
for ( var x0 = 0; x0 < low; x0++ )
{
var token = t * low * low + y * low + x0;
for ( var ch = 0; ch < dim; ch++ )
lowT[ch * low * low + y * low + x0] = x.Data[token * dim + ch];
}
var residual = ConvOps.ConvTranspose2dBlock( lowT, dim, low, low, UpsamplerW, UpsamplerB );
var plane = new float[512 * reso * reso];
Concurrency.For( 0, reso * reso, pix =>
{
Span<float> hbuf = stackalloc float[512];
for ( var oc = 0; oc < 512; oc++ )
{
var sum = Mlp0B.Data[oc];
for ( var ic = 0; ic < PlaneChannels; ic++ )
sum += Mlp0W.Data[oc * PlaneChannels + ic]
* planes[t][ic * reso * reso + pix];
hbuf[oc] = MathF.Max( sum, 0f );
}
for ( var oc = 0; oc < 512; oc++ )
{
var sum = Mlp2B.Data[oc];
for ( var ic = 0; ic < 512; ic++ )
sum += Mlp2W.Data[oc * 512 + ic] * hbuf[ic];
plane[oc * reso * reso + pix] = sum + residual[oc * reso * reso + pix];
}
} );
result[t] = plane;
}
return result;
}
}