Code/AutoRig/Solve/SkinTokensSolver.cs

Solver for "SkinTokens" token-rig format. Encodes a sampled point cloud with a learned mesh encoder, then performs a constrained two-phase greedy decode to produce a skeleton and per-bone skin codes, decodes skin codes via a VAE or falls back to a geodesic skinner, and returns a RigResult with skeleton and weights.

Native Interop
using AutoRig.Analyze;
using AutoRig.Dl;
using AutoRig.Dl.Nn;
using AutoRig.Dl.UniRig;
using AutoRig.Rig;
using AutoRig.Solve.Organic;
using AutoRig.Voxel;

namespace AutoRig.Solve;

using Vector3 = System.Numerics.Vector3;

/// <summary>
/// SkinTokens (TokenRig) inference, skeleton stage golden-verified vs PyTorch:
/// perceiver conditioning over the full 54000-point cloud, then a constrained
/// GREEDY two-phase decode on the Qwen3 backbone (the reference samples with
/// beams). Phase 1 is UniRig's skeleton grammar; the grammar EOS (258) switches
/// to phase 2, where FSQ skin-codebook ids stream until the unified EOS is
/// FORCED at exactly (length − switchIndex) == joints · 4 — reference behavior
/// verbatim, including its off-by-one (the unified EOS occupies the last skin
/// slot). Skin tokens are decoded structurally but weights come from our
/// geodesic skinner until the FSQ skin-VAE decoder is ported.
/// </summary>
public static class SkinTokensSolver
{
    /// <summary>Reference context is 3192 positions; 512 go to the mesh prefix.</summary>
    public const int MaxTokens = 2680;

    /// <summary>SkinTokens' cls map differs from UniRig's, but "articulation"
    /// (the open-corpus class the demo feeds every upload) lands on id 266.</summary>
    public const int ClsArticulation = UniRigTokenizer.TokenClsArticulationXl;

    public static Action<string> Progress { get; set; }

    public static RigResult Rig( AnalysisResult analysis, SkinTokensModel model )
    {
        ArgumentNullException.ThrowIfNull( analysis );
        ArgumentNullException.ThrowIfNull( model );
        var mesh = analysis.Mesh;

        // ---- conditioning ----
        Progress?.Invoke( "input: sampling 54000-point cloud" );
        var prepared = SkinTokensInput.Prepare( mesh );
        Progress?.Invoke( "perceiver: encoding (512 queries over the full cloud)" );
        var latents = model.MeshEncoder.Encode(
            prepared.Points, prepared.Normals, prepared.SampledIndices );
        var prefix = TransformerOps.RmsNorm(
            latents.MatMul( model.OutputProjWeight.Transposed ).Add( model.OutputProjBias ),
            model.OutputNormGamma );

        // ---- two-phase constrained greedy decode ----
        var ids = new List<int> { UniRigTokenizer.TokenBos, ClsArticulation };
        var cache = new OptKvCache( model.Decoder.Layers.Length );
        var logits = model.Decoder.Forward(
            prefix.Concat( model.Decoder.EmbedTokens( ids.ToArray() ), 0 ),
            cache, positionStart: 0 );
        var switchAt = -1;

        while ( ids.Count < MaxTokens )
        {
            var lastRow = (logits.Shape[0] - 1) * logits.Shape[1];
            int best;
            if ( switchAt < 0 )
            {
                var allowed = UniRigTokenizer.NextPossibleTokens( ids );
                best = allowed[0];
                foreach ( var id in allowed )
                    if ( logits.Data[lastRow + id] > logits.Data[lastRow + best] )
                        best = id;
            }
            else if ( ForcesUnifiedEos( ids, switchAt, model.TokensPerSkin ) )
            {
                best = model.UnifiedEos;
            }
            else
            {
                best = model.SwitchToken;   // argmax over [258, unifiedEos)
                for ( var id = model.SwitchToken; id < model.UnifiedEos; id++ )
                    if ( logits.Data[lastRow + id] > logits.Data[lastRow + best] )
                        best = id;
            }

            ids.Add( best );
            if ( switchAt < 0 && best == model.SwitchToken )
            {
                switchAt = ids.Count - 1;
                Progress?.Invoke( $"decode: skeleton closed at {ids.Count} tokens; skin phase" );
            }
            if ( best == model.UnifiedEos )
                break;
            if ( ids.Count % 64 == 0 )
                Progress?.Invoke( $"decode: {ids.Count} tokens" );

            logits = model.Decoder.Forward(
                model.Decoder.EmbedTokens( new[] { best } ),
                cache, positionStart: cache.Length );
        }
        if ( switchAt < 0 )
        {
            ids.Add( UniRigTokenizer.TokenEos );   // budget hit — close the skeleton
            switchAt = ids.Count - 1;
        }

        var decoded = UniRigTokenizer.Detokenize( ids.Take( switchAt + 1 ).ToList() );
        if ( decoded.Joints.Length < 2 )
            throw new FormatException(
                $"SkinTokens produced {decoded.Joints.Length} joint(s) - too few for a rig." );

        // ---- skeleton back to mesh space, parent-before-child ----
        var joints = decoded.Joints
            .Select( p => p * prepared.Scale + prepared.Center ).ToArray();
        var (skeleton, order) = BuildOrdered( joints, decoded.Parents );
        skeleton.Validate();

        // ---- NEURAL skinning: FSQ codes → VAE → per-vertex weights ----
        var jointCount = decoded.Joints.Length;
        var skinIds = ids.Skip( switchAt + 1 )
            .Take( jointCount * model.TokensPerSkin )
            .Select( id => id - model.SkeletonVocab )
            .ToList();
        SkinWeights weights = null;
        if ( skinIds.Count == jointCount * model.TokensPerSkin )
        {
            Progress?.Invoke( $"skinning: FSQ VAE, {jointCount} bones" );
            weights = NeuralSkin( model, prepared, mesh, skinIds, jointCount, order );
        }
        if ( weights is null )
        {
            // Reference behavior on a ragged sequence: no skin prediction.
            Progress?.Invoke( "skinning: geodesic fallback (ragged skin tokens)" );
            var resolution = mesh.TriangleCount < 50_000 ? 96 : 64;
            var grid = VoxelGrid.Build( mesh, resolution );
            weights = OrganicSkinner.Skin( mesh, grid, skeleton );
        }

        return new RigResult
        {
            Skeleton = skeleton,
            Weights = weights,
            SolverName = "deep-learning/skintokens",
            Degraded = false,
            Explanation = $"SkinTokens predicted {skeleton.Joints.Count} joints "
                + "with NEURAL skin weights (FSQ skin-VAE; greedy decode).",
        };
    }

