Code/AutoRig/Dl/RigNet/MstSkeleton.cs

Utility class for building and processing a symmetric skeleton graph for RigNet. It provides functions to mirror joints across the YZ plane, sample points along bones, modify pairwise costs based on voxel occupancy and middle-joint weighting, construct a cost matrix from pair probabilities, and run a symmetry-aware Prim minimum spanning tree to produce parent indices and a root.

Native Interop
namespace AutoRig.Dl.RigNet;

using Vector3 = System.Numerics.Vector3;

/// <summary>
/// RigNet skeleton assembly, transcribed from utils/mst_utils.py: x-symmetrization
/// of the joint set (flip), pairwise cost shaping, and the symmetry-aware Prim MST.
/// Costs are double precision to match the numpy reference exactly.
/// </summary>
public static class MstSkeleton
{
    /// <summary>Joints within this distance of the YZ plane count as "middle".</summary>
    public const float SymmetryTolerance = 2e-2f;

    /// <summary>
    /// flip(): keeps left (x &lt; -tol) and middle joints (x snapped to 0), replaces
    /// the right side with a mirror of the left. Returns joints ordered
    /// left / middle / right so index i-th left mirrors the i-th right.
    /// </summary>
    public static Vector3[] Flip( Vector3[] joints )
    {
        var left = joints.Where( j => j.X < -SymmetryTolerance ).ToArray();
        var middle = joints.Where( j => MathF.Abs( j.X ) <= SymmetryTolerance )
            .Select( j => new Vector3( 0f, j.Y, j.Z ) ).ToArray();
        var right = left.Select( j => new Vector3( -j.X, j.Y, j.Z ) ).ToArray();
        return left.Concat( middle ).Concat( right ).ToArray();
    }

    /// <summary>Points every 0.01 along the segment, endpoints excluded (sample_on_bone).</summary>
    public static Vector3[] SampleOnBone( Vector3 parent, Vector3 child )
    {
        var ray = child - parent;
        var steps = (int)MathF.Round( ray.Length() / 0.01f );
        var samples = new Vector3[Math.Max( steps, 0 )];
        for ( var i = 1; i <= steps; i++ )
            samples[i - 1] = parent + ray / (steps + 1e-30f) * i;
        return samples;
    }

    /// <summary>
    /// increase_cost_for_outside_bone: overwrite the cost of pairs whose connecting
    /// segment leaves the mesh (2 × outside sample count), then halve costs between
    /// two middle joints. <paramref name="isInside"/> is the voxel-occupancy test.
    /// </summary>
    public static void IncreaseCostForOutsideBone(
        double[,] cost, Vector3[] joints, Func<Vector3, bool> isInside )
    {
        for ( var i = 0; i < joints.Length; i++ )
            for ( var j = i + 1; j < joints.Length; j++ )
            {
                var outside = SampleOnBone( joints[i], joints[j] ).Count( s => !isInside( s ) );
                if ( outside > 1 )
                {
                    cost[i, j] = 2.0 * outside;
                    cost[j, i] = 2.0 * outside;
                }
                if ( MathF.Abs( joints[i].X ) < SymmetryTolerance
                    && MathF.Abs( joints[j].X ) < SymmetryTolerance )
                {
                    cost[i, j] *= 0.5;
                    cost[j, i] *= 0.5;
                }
            }
    }

