Neural rigging model code for Anymate. Defines a PointBert class that encodes 3D point clouds into transformer latents using FPS, KNN grouping, mini-PointNet feature extraction and ViT blocks, and a small AnymateEmbed helper with Fourier and spherical-harmonic embeddings.
using AutoRig.Dl.Nn;
using AutoRig.Dl.UniRig;
namespace AutoRig.Dl.Anymate;
using AutoRig.Dl;
using Vector3 = System.Numerics.Vector3;
/// <summary>
/// Anymate (Apache-2.0): three-stage neural rigging — joints, connectivity,
/// skin — each a PointBERT encoder (512 fps groups × knn 32, mini-PointNet
/// with BatchNorm running stats, 12 ViT blocks with positions re-added every
/// block) bridged 384→768, plus small Michelangelo-style decoders. Executable
/// spec: dev/anymate-reference/gen_golden_anymate.py.
/// </summary>
public sealed class PointBert
{
public const int Groups = 512;
public const int GroupSize = 32;
public const int Dim = 384;
public const int Heads = 6;
public required Tensor FirstConv0W, FirstConv0B; // 6→128 (1×1)
public required Tensor FirstBnMean, FirstBnVar, FirstBnG, FirstBnB;
public required Tensor FirstConv3W, FirstConv3B; // 128→256
public required Tensor SecondConv0W, SecondConv0B; // 512→512
public required Tensor SecondBnMean, SecondBnVar, SecondBnG, SecondBnB;
public required Tensor SecondConv3W, SecondConv3B; // 512→256
public required Tensor ReduceDimW, ReduceDimB; // 256→384
public required Tensor ClsToken, ClsPos; // [384]
public required Tensor PosEmbed0W, PosEmbed0B, PosEmbed2W, PosEmbed2B;
public sealed class VitBlock
{
public required Tensor Norm1G, Norm1B, QkvW, ProjW, ProjB;
public required Tensor Norm2G, Norm2B, Fc1W, Fc1B, Fc2W, Fc2B;
}
public required VitBlock[] Blocks;
public required Tensor NormG, NormB;
public required Tensor Proj0W, Proj0B, Proj2W, Proj2B, Proj4W, Proj4B; // 384→512→512→768
/// <summary>Points in the bert branch's space (mesh normalized to [-1,1],
/// then halved). Returns latents (513 × 768).</summary>
public Tensor Encode( Vector3[] points )
{
var n = points.Length;
// fps 512 centers (deterministic start 0 — reference uses a random start).
var centers = new int[Groups];
var distance = new float[n];
Array.Fill( distance, float.MaxValue );
var farthest = 0;
for ( var g = 0; g < Groups; g++ )
{
centers[g] = farthest;
var c = points[farthest];
var best = 0f;
var bestAt = 0;
for ( var i = 0; i < n; i++ )
{
var d = Vector3.DistanceSquared( points[i], c );
if ( d < distance[i] )
distance[i] = d;
if ( distance[i] > best )
{
best = distance[i];
bestAt = i;
}
}
farthest = bestAt;
}
// knn 32 + group tokens through the mini-PointNet.
var tokens = new float[Groups * 256];
Concurrency.For( 0, Groups, g =>
{
var center = points[centers[g]];
Span<int> nearest = stackalloc int[GroupSize];
Span<float> nearDistance = stackalloc float[GroupSize];
var used = 0;
for ( var i = 0; i < n; i++ )
{
var d = Vector3.DistanceSquared( points[i], center );
if ( used < GroupSize )
{
nearest[used] = i;
nearDistance[used++] = d;
}
else
{
var worst = 0;
for ( var k = 1; k < GroupSize; k++ )
if ( nearDistance[k] > nearDistance[worst] )
worst = k;
if ( d < nearDistance[worst] )
{
nearest[worst] = i;
nearDistance[worst] = d;
}
}
}
// group features: [xyz − center, 0,0,0] per point → conv1d chain.
Span<float> h128 = stackalloc float[GroupSize * 128];
for ( var p = 0; p < GroupSize; p++ )
{
var rel = points[nearest[p]] - center;
for ( var oc = 0; oc < 128; oc++ )
{
var sum = FirstConv0B.Data[oc]
+ FirstConv0W.Data[oc * 6 + 0] * rel.X
+ FirstConv0W.Data[oc * 6 + 1] * rel.Y
+ FirstConv0W.Data[oc * 6 + 2] * rel.Z;
// BN inference + ReLU
sum = (sum - FirstBnMean.Data[oc])
/ MathF.Sqrt( FirstBnVar.Data[oc] + 1e-5f )
* FirstBnG.Data[oc] + FirstBnB.Data[oc];
h128[p * 128 + oc] = MathF.Max( sum, 0f );
}
}
Span<float> h256 = stackalloc float[GroupSize * 256];
Span<float> global256 = stackalloc float[256];
global256.Fill( float.MinValue );
for ( var p = 0; p < GroupSize; p++ )
for ( var oc = 0; oc < 256; oc++ )
{
var sum = FirstConv3B.Data[oc];
for ( var ic = 0; ic < 128; ic++ )
sum += FirstConv3W.Data[oc * 128 + ic] * h128[p * 128 + ic];
h256[p * 256 + oc] = sum;
if ( sum > global256[oc] )
global256[oc] = sum;
}
Span<float> out256 = stackalloc float[256];
out256.Fill( float.MinValue );
Span<float> h512 = stackalloc float[512];
for ( var p = 0; p < GroupSize; p++ )
{
for ( var oc = 0; oc < 512; oc++ )
{
var sum = SecondConv0B.Data[oc];
for ( var ic = 0; ic < 256; ic++ )
sum += SecondConv0W.Data[oc * 512 + ic] * global256[ic]
+ SecondConv0W.Data[oc * 512 + 256 + ic] * h256[p * 256 + ic];
sum = (sum - SecondBnMean.Data[oc])
/ MathF.Sqrt( SecondBnVar.Data[oc] + 1e-5f )
* SecondBnG.Data[oc] + SecondBnB.Data[oc];
h512[oc] = MathF.Max( sum, 0f );
}
for ( var oc = 0; oc < 256; oc++ )
{
var sum = SecondConv3B.Data[oc];
for ( var ic = 0; ic < 512; ic++ )
sum += SecondConv3W.Data[oc * 512 + ic] * h512[ic];
if ( sum > out256[oc] )
out256[oc] = sum;
}
}
for ( var oc = 0; oc < 256; oc++ )
tokens[g * 256 + oc] = out256[oc];
} );
var reduced = Tensor.From( tokens, Groups, 256 )
.MatMul( ReduceDimW.Transposed ).Add( ReduceDimB );
// positions of centers through the pos MLP; cls token + cls pos.
var centerXyz = new float[Groups * 3];
for ( var g = 0; g < Groups; g++ )
{
centerXyz[g * 3] = points[centers[g]].X;
centerXyz[g * 3 + 1] = points[centers[g]].Y;
centerXyz[g * 3 + 2] = points[centers[g]].Z;
}
var pos = Tensor.From( centerXyz, Groups, 3 )
.MatMul( PosEmbed0W.Transposed ).Add( PosEmbed0B );
pos = TransformerOps.Gelu( pos ).MatMul( PosEmbed2W.Transposed ).Add( PosEmbed2B );
var x = Tensor.From( ClsToken.Data, 1, Dim ).Concat( reduced, 0 );
var fullPos = Tensor.From( ClsPos.Data, 1, Dim ).Concat( pos, 0 );
foreach ( var block in Blocks )
{
var xin = x.Add( fullPos ); // positions re-added EVERY block
var h = TransformerOps.LayerNorm( xin, block.Norm1G, block.Norm1B );
var qkv = h.MatMul( block.QkvW.Transposed );
// reshape (N, 3, H, D): q/k/v are CONTIGUOUS thirds per row-group?
// torch view(B,N,3,H,D).permute → q = qkv[..., :dim] with (3,H,D)
// factorization = 3 outer → q is the FIRST third of columns.
var q = qkv.Slice( 1, 0, Dim );
var k = qkv.Slice( 1, Dim, Dim );
var v = qkv.Slice( 1, 2 * Dim, Dim );
var attended = TransformerOps.Attention( q, k, v, Heads, causal: false );
x = xin.Add( attended.MatMul( block.ProjW.Transposed ).Add( block.ProjB ) );
h = TransformerOps.LayerNorm( x, block.Norm2G, block.Norm2B );
h = TransformerOps.Gelu( h.MatMul( block.Fc1W.Transposed ).Add( block.Fc1B ) );
x = x.Add( h.MatMul( block.Fc2W.Transposed ).Add( block.Fc2B ) );
}
x = TransformerOps.LayerNorm( x, NormG, NormB );
var bridge = TransformerOps.Gelu( x.MatMul( Proj0W.Transposed ).Add( Proj0B ) );
bridge = TransformerOps.Gelu( bridge.MatMul( Proj2W.Transposed ).Add( Proj2B ) );
return bridge.MatMul( Proj4W.Transposed ).Add( Proj4B ); // (513, 768)
}
}
/// <summary>Shared embedding helpers (include_pi fourier + SH25 normals).</summary>
public static class AnymateEmbed
{
/// <summary>[x, sin(x·π·2^0..7 per dim), cos(...)] = 51 dims per 3 inputs.</summary>
public static float[] FourierPi( Vector3 p )
{
var result = new float[51];
result[0] = p.X;
result[1] = p.Y;
result[2] = p.Z;
Span<float> v = stackalloc float[] { p.X, p.Y, p.Z };
for ( var d = 0; d < 3; d++ )
for ( var f = 0; f < 8; f++ )
{
var scaled = v[d] * (1 << f) * MathF.PI;
result[3 + d * 8 + f] = MathF.Sin( scaled );
result[3 + 24 + d * 8 + f] = MathF.Cos( scaled );
}
return result;
}
/// <summary>Real spherical harmonics, 5 levels = 25 components.</summary>
public static float[] Sh25( Vector3 d )
{
var (x, y, z) = (d.X, d.Y, d.Z);
var (xx, yy, zz) = (x * x, y * y, z * z);
return new[]
{
0.28209479f,
0.48860252f * y, 0.48860252f * z, 0.48860252f * x,
1.09254843f * x * y, 1.09254843f * y * z,
0.94617470f * zz - 0.31539157f,
1.09254843f * x * z, 0.54627422f * (xx - yy),
0.59004359f * y * (3 * xx - yy), 2.89061144f * x * y * z,
0.45704580f * y * (5 * zz - 1), 0.37317633f * z * (5 * zz - 3),
0.45704580f * x * (5 * zz - 1), 1.44530572f * z * (xx - yy),
0.59004359f * x * (xx - 3 * yy),
2.50334294f * x * y * (xx - yy),
1.77013077f * y * z * (3 * xx - yy),
0.94617470f * x * y * (7 * zz - 1),
0.66904654f * y * z * (7 * zz - 3),
0.10578555f * (35 * zz * zz - 30 * zz + 3),
0.66904654f * x * z * (7 * zz - 3),
0.47308735f * (xx - yy) * (7 * zz - 1),
1.77013077f * x * z * (xx - 3 * yy),
0.62583574f * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)),
};
}
}