AutoRig/Solve/PuppeteerSolver.cs

Solver that runs the Puppeteer deep model to produce a skeleton and skin weights from an input mesh. It prepares input conditioning, performs greedy token decoding with a transformer, detokenizes into joints/bones, builds and validates a skeleton, then computes skin weights either with a neural skinning model or a geodesic voxel-based fallback.

NetworkingNative Interop
using AutoRig.Analyze;
using AutoRig.Dl;
using AutoRig.Dl.MagicArticulate;
using AutoRig.Dl.Nn;
using AutoRig.Dl.Puppeteer;
using AutoRig.Dl.UniRig;
using AutoRig.Rig;
using AutoRig.Solve.Organic;
using AutoRig.Voxel;

namespace AutoRig.Solve;

using Vector3 = System.Numerics.Vector3;

/// <summary>
/// Puppeteer inference (golden-verified vs PyTorch): MagicArticulate's
/// conditioning and post-LN OPT decode plus the target-aware positional
/// embedding, GREEDY (the reference default), joint-token detokenize (parents
/// carried IN the sequence). Skin weights from our geodesic voxel skinner
/// until Puppeteer's skinning transformer is ported.
/// </summary>
public static class PuppeteerSolver
{
    /// <summary>generate_length = max_length 663 − cond_length 257.</summary>
    public const int MaxTokens = 406;

    public static Action<string> Progress { get; set; }

    public static RigResult Rig( AnalysisResult analysis, PuppeteerModel model )
    {
        ArgumentNullException.ThrowIfNull( analysis );
        ArgumentNullException.ThrowIfNull( model );
        var mesh = analysis.Mesh;
        var core = model.Core;

        Progress?.Invoke( "input: sampling 8192-point cloud" );
        var prepared = MagicArticulateInput.Prepare( mesh );
        Progress?.Invoke( "shapevae: encoding prefix" );
        var prefix = core.EncodePrefix( prepared.Points, prepared.Normals );
        var cond = prefix
            .Add( core.CondTable.Slice( 0, 0, 1 ).RepeatRows( prefix.Shape[0] ) )
            .Add( model.TargetPosEmbed.Slice( 0, 0, 1 ).RepeatRows( prefix.Shape[0] ) );

        // ---- greedy decode (BOS predicted, not fed) ----
        var ids = new List<int>();
        var cache = new OptKvCache( core.Decoder.Layers.Length );
        var logits = core.Decoder.Forward( cond, cache, positionStart: 0 );

        while ( ids.Count < MaxTokens )
        {
            var lastRow = (logits.Shape[0] - 1) * logits.Shape[1];
            var best = 0;
            for ( var id = 1; id < logits.Shape[1]; id++ )
                if ( logits.Data[lastRow + id] > logits.Data[lastRow + best] )
                    best = id;
            ids.Add( best );
            if ( best == MagicArticulateModel.TokenEos )
                break;
            if ( ids.Count % 64 == 0 )
                Progress?.Invoke( $"decode: {ids.Count} tokens" );

            var k = ids.Count;   // 1-based index of the token just added
            var embed = core.Decoder.EmbedTokens( new[] { best } )
                .Add( TransformerOps.Embed( core.BoneTable, new[]
                {
                    MagicArticulateModel.BonePositionId( k, best, PuppeteerModel.BonePerToken ),
                } ) )
                .Add( core.CondTable.Slice( 0, 1, 1 ) );
            // target-aware: token k ≥ 2 gets row 1 + (k−2)/4 (BOS exempt).
            if ( k >= 2 )
            {
                var row = 1 + (k - 2) / PuppeteerModel.BonePerToken;
                if ( row < model.TargetPosEmbed.Shape[0] )
                    embed = embed.Add( model.TargetPosEmbed.Slice( 0, row, 1 ) );
            }
            logits = core.Decoder.Forward( embed, cache, positionStart: cache.Length );
        }

        var body = ids
            .SkipWhile( ( _, i ) => i == 0 && ids[0] == MagicArticulateModel.TokenBos )
            .TakeWhile( id => id != MagicArticulateModel.TokenEos )
            .ToList();
        Progress?.Invoke( $"decode: {ids.Count} tokens → {body.Count / 4} joints" );

        var decoded = PuppeteerSkeleton.Detokenize( body );
        if ( decoded.Joints.Length < 2 )
            throw new FormatException(
                $"Puppeteer produced {decoded.Joints.Length} joint(s) - too few for a rig." );

        var joints = decoded.Joints
            .Select( p => ((p * prepared.PcScale + prepared.PcCenter) / prepared.Scale)
                + prepared.Center )
            .ToArray();

        var skeleton = MagicArticulateSolver.BuildTree(
            joints, decoded.Bones, decoded.Root, out var order );
        skeleton.Validate();

        SkinWeights weights = null;
        if ( model.Skin is not null )
        {
            Progress?.Invoke( $"skinning: neural transformer, {skeleton.Joints.Count} joints" );
            weights = NeuralSkin( model.Skin, mesh, prepared, decoded, order );
        }
        if ( weights is null )
        {
            Progress?.Invoke( $"skinning: geodesic, {skeleton.Joints.Count} joints" );
            var resolution = mesh.TriangleCount < 50_000 ? 96 : 64;
            var grid = VoxelGrid.Build( mesh, resolution );
            weights = OrganicSkinner.Skin( mesh, grid, skeleton );
        }

        return new RigResult
        {
            Skeleton = skeleton,
            Weights = weights,
            SolverName = "deep-learning/puppeteer",
            Degraded = false,
            Explanation = $"Puppeteer predicted {skeleton.Joints.Count} joints "
                + (model.Skin is not null
                    ? "with NEURAL skin weights (skinning transformer)."
                    : "(greedy decode; weights from geodesic skinning)."),
        };
    }