    /// <summary>
    /// primMST_symmetry: Prim's MST over the cost matrix, mirroring every left/right
    /// attachment to the opposite side as it happens (heuristic, like the original).
    /// The root moves to the nearest middle joint if needed.
    /// </summary>
    /// <returns>Parent index per joint (-1 for the root) and the final root id.</returns>
    public static (int[] Parents, int RootId) PrimMstSymmetry(
        double[,] cost, int initId, Vector3[] joints )
    {
        var n = joints.Length;
        var left = Enumerable.Range( 0, n ).Where( i => joints[i].X < -SymmetryTolerance ).ToArray();
        var middle = Enumerable.Range( 0, n )
            .Where( i => MathF.Abs( joints[i].X ) <= SymmetryTolerance ).ToArray();
        var right = Enumerable.Range( 0, n ).Where( i => joints[i].X > SymmetryTolerance ).ToArray();

        // i-th left pairs with i-th right (flip() guarantees matching order/count).
        var mirror = new Dictionary<int, int>();
        for ( var i = 0; i < Math.Min( left.Length, right.Length ); i++ )
        {
            mirror[left[i]] = right[i];
            mirror[right[i]] = left[i];
        }

        if ( !middle.Contains( initId ) && middle.Length > 0 )
        {
            var nearest = middle[0];
            foreach ( var m in middle )
                if ( Vector3.Distance( joints[m], joints[initId] )
                    < Vector3.Distance( joints[nearest], joints[initId] ) )
                    nearest = m;
            initId = nearest;
        }

        var key = new double[n];
        Array.Fill( key, double.MaxValue );
        var parent = new int[n];
        Array.Fill( parent, int.MinValue );   // "None" — distinct from -1 (root)
        var inTree = new bool[n];
        key[initId] = 0;
        parent[initId] = -1;

        int MinKey()
        {
            var best = -1;
            var bestKey = double.MaxValue;
            for ( var v = 0; v < n; v++ )
                if ( !inTree[v] && key[v] < bestKey )
                {
                    bestKey = key[v];
                    best = v;
                }
            return best;
        }

        while ( inTree.Contains( false ) )
        {
            var u = MinKey();
            if ( u < 0 )
            {
                // Unreachable joint (all its costs ≤ 0) — the reference would crash
                // here; attach it to the root so the skeleton stays a tree.
                u = Array.IndexOf( inTree, false );
                key[u] = 0;
                parent[u] = initId;
            }
            var isLeft = Array.IndexOf( left, u ) >= 0;
            var isRight = Array.IndexOf( right, u ) >= 0;
            var u2 = -1;
            if ( (isLeft || isRight) && parent[u] >= 0 && mirror.TryGetValue( u, out var mirrored )
                && !inTree[mirrored] )
            {
                // Mirror the attachment: parent maps through the mirror unless it
                // is a middle joint (its own mirror image).
                u2 = mirrored;
                inTree[u2] = true;
                parent[u2] = mirror.TryGetValue( parent[u], out var mirroredParent )
                    ? mirroredParent
                    : parent[u];
                key[u2] = cost[u2, parent[u2]];
            }

            inTree[u] = true;
            for ( var v = 0; v < n; v++ )
            {
                if ( cost[u, v] > 0 && !inTree[v] && key[v] > cost[u, v] )
                {
                    key[v] = cost[u, v];
                    parent[v] = u;
                }
                if ( u2 >= 0 && cost[u2, v] > 0 && !inTree[v] && key[v] > cost[u2, v] )
                {
                    key[v] = cost[u2, v];
                    parent[v] = u2;
                }
            }
        }
        return (parent, initId);
    }

    /// <summary>
    /// predict_skeleton's cost construction: symmetric pair probability matrix →
    /// -log(p + 1e-10) connectivity cost.
    /// </summary>
    public static double[,] BuildCostMatrix( int jointCount, (int A, int B)[] pairs, float[] probabilities )
    {
        var prob = new double[jointCount, jointCount];
        for ( var p = 0; p < pairs.Length; p++ )
        {
            prob[pairs[p].A, pairs[p].B] = probabilities[p];
            prob[pairs[p].B, pairs[p].A] = probabilities[p];
        }
        var cost = new double[jointCount, jointCount];
        for ( var i = 0; i < jointCount; i++ )
            for ( var j = 0; j < jointCount; j++ )
                cost[i, j] = -Math.Log( prob[i, j] + 1e-10 );
        return cost;
    }
}