Code/AutoRig/Solve/DeepLearningSolver.cs

DeepLearningSolver contains the deep-learning rigging inference pipeline. It loads/uses a RigNetBundle of neural nets, prepares a decimated normalized proxy mesh, predicts joints, builds a skeleton with MST symmetry logic, computes skinning weights with volumetric geodesic features, and returns a RigResult with skeleton and skin weights.

NetworkingFile Access
using AutoRig.Analyze;
using AutoRig.Dl;
using AutoRig.Dl.RigNet;
using AutoRig.Rig;
using AutoRig.Voxel;

namespace AutoRig.Solve;

using Vector3 = System.Numerics.Vector3;

/// <summary>All four RigNet networks, loaded once and reused across rigs.</summary>
public sealed class RigNetBundle
{
    public required JointPredNet JointNet;
    public required JointPredNet MaskNet;
    public required float Bandwidth;
    public required RootNet RootNet;
    public required BoneNet BoneNet;
    public required SkinNet SkinNet;

    /// <summary>
    /// Loads from the official checkpoints zip (as downloaded from the catalog):
    /// entries `checkpoints/{gcn_meanshift,rootnet,bonenet,skinnet}/model_best.pth.tar`.
    /// </summary>
    public static RigNetBundle FromCheckpointsZip( byte[] zipBytes )
    {
        var entries = TorchCheckpoint.ReadArchive( zipBytes );
        IReadOnlyDictionary<string, Tensor> Net( string name )
        {
            var entry = entries.Keys.FirstOrDefault(
                k => k.EndsWith( $"{name}/model_best.pth.tar", StringComparison.Ordinal ) )
                ?? throw new FormatException( $"Checkpoints zip has no '{name}/model_best.pth.tar'." );
            return TorchCheckpoint.Parse( entries[entry] );
        }

        var meanshift = RigNetLoader.LoadMeanshift( Net( "gcn_meanshift" ) );
        return new RigNetBundle
        {
            JointNet = meanshift.JointNet,
            MaskNet = meanshift.MaskNet,
            Bandwidth = meanshift.Bandwidth,
            RootNet = RigNetLoader.LoadRootNet( Net( "rootnet" ) ),
            BoneNet = RigNetLoader.LoadBoneNet( Net( "bonenet" ) ),
            SkinNet = RigNetLoader.LoadSkinNet( Net( "skinnet" ) ),
        };
    }
}

/// <summary>
/// The deep-learning solver: quick_start.py's inference pipeline in pure C#.
/// Joints from jointnet+masknet meanshift clustering, hierarchy from
/// rootnet/bonenet with the symmetry-aware MST, weights from skinnet over
/// volumetric geodesic features — computed on a ≤3k-vertex proxy in RigNet's
/// normalized y-up space, then transferred back to the full mesh.
/// </summary>
public static class DeepLearningSolver
{
    public const int ProxyVertexBudget = 3000;
    public const float DefaultDensityThreshold = 1e-5f;

