AutoRig/Rig/SkeletonCleanup.cs

Post-rig cleanup utility for skeletons and skin weights. It takes a mesh and a RigResult, fixes invalid parent indices, cuts cycles, collapses multiple roots into one, merges zero-length bones, reindexes joints parent-before-child, sanitizes and deduplicates names, and remaps skin weights to the cleaned skeleton. Returns a validated RigResult and never throws on malformed input.

File AccessNative Interop
using AutoRig.Mesh;

namespace AutoRig.Rig;

// s&box compat: Vector3 lives in the global namespace and shadows imports.
using Vector3 = System.Numerics.Vector3;

/// <summary>
/// Post-rig cleanup: REPAIRS any skeleton + skin into a valid, tidy rig so the
/// exported model imports cleanly everywhere. Foolproof by design - it never
/// throws and is idempotent. Feed it a broken armature (out-of-range or cyclic
/// parents, several roots, zero-length or duplicate bones, NaN positions, junk
/// weights) and it returns one that passes RigSkeleton.Validate() AND
/// SkinWeights.Validate(). Works on ANY RigResult regardless of which model (or
/// donor) produced it.
///
/// What it fixes:
///  - parent indices: out-of-range / self / forward refs → repaired; cycles cut.
///  - roots: collapses to exactly one (extra roots reparent under the biggest).
///  - zero-length bones: a joint sitting on its parent is merged into it.
///  - ordering: reindexed strictly parent-before-child.
///  - names: sanitized and de-duplicated.
///  - weights: remapped through the merge, non-finite/negative dropped, kept to
///    the top 4 influences per vertex, renormalized; empty rows bind to the root.
/// </summary>
public static class SkeletonCleanup
{
    public static RigResult Clean( RigMesh mesh, RigResult rig )
    {
        ArgumentNullException.ThrowIfNull( mesh );
        ArgumentNullException.ThrowIfNull( rig );

        var vertexCount = mesh.Positions.Length;
        var src = rig.Skeleton?.Joints;

        // Degenerate input: no joints at all → one root at the mesh centre, all
        // vertices rigidly bound to it. Nothing downstream can choke on that.
        if ( src is null || src.Count == 0 )
            return SingleRootFallback( mesh, rig );

        var n = src.Count;
        var pos = new Vector3[n];
        var parent = new int[n];
        var name = new string[n];
        var hinge = new Vector3[n];
        for ( var i = 0; i < n; i++ )
        {
            pos[i] = Finite( src[i].Position );
            parent[i] = src[i].Parent;
            name[i] = string.IsNullOrWhiteSpace( src[i].Name ) ? $"bone_{i}" : src[i].Name;
            hinge[i] = src[i].HingeAxis;
        }

        // 1. Sanitize parent references (range, self-parent).
        for ( var i = 0; i < n; i++ )
            if ( parent[i] == i || parent[i] < -1 || parent[i] >= n )
                parent[i] = -1;

        // 2. Cut cycles: any joint whose parent chain does not terminate within n
        //    hops is inside a cycle - detach it to a root.
        for ( var i = 0; i < n; i++ )
        {
            var cur = parent[i];
            var steps = 0;
            while ( cur != -1 && steps <= n )
            {
                cur = parent[cur];
                steps++;
            }
            if ( steps > n )
                parent[i] = -1;
        }

        // 3. Exactly one root. Pick the root with the largest subtree; reparent
        //    the others under it (keeps everything in one connected skeleton).
        var roots = new List<int>();
        for ( var i = 0; i < n; i++ )
            if ( parent[i] == -1 )
                roots.Add( i );
        if ( roots.Count == 0 )        // impossible after step 2, but stay safe
        {
            parent[0] = -1;
            roots.Add( 0 );
        }
        var primary = roots[0];
        if ( roots.Count > 1 )
        {
            var best = -1;
            foreach ( var r in roots )
            {
                var size = SubtreeSize( r, parent, n );
                if ( size > best ) { best = size; primary = r; }
            }
            foreach ( var r in roots )
                if ( r != primary )
                    parent[r] = primary;
        }

        // 4. Merge zero-length bones (a joint coincident with its parent). Walk
        //    parents-before-children so a merged joint's children resolve upward.
        var order = TopoOrder( primary, parent, n );
        var mergedInto = new int[n];
        Array.Fill( mergedInto, -1 );
        var eps = MathF.Max( Diagonal( pos ) * 1e-3f, 1e-6f );
        foreach ( var i in order )
        {
            if ( i == primary )
                continue;
            var p = Resolve( parent[i], mergedInto );
            if ( Vector3.Distance( pos[i], pos[p] ) < eps )
                mergedInto[i] = p;
        }

        // 5. Reindex survivors (still parents-before-children thanks to `order`).
        var newIndex = new int[n];
        Array.Fill( newIndex, -1 );
        var survivors = new List<int>();
        foreach ( var i in order )
            if ( mergedInto[i] == -1 )
            {
                newIndex[i] = survivors.Count;
                survivors.Add( i );
            }

        var usedNames = new HashSet<string>( StringComparer.Ordinal );
        var skeleton = new RigSkeleton();
        foreach ( var s in survivors )
        {
            skeleton.Joints.Add( new RigJoint
            {
                Name = UniqueName( name[s], usedNames ),
                Parent = s == primary ? -1 : newIndex[Resolve( parent[s], mergedInto )],
                Position = pos[s],
                HingeAxis = hinge[s],
            } );
        }

        var weights = RemapWeights( rig.Weights, vertexCount, n, mergedInto, newIndex, newIndex[primary] );

        skeleton.Validate();
        weights.Validate( mesh, skeleton );

        return new RigResult
        {
            Skeleton = skeleton,
            Weights = weights,
            SolverName = rig.SolverName,
            Degraded = rig.Degraded,
            Explanation = Append( rig.Explanation, n, skeleton.Joints.Count ),
        };
    }

