Code/AutoRig/Dl/Puppeteer/PuppeteerSkinModel.cs

Neural-network based skinning transformer for a Puppeteer rig. It encodes per-part and shape features, computes point and joint embeddings, runs multi-head attention blocks (including a TAJA graph-distance bias), GEGLU feed-forwards, and outputs per-point softmaxed weights over joints.

Native Interop
using AutoRig.Dl.Nn;
using AutoRig.Dl.UniRig;

namespace AutoRig.Dl.Puppeteer;

using AutoRig.Dl;
using Vector3 = System.Numerics.Vector3;

/// <summary>
/// Puppeteer's skinning transformer (skin checkpoint, blocks.0.* etc.):
/// PartField point features + sinusoidal point/joint embeddings through one
/// SkinningBlock — joint self-attention with the TAJA graph-distance bias,
/// four cross attentions against shape/point/joint streams, GEGLU FFNs — then
/// L2-normalized cosine similarity × |scale| softmaxed over joints.
/// Executable spec: dev/puppeteer-reference/gen_golden_skinhead.py.
/// </summary>
public sealed class PuppeteerSkinModel
{
    public const int D = 768;
    public const int Heads = 8;

    public required PartFieldModel PartField;
    public required PerceiverEncoder ShapeEncoder;         // Michelangelo (skin ckpt copy)
    public required Tensor ShapeQuery;                     // encoder.query [257,768]
    public required Tensor PreKlWeight, PreKlBias, PostKlWeight, PostKlBias;
    public required PerceiverSelfBlock[] ShapeDecodeBlocks;

    public required Tensor Proj0W, Proj0B, Proj2W, Proj2B; // 448→768 SiLU 768→768
    public required Tensor PointPeBasis, PointPeW, PointPeB;
    public required Tensor SkelBasis, SkelW, SkelB;
    public required Tensor Scale;                          // scalar

    public sealed class AttentionUnit
    {
        public required Tensor NormG, NormB;
        public Tensor ContextNormG, ContextNormB;          // cross only
        public required Tensor ToQ, ToKv, ToOutW, ToOutB;
        public required Tensor FfNormG, FfNormB, FfInW, FfInB, FfOutW, FfOutB;
    }

    public required AttentionUnit SelfJ, CrossPs, CrossJs, CrossJp, CrossPj;
    public required Tensor RelPosEmbedding, RelPosProjW, RelPosProjB, RelPosScale;

    public Action<string> Progress { get; set; }

    /// <summary>PointEmbed: sin/cos of block-diagonal frequency projections
    /// (basis [6, 24]) concatenated with the raw 6-dim input → linear.</summary>
    static Tensor PointEmbed( float[] input6, int rows, Tensor basis, Tensor w, Tensor b )
    {
        var e = basis.Shape[1];
        var embedded = new float[rows * (2 * e + 6)];
        for ( var r = 0; r < rows; r++ )
        {
            for ( var col = 0; col < e; col++ )
            {
                float projection = 0;
                for ( var d = 0; d < 6; d++ )
                    projection += input6[r * 6 + d] * basis.Data[d * e + col];
                embedded[r * (2 * e + 6) + col] = MathF.Sin( projection );
                embedded[r * (2 * e + 6) + e + col] = MathF.Cos( projection );
            }
            for ( var d = 0; d < 6; d++ )
                embedded[r * (2 * e + 6) + 2 * e + d] = input6[r * 6 + d];
        }
        return Tensor.From( embedded, rows, 2 * e + 6 )
            .MatMul( w.Transposed ).Add( b );
    }

    static Tensor Attend( AttentionUnit unit, Tensor x, Tensor context, Tensor relPos )
    {
        // PreNorm: query stream through `norm`; the context (which for the
        // joint SELF attention is the raw joint stream again) through its own
        // `norm_context` — two different LayerNorms in the reference.
        var h = TransformerOps.LayerNorm( x, unit.NormG, unit.NormB );
        context ??= x;
        var ctx = unit.ContextNormG is null
            ? h
            : TransformerOps.LayerNorm( context, unit.ContextNormG, unit.ContextNormB );
        var q = h.MatMul( unit.ToQ.Transposed );
        var kv = ctx.MatMul( unit.ToKv.Transposed );
        var inner = q.Shape[1];
        var k = kv.Slice( 1, 0, inner );
        var v = kv.Slice( 1, inner, inner );
        var attended = relPos is null
            ? TransformerOps.Attention( q, k, v, Heads, causal: false )
            : AttentionWithBias( q, k, v, relPos );
        return attended.MatMul( unit.ToOutW.Transposed ).Add( unit.ToOutB );
    }

