Code/AutoRig/Solve/AnymateSolver.cs

Solver that runs the Anymate neural pipeline to produce a rig and skin weights from an analyzed mesh. Samples a point cloud, runs model stages to predict joint locations, parent links, and per-vertex skin weights, builds a BFS-ordered skeleton and weight tables, and returns a RigResult.

NetworkingFile Access
using AutoRig.Analyze;
using AutoRig.Dl;
using AutoRig.Dl.Anymate;
using AutoRig.Dl.UniRig;
using AutoRig.Rig;

namespace AutoRig.Solve;

using Vector3 = System.Numerics.Vector3;

/// <summary>
/// Anymate inference (golden-verified vs PyTorch): three neural stages —
/// joint samples clustered to joints, per-joint parent classification
/// (self = root), and per-vertex skin weights over bones. The PointBERT
/// encoders see the cloud halved (the reference's bert branch); the decoders
/// see raw normalized coordinates. Fully neural rig, no geodesic fallback
/// needed unless the skeleton degenerates.
/// </summary>
public static class AnymateSolver
{
    public const int NumSamples = 8192;
    public const int MaxSkinQueries = 4096;

    public static Action<string> Progress { get; set; }

    public static RigResult Rig( AnalysisResult analysis, AnymateModel model )
    {
        ArgumentNullException.ThrowIfNull( analysis );
        ArgumentNullException.ThrowIfNull( model );
        var mesh = analysis.Mesh;
        model.Progress = Progress;

        // ---- normalized cloud: [-1,1] (max-extent = 2), like normalize_mesh ----
        Progress?.Invoke( "input: sampling 8192-point cloud" );
        var (cloud, normals, center, halfExtent) =
            UniRigInput.SampleCloud( mesh, NumSamples, vertexSamples: 0 );
        var halved = new Vector3[cloud.Length];
        for ( var i = 0; i < cloud.Length; i++ )
            halved[i] = cloud[i] / 2f;   // the bert branch's /2

        // ---- stage 1: joints ----
        var joints = model.Joints( halved );
        if ( joints.Count < 2 )
            throw new FormatException(
                $"Anymate produced {joints.Count} joint(s) - too few for a rig." );
        Progress?.Invoke( $"anymate: {joints.Count} joints" );

        // ---- stage 2: parents (self = root; extra roots re-attach to first) ----
        var parents = model.Parents( halved, joints );
        var root = -1;
        for ( var j = 0; j < joints.Count; j++ )
        {
            if ( parents[j] != j )
                continue;
            if ( root < 0 )
                root = j;
            else
                parents[j] = root;
        }
        if ( root < 0 )
        {
            root = 0;
            parents[0] = 0;
        }
        // break any cycles (network output is unconstrained): walk each chain.
        for ( var j = 0; j < joints.Count; j++ )
        {
            var seen = new HashSet<int> { j };
            var node = j;
            while ( parents[node] != node && node != root )
            {
                node = parents[node];
                if ( !seen.Add( node ) )
                {
                    parents[node] = root;   // cycle → cut to root
                    break;
                }
            }
        }
        Progress?.Invoke( $"anymate: root {root}" );

        // ---- stage 3: neural skin over bones ----
        var boneJoint = new List<int>();   // bone index → child joint
        for ( var j = 0; j < joints.Count; j++ )
            if ( parents[j] != j )
                boneJoint.Add( j );
        var bones6 = new float[boneJoint.Count * 6];
        for ( var b = 0; b < boneJoint.Count; b++ )
        {
            var child = boneJoint[b];
            var parent = joints[parents[child]];
            bones6[b * 6 + 0] = parent.X;
            bones6[b * 6 + 1] = parent.Y;
            bones6[b * 6 + 2] = parent.Z;
            bones6[b * 6 + 3] = joints[child].X;
            bones6[b * 6 + 4] = joints[child].Y;
            bones6[b * 6 + 5] = joints[child].Z;
        }

        var vertexCount = mesh.Positions.Length;
        var queryCount = Math.Min( vertexCount, MaxSkinQueries );
        var queryVertex = new int[queryCount];
        var queryPoints = new Vector3[queryCount];
        var queryNormals = new Vector3[queryCount];
        for ( var i = 0; i < queryCount; i++ )
        {
            var v = (int)((long)i * vertexCount / queryCount);
            queryVertex[i] = v;
            queryPoints[i] = (mesh.Positions[v] - center) / halfExtent;   // [-1,1]
            queryNormals[i] = v < mesh.Normals.Length
                ? mesh.Normals[v]
                : Vector3.UnitZ;
        }
        var perQuery = model.Skin( halved, queryPoints, queryNormals, bones6, boneJoint.Count );

        // ---- skeleton (denormalize) + weights (bone → child joint, BFS order) ----
        var world = new Vector3[joints.Count];
        for ( var j = 0; j < joints.Count; j++ )
            world[j] = joints[j] * halfExtent + center;

        var order = new int[joints.Count];
        var skeleton = new RigSkeleton();
        var queue = new Queue<int>();
        queue.Enqueue( root );
        var emitted = 0;
        var visited = new bool[joints.Count];
        visited[root] = true;
        while ( queue.Count > 0 )
        {
            var j = queue.Dequeue();
            order[j] = emitted++;
            skeleton.Joints.Add( new RigJoint
            {
                Name = j == root ? "root" : $"bone_{j}",
                Parent = j == root ? -1 : order[parents[j]],
                Position = world[j],
            } );
            for ( var c = 0; c < joints.Count; c++ )
                if ( !visited[c] && parents[c] == j && c != j )
                {
                    visited[c] = true;
                    queue.Enqueue( c );
                }
        }
        // unreachable joints (disconnected branches): attach to root.
        for ( var j = 0; j < joints.Count; j++ )
            if ( !visited[j] )
            {
                order[j] = skeleton.Joints.Count;
                skeleton.Joints.Add( new RigJoint
                {
                    Name = $"bone_{j}",
                    Parent = 0,
                    Position = world[j],
                } );
            }
        skeleton.Validate();

        var boneIndices = new int[vertexCount * 4];
        var boneWeights = new float[vertexCount * 4];
        var boneCount = boneJoint.Count;
        Concurrency.For( 0, vertexCount, v =>
        {
            // nearest query (IDW not needed: strided queries are dense).
            var p = (mesh.Positions[v] - center) / halfExtent;
            var nearest = 0;
            var best = float.MaxValue;
            for ( var q = 0; q < queryCount; q++ )
            {
                var d = Vector3.DistanceSquared( p, queryPoints[q] );
                if ( d < best )
                {
                    best = d;
                    nearest = q;
                }
            }
            Span<int> top = stackalloc int[4];
            Span<float> topW = stackalloc float[4];
            var kept = 0;
            for ( var b = 0; b < boneCount; b++ )
            {
                var weight = perQuery[nearest * boneCount + b];
                if ( kept < 4 )
                {
                    top[kept] = b;
                    topW[kept++] = weight;
                }
                else
                {
                    var weakest = 0;
                    for ( var k = 1; k < 4; k++ )
                        if ( topW[k] < topW[weakest] )
                            weakest = k;
                    if ( weight > topW[weakest] )
                    {
                        top[weakest] = b;
                        topW[weakest] = weight;
                    }
                }
            }
            float total = 0;
            for ( var k = 0; k < kept; k++ )
                total += topW[k];
            if ( total <= 1e-8f )
            {
                topW[0] = 1f;
                total = 1f;
                kept = 1;
            }
            for ( var k = 0; k < kept; k++ )
            {
                boneIndices[v * 4 + k] = order[boneJoint[top[k]]];
                boneWeights[v * 4 + k] = topW[k] / total;
            }
        } );
        var weights = new SkinWeights { BoneIndices = boneIndices, Weights = boneWeights };
        weights.Validate( mesh, skeleton );

        return new RigResult
        {
            Skeleton = skeleton,
            Weights = weights,
            SolverName = "deep-learning/anymate",
            Degraded = false,
            Explanation = $"Anymate predicted {skeleton.Joints.Count} joints, "
                + "connectivity, and NEURAL skin weights (three-stage pipeline).",
        };
    }
}