Code/AutoRig/Dl/Anymate/AnymateModel.cs

Neural rigging model code for the Anymate PointBert encoder and embedding helpers. PointBert.Encode performs furthest-point sampling, k-NN grouping, a small PointNet per group, and runs transformer (ViT) blocks to produce a (513×768) latent tensor. AnymateEmbed provides Fourier positional encoding and 25-component real spherical harmonics.

Native Interop
using AutoRig.Dl.Nn;
using AutoRig.Dl.UniRig;

namespace AutoRig.Dl.Anymate;

using AutoRig.Dl;
using Vector3 = System.Numerics.Vector3;

/// <summary>
/// Anymate (Apache-2.0): three-stage neural rigging — joints, connectivity,
/// skin — each a PointBERT encoder (512 fps groups × knn 32, mini-PointNet
/// with BatchNorm running stats, 12 ViT blocks with positions re-added every
/// block) bridged 384→768, plus small Michelangelo-style decoders. Executable
/// spec: dev/anymate-reference/gen_golden_anymate.py.
/// </summary>
public sealed class PointBert
{
    public const int Groups = 512;
    public const int GroupSize = 32;
    public const int Dim = 384;
    public const int Heads = 6;

    public required Tensor FirstConv0W, FirstConv0B;   // 6→128 (1×1)
    public required Tensor FirstBnMean, FirstBnVar, FirstBnG, FirstBnB;
    public required Tensor FirstConv3W, FirstConv3B;   // 128→256
    public required Tensor SecondConv0W, SecondConv0B; // 512→512
    public required Tensor SecondBnMean, SecondBnVar, SecondBnG, SecondBnB;
    public required Tensor SecondConv3W, SecondConv3B; // 512→256
    public required Tensor ReduceDimW, ReduceDimB;     // 256→384
    public required Tensor ClsToken, ClsPos;           // [384]
    public required Tensor PosEmbed0W, PosEmbed0B, PosEmbed2W, PosEmbed2B;
    public sealed class VitBlock
    {
        public required Tensor Norm1G, Norm1B, QkvW, ProjW, ProjB;
        public required Tensor Norm2G, Norm2B, Fc1W, Fc1B, Fc2W, Fc2B;
    }
    public required VitBlock[] Blocks;
    public required Tensor NormG, NormB;
    public required Tensor Proj0W, Proj0B, Proj2W, Proj2B, Proj4W, Proj4B;   // 384→512→512→768