    public static RigResult Rig( AnalysisResult analysis, RigNetBundle net,
        float densityThreshold = DefaultDensityThreshold )
    {
        ArgumentNullException.ThrowIfNull( analysis );
        ArgumentNullException.ThrowIfNull( net );
        var mesh = analysis.Mesh;

        // ---- proxy in normalized space ----
        var (proxy, vertexMap) = RigNetInput.Decimate( mesh, ProxyVertexBudget );
        var (normalized, pivot, scale) = RigNetInput.Normalize( proxy.Positions );
        var proxyNormalized = new Mesh.RigMesh
        {
            SourceName = proxy.SourceName,
            Positions = normalized,
            Normals = proxy.Normals,
            Uvs = proxy.Uvs,
            Triangles = proxy.Triangles,
            TriangleTags = proxy.TriangleTags,
            Tags = proxy.Tags,
        };
        var vertexCount = normalized.Length;

        var tplEdges = RigNetInput.TopologyEdges( vertexCount, proxy.Triangles );
        var surfaceGeodesic = SurfaceGeodesic.Build( proxyNormalized );
        var geoEdges = RigNetInput.GeodesicEdges( surfaceGeodesic, vertexCount );
        var voxels = VoxelGrid.Build( proxyNormalized, 88 );

        var pos = Tensor.From(
            normalized.SelectMany( p => new[] { p.X, p.Y, p.Z } ).ToArray(), vertexCount, 3 );

        // ---- joints (predict_joints) ----
        var offsets = net.JointNet.Forward( pos, tplEdges, geoEdges );
        var attention = net.MaskNet.Forward( pos, tplEdges, geoEdges ).Sigmoid();

        var insideCandidates = new List<Vector3>();
        var insideAttention = new List<float>();
        for ( var v = 0; v < vertexCount; v++ )
        {
            var candidate = normalized[v] + new Vector3(
                offsets.Data[v * 3], offsets.Data[v * 3 + 1], offsets.Data[v * 3 + 2] );
            if ( IsInside( candidate, voxels ) )
            {
                insideCandidates.Add( candidate );
                insideAttention.Add( attention.Data[v] );
            }
        }
        if ( insideCandidates.Count == 0 )
            throw new FormatException( "RigNet found no joint candidates inside the mesh." );

        var clustered = MeanShift.ExtractJoints(
            insideCandidates.ToArray(), insideAttention.ToArray(), net.Bandwidth, densityThreshold );
        var joints = MstSkeleton.Flip( clustered );
        if ( joints.Length < 2 )
            throw new FormatException( $"RigNet produced {joints.Length} joint(s) - too few for a rig." );

        // ---- hierarchy (predict_skeleton) ----
        var pairs = new List<(int A, int B)>();
        var pairAttr = new List<float>();
        for ( var a = 0; a < joints.Length; a++ )
            for ( var b = a + 1; b < joints.Length; b++ )
            {
                pairs.Add( (a, b) );
                var samples = MstSkeleton.SampleOnBone( joints[a], joints[b] );
                var inside = samples.Count( s => IsInside( s, voxels ) );
                pairAttr.Add( Vector3.Distance( joints[a], joints[b] ) );
                pairAttr.Add( inside / (samples.Length + 1e-10f) );
            }

        var jointTensor = Tensor.From(
            joints.SelectMany( p => new[] { p.X, p.Y, p.Z } ).ToArray(), joints.Length, 3 );
        var rootLogits = net.RootNet.Forward( pos, tplEdges, geoEdges, jointTensor );
        var rootId = 0;
        for ( var i = 1; i < joints.Length; i++ )
            if ( rootLogits.Data[i] > rootLogits.Data[rootId] )
                rootId = i;

        var boneLogits = net.BoneNet.Forward( pos, tplEdges, geoEdges, jointTensor,
            pairs.ToArray(), Tensor.From( pairAttr.ToArray(), pairs.Count, 2 ) );
        var cost = MstSkeleton.BuildCostMatrix(
            joints.Length, pairs.ToArray(), boneLogits.Sigmoid().Data );
        MstSkeleton.IncreaseCostForOutsideBone( cost, joints, p => IsInside( p, voxels ) );
        var (parents, _) = MstSkeleton.PrimMstSymmetry( cost, rootId, joints );

        // ---- weights (predict_skinning) ----
        var bones = SkinPipeline.GetBones( joints, parents );
        var geoDist = VolumetricGeodesic.Compute( proxyNormalized, bones, surfaceGeodesic, voxels );
        var (skinInput, nearestBones, mask) = SkinPipeline.BuildSkinInput( bones, geoDist );
        var logits = net.SkinNet.Forward( pos, skinInput, tplEdges, geoEdges );
        var boneWeights = SkinPipeline.FinalizeWeights(
            logits, nearestBones, mask, bones.Length, tplEdges );

        // ---- assemble: bone weights → joint weights, proxy → full mesh ----
        var skeleton = BuildSkeleton( joints, parents, pivot, scale );
        var jointOrder = skeleton.Order;   // RigNet joint id → skeleton index

        var proxyJointWeights = new float[vertexCount][];
        for ( var v = 0; v < vertexCount; v++ )
        {
            var row = new float[joints.Length];
            for ( var b = 0; b < bones.Length; b++ )
            {
                // assemble_skel_skin: a bone's weight binds to its starting joint
                // (a leaf bone's to the leaf joint itself).
                if ( boneWeights[v, b] > 0f )
                    row[bones[b].StartJoint] += boneWeights[v, b];
            }
            proxyJointWeights[v] = row;
        }

        var totalVertices = mesh.Positions.Length;
        var indices = new int[totalVertices * 4];
        var weights = new float[totalVertices * 4];
        var top = new int[4];
        for ( var v = 0; v < totalVertices; v++ )
        {
            var row = proxyJointWeights[vertexMap[v]];
            var count = 0;
            for ( var j = 0; j < row.Length; j++ )
            {
                if ( row[j] <= 0f )
                    continue;
                if ( count < 4 )
                    top[count++] = j;
                else
                {
                    var weakest = 0;
                    for ( var k = 1; k < 4; k++ )
                        if ( row[top[k]] < row[top[weakest]] )
                            weakest = k;
                    if ( row[j] > row[top[weakest]] )
                        top[weakest] = j;
                }
            }
            float total = 0;
            for ( var k = 0; k < count; k++ )
                total += row[top[k]];
            if ( count == 0 || total <= 0f )
            {
                indices[v * 4] = jointOrder[FindNearestJoint( joints, normalized[vertexMap[v]] )];
                weights[v * 4] = 1f;
                continue;
            }
            for ( var k = 0; k < count; k++ )
            {
                indices[v * 4 + k] = jointOrder[top[k]];
                weights[v * 4 + k] = row[top[k]] / total;
            }
        }

        return new RigResult
        {
            Skeleton = skeleton.Skeleton,
            Weights = new SkinWeights { BoneIndices = indices, Weights = weights },
            SolverName = "deep-learning/rignet",
            Degraded = false,
            Explanation = $"RigNet predicted {joints.Length} joints "
                + $"({bones.Length} bones) from the mesh's shape.",
        };
    }

