DeepLearningSolver contains the deep-learning rigging inference pipeline. It loads/uses a RigNetBundle of neural nets, prepares a decimated normalized proxy mesh, predicts joints, builds a skeleton with MST symmetry logic, computes skinning weights with volumetric geodesic features, and returns a RigResult with skeleton and skin weights.
using AutoRig.Analyze;
using AutoRig.Dl;
using AutoRig.Dl.RigNet;
using AutoRig.Rig;
using AutoRig.Voxel;
namespace AutoRig.Solve;
using Vector3 = System.Numerics.Vector3;
/// <summary>All four RigNet networks, loaded once and reused across rigs.</summary>
public sealed class RigNetBundle
{
public required JointPredNet JointNet;
public required JointPredNet MaskNet;
public required float Bandwidth;
public required RootNet RootNet;
public required BoneNet BoneNet;
public required SkinNet SkinNet;
/// <summary>
/// Loads from the official checkpoints zip (as downloaded from the catalog):
/// entries `checkpoints/{gcn_meanshift,rootnet,bonenet,skinnet}/model_best.pth.tar`.
/// </summary>
public static RigNetBundle FromCheckpointsZip( byte[] zipBytes )
{
var entries = TorchCheckpoint.ReadArchive( zipBytes );
IReadOnlyDictionary<string, Tensor> Net( string name )
{
var entry = entries.Keys.FirstOrDefault(
k => k.EndsWith( $"{name}/model_best.pth.tar", StringComparison.Ordinal ) )
?? throw new FormatException( $"Checkpoints zip has no '{name}/model_best.pth.tar'." );
return TorchCheckpoint.Parse( entries[entry] );
}
var meanshift = RigNetLoader.LoadMeanshift( Net( "gcn_meanshift" ) );
return new RigNetBundle
{
JointNet = meanshift.JointNet,
MaskNet = meanshift.MaskNet,
Bandwidth = meanshift.Bandwidth,
RootNet = RigNetLoader.LoadRootNet( Net( "rootnet" ) ),
BoneNet = RigNetLoader.LoadBoneNet( Net( "bonenet" ) ),
SkinNet = RigNetLoader.LoadSkinNet( Net( "skinnet" ) ),
};
}
}
/// <summary>
/// The deep-learning solver: quick_start.py's inference pipeline in pure C#.
/// Joints from jointnet+masknet meanshift clustering, hierarchy from
/// rootnet/bonenet with the symmetry-aware MST, weights from skinnet over
/// volumetric geodesic features — computed on a ≤3k-vertex proxy in RigNet's
/// normalized y-up space, then transferred back to the full mesh.
/// </summary>
public static class DeepLearningSolver
{
public const int ProxyVertexBudget = 3000;
public const float DefaultDensityThreshold = 1e-5f;
public static RigResult Rig( AnalysisResult analysis, RigNetBundle net,
float densityThreshold = DefaultDensityThreshold )
{
ArgumentNullException.ThrowIfNull( analysis );
ArgumentNullException.ThrowIfNull( net );
var mesh = analysis.Mesh;
// ---- proxy in normalized space ----
var (proxy, vertexMap) = RigNetInput.Decimate( mesh, ProxyVertexBudget );
var (normalized, pivot, scale) = RigNetInput.Normalize( proxy.Positions );
var proxyNormalized = new Mesh.RigMesh
{
SourceName = proxy.SourceName,
Positions = normalized,
Normals = proxy.Normals,
Uvs = proxy.Uvs,
Triangles = proxy.Triangles,
TriangleTags = proxy.TriangleTags,
Tags = proxy.Tags,
};
var vertexCount = normalized.Length;
var tplEdges = RigNetInput.TopologyEdges( vertexCount, proxy.Triangles );
var surfaceGeodesic = SurfaceGeodesic.Build( proxyNormalized );
var geoEdges = RigNetInput.GeodesicEdges( surfaceGeodesic, vertexCount );
var voxels = VoxelGrid.Build( proxyNormalized, 88 );
var pos = Tensor.From(
normalized.SelectMany( p => new[] { p.X, p.Y, p.Z } ).ToArray(), vertexCount, 3 );
// ---- joints (predict_joints) ----
var offsets = net.JointNet.Forward( pos, tplEdges, geoEdges );
var attention = net.MaskNet.Forward( pos, tplEdges, geoEdges ).Sigmoid();
var insideCandidates = new List<Vector3>();
var insideAttention = new List<float>();
for ( var v = 0; v < vertexCount; v++ )
{
var candidate = normalized[v] + new Vector3(
offsets.Data[v * 3], offsets.Data[v * 3 + 1], offsets.Data[v * 3 + 2] );
if ( IsInside( candidate, voxels ) )
{
insideCandidates.Add( candidate );
insideAttention.Add( attention.Data[v] );
}
}
if ( insideCandidates.Count == 0 )
throw new FormatException( "RigNet found no joint candidates inside the mesh." );
var clustered = MeanShift.ExtractJoints(
insideCandidates.ToArray(), insideAttention.ToArray(), net.Bandwidth, densityThreshold );
var joints = MstSkeleton.Flip( clustered );
if ( joints.Length < 2 )
throw new FormatException( $"RigNet produced {joints.Length} joint(s) - too few for a rig." );
// ---- hierarchy (predict_skeleton) ----
var pairs = new List<(int A, int B)>();
var pairAttr = new List<float>();
for ( var a = 0; a < joints.Length; a++ )
for ( var b = a + 1; b < joints.Length; b++ )
{
pairs.Add( (a, b) );
var samples = MstSkeleton.SampleOnBone( joints[a], joints[b] );
var inside = samples.Count( s => IsInside( s, voxels ) );
pairAttr.Add( Vector3.Distance( joints[a], joints[b] ) );
pairAttr.Add( inside / (samples.Length + 1e-10f) );
}
var jointTensor = Tensor.From(
joints.SelectMany( p => new[] { p.X, p.Y, p.Z } ).ToArray(), joints.Length, 3 );
var rootLogits = net.RootNet.Forward( pos, tplEdges, geoEdges, jointTensor );
var rootId = 0;
for ( var i = 1; i < joints.Length; i++ )
if ( rootLogits.Data[i] > rootLogits.Data[rootId] )
rootId = i;
var boneLogits = net.BoneNet.Forward( pos, tplEdges, geoEdges, jointTensor,
pairs.ToArray(), Tensor.From( pairAttr.ToArray(), pairs.Count, 2 ) );
var cost = MstSkeleton.BuildCostMatrix(
joints.Length, pairs.ToArray(), boneLogits.Sigmoid().Data );
MstSkeleton.IncreaseCostForOutsideBone( cost, joints, p => IsInside( p, voxels ) );
var (parents, _) = MstSkeleton.PrimMstSymmetry( cost, rootId, joints );
// ---- weights (predict_skinning) ----
var bones = SkinPipeline.GetBones( joints, parents );
var geoDist = VolumetricGeodesic.Compute( proxyNormalized, bones, surfaceGeodesic, voxels );
var (skinInput, nearestBones, mask) = SkinPipeline.BuildSkinInput( bones, geoDist );
var logits = net.SkinNet.Forward( pos, skinInput, tplEdges, geoEdges );
var boneWeights = SkinPipeline.FinalizeWeights(
logits, nearestBones, mask, bones.Length, tplEdges );
// ---- assemble: bone weights → joint weights, proxy → full mesh ----
var skeleton = BuildSkeleton( joints, parents, pivot, scale );
var jointOrder = skeleton.Order; // RigNet joint id → skeleton index
var proxyJointWeights = new float[vertexCount][];
for ( var v = 0; v < vertexCount; v++ )
{
var row = new float[joints.Length];
for ( var b = 0; b < bones.Length; b++ )
{
// assemble_skel_skin: a bone's weight binds to its starting joint
// (a leaf bone's to the leaf joint itself).
if ( boneWeights[v, b] > 0f )
row[bones[b].StartJoint] += boneWeights[v, b];
}
proxyJointWeights[v] = row;
}
var totalVertices = mesh.Positions.Length;
var indices = new int[totalVertices * 4];
var weights = new float[totalVertices * 4];
var top = new int[4];
for ( var v = 0; v < totalVertices; v++ )
{
var row = proxyJointWeights[vertexMap[v]];
var count = 0;
for ( var j = 0; j < row.Length; j++ )
{
if ( row[j] <= 0f )
continue;
if ( count < 4 )
top[count++] = j;
else
{
var weakest = 0;
for ( var k = 1; k < 4; k++ )
if ( row[top[k]] < row[top[weakest]] )
weakest = k;
if ( row[j] > row[top[weakest]] )
top[weakest] = j;
}
}
float total = 0;
for ( var k = 0; k < count; k++ )
total += row[top[k]];
if ( count == 0 || total <= 0f )
{
indices[v * 4] = jointOrder[FindNearestJoint( joints, normalized[vertexMap[v]] )];
weights[v * 4] = 1f;
continue;
}
for ( var k = 0; k < count; k++ )
{
indices[v * 4 + k] = jointOrder[top[k]];
weights[v * 4 + k] = row[top[k]] / total;
}
}
return new RigResult
{
Skeleton = skeleton.Skeleton,
Weights = new SkinWeights { BoneIndices = indices, Weights = weights },
SolverName = "deep-learning/rignet",
Degraded = false,
Explanation = $"RigNet predicted {joints.Length} joints "
+ $"({bones.Length} bones) from the mesh's shape.",
};
}
/// <summary>inside_check equivalent against our solid voxelization.</summary>
static bool IsInside( Vector3 point, VoxelGrid voxels )
{
var (x, y, z) = voxels.VoxelOf( point );
return voxels.IsSolid( x, y, z );
}
static int FindNearestJoint( Vector3[] joints, Vector3 point )
{
var best = 0;
for ( var j = 1; j < joints.Length; j++ )
if ( Vector3.DistanceSquared( joints[j], point )
< Vector3.DistanceSquared( joints[best], point ) )
best = j;
return best;
}
/// <summary>Skeleton in parent-before-child order, back in mesh space.</summary>
static (RigSkeleton Skeleton, int[] Order) BuildSkeleton(
Vector3[] joints, int[] parents, Vector3 pivot, float scale )
{
var order = new int[joints.Length]; // RigNet joint id → skeleton index
var skeleton = new RigSkeleton();
var queue = new Queue<int>();
var root = Array.IndexOf( parents, -1 );
queue.Enqueue( root );
var emitted = 0;
while ( queue.Count > 0 )
{
var j = queue.Dequeue();
order[j] = emitted++;
skeleton.Joints.Add( new RigJoint
{
Name = j == root ? "root" : $"joint_{j}",
Parent = j == root ? -1 : order[parents[j]],
Position = joints[j] / scale + pivot,
} );
for ( var c = 0; c < parents.Length; c++ )
if ( parents[c] == j )
queue.Enqueue( c );
}
return (skeleton, order);
}
}