Solver component that runs UniRig skeleton inference for a mesh. It encodes a sampled point cloud, runs an autoregressive greedy decode of tokenized skeletons through a decoder model, detokenizes to joint positions, constructs a parent-before-child RigSkeleton, and computes interim skin weights via a geodesic voxel skinner.
using AutoRig.Analyze;
using AutoRig.Dl;
using AutoRig.Dl.Nn;
using AutoRig.Dl.UniRig;
using AutoRig.Rig;
using AutoRig.Solve.Organic;
using AutoRig.Voxel;
namespace AutoRig.Solve;
using Vector3 = System.Numerics.Vector3;
/// <summary>
/// UniRig inference (skeleton stage, golden-verified vs PyTorch): point-cloud
/// perceiver conditioning, constrained GREEDY autoregressive skeleton decode
/// (deterministic; the reference samples with beams), detokenization back to mesh
/// space. Skin weights come from our geodesic voxel skinner until UniRig's skin
/// network is ported (documented interim).
/// </summary>
public static class UniRigSolver
{
/// <summary>Decode budget: ~2900 tokens fit after the 1024-latent prefix; cap
/// well below for interactivity (~256 bones).</summary>
public const int MaxTokens = 1024;
/// <summary>Optional progress sink (stage names + decode ticks) — long runs
/// were unobservable, which made gate diagnosis guesswork.</summary>
public static Action<string> Progress { get; set; }
public static RigResult Rig( AnalysisResult analysis, UniRigSkeletonModel model )
{
ArgumentNullException.ThrowIfNull( analysis );
ArgumentNullException.ThrowIfNull( model );
var mesh = analysis.Mesh;
// ---- conditioning ----
Progress?.Invoke( "input: sampling point cloud" );
var prepared = UniRigInput.Prepare( mesh );
Progress?.Invoke( "perceiver: encoding 4096-point cloud" );
var latents = model.MeshEncoder.Encode(
prepared.PrePoints, prepared.PreNormals, prepared.SampledIndices );
var prefix = latents.MatMul( model.OutputProjWeight.Transposed )
.Add( model.OutputProjBias );
// ---- constrained greedy decode (no cls, like the reference's no_cls path) ----
var ids = new List<int> { UniRigTokenizer.TokenBos };
var cache = new OptKvCache( model.Decoder.Layers.Length );
var logits = model.Decoder.Forward(
prefix.Concat( model.Decoder.EmbedTokens( new[] { UniRigTokenizer.TokenBos } ), 0 ),
cache, positionStart: 0 );
while ( ids.Count < MaxTokens )
{
var allowed = UniRigTokenizer.NextPossibleTokens( ids );
var lastRow = (logits.Shape[0] - 1) * logits.Shape[1];
var best = allowed[0];
foreach ( var id in allowed )
if ( logits.Data[lastRow + id] > logits.Data[lastRow + best] )
best = id;
ids.Add( best );
if ( ids.Count % 64 == 0 )
Progress?.Invoke( $"decode: {ids.Count} tokens" );
if ( best == UniRigTokenizer.TokenEos )
break;
logits = model.Decoder.Forward(
model.Decoder.EmbedTokens( new[] { best } ),
cache, positionStart: cache.Length );
}
if ( ids[^1] != UniRigTokenizer.TokenEos )
ids.Add( UniRigTokenizer.TokenEos ); // budget hit — close the sequence
var decoded = UniRigTokenizer.Detokenize( ids );
if ( decoded.Joints.Length < 2 )
throw new FormatException(
$"UniRig produced {decoded.Joints.Length} joint(s) - too few for a rig." );
// ---- skeleton back to mesh space, parent-before-child ----
var joints = decoded.Joints
.Select( p => p * prepared.Scale + prepared.Center ).ToArray();
var (skeleton, _) = BuildOrdered( joints, decoded.Parents );
skeleton.Validate();
// ---- interim skinning: geodesic voxel weights on the decoded skeleton ----
Progress?.Invoke(
$"skinning: {skeleton.Joints.Count} joints (decoded {ids.Count} tokens)" );
var resolution = mesh.TriangleCount < 50_000 ? 96 : 64;
var grid = VoxelGrid.Build( mesh, resolution );
var weights = OrganicSkinner.Skin( mesh, grid, skeleton );
return new RigResult
{
Skeleton = skeleton,
Weights = weights,
SolverName = "deep-learning/unirig",
Degraded = false,
Explanation = $"UniRig predicted {skeleton.Joints.Count} joints "
+ "(greedy decode; weights from geodesic skinning).",
};
}
static (RigSkeleton Skeleton, int[] Order) BuildOrdered( Vector3[] joints, int[] parents )
{
var order = new int[joints.Length];
var skeleton = new RigSkeleton();
var root = Array.IndexOf( parents, -1 );
var queue = new Queue<int>();
queue.Enqueue( root );
var emitted = 0;
while ( queue.Count > 0 )
{
var j = queue.Dequeue();
order[j] = emitted++;
skeleton.Joints.Add( new RigJoint
{
Name = j == root ? "root" : $"bone_{j}",
Parent = j == root ? -1 : order[parents[j]],
Position = joints[j],
} );
for ( var c = 0; c < parents.Length; c++ )
if ( c != root && parents[c] == j )
queue.Enqueue( c );
}
return (skeleton, order);
}
}