AutoRig/Dl/RigNet/JointPredNet.cs

Neural network component that implements RigNet's JointPredNet. It runs three GCU layers, computes a global pooled feature from concatenated GCU outputs via an MLP, broadcasts that global feature to every vertex, concatenates per-vertex inputs and GCU features, passes them through a transform MLP and a final linear layer, and optionally applies tanh to the output.

Native Interop
namespace AutoRig.Dl.RigNet;

/// <summary>
/// RigNet's JointPredNet (see docs/superpowers/rignet-port-notes.md): three GCUs
/// (in→64→256→512), a global feature (concat 832 → MLP 1024 → column max), and the
/// per-vertex transform MLP over concat(global, input, gcu1, gcu2, gcu3) → 1024 →
/// 256 → out (tanh for offsets, raw logit for the attention head).
/// </summary>
public sealed class JointPredNet
{
    public required Gcu Gcu1;
    public required Gcu Gcu2;
    public required Gcu Gcu3;
    public required Mlp GlobalMlp;
    public required Mlp TransformMlp;   // all but the last layer
    public required Linear TransformOut;
    public required bool TanhOutput;

    public Tensor Forward( Tensor x, (int From, int To)[] topologyEdges, (int From, int To)[] geodesicEdges )
    {
        var f1 = Gcu1.Forward( x, topologyEdges, geodesicEdges );
        var f2 = Gcu2.Forward( f1, topologyEdges, geodesicEdges );
        var f3 = Gcu3.Forward( f2, topologyEdges, geodesicEdges );

        var concat = f1.Concat( f2, 1 ).Concat( f3, 1 );        // [N, 832]
        var global = GlobalMlp.Forward( concat );               // [N, 1024]
        var pooled = global.Max( 0 );                           // [1, 1024]

        // Broadcast the pooled global feature to every vertex.
        var n = x.Shape[0];
        var globalWidth = pooled.Shape[1];
        var broadcast = new float[n * globalWidth];
        for ( var r = 0; r < n; r++ )
            Array.Copy( pooled.Data, 0, broadcast, r * globalWidth, globalWidth );
        var globalPerVertex = Tensor.From( broadcast, n, globalWidth );

        var transformIn = globalPerVertex
            .Concat( x, 1 )
            .Concat( f1, 1 )
            .Concat( f2, 1 )
            .Concat( f3, 1 );

        var hidden = TransformMlp.Forward( transformIn );
        var output = TransformOut.Forward( hidden );
        return TanhOutput ? output.Tanh() : output;
    }
}