Neural network modules for skeleton/root/bone prediction. Defines ShapeEncoder that runs three GCUs and a global MLP then max-pools, RootNet that scores candidate joints as skeleton roots using PointNet++ style SA/FP stacks fused with shape code, and BoneNet that scores candidate joint pairs using pair features fused with joint and shape codes.
namespace AutoRig.Dl.RigNet;
/// <summary>
/// Mesh-level encoder shared by ROOTNET and BONENET: three GCUs over positions,
/// a global MLP over their concat, max-pooled to one feature row (batch of 1).
/// </summary>
public sealed class ShapeEncoder
{
public required Gcu Gcu1;
public required Gcu Gcu2;
public required Gcu Gcu3;
public required Mlp GlobalMlp;
public Tensor Forward( Tensor pos, (int From, int To)[] tplEdges, (int From, int To)[] geoEdges )
{
var x1 = Gcu1.Forward( pos, tplEdges, geoEdges );
var x2 = Gcu2.Forward( x1, tplEdges, geoEdges );
var x3 = Gcu3.Forward( x2, tplEdges, geoEdges );
var x4 = GlobalMlp.Forward( x1.Concat( x2, 1 ).Concat( x3, 1 ) );
return PointNet.GlobalMaxPool( x4 );
}
}
/// <summary>
/// ROOTNET (models/ROOT_GCN.py): scores each candidate joint as the skeleton root.
/// Joint encoder is a PointNet++ SA/FP stack over the joints with |x| (distance to
/// the symmetry plane) as the input feature; back layers fuse with the shape code.
/// Run with shuffle=False semantics (mst_generate.getInitId): root = argmax(logits).
/// </summary>
public sealed class RootNet
{
public required ShapeEncoder Shape;
public required SaModule Sa1;
public required SaModule Sa2;
public required GlobalSaModule Sa3;
public required FpModule Fp3;
public required FpModule Fp2;
public required FpModule Fp1;
public required Mlp BackMlp;
public required Linear BackOut;
/// <returns>[J, 1] root logits (sigmoid+argmax picks the root).</returns>
public Tensor Forward(
Tensor pos, (int From, int To)[] tplEdges, (int From, int To)[] geoEdges, Tensor joints )
{
var jointCount = joints.Shape[0];
// |x| of each joint — proximity to the symmetry plane.
var absX = new float[jointCount];
for ( var j = 0; j < jointCount; j++ )
absX[j] = MathF.Abs( joints.Data[j * 3] );
var x0 = Tensor.From( absX, jointCount, 1 );
var (x1, pos1) = Sa1.Forward( x0, joints );
var (x2, pos2) = Sa2.Forward( x1, pos1 );
var (x3, pos3) = Sa3.Forward( x2!, pos2 );
var f3 = Fp3.Forward( x3, pos3, x2, pos2 );
var f2 = Fp2.Forward( f3, pos2, x1, pos1 );
var jointFeature = Fp1.Forward( f2, pos1, x0, joints );
var shapeFeature = Shape.Forward( pos, tplEdges, geoEdges ).RepeatRows( jointCount );
var fused = BackMlp.Forward( shapeFeature.Concat( jointFeature, 1 ) );
return BackOut.Forward( fused );
}
}
/// <summary>
/// BONENET (models/PairCls_GCN.py, PairCls): scores each candidate joint pair for
/// connectivity. Pair feature = MLP over (joint_a, joint_b, dist, inside-proportion);
/// fused with the global shape and joint-cloud codes. permute_joints=False semantics.
/// </summary>
public sealed class BoneNet
{
public required ShapeEncoder Shape;
public required SaModule Sa1;
public required SaModule Sa2;
public required GlobalSaModule Sa3;
public required Mlp ExpandPair;
public required Mlp MixMlp;
public required Linear MixOut;
/// <param name="pairs">Candidate joint index pairs.</param>
/// <param name="pairAttr">[P, 2] rows of (distance, inside-proportion) — the
/// original's pair_attr[:, :-1] (the constant 1 column is dropped before use).</param>
/// <returns>[P, 1] connectivity logits (sigmoid gives probabilities).</returns>
public Tensor Forward(
Tensor pos, (int From, int To)[] tplEdges, (int From, int To)[] geoEdges,
Tensor joints, (int A, int B)[] pairs, Tensor pairAttr )
{
var (x1, pos1) = Sa1.Forward( null, joints );
var (x2, pos2) = Sa2.Forward( x1, pos1 );
var (jointGlobal, _) = Sa3.Forward( x2!, pos2 );
var shapeFeature = Shape.Forward( pos, tplEdges, geoEdges );
var pairInput = new float[pairs.Length * 8];
for ( var p = 0; p < pairs.Length; p++ )
{
var baseIndex = p * 8;
for ( var c = 0; c < 3; c++ )
{
pairInput[baseIndex + c] = joints.Data[pairs[p].A * 3 + c];
pairInput[baseIndex + 3 + c] = joints.Data[pairs[p].B * 3 + c];
}
pairInput[baseIndex + 6] = pairAttr.Data[p * 2];
pairInput[baseIndex + 7] = pairAttr.Data[p * 2 + 1];
}
var pairFeature = ExpandPair.Forward( Tensor.From( pairInput, pairs.Length, 8 ) );
var fused = shapeFeature.RepeatRows( pairs.Length )
.Concat( jointGlobal.RepeatRows( pairs.Length ), 1 )
.Concat( pairFeature, 1 );
return MixOut.Forward( MixMlp.Forward( fused ) );
}
}