Utility class that builds and processes a symmetric skeleton graph for a RigNet-style rig. It provides functions to mirror joint sets across the X plane, sample points along bones, adjust pairwise costs based on voxel occupancy and middle-joint handling, run a symmetry-aware Prim MST, and build a cost matrix from pair probabilities.
namespace AutoRig.Dl.RigNet;
using Vector3 = System.Numerics.Vector3;
/// <summary>
/// RigNet skeleton assembly, transcribed from utils/mst_utils.py: x-symmetrization
/// of the joint set (flip), pairwise cost shaping, and the symmetry-aware Prim MST.
/// Costs are double precision to match the numpy reference exactly.
/// </summary>
public static class MstSkeleton
{
/// <summary>Joints within this distance of the YZ plane count as "middle".</summary>
public const float SymmetryTolerance = 2e-2f;
/// <summary>
/// flip(): keeps left (x < -tol) and middle joints (x snapped to 0), replaces
/// the right side with a mirror of the left. Returns joints ordered
/// left / middle / right so index i-th left mirrors the i-th right.
/// </summary>
public static Vector3[] Flip( Vector3[] joints )
{
var left = joints.Where( j => j.X < -SymmetryTolerance ).ToArray();
var middle = joints.Where( j => MathF.Abs( j.X ) <= SymmetryTolerance )
.Select( j => new Vector3( 0f, j.Y, j.Z ) ).ToArray();
var right = left.Select( j => new Vector3( -j.X, j.Y, j.Z ) ).ToArray();
return left.Concat( middle ).Concat( right ).ToArray();
}
/// <summary>Points every 0.01 along the segment, endpoints excluded (sample_on_bone).</summary>
public static Vector3[] SampleOnBone( Vector3 parent, Vector3 child )
{
var ray = child - parent;
var steps = (int)MathF.Round( ray.Length() / 0.01f );
var samples = new Vector3[Math.Max( steps, 0 )];
for ( var i = 1; i <= steps; i++ )
samples[i - 1] = parent + ray / (steps + 1e-30f) * i;
return samples;
}
/// <summary>
/// increase_cost_for_outside_bone: overwrite the cost of pairs whose connecting
/// segment leaves the mesh (2 × outside sample count), then halve costs between
/// two middle joints. <paramref name="isInside"/> is the voxel-occupancy test.
/// </summary>
public static void IncreaseCostForOutsideBone(
double[,] cost, Vector3[] joints, Func<Vector3, bool> isInside )
{
for ( var i = 0; i < joints.Length; i++ )
for ( var j = i + 1; j < joints.Length; j++ )
{
var outside = SampleOnBone( joints[i], joints[j] ).Count( s => !isInside( s ) );
if ( outside > 1 )
{
cost[i, j] = 2.0 * outside;
cost[j, i] = 2.0 * outside;
}
if ( MathF.Abs( joints[i].X ) < SymmetryTolerance
&& MathF.Abs( joints[j].X ) < SymmetryTolerance )
{
cost[i, j] *= 0.5;
cost[j, i] *= 0.5;
}
}
}
/// <summary>
/// primMST_symmetry: Prim's MST over the cost matrix, mirroring every left/right
/// attachment to the opposite side as it happens (heuristic, like the original).
/// The root moves to the nearest middle joint if needed.
/// </summary>
/// <returns>Parent index per joint (-1 for the root) and the final root id.</returns>
public static (int[] Parents, int RootId) PrimMstSymmetry(
double[,] cost, int initId, Vector3[] joints )
{
var n = joints.Length;
var left = Enumerable.Range( 0, n ).Where( i => joints[i].X < -SymmetryTolerance ).ToArray();
var middle = Enumerable.Range( 0, n )
.Where( i => MathF.Abs( joints[i].X ) <= SymmetryTolerance ).ToArray();
var right = Enumerable.Range( 0, n ).Where( i => joints[i].X > SymmetryTolerance ).ToArray();
// i-th left pairs with i-th right (flip() guarantees matching order/count).
var mirror = new Dictionary<int, int>();
for ( var i = 0; i < Math.Min( left.Length, right.Length ); i++ )
{
mirror[left[i]] = right[i];
mirror[right[i]] = left[i];
}
if ( !middle.Contains( initId ) && middle.Length > 0 )
{
var nearest = middle[0];
foreach ( var m in middle )
if ( Vector3.Distance( joints[m], joints[initId] )
< Vector3.Distance( joints[nearest], joints[initId] ) )
nearest = m;
initId = nearest;
}
var key = new double[n];
Array.Fill( key, double.MaxValue );
var parent = new int[n];
Array.Fill( parent, int.MinValue ); // "None" — distinct from -1 (root)
var inTree = new bool[n];
key[initId] = 0;
parent[initId] = -1;
int MinKey()
{
var best = -1;
var bestKey = double.MaxValue;
for ( var v = 0; v < n; v++ )
if ( !inTree[v] && key[v] < bestKey )
{
bestKey = key[v];
best = v;
}
return best;
}
while ( inTree.Contains( false ) )
{
var u = MinKey();
if ( u < 0 )
{
// Unreachable joint (all its costs ≤ 0) — the reference would crash
// here; attach it to the root so the skeleton stays a tree.
u = Array.IndexOf( inTree, false );
key[u] = 0;
parent[u] = initId;
}
var isLeft = Array.IndexOf( left, u ) >= 0;
var isRight = Array.IndexOf( right, u ) >= 0;
var u2 = -1;
if ( (isLeft || isRight) && parent[u] >= 0 && mirror.TryGetValue( u, out var mirrored )
&& !inTree[mirrored] )
{
// Mirror the attachment: parent maps through the mirror unless it
// is a middle joint (its own mirror image).
u2 = mirrored;
inTree[u2] = true;
parent[u2] = mirror.TryGetValue( parent[u], out var mirroredParent )
? mirroredParent
: parent[u];
key[u2] = cost[u2, parent[u2]];
}
inTree[u] = true;
for ( var v = 0; v < n; v++ )
{
if ( cost[u, v] > 0 && !inTree[v] && key[v] > cost[u, v] )
{
key[v] = cost[u, v];
parent[v] = u;
}
if ( u2 >= 0 && cost[u2, v] > 0 && !inTree[v] && key[v] > cost[u2, v] )
{
key[v] = cost[u2, v];
parent[v] = u2;
}
}
}
return (parent, initId);
}
/// <summary>
/// predict_skeleton's cost construction: symmetric pair probability matrix →
/// -log(p + 1e-10) connectivity cost.
/// </summary>
public static double[,] BuildCostMatrix( int jointCount, (int A, int B)[] pairs, float[] probabilities )
{
var prob = new double[jointCount, jointCount];
for ( var p = 0; p < pairs.Length; p++ )
{
prob[pairs[p].A, pairs[p].B] = probabilities[p];
prob[pairs[p].B, pairs[p].A] = probabilities[p];
}
var cost = new double[jointCount, jointCount];
for ( var i = 0; i < jointCount; i++ )
for ( var j = 0; j < jointCount; j++ )
cost[i, j] = -Math.Log( prob[i, j] + 1e-10 );
return cost;
}
}