Code/AutoRig/Dl/RigAnything/RigAnythingModel.cs

Model and diffusion classes for the RigAnything autoregressive-diffusion rigging network. Defines the transformer blocks, tokenization and attention cache handling, joint fusion/composition, parent and skin score MLPs, and a MAR/DiT-style diffusion head with cosine schedule and ancestral sampling.

Native Interop
using AutoRig.Dl.Nn;
using AutoRig.Dl.UniRig;

namespace AutoRig.Dl.RigAnything;

using AutoRig.Dl;
using Vector3 = System.Numerics.Vector3;

/// <summary>
/// RigAnything (Adobe, SIGGRAPH TOG 2025 — NONCOMMERCIAL research license):
/// autoregressive diffusion rigging. A 12-block transformer (d 1024, 64 heads
/// × 16 dims, pre-LN, γ-only LayerNorm, no biases anywhere) attends a
/// point-cloud prefix bidirectionally and joints causally; each joint's xyz is
/// SAMPLED from a MAR-style AdaLN diffusion head (cosine schedule, 1000 steps
/// respaced to 50); parents come from a pairwise MLP (self-parent = stop);
/// skin weights from a pairwise point×joint MLP — this is the catalog's first
/// port with LEARNED skin weights.
/// </summary>
public sealed class RigAnythingBlock
{
    public required Tensor Norm1Gamma, Norm1Beta;
    public required Tensor QkvWeight;             // [3d, d] head-major fused
    public required Tensor FcWeight;              // [d, d]
    public required Tensor Norm2Gamma, Norm2Beta;
    public required Tensor Mlp0Weight, Mlp2Weight; // d→4d GELU 4d→d
}

public sealed class RigAnythingModel
{
    public const int D = 1024;
    public const int Heads = 64;   // d_head 16
    public const int MaxJoints = 64;

    public required Tensor PcTok0, PcTok2;        // 6→512 LeakyReLU 512→1024
    public required Tensor JointTok0, JointTok2;  // 3→512→1024
    public required Tensor JointMlp0, JointMlp2;  // 4d→2d→d
    public required Tensor FuseMlp0, FuseMlp2;    // 3d→2d→d
    public required Tensor SkinMlp0, SkinMlp2;    // 2d→d→1
    public required Tensor ParentLnGamma, ParentLnBeta;   // LN(2d)
    public required Tensor Parent1, Parent3;      // 2d→d LeakyReLU d→1
    public required Tensor StartToken;            // [1, d]
    public required Tensor InputLnGamma, InputLnBeta;
    public required RigAnythingBlock[] Blocks;
    public required RigAnythingDiffusion Diffusion;

    /// <summary>Fixed sinusoidal joint-index table (64 × 1024): sin on even
    /// columns, cos on odd, rate 1/10000^(2⌊i/2⌋/d).</summary>
    public Tensor IndexEmbedding => _indexEmbedding ??= BuildIndexEmbedding();
    Tensor _indexEmbedding;

    static Tensor BuildIndexEmbedding()
    {
        var data = new float[MaxJoints * D];
        for ( var pos = 0; pos < MaxJoints; pos++ )
            for ( var i = 0; i < D; i++ )
            {
                var rate = 1.0 / Math.Pow( 10000.0, 2 * (i / 2) / (double)D );
                var angle = pos * rate;
                data[pos * D + i] = i % 2 == 0
                    ? (float)Math.Sin( angle )
                    : (float)Math.Cos( angle );
            }
        return Tensor.From( data, MaxJoints, D );
    }

    internal static Tensor LeakyMlp( Tensor x, Tensor w0, Tensor w2 )
    {
        var h = x.MatMul( w0.Transposed );
        LeakyRelu( h.Data );
        return h.MatMul( w2.Transposed );
    }

    internal static void LeakyRelu( float[] data )
    {
        for ( var i = 0; i < data.Length; i++ )
            if ( data[i] < 0f )
                data[i] *= 0.01f;
    }

    public Tensor PcTokenize( Vector3[] points, Vector3[] normals )
    {
        var input = new float[points.Length * 6];
        for ( var i = 0; i < points.Length; i++ )
        {
            input[i * 6 + 0] = points[i].X;
            input[i * 6 + 1] = points[i].Y;
            input[i * 6 + 2] = points[i].Z;
            input[i * 6 + 3] = normals[i].X;
            input[i * 6 + 4] = normals[i].Y;
            input[i * 6 + 5] = normals[i].Z;
        }
        return LeakyMlp( Tensor.From( input, points.Length, 6 ), PcTok0, PcTok2 );
    }