    /// <summary>The reference skinning pipeline (skin_data.py + main.py):
    /// weights predicted on an 8192-point cloud in the normalized space, each
    /// mesh vertex takes its nearest sample's weights, top-4 kept. Skeleton
    /// input rows = [parentPos | jointPos] (root parent slot zero); TAJA bias
    /// uses joint-graph hop distances.</summary>
    static SkinWeights NeuralSkin(
        AutoRig.Dl.Puppeteer.PuppeteerSkinModel skin,
        AutoRig.Mesh.RigMesh mesh, MagicArticulateInput.Prepared prepared,
        AutoRig.Dl.MagicArticulate.MagicArticulateSkeleton.Decoded decoded,
        int[] order )
    {
        // 8192 samples in the SAME normalized space as the decoded joints:
        // Prepare()'s cloud (±0.9995) then the residual bounds transform.
        var (cloud, normals, _, _) = UniRigInput.SampleCloud(
            mesh, 8192, vertexSamples: 0 );
        var samples = new Vector3[cloud.Length];
        for ( var i = 0; i < cloud.Length; i++ )
            samples[i] = (cloud[i] * MagicArticulateInput.ScaleFactor - prepared.PcCenter)
                / prepared.PcScale;

        var n = samples.Length;
        var points6 = new float[n * 6];
        for ( var i = 0; i < n; i++ )
        {
            points6[i * 6 + 0] = samples[i].X;
            points6[i * 6 + 1] = samples[i].Y;
            points6[i * 6 + 2] = samples[i].Z;
            points6[i * 6 + 3] = normals[i].X;
            points6[i * 6 + 4] = normals[i].Y;
            points6[i * 6 + 5] = normals[i].Z;
        }

        skin.Progress = skin.PartField.Progress = PuppeteerSolver.Progress;
        var partFeatures = skin.PartField.Features( samples, samples );
        var shape = skin.ShapeFeature( samples, normals );

        // bone_coor rows in the same normalized space; graph hop distances.
        var jointCount = decoded.Joints.Length;
        var joints6 = new float[jointCount * 6];
        var parents = new int[jointCount];
        for ( var j = 0; j < jointCount; j++ )
            parents[j] = -1;
        foreach ( var (p, c) in decoded.Bones )
            if ( parents[c] < 0 && c != decoded.Root )
                parents[c] = p;
        for ( var j = 0; j < jointCount; j++ )
        {
            if ( parents[j] >= 0 )
            {
                joints6[j * 6 + 0] = decoded.Joints[parents[j]].X;
                joints6[j * 6 + 1] = decoded.Joints[parents[j]].Y;
                joints6[j * 6 + 2] = decoded.Joints[parents[j]].Z;
            }
            joints6[j * 6 + 3] = decoded.Joints[j].X;
            joints6[j * 6 + 4] = decoded.Joints[j].Y;
            joints6[j * 6 + 5] = decoded.Joints[j].Z;
        }
        var graphDist = HopDistances( jointCount, decoded.Bones );

        var perSample = skin.Weights(
            points6, n, partFeatures, joints6, jointCount, graphDist, shape );

        // Vertices take the nearest sample's weights, top-4 renormalized.
        var vertexCount = mesh.Positions.Length;
        var boneIndices = new int[vertexCount * 4];
        var boneWeights = new float[vertexCount * 4];
        AutoRig.Dl.Concurrency.For( 0, vertexCount, v =>
        {
            var p = ((mesh.Positions[v] - prepared.Center) * prepared.Scale
                - prepared.PcCenter) / prepared.PcScale;
            var nearest = 0;
            var best = float.MaxValue;
            for ( var s = 0; s < n; s++ )
            {
                var d = Vector3.DistanceSquared( p, samples[s] );
                if ( d < best )
                {
                    best = d;
                    nearest = s;
                }
            }
            Span<int> top = stackalloc int[4];
            Span<float> topW = stackalloc float[4];
            var kept = 0;
            for ( var j = 0; j < jointCount; j++ )
            {
                var weight = perSample[nearest * jointCount + j];
                if ( kept < 4 )
                {
                    top[kept] = j;
                    topW[kept++] = weight;
                }
                else
                {
                    var weakest = 0;
                    for ( var k = 1; k < 4; k++ )
                        if ( topW[k] < topW[weakest] )
                            weakest = k;
                    if ( weight > topW[weakest] )
                    {
                        top[weakest] = j;
                        topW[weakest] = weight;
                    }
                }
            }
            float total = 0;
            for ( var k = 0; k < kept; k++ )
                total += topW[k];
            if ( total <= 1e-8f )
            {
                topW[0] = 1f;
                total = 1f;
                kept = 1;
            }
            for ( var k = 0; k < kept; k++ )
            {
                boneIndices[v * 4 + k] = order[top[k]];
                boneWeights[v * 4 + k] = topW[k] / total;
            }
        } );

        return new SkinWeights { BoneIndices = boneIndices, Weights = boneWeights };
    }

    static int[,] HopDistances( int jointCount, (int Parent, int Child)[] bones )
    {
        var adjacency = new List<int>[jointCount];
        for ( var i = 0; i < jointCount; i++ )
            adjacency[i] = new List<int>();
        foreach ( var (p, c) in bones )
        {
            adjacency[p].Add( c );
            adjacency[c].Add( p );
        }
        var distances = new int[jointCount, jointCount];
        for ( var start = 0; start < jointCount; start++ )
        {
            var seen = new bool[jointCount];
            var queue = new Queue<(int, int)>();
            queue.Enqueue( (start, 0) );
            seen[start] = true;
            while ( queue.Count > 0 )
            {
                var (node, depth) = queue.Dequeue();
                distances[start, node] = depth;
                foreach ( var neighbor in adjacency[node] )
                    if ( !seen[neighbor] )
                    {
                        seen[neighbor] = true;
                        queue.Enqueue( (neighbor, depth + 1) );
                    }
            }
        }
        return distances;
    }

}