AutoRig/Solve/UniRigSolver.cs

Solver that runs the UniRig skeleton prediction pipeline. It encodes a sampled point cloud with a mesh encoder, runs a constrained greedy autoregressive decode over UniRig tokens to produce joint positions, reconstructs a parent-ordered RigSkeleton, and computes interim skin weights via a geodesic voxel skinner.

Native Interop
using AutoRig.Analyze;
using AutoRig.Dl;
using AutoRig.Dl.Nn;
using AutoRig.Dl.UniRig;
using AutoRig.Rig;
using AutoRig.Solve.Organic;
using AutoRig.Voxel;

namespace AutoRig.Solve;

using Vector3 = System.Numerics.Vector3;

/// <summary>
/// UniRig inference (skeleton stage, golden-verified vs PyTorch): point-cloud
/// perceiver conditioning, constrained GREEDY autoregressive skeleton decode
/// (deterministic; the reference samples with beams), detokenization back to mesh
/// space. Skin weights come from our geodesic voxel skinner until UniRig's skin
/// network is ported (documented interim).
/// </summary>
public static class UniRigSolver
{
    /// <summary>Decode budget: ~2900 tokens fit after the 1024-latent prefix; cap
    /// well below for interactivity (~256 bones).</summary>
    public const int MaxTokens = 1024;

    /// <summary>Optional progress sink (stage names + decode ticks) — long runs
    /// were unobservable, which made gate diagnosis guesswork.</summary>
    public static Action<string> Progress { get; set; }

    public static RigResult Rig( AnalysisResult analysis, UniRigSkeletonModel model )
    {
        ArgumentNullException.ThrowIfNull( analysis );
        ArgumentNullException.ThrowIfNull( model );
        var mesh = analysis.Mesh;

        // ---- conditioning ----
        Progress?.Invoke( "input: sampling point cloud" );
        var prepared = UniRigInput.Prepare( mesh );
        Progress?.Invoke( "perceiver: encoding 4096-point cloud" );
        var latents = model.MeshEncoder.Encode(
            prepared.PrePoints, prepared.PreNormals, prepared.SampledIndices );
        var prefix = latents.MatMul( model.OutputProjWeight.Transposed )
            .Add( model.OutputProjBias );

        // ---- constrained greedy decode (no cls, like the reference's no_cls path) ----
        var ids = new List<int> { UniRigTokenizer.TokenBos };
        var cache = new OptKvCache( model.Decoder.Layers.Length );
        var logits = model.Decoder.Forward(
            prefix.Concat( model.Decoder.EmbedTokens( new[] { UniRigTokenizer.TokenBos } ), 0 ),
            cache, positionStart: 0 );

        while ( ids.Count < MaxTokens )
        {
            var allowed = UniRigTokenizer.NextPossibleTokens( ids );
            var lastRow = (logits.Shape[0] - 1) * logits.Shape[1];
            var best = allowed[0];
            foreach ( var id in allowed )
                if ( logits.Data[lastRow + id] > logits.Data[lastRow + best] )
                    best = id;
            ids.Add( best );
            if ( ids.Count % 64 == 0 )
                Progress?.Invoke( $"decode: {ids.Count} tokens" );
            if ( best == UniRigTokenizer.TokenEos )
                break;

            logits = model.Decoder.Forward(
                model.Decoder.EmbedTokens( new[] { best } ),
                cache, positionStart: cache.Length );
        }
        if ( ids[^1] != UniRigTokenizer.TokenEos )
            ids.Add( UniRigTokenizer.TokenEos );   // budget hit — close the sequence

        var decoded = UniRigTokenizer.Detokenize( ids );
        if ( decoded.Joints.Length < 2 )
            throw new FormatException(
                $"UniRig produced {decoded.Joints.Length} joint(s) - too few for a rig." );

        // ---- skeleton back to mesh space, parent-before-child ----
        var joints = decoded.Joints
            .Select( p => p * prepared.Scale + prepared.Center ).ToArray();
        var (skeleton, _) = BuildOrdered( joints, decoded.Parents );
        skeleton.Validate();

        // ---- interim skinning: geodesic voxel weights on the decoded skeleton ----
        Progress?.Invoke(
            $"skinning: {skeleton.Joints.Count} joints (decoded {ids.Count} tokens)" );
        var resolution = mesh.TriangleCount < 50_000 ? 96 : 64;
        var grid = VoxelGrid.Build( mesh, resolution );
        var weights = OrganicSkinner.Skin( mesh, grid, skeleton );

        return new RigResult
        {
            Skeleton = skeleton,
            Weights = weights,
            SolverName = "deep-learning/unirig",
            Degraded = false,
            Explanation = $"UniRig predicted {skeleton.Joints.Count} joints "
                + "(greedy decode; weights from geodesic skinning).",
        };
    }

    static (RigSkeleton Skeleton, int[] Order) BuildOrdered( Vector3[] joints, int[] parents )
    {
        var order = new int[joints.Length];
        var skeleton = new RigSkeleton();
        var root = Array.IndexOf( parents, -1 );
        var queue = new Queue<int>();
        queue.Enqueue( root );
        var emitted = 0;
        while ( queue.Count > 0 )
        {
            var j = queue.Dequeue();
            order[j] = emitted++;
            skeleton.Joints.Add( new RigJoint
            {
                Name = j == root ? "root" : $"bone_{j}",
                Parent = j == root ? -1 : order[parents[j]],
                Position = joints[j],
            } );
            for ( var c = 0; c < parents.Length; c++ )
                if ( c != root && parents[c] == j )
                    queue.Enqueue( c );
        }
        return (skeleton, order);
    }
}