    /// <summary>tokenrig.decode: per-bone FSQ codes → VAE sigmoids on query
    /// points → normalize across bones (asset.normalize_skin) → 4 influences.
    /// Queries are mesh vertices (a strided subset past 4096, IDW-8 transfer —
    /// the reference queries its sampled cloud and kd-transfers likewise).</summary>
    static SkinWeights NeuralSkin(
        SkinTokensModel model, SkinTokensInput.Prepared prepared,
        AutoRig.Mesh.RigMesh mesh, List<int> skinIds, int jointCount, int[] order )
    {
        // Cond latents: 384 queries fps'd from a strided 1536-point subset
        // (the reference subsamples unseeded — deterministic here).
        const int CondTokens = 384;
        var pre = new int[CondTokens * 4];
        for ( var i = 0; i < pre.Length; i++ )
            pre[i] = (int)((long)i * prepared.Points.Length / pre.Length);
        var preXyz = new float[pre.Length * 3];
        for ( var i = 0; i < pre.Length; i++ )
        {
            var p = prepared.Points[pre[i]];
            preXyz[i * 3] = p.X;
            preXyz[i * 3 + 1] = p.Y;
            preXyz[i * 3 + 2] = p.Z;
        }
        var fps = AutoRig.Dl.RigNet.PointNet.FarthestPointSample(
            Tensor.From( preXyz, pre.Length, 3 ), ratio: 0.25f );
        var condIndices = fps.Select( i => pre[i] ).ToArray();
        var cond = model.Vae.EncodeCond( prepared.Points, prepared.Normals, condIndices );

        // Query points: normalized mesh vertices (subset when large).
        var vertexCount = mesh.Positions.Length;
        var queryCount = Math.Min( vertexCount, 4096 );
        var queryVertex = new int[queryCount];
        for ( var i = 0; i < queryCount; i++ )
            queryVertex[i] = (int)((long)i * vertexCount / queryCount);
        var queryPoints = new Vector3[queryCount];
        var queryNormals = new Vector3[queryCount];
        for ( var i = 0; i < queryCount; i++ )
        {
            var v = queryVertex[i];
            queryPoints[i] = (mesh.Positions[v] - prepared.Center) / prepared.Scale;
            queryNormals[i] = v < mesh.Normals.Length ? mesh.Normals[v] : Vector3.UnitZ;
        }
        var queryEmbed = SkinVae.Embed( queryPoints, queryNormals );

        // Per bone: codes → weights on the query set.
        var perBone = new float[jointCount][];
        for ( var j = 0; j < jointCount; j++ )
        {
            var codes = model.Vae.DequantizeCodes(
                skinIds.Skip( j * model.TokensPerSkin ).Take( model.TokensPerSkin ).ToList() );
            perBone[j] = model.Vae.DecodeBone( codes, cond, queryEmbed );
            Progress?.Invoke( $"skinning: bone {j + 1}/{jointCount}" );
        }

        // Per vertex: nearest queried vertices (IDW-8 when subsetted), then
        // normalize across bones and keep the engine's 4 influences.
        var boneIndices = new int[vertexCount * 4];
        var boneWeights = new float[vertexCount * 4];
        Concurrency.For( 0, vertexCount, v =>
        {
            Span<float> row = stackalloc float[64];
            var bones = Math.Min( jointCount, 64 );
            if ( queryCount == vertexCount )
            {
                for ( var j = 0; j < bones; j++ )
                    row[j] = perBone[j][v];
            }
            else
            {
                var p = (mesh.Positions[v] - prepared.Center) / prepared.Scale;
                Span<int> nearest = stackalloc int[8];
                Span<float> distance = stackalloc float[8];
                var used = 0;
                for ( var q = 0; q < queryCount; q++ )
                {
                    var d = Vector3.DistanceSquared( p, queryPoints[q] );
                    if ( used < 8 )
                    {
                        nearest[used] = q;
                        distance[used++] = d;
                    }
                    else
                    {
                        var worst = 0;
                        for ( var k = 1; k < 8; k++ )
                            if ( distance[k] > distance[worst] )
                                worst = k;
                        if ( d < distance[worst] )
                        {
                            nearest[worst] = q;
                            distance[worst] = d;
                        }
                    }
                }
                float totalInverse = 0;
                for ( var k = 0; k < used; k++ )
                {
                    distance[k] = 1f / (MathF.Sqrt( distance[k] ) + 1e-8f);
                    totalInverse += distance[k];
                }
                for ( var j = 0; j < bones; j++ )
                {
                    float value = 0;
                    for ( var k = 0; k < used; k++ )
                        value += distance[k] * perBone[j][nearest[k]];
                    row[j] = value / totalInverse;
                }
            }

            // normalize_skin + top-4 (engine limit), remapped to BFS order.
            Span<int> top = stackalloc int[4];
            Span<float> topWeight = stackalloc float[4];
            var kept = 0;
            for ( var j = 0; j < bones; j++ )
            {
                if ( kept < 4 )
                {
                    top[kept] = j;
                    topWeight[kept++] = row[j];
                }
                else
                {
                    var weakest = 0;
                    for ( var k = 1; k < 4; k++ )
                        if ( topWeight[k] < topWeight[weakest] )
                            weakest = k;
                    if ( row[j] > topWeight[weakest] )
                    {
                        top[weakest] = j;
                        topWeight[weakest] = row[j];
                    }
                }
            }
            float total = 0;
            for ( var k = 0; k < kept; k++ )
                total += topWeight[k];
            if ( total <= 1e-8f )
            {
                topWeight[0] = 1f;
                total = 1f;
                kept = 1;
            }
            for ( var k = 0; k < kept; k++ )
            {
                boneIndices[v * 4 + k] = order[top[k]];
                boneWeights[v * 4 + k] = topWeight[k] / total;
            }
        } );

        return new SkinWeights { BoneIndices = boneIndices, Weights = boneWeights };
    }