    /// <summary>Points in the bert branch's space (mesh normalized to [-1,1],
    /// then halved). Returns latents (513 × 768).</summary>
    public Tensor Encode( Vector3[] points )
    {
        var n = points.Length;

        // fps 512 centers (deterministic start 0 — reference uses a random start).
        var centers = new int[Groups];
        var distance = new float[n];
        Array.Fill( distance, float.MaxValue );
        var farthest = 0;
        for ( var g = 0; g < Groups; g++ )
        {
            centers[g] = farthest;
            var c = points[farthest];
            var best = 0f;
            var bestAt = 0;
            for ( var i = 0; i < n; i++ )
            {
                var d = Vector3.DistanceSquared( points[i], c );
                if ( d < distance[i] )
                    distance[i] = d;
                if ( distance[i] > best )
                {
                    best = distance[i];
                    bestAt = i;
                }
            }
            farthest = bestAt;
        }

        // knn 32 + group tokens through the mini-PointNet.
        var tokens = new float[Groups * 256];
        Concurrency.For( 0, Groups, g =>
        {
            var center = points[centers[g]];
            Span<int> nearest = stackalloc int[GroupSize];
            Span<float> nearDistance = stackalloc float[GroupSize];
            var used = 0;
            for ( var i = 0; i < n; i++ )
            {
                var d = Vector3.DistanceSquared( points[i], center );
                if ( used < GroupSize )
                {
                    nearest[used] = i;
                    nearDistance[used++] = d;
                }
                else
                {
                    var worst = 0;
                    for ( var k = 1; k < GroupSize; k++ )
                        if ( nearDistance[k] > nearDistance[worst] )
                            worst = k;
                    if ( d < nearDistance[worst] )
                    {
                        nearest[worst] = i;
                        nearDistance[worst] = d;
                    }
                }
            }

            // group features: [xyz − center, 0,0,0] per point → conv1d chain.
            Span<float> h128 = stackalloc float[GroupSize * 128];
            for ( var p = 0; p < GroupSize; p++ )
            {
                var rel = points[nearest[p]] - center;
                for ( var oc = 0; oc < 128; oc++ )
                {
                    var sum = FirstConv0B.Data[oc]
                        + FirstConv0W.Data[oc * 6 + 0] * rel.X
                        + FirstConv0W.Data[oc * 6 + 1] * rel.Y
                        + FirstConv0W.Data[oc * 6 + 2] * rel.Z;
                    // BN inference + ReLU
                    sum = (sum - FirstBnMean.Data[oc])
                        / MathF.Sqrt( FirstBnVar.Data[oc] + 1e-5f )
                        * FirstBnG.Data[oc] + FirstBnB.Data[oc];
                    h128[p * 128 + oc] = MathF.Max( sum, 0f );
                }
            }
            Span<float> h256 = stackalloc float[GroupSize * 256];
            Span<float> global256 = stackalloc float[256];
            global256.Fill( float.MinValue );
            for ( var p = 0; p < GroupSize; p++ )
                for ( var oc = 0; oc < 256; oc++ )
                {
                    var sum = FirstConv3B.Data[oc];
                    for ( var ic = 0; ic < 128; ic++ )
                        sum += FirstConv3W.Data[oc * 128 + ic] * h128[p * 128 + ic];
                    h256[p * 256 + oc] = sum;
                    if ( sum > global256[oc] )
                        global256[oc] = sum;
                }
            Span<float> out256 = stackalloc float[256];
            out256.Fill( float.MinValue );
            Span<float> h512 = stackalloc float[512];
            for ( var p = 0; p < GroupSize; p++ )
            {
                for ( var oc = 0; oc < 512; oc++ )
                {
                    var sum = SecondConv0B.Data[oc];
                    for ( var ic = 0; ic < 256; ic++ )
                        sum += SecondConv0W.Data[oc * 512 + ic] * global256[ic]
                             + SecondConv0W.Data[oc * 512 + 256 + ic] * h256[p * 256 + ic];
                    sum = (sum - SecondBnMean.Data[oc])
                        / MathF.Sqrt( SecondBnVar.Data[oc] + 1e-5f )
                        * SecondBnG.Data[oc] + SecondBnB.Data[oc];
                    h512[oc] = MathF.Max( sum, 0f );
                }
                for ( var oc = 0; oc < 256; oc++ )
                {
                    var sum = SecondConv3B.Data[oc];
                    for ( var ic = 0; ic < 512; ic++ )
                        sum += SecondConv3W.Data[oc * 512 + ic] * h512[ic];
                    if ( sum > out256[oc] )
                        out256[oc] = sum;
                }
            }
            for ( var oc = 0; oc < 256; oc++ )
                tokens[g * 256 + oc] = out256[oc];
        } );

        var reduced = Tensor.From( tokens, Groups, 256 )
            .MatMul( ReduceDimW.Transposed ).Add( ReduceDimB );

        // positions of centers through the pos MLP; cls token + cls pos.
        var centerXyz = new float[Groups * 3];
        for ( var g = 0; g < Groups; g++ )
        {
            centerXyz[g * 3] = points[centers[g]].X;
            centerXyz[g * 3 + 1] = points[centers[g]].Y;
            centerXyz[g * 3 + 2] = points[centers[g]].Z;
        }
        var pos = Tensor.From( centerXyz, Groups, 3 )
            .MatMul( PosEmbed0W.Transposed ).Add( PosEmbed0B );
        pos = TransformerOps.Gelu( pos ).MatMul( PosEmbed2W.Transposed ).Add( PosEmbed2B );

        var x = Tensor.From( ClsToken.Data, 1, Dim ).Concat( reduced, 0 );
        var fullPos = Tensor.From( ClsPos.Data, 1, Dim ).Concat( pos, 0 );

        foreach ( var block in Blocks )
        {
            var xin = x.Add( fullPos );   // positions re-added EVERY block
            var h = TransformerOps.LayerNorm( xin, block.Norm1G, block.Norm1B );
            var qkv = h.MatMul( block.QkvW.Transposed );
            // reshape (N, 3, H, D): q/k/v are CONTIGUOUS thirds per row-group?
            // torch view(B,N,3,H,D).permute → q = qkv[..., :dim] with (3,H,D)
            // factorization = 3 outer → q is the FIRST third of columns.
            var q = qkv.Slice( 1, 0, Dim );
            var k = qkv.Slice( 1, Dim, Dim );
            var v = qkv.Slice( 1, 2 * Dim, Dim );
            var attended = TransformerOps.Attention( q, k, v, Heads, causal: false );
            x = xin.Add( attended.MatMul( block.ProjW.Transposed ).Add( block.ProjB ) );
            h = TransformerOps.LayerNorm( x, block.Norm2G, block.Norm2B );
            h = TransformerOps.Gelu( h.MatMul( block.Fc1W.Transposed ).Add( block.Fc1B ) );
            x = x.Add( h.MatMul( block.Fc2W.Transposed ).Add( block.Fc2B ) );
        }
        x = TransformerOps.LayerNorm( x, NormG, NormB );

        var bridge = TransformerOps.Gelu( x.MatMul( Proj0W.Transposed ).Add( Proj0B ) );
        bridge = TransformerOps.Gelu( bridge.MatMul( Proj2W.Transposed ).Add( Proj2B ) );
        return bridge.MatMul( Proj4W.Transposed ).Add( Proj4B );   // (513, 768)
    }
}

