AutoRig/Dl/RigNet/RootNet.cs

Neural network model code for rigging: ShapeEncoder encodes mesh geometry via three GCUs and a global MLP with max-pooling. RootNet scores candidate joints as skeleton roots by running PointNet++ style SA/FP layers fused with the shape code. BoneNet scores candidate joint pairs for connectivity by building pair features, fusing joint/global/shape codes and producing logits.

NetworkingFile Access
namespace AutoRig.Dl.RigNet;

/// <summary>
/// Mesh-level encoder shared by ROOTNET and BONENET: three GCUs over positions,
/// a global MLP over their concat, max-pooled to one feature row (batch of 1).
/// </summary>
public sealed class ShapeEncoder
{
    public required Gcu Gcu1;
    public required Gcu Gcu2;
    public required Gcu Gcu3;
    public required Mlp GlobalMlp;

    public Tensor Forward( Tensor pos, (int From, int To)[] tplEdges, (int From, int To)[] geoEdges )
    {
        var x1 = Gcu1.Forward( pos, tplEdges, geoEdges );
        var x2 = Gcu2.Forward( x1, tplEdges, geoEdges );
        var x3 = Gcu3.Forward( x2, tplEdges, geoEdges );
        var x4 = GlobalMlp.Forward( x1.Concat( x2, 1 ).Concat( x3, 1 ) );
        return PointNet.GlobalMaxPool( x4 );
    }
}

/// <summary>
/// ROOTNET (models/ROOT_GCN.py): scores each candidate joint as the skeleton root.
/// Joint encoder is a PointNet++ SA/FP stack over the joints with |x| (distance to
/// the symmetry plane) as the input feature; back layers fuse with the shape code.
/// Run with shuffle=False semantics (mst_generate.getInitId): root = argmax(logits).
/// </summary>
public sealed class RootNet
{
    public required ShapeEncoder Shape;
    public required SaModule Sa1;
    public required SaModule Sa2;
    public required GlobalSaModule Sa3;
    public required FpModule Fp3;
    public required FpModule Fp2;
    public required FpModule Fp1;
    public required Mlp BackMlp;
    public required Linear BackOut;

    /// <returns>[J, 1] root logits (sigmoid+argmax picks the root).</returns>
    public Tensor Forward(
        Tensor pos, (int From, int To)[] tplEdges, (int From, int To)[] geoEdges, Tensor joints )
    {
        var jointCount = joints.Shape[0];

        // |x| of each joint — proximity to the symmetry plane.
        var absX = new float[jointCount];
        for ( var j = 0; j < jointCount; j++ )
            absX[j] = MathF.Abs( joints.Data[j * 3] );
        var x0 = Tensor.From( absX, jointCount, 1 );

        var (x1, pos1) = Sa1.Forward( x0, joints );
        var (x2, pos2) = Sa2.Forward( x1, pos1 );
        var (x3, pos3) = Sa3.Forward( x2!, pos2 );
        var f3 = Fp3.Forward( x3, pos3, x2, pos2 );
        var f2 = Fp2.Forward( f3, pos2, x1, pos1 );
        var jointFeature = Fp1.Forward( f2, pos1, x0, joints );

        var shapeFeature = Shape.Forward( pos, tplEdges, geoEdges ).RepeatRows( jointCount );
        var fused = BackMlp.Forward( shapeFeature.Concat( jointFeature, 1 ) );
        return BackOut.Forward( fused );
    }
}

/// <summary>
/// BONENET (models/PairCls_GCN.py, PairCls): scores each candidate joint pair for
/// connectivity. Pair feature = MLP over (joint_a, joint_b, dist, inside-proportion);
/// fused with the global shape and joint-cloud codes. permute_joints=False semantics.
/// </summary>
public sealed class BoneNet
{
    public required ShapeEncoder Shape;
    public required SaModule Sa1;
    public required SaModule Sa2;
    public required GlobalSaModule Sa3;
    public required Mlp ExpandPair;
    public required Mlp MixMlp;
    public required Linear MixOut;

    /// <param name="pairs">Candidate joint index pairs.</param>
    /// <param name="pairAttr">[P, 2] rows of (distance, inside-proportion) — the
    /// original's pair_attr[:, :-1] (the constant 1 column is dropped before use).</param>
    /// <returns>[P, 1] connectivity logits (sigmoid gives probabilities).</returns>
    public Tensor Forward(
        Tensor pos, (int From, int To)[] tplEdges, (int From, int To)[] geoEdges,
        Tensor joints, (int A, int B)[] pairs, Tensor pairAttr )
    {
        var (x1, pos1) = Sa1.Forward( null, joints );
        var (x2, pos2) = Sa2.Forward( x1, pos1 );
        var (jointGlobal, _) = Sa3.Forward( x2!, pos2 );

        var shapeFeature = Shape.Forward( pos, tplEdges, geoEdges );

        var pairInput = new float[pairs.Length * 8];
        for ( var p = 0; p < pairs.Length; p++ )
        {
            var baseIndex = p * 8;
            for ( var c = 0; c < 3; c++ )
            {
                pairInput[baseIndex + c] = joints.Data[pairs[p].A * 3 + c];
                pairInput[baseIndex + 3 + c] = joints.Data[pairs[p].B * 3 + c];
            }
            pairInput[baseIndex + 6] = pairAttr.Data[p * 2];
            pairInput[baseIndex + 7] = pairAttr.Data[p * 2 + 1];
        }
        var pairFeature = ExpandPair.Forward( Tensor.From( pairInput, pairs.Length, 8 ) );

        var fused = shapeFeature.RepeatRows( pairs.Length )
            .Concat( jointGlobal.RepeatRows( pairs.Length ), 1 )
            .Concat( pairFeature, 1 );
        return MixOut.Forward( MixMlp.Forward( fused ) );
    }
}