    /// <summary>MHA with an additive per-head bias (TAJA): bias [Nq, Nk, heads].</summary>
    static Tensor AttentionWithBias( Tensor q, Tensor k, Tensor v, Tensor bias )
    {
        var nq = q.Shape[0];
        var nk = k.Shape[0];
        var headDim = q.Shape[1] / Heads;
        var scale = 1f / MathF.Sqrt( headDim );
        var result = new float[nq * Heads * headDim];
        for ( var head = 0; head < Heads; head++ )
        {
            var hb = head * headDim;
            for ( var i = 0; i < nq; i++ )
            {
                Span<float> scores = stackalloc float[128];
                var row = scores[..nk];
                var max = float.MinValue;
                for ( var j = 0; j < nk; j++ )
                {
                    float dot = 0;
                    for ( var d = 0; d < headDim; d++ )
                        dot += q.Data[i * Heads * headDim + hb + d]
                             * k.Data[j * Heads * headDim + hb + d];
                    row[j] = dot * scale + bias.Data[(i * nk + j) * Heads + head];
                    max = MathF.Max( max, row[j] );
                }
                float total = 0;
                for ( var j = 0; j < nk; j++ )
                {
                    row[j] = MathF.Exp( row[j] - max );
                    total += row[j];
                }
                for ( var j = 0; j < nk; j++ )
                {
                    var weight = row[j] / total;
                    for ( var d = 0; d < headDim; d++ )
                        result[i * Heads * headDim + hb + d] +=
                            weight * v.Data[j * Heads * headDim + hb + d];
                }
            }
        }
        return Tensor.From( result, nq, Heads * headDim );
    }

    static Tensor Geglu( AttentionUnit unit, Tensor x )
    {
        var h = TransformerOps.LayerNorm( x, unit.FfNormG, unit.FfNormB )
            .MatMul( unit.FfInW.Transposed ).Add( unit.FfInB );
        var half = h.Shape[1] / 2;
        var gated = new float[h.Shape[0] * half];
        for ( var r = 0; r < h.Shape[0]; r++ )
            for ( var col = 0; col < half; col++ )
            {
                var value = h.Data[r * h.Shape[1] + col];
                var gate = h.Data[r * h.Shape[1] + half + col];
                var gelu = gate * 0.5f * (1f + TransformerOps.Erf( gate * 0.70710678f ));
                gated[r * half + col] = value * gelu;
            }
        return Tensor.From( gated, h.Shape[0], half )
            .MatMul( unit.FfOutW.Transposed ).Add( unit.FfOutB );
    }

    /// <summary>Michelangelo shape conditioning: encode_latents + KL mean +
    /// decode stack, cat[shape embed, decoded] — process_point_feature WITHOUT
    /// cond projections (stays 768).</summary>
    public Tensor ShapeFeature( Vector3[] points, Vector3[] normals )
    {
        var data = ShapeEncoder.Project( points, normals );
        var x = ShapeQuery;
        x = ShapeEncoder.Cross.Forward( x, data, ShapeEncoder.Heads );
        foreach ( var block in ShapeEncoder.SelfBlocks )
            x = block.Forward( x, ShapeEncoder.Heads );
        x = TransformerOps.LayerNorm( x, ShapeEncoder.LnPostGamma, ShapeEncoder.LnPostBeta );

        var shapeEmbed = x.Slice( 0, 0, 1 );
        var latents = x.Slice( 0, 1, x.Shape[0] - 1 );
        var moments = latents.MatMul( PreKlWeight.Transposed ).Add( PreKlBias );
        var zq = moments.Slice( 1, 0, moments.Shape[1] / 2 );
        var decoded = zq.MatMul( PostKlWeight.Transposed ).Add( PostKlBias );
        foreach ( var block in ShapeDecodeBlocks )
            decoded = block.Forward( decoded, ShapeEncoder.Heads );
        return shapeEmbed.Concat( decoded, 0 );   // (257, 768)
    }

