AutoRig/Dl/MagicArticulate/MagicArticulateSkeleton.cs

Utility for processing MagicArticulate skeleton data. Converts tokenized coordinate bins into directed bones and unique joint positions, merges near-duplicate joints, ensures connectivity, and compacts the result into a root, joint array, and bone index pairs.

Native Interop
namespace AutoRig.Dl.MagicArticulate;

using Vector3 = System.Numerics.Vector3;

/// <summary>
/// MagicArticulate's detokenize + post-processing (skeletongen.detokenize +
/// save_utils.pred_joints_and_bones / merge_duplicate_joints_and_fix_bones /
/// ensure_skeleton_connectivity, hier order): coordinate bins → bones →
/// unique joints (lexicographic, like np.unique) → tolerance merge with the
/// root preserved → connectivity repair. Bones are DIRECTED (parent, child);
/// root = first bone's parent.
/// </summary>
public static class MagicArticulateSkeleton
{
    public const float MergeTolerance = 0.0025f;

    public sealed class Decoded
    {
        public required Vector3[] Joints;
        public required (int Parent, int Child)[] Bones;
        public required int Root;
    }

    /// <summary>ids = generated stream with BOS/EOS already stripped;
    /// specials are dropped, the rest are bins (id − 3), groups of 6 = a bone.</summary>
    public static Decoded Detokenize( IReadOnlyList<int> ids )
    {
        var bins = ids.Where( id => id >= 3 ).Select( id => id - 3 ).ToList();
        var boneCount = bins.Count / 6;   // ref returns None on ragged length; we truncate
        if ( boneCount == 0 )
            throw new FormatException( "MagicArticulate produced no complete bones." );

        var pairs = new List<(Vector3 A, Vector3 B)>();
        for ( var b = 0; b < boneCount; b++ )
            pairs.Add( (Point( bins, b * 6 ), Point( bins, b * 6 + 3 )) );
        return FromBonePairs( pairs );
    }

    /// <summary>Shared post-processing from raw (parent, child) coordinate pairs
    /// (Puppeteer's joint-token mode feeds this directly).</summary>
    public static Decoded FromBonePairs( IReadOnlyList<(Vector3 A, Vector3 B)> pairs )
    {
        var boneCount = pairs.Count;

        // pred_joints_and_bones: np.unique(axis=0) = LEXICOGRAPHIC row sort with
        // inverse indices; self-loop bones dropped.
        var all = new List<Vector3>();
        for ( var b = 0; b < boneCount; b++ )
            all.Add( pairs[b].A );
        for ( var b = 0; b < boneCount; b++ )
            all.Add( pairs[b].B );
        var unique = all.Distinct().OrderBy( p => p.X ).ThenBy( p => p.Y ).ThenBy( p => p.Z ).ToArray();
        var index = new Dictionary<Vector3, int>();
        for ( var i = 0; i < unique.Length; i++ )
            index[unique[i]] = i;

        var bones = new List<(int, int)>();
        var root = -1;
        for ( var b = 0; b < boneCount; b++ )
        {
            var parent = index[pairs[b].A];
            var child = index[pairs[b].B];
            if ( root < 0 )
                root = parent;   // hier order: first bone's parent is the root
            if ( parent != child )
                bones.Add( (parent, child) );
        }
        if ( bones.Count == 0 )
            throw new FormatException( "MagicArticulate produced only self-loop bones." );

        var (joints, merged, newRoot) = Merge( unique, bones, root );
        (joints, merged, newRoot) = Connect( joints, merged, newRoot, MergeTolerance * 8 );
        return new Decoded { Joints = joints, Bones = merged.ToArray(), Root = newRoot };
    }

    static Vector3 Point( List<int> bins, int at ) => new(
        MagicArticulateModel.Undiscretize( bins[at] ),
        MagicArticulateModel.Undiscretize( bins[at + 1] ),
        MagicArticulateModel.Undiscretize( bins[at + 2] ) );