    // ------------------------------------------------------------------ helpers

    static RigResult SingleRootFallback( RigMesh mesh, RigResult rig )
    {
        var center = mesh.ComputeBounds().Center;
        var skeleton = new RigSkeleton();
        skeleton.Joints.Add( new RigJoint { Name = "root", Parent = -1, Position = center } );
        var count = mesh.Positions.Length;
        var indices = new int[count * 4];
        var weights = new float[count * 4];
        for ( var v = 0; v < count; v++ )
            weights[v * 4] = 1f;   // all influence on the single root
        var skin = new SkinWeights { BoneIndices = indices, Weights = weights };
        skeleton.Validate();
        skin.Validate( mesh, skeleton );
        return new RigResult
        {
            Skeleton = skeleton,
            Weights = skin,
            SolverName = rig.SolverName,
            Degraded = true,
            Explanation = Append( rig.Explanation, rig.Skeleton?.Joints.Count ?? 0, 1 ),
        };
    }

    static SkinWeights RemapWeights(
        SkinWeights src, int vertexCount, int oldJointCount,
        int[] mergedInto, int[] newIndex, int rootNew )
    {
        var indices = new int[vertexCount * 4];
        var weights = new float[vertexCount * 4];
        var hasSrc = src?.BoneIndices is { } bi && src.Weights is { } wv
            && bi.Length == vertexCount * 4 && wv.Length == vertexCount * 4;

        for ( var v = 0; v < vertexCount; v++ )
        {
            var acc = new Dictionary<int, float>();
            if ( hasSrc )
            {
                for ( var k = 0; k < 4; k++ )
                {
                    var w = src.Weights[v * 4 + k];
                    if ( !float.IsFinite( w ) || w <= 0f )
                        continue;
                    var ob = src.BoneIndices[v * 4 + k];
                    var nb = ob < 0 || ob >= oldJointCount
                        ? rootNew
                        : newIndex[Resolve( ob, mergedInto )];
                    acc.TryGetValue( nb, out var cur );
                    acc[nb] = cur + w;
                }
            }
            if ( acc.Count == 0 )
                acc[rootNew] = 1f;

            // Top 4 influences, renormalized to sum 1.
            var top = new List<KeyValuePair<int, float>>( acc );
            top.Sort( ( a, b ) => b.Value.CompareTo( a.Value ) );
            var total = 0f;
            var take = Math.Min( 4, top.Count );
            for ( var k = 0; k < take; k++ )
                total += top[k].Value;
            if ( total <= 0f ) { indices[v * 4] = rootNew; weights[v * 4] = 1f; continue; }
            for ( var k = 0; k < take; k++ )
            {
                indices[v * 4 + k] = top[k].Key;
                weights[v * 4 + k] = top[k].Value / total;
            }
        }
        return new SkinWeights { BoneIndices = indices, Weights = weights };
    }