    public Tensor JointTokenize( IReadOnlyList<Vector3> joints )
    {
        var input = new float[joints.Count * 3];
        for ( var i = 0; i < joints.Count; i++ )
        {
            input[i * 3 + 0] = joints[i].X;
            input[i * 3 + 1] = joints[i].Y;
            input[i * 3 + 2] = joints[i].Z;
        }
        return LeakyMlp( Tensor.From( input, joints.Count, 3 ), JointTok0, JointTok2 );
    }

    /// <summary>
    /// Runs new tokens through all blocks, appending K/V to the cache. The
    /// FIRST call must be the point-cloud prefix + start token (pcCount set):
    /// pc rows attend pc rows only (bidirectional), the start row attends
    /// everything; later single-token calls attend the whole cache — exactly
    /// the reference mask. Input LN applied per call (reference does too).
    /// </summary>
    public Tensor Forward( Tensor newTokens, OptKvCache cache, int pcCount )
    {
        var x = TransformerOps.LayerNorm( newTokens, InputLnGamma, InputLnBeta );
        for ( var l = 0; l < Blocks.Length; l++ )
        {
            var b = Blocks[l];
            var h = TransformerOps.LayerNorm( x, b.Norm1Gamma, b.Norm1Beta );
            var qkv = h.MatMul( b.QkvWeight.Transposed );
            var q = qkv.Slice( 1, 0, D );
            var k = qkv.Slice( 1, D, D );
            var v = qkv.Slice( 1, 2 * D, D );

            cache.Keys[l] = cache.Keys[l] is null ? k : cache.Keys[l].Concat( k, 0 );
            cache.Values[l] = cache.Values[l] is null ? v : cache.Values[l].Concat( v, 0 );

            Tensor attended;
            if ( pcCount > 0 )
            {
                // First pass: pc block bidirectional over pc keys; the start
                // row (and any joints) attend everything so far.
                var pcOut = TransformerOps.Attention(
                    q.Slice( 0, 0, pcCount ),
                    cache.Keys[l].Slice( 0, 0, pcCount ),
                    cache.Values[l].Slice( 0, 0, pcCount ),
                    Heads, causal: false );
                var rest = TransformerOps.Attention(
                    q.Slice( 0, pcCount, q.Shape[0] - pcCount ),
                    cache.Keys[l], cache.Values[l], Heads, causal: true );
                attended = pcOut.Concat( rest, 0 );
            }
            else
            {
                attended = TransformerOps.Attention(
                    q, cache.Keys[l], cache.Values[l], Heads, causal: true );
            }

            x = x.Add( attended.MatMul( b.FcWeight.Transposed ) );
            h = TransformerOps.LayerNorm( x, b.Norm2Gamma, b.Norm2Beta );
            h = TransformerOps.Gelu( h.MatMul( b.Mlp0Weight.Transposed ) );
            x = x.Add( h.MatMul( b.Mlp2Weight.Transposed ) );
        }
        return x;
    }

    /// <summary>joint_fuse_mlp(cat[transformer token, index embed, position token]).</summary>
    public Tensor FuseJoint( Tensor transformerToken, int jointIndex, Tensor positionToken )
    {
        var input = new float[3 * D];
        Array.Copy( transformerToken.Data, 0, input, 0, D );
        Array.Copy( IndexEmbedding.Data, jointIndex * D, input, D, D );
        Array.Copy( positionToken.Data, 0, input, 2 * D, D );
        return LeakyMlp( Tensor.From( input, 1, 3 * D ), FuseMlp0, FuseMlp2 );
    }

    /// <summary>joint_mlp(cat[index embed, pos token, parent index embed, parent pos token]).</summary>
    public Tensor ComposeInputToken( int jointIndex, Tensor positionToken, int parentIndex, Tensor parentPositionToken )
    {
        var input = new float[4 * D];
        Array.Copy( IndexEmbedding.Data, jointIndex * D, input, 0, D );
        Array.Copy( positionToken.Data, 0, input, D, D );
        Array.Copy( IndexEmbedding.Data, parentIndex * D, input, 2 * D, D );
        Array.Copy( parentPositionToken.Data, 0, input, 3 * D, D );
        return LeakyMlp( Tensor.From( input, 1, 4 * D ), JointMlp0, JointMlp2 );
    }