    /// <summary>The reference's VocabSwitchingLogitsProcessor termination rule:
    /// with 258 at <paramref name="switchAt"/> in the FULL sequence (start
    /// tokens included), the unified EOS is banned until
    /// (length − switchAt) == joints · tokensPerSkin, then forced.</summary>
    internal static bool ForcesUnifiedEos(
        IReadOnlyList<int> ids, int switchAt, int tokensPerSkin )
        => ids.Count - switchAt
            == UniRigTokenizer.BonesInSequence( ids ) * tokensPerSkin;

    static (RigSkeleton Skeleton, int[] Order) BuildOrdered( Vector3[] joints, int[] parents )
    {
        var order = new int[joints.Length];
        var skeleton = new RigSkeleton();
        var root = Array.IndexOf( parents, -1 );
        var queue = new Queue<int>();
        queue.Enqueue( root );
        var emitted = 0;
        while ( queue.Count > 0 )
        {
            var j = queue.Dequeue();
            order[j] = emitted++;
            skeleton.Joints.Add( new RigJoint
            {
                Name = j == root ? "root" : $"bone_{j}",
                Parent = j == root ? -1 : order[parents[j]],
                Position = joints[j],
            } );
            for ( var c = 0; c < parents.Length; c++ )
                if ( c != root && parents[c] == j )
                    queue.Enqueue( c );
        }
        return (skeleton, order);
    }
}