    static int Resolve( int i, int[] mergedInto )
    {
        if ( i < 0 )
            return i;
        var guard = 0;
        while ( mergedInto[i] != -1 && guard++ < mergedInto.Length )
            i = mergedInto[i];
        return i;
    }

    static int SubtreeSize( int root, int[] parent, int n )
    {
        // Count nodes whose parent chain reaches `root` (bounded to avoid cycles).
        var count = 0;
        for ( var i = 0; i < n; i++ )
        {
            var cur = i;
            var steps = 0;
            while ( cur != -1 && steps++ <= n )
            {
                if ( cur == root ) { count++; break; }
                cur = parent[cur];
            }
        }
        return count;
    }

    static List<int> TopoOrder( int root, int[] parent, int n )
    {
        var children = new List<int>[n];
        for ( var i = 0; i < n; i++ )
            children[i] = new List<int>();
        for ( var i = 0; i < n; i++ )
            if ( parent[i] != -1 && parent[i] != i )
                children[parent[i]].Add( i );

        var order = new List<int>();
        var queue = new Queue<int>();
        queue.Enqueue( root );
        var seen = new bool[n];
        seen[root] = true;
        while ( queue.Count > 0 )
        {
            var j = queue.Dequeue();
            order.Add( j );
            foreach ( var c in children[j] )
                if ( !seen[c] )
                {
                    seen[c] = true;
                    queue.Enqueue( c );
                }
        }
        // Any node not reached (defensive) is appended as its own tail.
        for ( var i = 0; i < n; i++ )
            if ( !seen[i] )
                order.Add( i );
        return order;
    }

    static float Diagonal( Vector3[] pos )
    {
        if ( pos.Length == 0 )
            return 0f;
        var min = pos[0];
        var max = pos[0];
        foreach ( var p in pos )
        {
            min = Vector3.Min( min, p );
            max = Vector3.Max( max, p );
        }
        return Vector3.Distance( min, max );
    }

    static Vector3 Finite( Vector3 v )
        => new(
            float.IsFinite( v.X ) ? v.X : 0f,
            float.IsFinite( v.Y ) ? v.Y : 0f,
            float.IsFinite( v.Z ) ? v.Z : 0f );

    static string UniqueName( string raw, HashSet<string> used )
    {
        var baseName = string.IsNullOrWhiteSpace( raw ) ? "bone" : raw.Trim();
        var candidate = baseName;
        var suffix = 1;
        while ( !used.Add( candidate ) )
            candidate = $"{baseName}_{suffix++}";
        return candidate;
    }

    static string Append( string explanation, int before, int after )
    {
        var note = before == after
            ? $"Cleanup: {after} joints validated."
            : $"Cleanup: {before} → {after} joints (merged/repaired).";
        return string.IsNullOrEmpty( explanation ) ? note : $"{explanation} {note}";
    }
}