    /// <summary>parents_decoder(cat[current joint feature, candidate feature]) per candidate.</summary>
    public float[] ParentScores( Tensor currentFeature, IReadOnlyList<Tensor> candidateFeatures )
    {
        var rows = candidateFeatures.Count;
        var input = new float[rows * 2 * D];
        for ( var r = 0; r < rows; r++ )
        {
            Array.Copy( currentFeature.Data, 0, input, r * 2 * D, D );
            Array.Copy( candidateFeatures[r].Data, 0, input, r * 2 * D + D, D );
        }
        var x = TransformerOps.LayerNorm(
            Tensor.From( input, rows, 2 * D ), ParentLnGamma, ParentLnBeta );
        x = x.MatMul( Parent1.Transposed );
        LeakyRelu( x.Data );
        return x.MatMul( Parent3.Transposed ).Data;
    }

    /// <summary>skinning_mlp(cat[pc token, joint feature]) → raw score per point.</summary>
    public float[] SkinScores( Tensor pcTokens, Tensor jointFeature )
    {
        var rows = pcTokens.Shape[0];
        var input = new float[rows * 2 * D];
        for ( var r = 0; r < rows; r++ )
        {
            Array.Copy( pcTokens.Data, r * D, input, r * 2 * D, D );
            Array.Copy( jointFeature.Data, 0, input, r * 2 * D + D, D );
        }
        return LeakyMlp( Tensor.From( input, rows, 2 * D ), SkinMlp0, SkinMlp2 ).Data;
    }
}

/// <summary>
/// The MAR/DiT diffusion head (SimpleMLPAdaLN) plus the IDDPM machinery it
/// samples with: cosine ᾱ over 1000 steps respaced to 50, epsilon prediction,
/// LEARNED_RANGE variance, ancestral sampling. The reference draws noise from
/// torch's RNG; we use a seeded Gaussian (documented — sampling is stochastic
/// in the reference too, and every deterministic sub-part is golden-tested).
/// </summary>
public sealed class RigAnythingDiffusion
{
    public const int TrainSteps = 1000;
    public const int SampleSteps = 50;

    public required Tensor InputProjWeight, InputProjBias;    // 3 → 1024
    public required Tensor Time0Weight, Time0Bias, Time2Weight, Time2Bias;
    public required Tensor CondWeight, CondBias;              // 1024 → 1024
    public required DiffResBlock[] ResBlocks;                 // 3
    public required Tensor FinalAdaWeight, FinalAdaBias;      // 1024 → 2048
    public required Tensor FinalLinWeight, FinalLinBias;      // 1024 → 6

    public sealed class DiffResBlock
    {
        public required Tensor InLnGamma, InLnBeta;           // eps 1e-6
        public required Tensor Mlp0Weight, Mlp0Bias, Mlp2Weight, Mlp2Bias;
        public required Tensor AdaWeight, AdaBias;            // 1024 → 3072
    }

    /// <summary>Net forward: x (1,3), ORIGINAL timestep t, cond c (1,1024) → (1,6).</summary>
    public Tensor Net( Tensor x, int t, Tensor c )
    {
        x = x.MatMul( InputProjWeight.Transposed ).Add( InputProjBias );

        // Sinusoidal timestep embedding: 256 dims = 128 cos then 128 sin.
        const int half = 128;
        var freq = new float[256];
        for ( var i = 0; i < half; i++ )
        {
            var f = Math.Exp( -Math.Log( 10000.0 ) * i / half ) * t;
            freq[i] = (float)Math.Cos( f );
            freq[half + i] = (float)Math.Sin( f );
        }
        var temb = Tensor.From( freq, 1, 256 )
            .MatMul( Time0Weight.Transposed ).Add( Time0Bias );
        temb = TransformerOps.Silu( temb );
        temb = temb.MatMul( Time2Weight.Transposed ).Add( Time2Bias );

        var y = temb.Add( c.MatMul( CondWeight.Transposed ).Add( CondBias ) );

        foreach ( var block in ResBlocks )
        {
            var mod = TransformerOps.Silu( y )
                .MatMul( block.AdaWeight.Transposed ).Add( block.AdaBias );
            var h = TransformerOps.LayerNorm( x, block.InLnGamma, block.InLnBeta, eps: 1e-6f );
            Modulate( h, mod, 0, 1 );
            h = h.MatMul( block.Mlp0Weight.Transposed ).Add( block.Mlp0Bias );
            h = TransformerOps.Silu( h );
            h = h.MatMul( block.Mlp2Weight.Transposed ).Add( block.Mlp2Bias );
            for ( var i = 0; i < x.Data.Length; i++ )   // x + gate·h
                h.Data[i] = x.Data[i] + mod.Data[2 * x.Shape[1] + i] * h.Data[i];
            x = h;
        }

        var finalMod = TransformerOps.Silu( y )
            .MatMul( FinalAdaWeight.Transposed ).Add( FinalAdaBias );
        var normed = TransformerOps.LayerNorm( x, null, null, eps: 1e-6f );
        Modulate( normed, finalMod, 0, 1 );
        return normed.MatMul( FinalLinWeight.Transposed ).Add( FinalLinBias );
    }

