Neural network component for RigNet joint prediction. Runs three GCU layers, computes a global pooled feature via GlobalMlp, broadcasts it to vertices, concatenates per-vertex and global features, applies a per-vertex MLP and a final linear output, optionally applying tanh.
namespace AutoRig.Dl.RigNet;
/// <summary>
/// RigNet's JointPredNet (see docs/superpowers/rignet-port-notes.md): three GCUs
/// (in→64→256→512), a global feature (concat 832 → MLP 1024 → column max), and the
/// per-vertex transform MLP over concat(global, input, gcu1, gcu2, gcu3) → 1024 →
/// 256 → out (tanh for offsets, raw logit for the attention head).
/// </summary>
public sealed class JointPredNet
{
public required Gcu Gcu1;
public required Gcu Gcu2;
public required Gcu Gcu3;
public required Mlp GlobalMlp;
public required Mlp TransformMlp; // all but the last layer
public required Linear TransformOut;
public required bool TanhOutput;
public Tensor Forward( Tensor x, (int From, int To)[] topologyEdges, (int From, int To)[] geodesicEdges )
{
var f1 = Gcu1.Forward( x, topologyEdges, geodesicEdges );
var f2 = Gcu2.Forward( f1, topologyEdges, geodesicEdges );
var f3 = Gcu3.Forward( f2, topologyEdges, geodesicEdges );
var concat = f1.Concat( f2, 1 ).Concat( f3, 1 ); // [N, 832]
var global = GlobalMlp.Forward( concat ); // [N, 1024]
var pooled = global.Max( 0 ); // [1, 1024]
// Broadcast the pooled global feature to every vertex.
var n = x.Shape[0];
var globalWidth = pooled.Shape[1];
var broadcast = new float[n * globalWidth];
for ( var r = 0; r < n; r++ )
Array.Copy( pooled.Data, 0, broadcast, r * globalWidth, globalWidth );
var globalPerVertex = Tensor.From( broadcast, n, globalWidth );
var transformIn = globalPerVertex
.Concat( x, 1 )
.Concat( f1, 1 )
.Concat( f2, 1 )
.Concat( f3, 1 );
var hidden = TransformMlp.Forward( transformIn );
var output = TransformOut.Forward( hidden );
return TanhOutput ? output.Tanh() : output;
}
}