    /// <summary>merge_duplicate_joints_and_fix_bones: greedy first-wins grouping
    /// within tolerance (root becomes its group's representative), remap, drop
    /// self-loops + duplicate bones, compact to used joints.</summary>
    static (Vector3[], List<(int, int)>, int) Merge(
        Vector3[] joints, List<(int, int)> bones, int root )
    {
        var map = new int[joints.Length];
        var used = new bool[joints.Length];
        for ( var i = 0; i < joints.Length; i++ )
        {
            if ( used[i] )
                continue;
            var group = new List<int> { i };
            for ( var j = i + 1; j < joints.Length; j++ )
                if ( !used[j] && Vector3.Distance( joints[i], joints[j] ) < MergeTolerance )
                {
                    group.Add( j );
                    used[j] = true;
                }
            used[i] = true;
            var representative = group.Contains( root ) ? root : group[0];
            foreach ( var member in group )
                map[member] = representative;
        }

        var remapped = new List<(int, int)>();
        var seen = new HashSet<(int, int)>();
        foreach ( var (p, c) in bones )
        {
            var bone = (map[p], map[c]);
            if ( bone.Item1 != bone.Item2 && seen.Add( bone ) )
                remapped.Add( bone );
        }

        return Compact( joints, remapped, map[root] );
    }

    /// <summary>ensure_skeleton_connectivity: repeatedly join the two closest
    /// components — merge the closest joint pair under the threshold, otherwise
    /// bridge them with a bone.</summary>
    static (Vector3[], List<(int, int)>, int) Connect(
        Vector3[] joints, List<(int, int)> bones, int root, float threshold )
    {
        while ( true )
        {
            var components = Components( joints.Length, bones );
            if ( components.Count <= 1 )
                return (joints, bones, root);

            var best = (A: -1, B: -1, Distance: float.MaxValue);
            for ( var i = 0; i < components.Count; i++ )
                for ( var j = i + 1; j < components.Count; j++ )
                    foreach ( var a in components[i] )
                        foreach ( var b in components[j] )
                        {
                            var d = Vector3.Distance( joints[a], joints[b] );
                            if ( d < best.Distance )
                                best = (a, b, d);
                        }

            if ( best.Distance < threshold )
            {
                // Merge b into a.
                var remapped = new List<(int, int)>();
                var seen = new HashSet<(int, int)>();
                foreach ( var (p, c) in bones )
                {
                    var bone = (p == best.B ? best.A : p, c == best.B ? best.A : c);
                    if ( bone.Item1 != bone.Item2 && seen.Add( bone ) )
                        remapped.Add( bone );
                }
                (joints, bones, root) = Compact(
                    joints, remapped, root == best.B ? best.A : root );
            }
            else
            {
                bones.Add( (best.A, best.B) );
            }
        }
    }

    static (Vector3[], List<(int, int)>, int) Compact(
        Vector3[] joints, List<(int, int)> bones, int root )
    {
        var usedIndices = new SortedSet<int> { root };
        foreach ( var (p, c) in bones )
        {
            usedIndices.Add( p );
            usedIndices.Add( c );
        }
        var oldToNew = new Dictionary<int, int>();
        var compactJoints = new Vector3[usedIndices.Count];
        var next = 0;
        foreach ( var old in usedIndices )
        {
            oldToNew[old] = next;
            compactJoints[next++] = joints[old];
        }
        var compactBones = bones.Select( b => (oldToNew[b.Item1], oldToNew[b.Item2]) ).ToList();
        return (compactJoints, compactBones, oldToNew[root]);
    }

    static List<List<int>> Components( int jointCount, List<(int, int)> bones )
    {
        var adjacency = new List<int>[jointCount];
        for ( var i = 0; i < jointCount; i++ )
            adjacency[i] = new List<int>();
        foreach ( var (p, c) in bones )
        {
            adjacency[p].Add( c );
            adjacency[c].Add( p );
        }

        var visited = new bool[jointCount];
        var components = new List<List<int>>();
        for ( var i = 0; i < jointCount; i++ )
        {
            if ( visited[i] )
                continue;
            var component = new List<int>();
            var queue = new Queue<int>();
            queue.Enqueue( i );
            visited[i] = true;
            while ( queue.Count > 0 )
            {
                var node = queue.Dequeue();
                component.Add( node );
                foreach ( var neighbor in adjacency[node] )
                    if ( !visited[neighbor] )
                    {
                        visited[neighbor] = true;
                        queue.Enqueue( neighbor );
                    }
            }
            components.Add( component );
        }
        return components;
    }
}