    /// <summary>x = x·(1+scale) + shift with (shift, scale) packed in mod rows.</summary>
    static void Modulate( Tensor x, Tensor mod, int shiftChunk, int scaleChunk )
    {
        var d = x.Shape[1];
        for ( var i = 0; i < d; i++ )
            x.Data[i] = x.Data[i] * (1f + mod.Data[scaleChunk * d + i])
                + mod.Data[shiftChunk * d + i];
    }

    // ---- cosine schedule, respaced (IDDPM space_timesteps semantics) ----

    static double AlphaBar( double t )
    {
        var v = Math.Cos( (t + 0.008) / 1.008 * Math.PI / 2 );
        return v * v;
    }

    /// <summary>Respaced timestep map + per-step ᾱ (index 0 = smallest t).</summary>
    internal static (int[] Map, double[] AlphaCum) BuildSchedule()
    {
        // 1000-step cosine betas → cumulative alphas.
        var alphaCum1000 = new double[TrainSteps];
        double running = 1;
        for ( var i = 0; i < TrainSteps; i++ )
        {
            var beta = Math.Min(
                1 - AlphaBar( (i + 1) / (double)TrainSteps ) / AlphaBar( i / (double)TrainSteps ),
                0.999 );
            running *= 1 - beta;
            alphaCum1000[i] = running;
        }

        // space_timesteps(1000, "50"): one section, frac stride (1000−1)/(50−1).
        var map = new int[SampleSteps];
        var stride = (TrainSteps - 1) / (double)(SampleSteps - 1);
        for ( var i = 0; i < SampleSteps; i++ )
            map[i] = (int)Math.Round( stride * i );

        var alphaCum = new double[SampleSteps];
        for ( var i = 0; i < SampleSteps; i++ )
            alphaCum[i] = alphaCum1000[map[i]];
        return (map, alphaCum);
    }

    (int[] Map, double[] AlphaCum)? _schedule;

    /// <summary>Ancestral sampling of one joint's xyz from cond feature c.</summary>
    public Vector3 Sample( Tensor c, Random rng )
    {
        _schedule ??= BuildSchedule();
        var (map, alphaCum) = _schedule.Value;

        var x = new float[3];
        for ( var i = 0; i < 3; i++ )
            x[i] = Gaussian( rng );

        for ( var step = SampleSteps - 1; step >= 0; step-- )
        {
            var aCum = alphaCum[step];
            var aPrev = step > 0 ? alphaCum[step - 1] : 1.0;
            var beta = 1 - aCum / aPrev;
            var posteriorVar = beta * (1 - aPrev) / (1 - aCum);
            var minLog = Math.Log( Math.Max( posteriorVar, 1e-20 ) );
            var maxLog = Math.Log( beta );
            if ( step == 0 )
                minLog = maxLog;   // posterior variance clipped at t=0 (unused: no noise)

            var output = Net( Tensor.From( x, 1, 3 ), map[step], c );
            for ( var i = 0; i < 3; i++ )
            {
                var eps = output.Data[i];
                var v = output.Data[3 + i];
                var frac = (v + 1) / 2;
                var logVar = frac * maxLog + (1 - frac) * minLog;

                var x0 = (x[i] - Math.Sqrt( 1 - aCum ) * eps) / Math.Sqrt( aCum );
                var c1 = beta * Math.Sqrt( aPrev ) / (1 - aCum);
                var c2 = (1 - aPrev) * Math.Sqrt( aCum / aPrev ) / (1 - aCum);
                var mean = c1 * x0 + c2 * x[i];
                x[i] = (float)(step > 0
                    ? mean + Math.Exp( 0.5 * logVar ) * Gaussian( rng )
                    : mean);
            }
        }
        return new Vector3( x[0], x[1], x[2] );
    }

    static float Gaussian( Random rng )
    {
        var u1 = 1.0 - rng.NextDouble();
        var u2 = rng.NextDouble();
        return (float)(Math.Sqrt( -2.0 * Math.Log( u1 ) ) * Math.Cos( 2.0 * Math.PI * u2 ));
    }
}