    /// <summary>inside_check equivalent against our solid voxelization.</summary>
    static bool IsInside( Vector3 point, VoxelGrid voxels )
    {
        var (x, y, z) = voxels.VoxelOf( point );
        return voxels.IsSolid( x, y, z );
    }

    static int FindNearestJoint( Vector3[] joints, Vector3 point )
    {
        var best = 0;
        for ( var j = 1; j < joints.Length; j++ )
            if ( Vector3.DistanceSquared( joints[j], point )
                < Vector3.DistanceSquared( joints[best], point ) )
                best = j;
        return best;
    }

    /// <summary>Skeleton in parent-before-child order, back in mesh space.</summary>
    static (RigSkeleton Skeleton, int[] Order) BuildSkeleton(
        Vector3[] joints, int[] parents, Vector3 pivot, float scale )
    {
        var order = new int[joints.Length];   // RigNet joint id → skeleton index
        var skeleton = new RigSkeleton();
        var queue = new Queue<int>();
        var root = Array.IndexOf( parents, -1 );
        queue.Enqueue( root );
        var emitted = 0;
        while ( queue.Count > 0 )
        {
            var j = queue.Dequeue();
            order[j] = emitted++;
            skeleton.Joints.Add( new RigJoint
            {
                Name = j == root ? "root" : $"joint_{j}",
                Parent = j == root ? -1 : order[parents[j]],
                Position = joints[j] / scale + pivot,
            } );
            for ( var c = 0; c < parents.Length; c++ )
                if ( parents[c] == j )
                    queue.Enqueue( c );
        }
        return (skeleton, order);
    }
}