Code/AutoRig/Solve/MagicArticulateSolver.cs

Solver that runs a MagicArticulate deep model to predict a skeleton from an input mesh. It encodes a point cloud, performs greedy autoregressive token decoding, detokenizes to joints/bones, builds a rooted RigSkeleton, and computes skin weights via a voxel-based OrganicSkinner.

Native Interop
using AutoRig.Analyze;
using AutoRig.Dl;
using AutoRig.Dl.MagicArticulate;
using AutoRig.Dl.Nn;
using AutoRig.Dl.UniRig;
using AutoRig.Rig;
using AutoRig.Solve.Organic;
using AutoRig.Voxel;

namespace AutoRig.Solve;

using Vector3 = System.Numerics.Vector3;

/// <summary>
/// MagicArticulate inference (golden-verified vs PyTorch): ShapeVAE point-cloud
/// conditioning, GREEDY autoregressive bone decode (the reference's default is
/// also greedy, num_beams=1), detokenize + merge + connectivity repair back to
/// a rooted tree in mesh space. Skin weights come from our geodesic voxel
/// skinner (MagicArticulate predicts skeletons only).
/// </summary>
public static class MagicArticulateSolver
{
    /// <summary>generate_length = max_length 859 − cond_length 257.</summary>
    public const int MaxTokens = 602;

    public static Action<string> Progress { get; set; }

    public static RigResult Rig( AnalysisResult analysis, MagicArticulateModel model )
    {
        ArgumentNullException.ThrowIfNull( analysis );
        ArgumentNullException.ThrowIfNull( model );
        var mesh = analysis.Mesh;

        Progress?.Invoke( "input: sampling 8192-point cloud" );
        var prepared = MagicArticulateInput.Prepare( mesh );
        Progress?.Invoke( "shapevae: encoding prefix" );
        var prefix = model.EncodePrefix( prepared.Points, prepared.Normals );
        var cond = prefix.Add(
            model.CondTable.Slice( 0, 0, 1 ).RepeatRows( prefix.Shape[0] ) );

        // ---- greedy decode (reference default; BOS is PREDICTED, not fed) ----
        var ids = new List<int>();
        var cache = new OptKvCache( model.Decoder.Layers.Length );
        var logits = model.Decoder.Forward( cond, cache, positionStart: 0 );

        while ( ids.Count < MaxTokens )
        {
            var lastRow = (logits.Shape[0] - 1) * logits.Shape[1];
            var best = 0;
            for ( var id = 1; id < logits.Shape[1]; id++ )
                if ( logits.Data[lastRow + id] > logits.Data[lastRow + best] )
                    best = id;
            ids.Add( best );
            if ( best == MagicArticulateModel.TokenEos )
                break;
            if ( ids.Count % 64 == 0 )
                Progress?.Invoke( $"decode: {ids.Count} tokens" );

            var k = ids.Count;   // 1-based index of the token just added
            var embed = model.Decoder.EmbedTokens( new[] { best } )
                .Add( TransformerOps.Embed( model.BoneTable,
                    new[] { MagicArticulateModel.BonePositionId( k, best ) } ) )
                .Add( model.CondTable.Slice( 0, 1, 1 ) );
            logits = model.Decoder.Forward( embed, cache, positionStart: cache.Length );
        }

        // generate() strips the predicted BOS and the trailing EOS.
        var body = ids
            .SkipWhile( ( _, i ) => i == 0 && ids[0] == MagicArticulateModel.TokenBos )
            .TakeWhile( id => id != MagicArticulateModel.TokenEos )
            .ToList();
        Progress?.Invoke( $"decode: {ids.Count} tokens → {body.Count / 6} bones" );

        var decoded = MagicArticulateSkeleton.Detokenize( body );
        if ( decoded.Joints.Length < 2 )
            throw new FormatException(
                $"MagicArticulate produced {decoded.Joints.Length} joint(s) - too few for a rig." );

        // ---- back to mesh space ----
        var joints = decoded.Joints
            .Select( p => ((p * prepared.PcScale + prepared.PcCenter) / prepared.Scale)
                + prepared.Center )
            .ToArray();

        var skeleton = BuildTree( joints, decoded.Bones, decoded.Root );
        skeleton.Validate();

        Progress?.Invoke( $"skinning: {skeleton.Joints.Count} joints" );
        var resolution = mesh.TriangleCount < 50_000 ? 96 : 64;
        var grid = VoxelGrid.Build( mesh, resolution );
        var weights = OrganicSkinner.Skin( mesh, grid, skeleton );

        return new RigResult
        {
            Skeleton = skeleton,
            Weights = weights,
            SolverName = "deep-learning/magicarticulate",
            Degraded = false,
            Explanation = $"MagicArticulate predicted {skeleton.Joints.Count} joints "
                + "(greedy decode; weights from geodesic skinning).",
        };
    }

    /// <summary>Rooted tree from directed bones: BFS from the root following
    /// bone direction, then absorb any leftovers via undirected edges (the
    /// connectivity repair guarantees one component).</summary>
    internal static RigSkeleton BuildTree( Vector3[] joints, (int Parent, int Child)[] bones, int root )
        => BuildTree( joints, bones, root, out _ );

    /// <summary>As above, also exposing original-joint → BFS-index order
    /// (Puppeteer's neural skin weights need it to remap bone ids).</summary>
    internal static RigSkeleton BuildTree(
        Vector3[] joints, (int Parent, int Child)[] bones, int root, out int[] order )
    {
        var adjacency = new List<int>[joints.Length];
        for ( var i = 0; i < joints.Length; i++ )
            adjacency[i] = new List<int>();
        foreach ( var (p, c) in bones )
        {
            adjacency[p].Add( c );
            adjacency[c].Add( p );
        }

        order = new int[joints.Length];
        var skeleton = new RigSkeleton();
        var visited = new bool[joints.Length];
        var queue = new Queue<(int Joint, int Parent)>();
        queue.Enqueue( (root, -1) );
        visited[root] = true;
        var emitted = 0;
        while ( queue.Count > 0 )
        {
            var (j, parent) = queue.Dequeue();
            order[j] = emitted++;
            skeleton.Joints.Add( new RigJoint
            {
                Name = parent < 0 ? "root" : $"bone_{j}",
                Parent = parent < 0 ? -1 : order[parent],
                Position = joints[j],
            } );
            foreach ( var neighbor in adjacency[j] )
                if ( !visited[neighbor] )
                {
                    visited[neighbor] = true;
                    queue.Enqueue( (neighbor, j) );
                }
        }
        return skeleton;
    }
}