/// <summary>Shared embedding helpers (include_pi fourier + SH25 normals).</summary>
public static class AnymateEmbed
{
    /// <summary>[x, sin(x·π·2^0..7 per dim), cos(...)] = 51 dims per 3 inputs.</summary>
    public static float[] FourierPi( Vector3 p )
    {
        var result = new float[51];
        result[0] = p.X;
        result[1] = p.Y;
        result[2] = p.Z;
        Span<float> v = stackalloc float[] { p.X, p.Y, p.Z };
        for ( var d = 0; d < 3; d++ )
            for ( var f = 0; f < 8; f++ )
            {
                var scaled = v[d] * (1 << f) * MathF.PI;
                result[3 + d * 8 + f] = MathF.Sin( scaled );
                result[3 + 24 + d * 8 + f] = MathF.Cos( scaled );
            }
        return result;
    }

    /// <summary>Real spherical harmonics, 5 levels = 25 components.</summary>
    public static float[] Sh25( Vector3 d )
    {
        var (x, y, z) = (d.X, d.Y, d.Z);
        var (xx, yy, zz) = (x * x, y * y, z * z);
        return new[]
        {
            0.28209479f,
            0.48860252f * y, 0.48860252f * z, 0.48860252f * x,
            1.09254843f * x * y, 1.09254843f * y * z,
            0.94617470f * zz - 0.31539157f,
            1.09254843f * x * z, 0.54627422f * (xx - yy),
            0.59004359f * y * (3 * xx - yy), 2.89061144f * x * y * z,
            0.45704580f * y * (5 * zz - 1), 0.37317633f * z * (5 * zz - 3),
            0.45704580f * x * (5 * zz - 1), 1.44530572f * z * (xx - yy),
            0.59004359f * x * (xx - 3 * yy),
            2.50334294f * x * y * (xx - yy),
            1.77013077f * y * z * (3 * xx - yy),
            0.94617470f * x * y * (7 * zz - 1),
            0.66904654f * y * z * (7 * zz - 3),
            0.10578555f * (35 * zz * zz - 30 * zz + 3),
            0.66904654f * x * z * (7 * zz - 3),
            0.47308735f * (xx - yy) * (7 * zz - 1),
            1.77013077f * x * z * (xx - 3 * yy),
            0.62583574f * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)),
        };
    }
}