    /// <summary>Per-sample-point weights over joints. points6/joints6 rows =
    /// [x,y,z,nx,ny,nz] / [parentPos, jointPos]; graphDist = hop distances.</summary>
    public float[] Weights(
        float[] points6, int pointCount, float[] partFeatures448,
        float[] joints6, int jointCount, int[,] graphDist, Tensor shapeFeature )
    {
        // point stream: proj(partfield) + PointEmbed
        var pf = Tensor.From( partFeatures448, pointCount, 448 );
        var h = pf.MatMul( Proj0W.Transposed ).Add( Proj0B );
        h = TransformerOps.Silu( h );
        var point = h.MatMul( Proj2W.Transposed ).Add( Proj2B )
            .Add( PointEmbed( points6, pointCount, PointPeBasis, PointPeW, PointPeB ) );

        var joint = PointEmbed( joints6, jointCount, SkelBasis, SkelW, SkelB );

        // TAJA bias: embedding[clamp(dist,0,9)] → proj → × scale, [J,J,heads]
        var e = RelPosEmbedding.Shape[1];
        var bias = new float[jointCount * jointCount * Heads];
        for ( var i = 0; i < jointCount; i++ )
            for ( var j = 0; j < jointCount; j++ )
            {
                var d = Math.Clamp( graphDist[i, j], 0, 9 );
                for ( var head = 0; head < Heads; head++ )
                {
                    float value = RelPosProjB.Data[head];
                    for ( var col = 0; col < e; col++ )
                        value += RelPosProjW.Data[head * e + col]
                            * RelPosEmbedding.Data[d * e + col];
                    bias[(i * jointCount + j) * Heads + head] = value * RelPosScale.Data[0];
                }
            }
        var relPos = Tensor.From( bias, jointCount * jointCount, Heads );

        Progress?.Invoke( "skin head: attention block" );
        var je = joint.Add( Attend( SelfJ, joint, null, relPos ) );
        je = je.Add( Geglu( SelfJ, je ) );
        var pc = point.Add( Attend( CrossPs, point, shapeFeature, null ) );
        pc = pc.Add( Geglu( CrossPs, pc ) );
        var jc = je.Add( Attend( CrossJs, je, shapeFeature, null ) );
        jc = jc.Add( Geglu( CrossJs, jc ) );
        var jr = jc.Add( Attend( CrossJp, jc, pc, null ) );
        jr = jr.Add( Geglu( CrossJp, jr ) );
        var pfin = pc.Add( Attend( CrossPj, pc, jr, null ) );
        pfin = pfin.Add( Geglu( CrossPj, pfin ) );

        // cosine similarity × |scale| → softmax over joints
        var weights = new float[pointCount * jointCount];
        var scale = MathF.Abs( Scale.Data[0] ) + 1e-9f;
        var jointNorms = new float[jointCount];
        for ( var j = 0; j < jointCount; j++ )
        {
            float sum = 0;
            for ( var d = 0; d < D; d++ )
            {
                var value = jr.Data[j * D + d];
                sum += value * value;
            }
            jointNorms[j] = MathF.Sqrt( sum ) + 1e-12f;
        }
        Concurrency.For( 0, pointCount, i =>
        {
            float pn = 0;
            for ( var d = 0; d < D; d++ )
            {
                var value = pfin.Data[i * D + d];
                pn += value * value;
            }
            pn = MathF.Sqrt( pn ) + 1e-12f;

            Span<float> row = stackalloc float[128];
            var max = float.MinValue;
            for ( var j = 0; j < jointCount; j++ )
            {
                float dot = 0;
                for ( var d = 0; d < D; d++ )
                    dot += pfin.Data[i * D + d] * jr.Data[j * D + d];
                row[j] = dot / (pn * jointNorms[j]) * scale;
                max = MathF.Max( max, row[j] );
            }
            float total = 0;
            for ( var j = 0; j < jointCount; j++ )
            {
                row[j] = MathF.Exp( row[j] - max );
                total += row[j];
            }
            for ( var j = 0; j < jointCount; j++ )
                weights[i * jointCount + j] = row[j] / total;
        } );
        return weights;
    }
}