Solver for "SkinTokens" token-rig format. Encodes a sampled point cloud with a learned mesh encoder, then performs a constrained two-phase greedy decode to produce a skeleton and per-bone skin codes, decodes skin codes via a VAE or falls back to a geodesic skinner, and returns a RigResult with skeleton and weights.
using AutoRig.Analyze;
using AutoRig.Dl;
using AutoRig.Dl.Nn;
using AutoRig.Dl.UniRig;
using AutoRig.Rig;
using AutoRig.Solve.Organic;
using AutoRig.Voxel;
namespace AutoRig.Solve;
using Vector3 = System.Numerics.Vector3;
/// <summary>
/// SkinTokens (TokenRig) inference, skeleton stage golden-verified vs PyTorch:
/// perceiver conditioning over the full 54000-point cloud, then a constrained
/// GREEDY two-phase decode on the Qwen3 backbone (the reference samples with
/// beams). Phase 1 is UniRig's skeleton grammar; the grammar EOS (258) switches
/// to phase 2, where FSQ skin-codebook ids stream until the unified EOS is
/// FORCED at exactly (length − switchIndex) == joints · 4 — reference behavior
/// verbatim, including its off-by-one (the unified EOS occupies the last skin
/// slot). Skin tokens are decoded structurally but weights come from our
/// geodesic skinner until the FSQ skin-VAE decoder is ported.
/// </summary>
public static class SkinTokensSolver
{
/// <summary>Reference context is 3192 positions; 512 go to the mesh prefix.</summary>
public const int MaxTokens = 2680;
/// <summary>SkinTokens' cls map differs from UniRig's, but "articulation"
/// (the open-corpus class the demo feeds every upload) lands on id 266.</summary>
public const int ClsArticulation = UniRigTokenizer.TokenClsArticulationXl;
public static Action<string> Progress { get; set; }
public static RigResult Rig( AnalysisResult analysis, SkinTokensModel model )
{
ArgumentNullException.ThrowIfNull( analysis );
ArgumentNullException.ThrowIfNull( model );
var mesh = analysis.Mesh;
// ---- conditioning ----
Progress?.Invoke( "input: sampling 54000-point cloud" );
var prepared = SkinTokensInput.Prepare( mesh );
Progress?.Invoke( "perceiver: encoding (512 queries over the full cloud)" );
var latents = model.MeshEncoder.Encode(
prepared.Points, prepared.Normals, prepared.SampledIndices );
var prefix = TransformerOps.RmsNorm(
latents.MatMul( model.OutputProjWeight.Transposed ).Add( model.OutputProjBias ),
model.OutputNormGamma );
// ---- two-phase constrained greedy decode ----
var ids = new List<int> { UniRigTokenizer.TokenBos, ClsArticulation };
var cache = new OptKvCache( model.Decoder.Layers.Length );
var logits = model.Decoder.Forward(
prefix.Concat( model.Decoder.EmbedTokens( ids.ToArray() ), 0 ),
cache, positionStart: 0 );
var switchAt = -1;
while ( ids.Count < MaxTokens )
{
var lastRow = (logits.Shape[0] - 1) * logits.Shape[1];
int best;
if ( switchAt < 0 )
{
var allowed = UniRigTokenizer.NextPossibleTokens( ids );
best = allowed[0];
foreach ( var id in allowed )
if ( logits.Data[lastRow + id] > logits.Data[lastRow + best] )
best = id;
}
else if ( ForcesUnifiedEos( ids, switchAt, model.TokensPerSkin ) )
{
best = model.UnifiedEos;
}
else
{
best = model.SwitchToken; // argmax over [258, unifiedEos)
for ( var id = model.SwitchToken; id < model.UnifiedEos; id++ )
if ( logits.Data[lastRow + id] > logits.Data[lastRow + best] )
best = id;
}
ids.Add( best );
if ( switchAt < 0 && best == model.SwitchToken )
{
switchAt = ids.Count - 1;
Progress?.Invoke( $"decode: skeleton closed at {ids.Count} tokens; skin phase" );
}
if ( best == model.UnifiedEos )
break;
if ( ids.Count % 64 == 0 )
Progress?.Invoke( $"decode: {ids.Count} tokens" );
logits = model.Decoder.Forward(
model.Decoder.EmbedTokens( new[] { best } ),
cache, positionStart: cache.Length );
}
if ( switchAt < 0 )
{
ids.Add( UniRigTokenizer.TokenEos ); // budget hit — close the skeleton
switchAt = ids.Count - 1;
}
var decoded = UniRigTokenizer.Detokenize( ids.Take( switchAt + 1 ).ToList() );
if ( decoded.Joints.Length < 2 )
throw new FormatException(
$"SkinTokens produced {decoded.Joints.Length} joint(s) - too few for a rig." );
// ---- skeleton back to mesh space, parent-before-child ----
var joints = decoded.Joints
.Select( p => p * prepared.Scale + prepared.Center ).ToArray();
var (skeleton, order) = BuildOrdered( joints, decoded.Parents );
skeleton.Validate();
// ---- NEURAL skinning: FSQ codes → VAE → per-vertex weights ----
var jointCount = decoded.Joints.Length;
var skinIds = ids.Skip( switchAt + 1 )
.Take( jointCount * model.TokensPerSkin )
.Select( id => id - model.SkeletonVocab )
.ToList();
SkinWeights weights = null;
if ( skinIds.Count == jointCount * model.TokensPerSkin )
{
Progress?.Invoke( $"skinning: FSQ VAE, {jointCount} bones" );
weights = NeuralSkin( model, prepared, mesh, skinIds, jointCount, order );
}
if ( weights is null )
{
// Reference behavior on a ragged sequence: no skin prediction.
Progress?.Invoke( "skinning: geodesic fallback (ragged skin tokens)" );
var resolution = mesh.TriangleCount < 50_000 ? 96 : 64;
var grid = VoxelGrid.Build( mesh, resolution );
weights = OrganicSkinner.Skin( mesh, grid, skeleton );
}
return new RigResult
{
Skeleton = skeleton,
Weights = weights,
SolverName = "deep-learning/skintokens",
Degraded = false,
Explanation = $"SkinTokens predicted {skeleton.Joints.Count} joints "
+ "with NEURAL skin weights (FSQ skin-VAE; greedy decode).",
};
}
/// <summary>tokenrig.decode: per-bone FSQ codes → VAE sigmoids on query
/// points → normalize across bones (asset.normalize_skin) → 4 influences.
/// Queries are mesh vertices (a strided subset past 4096, IDW-8 transfer —
/// the reference queries its sampled cloud and kd-transfers likewise).</summary>
static SkinWeights NeuralSkin(
SkinTokensModel model, SkinTokensInput.Prepared prepared,
AutoRig.Mesh.RigMesh mesh, List<int> skinIds, int jointCount, int[] order )
{
// Cond latents: 384 queries fps'd from a strided 1536-point subset
// (the reference subsamples unseeded — deterministic here).
const int CondTokens = 384;
var pre = new int[CondTokens * 4];
for ( var i = 0; i < pre.Length; i++ )
pre[i] = (int)((long)i * prepared.Points.Length / pre.Length);
var preXyz = new float[pre.Length * 3];
for ( var i = 0; i < pre.Length; i++ )
{
var p = prepared.Points[pre[i]];
preXyz[i * 3] = p.X;
preXyz[i * 3 + 1] = p.Y;
preXyz[i * 3 + 2] = p.Z;
}
var fps = AutoRig.Dl.RigNet.PointNet.FarthestPointSample(
Tensor.From( preXyz, pre.Length, 3 ), ratio: 0.25f );
var condIndices = fps.Select( i => pre[i] ).ToArray();
var cond = model.Vae.EncodeCond( prepared.Points, prepared.Normals, condIndices );
// Query points: normalized mesh vertices (subset when large).
var vertexCount = mesh.Positions.Length;
var queryCount = Math.Min( vertexCount, 4096 );
var queryVertex = new int[queryCount];
for ( var i = 0; i < queryCount; i++ )
queryVertex[i] = (int)((long)i * vertexCount / queryCount);
var queryPoints = new Vector3[queryCount];
var queryNormals = new Vector3[queryCount];
for ( var i = 0; i < queryCount; i++ )
{
var v = queryVertex[i];
queryPoints[i] = (mesh.Positions[v] - prepared.Center) / prepared.Scale;
queryNormals[i] = v < mesh.Normals.Length ? mesh.Normals[v] : Vector3.UnitZ;
}
var queryEmbed = SkinVae.Embed( queryPoints, queryNormals );
// Per bone: codes → weights on the query set.
var perBone = new float[jointCount][];
for ( var j = 0; j < jointCount; j++ )
{
var codes = model.Vae.DequantizeCodes(
skinIds.Skip( j * model.TokensPerSkin ).Take( model.TokensPerSkin ).ToList() );
perBone[j] = model.Vae.DecodeBone( codes, cond, queryEmbed );
Progress?.Invoke( $"skinning: bone {j + 1}/{jointCount}" );
}
// Per vertex: nearest queried vertices (IDW-8 when subsetted), then
// normalize across bones and keep the engine's 4 influences.
var boneIndices = new int[vertexCount * 4];
var boneWeights = new float[vertexCount * 4];
Concurrency.For( 0, vertexCount, v =>
{
Span<float> row = stackalloc float[64];
var bones = Math.Min( jointCount, 64 );
if ( queryCount == vertexCount )
{
for ( var j = 0; j < bones; j++ )
row[j] = perBone[j][v];
}
else
{
var p = (mesh.Positions[v] - prepared.Center) / prepared.Scale;
Span<int> nearest = stackalloc int[8];
Span<float> distance = stackalloc float[8];
var used = 0;
for ( var q = 0; q < queryCount; q++ )
{
var d = Vector3.DistanceSquared( p, queryPoints[q] );
if ( used < 8 )
{
nearest[used] = q;
distance[used++] = d;
}
else
{
var worst = 0;
for ( var k = 1; k < 8; k++ )
if ( distance[k] > distance[worst] )
worst = k;
if ( d < distance[worst] )
{
nearest[worst] = q;
distance[worst] = d;
}
}
}
float totalInverse = 0;
for ( var k = 0; k < used; k++ )
{
distance[k] = 1f / (MathF.Sqrt( distance[k] ) + 1e-8f);
totalInverse += distance[k];
}
for ( var j = 0; j < bones; j++ )
{
float value = 0;
for ( var k = 0; k < used; k++ )
value += distance[k] * perBone[j][nearest[k]];
row[j] = value / totalInverse;
}
}
// normalize_skin + top-4 (engine limit), remapped to BFS order.
Span<int> top = stackalloc int[4];
Span<float> topWeight = stackalloc float[4];
var kept = 0;
for ( var j = 0; j < bones; j++ )
{
if ( kept < 4 )
{
top[kept] = j;
topWeight[kept++] = row[j];
}
else
{
var weakest = 0;
for ( var k = 1; k < 4; k++ )
if ( topWeight[k] < topWeight[weakest] )
weakest = k;
if ( row[j] > topWeight[weakest] )
{
top[weakest] = j;
topWeight[weakest] = row[j];
}
}
}
float total = 0;
for ( var k = 0; k < kept; k++ )
total += topWeight[k];
if ( total <= 1e-8f )
{
topWeight[0] = 1f;
total = 1f;
kept = 1;
}
for ( var k = 0; k < kept; k++ )
{
boneIndices[v * 4 + k] = order[top[k]];
boneWeights[v * 4 + k] = topWeight[k] / total;
}
} );
return new SkinWeights { BoneIndices = boneIndices, Weights = boneWeights };
}
/// <summary>The reference's VocabSwitchingLogitsProcessor termination rule:
/// with 258 at <paramref name="switchAt"/> in the FULL sequence (start
/// tokens included), the unified EOS is banned until
/// (length − switchAt) == joints · tokensPerSkin, then forced.</summary>
internal static bool ForcesUnifiedEos(
IReadOnlyList<int> ids, int switchAt, int tokensPerSkin )
=> ids.Count - switchAt
== UniRigTokenizer.BonesInSequence( ids ) * tokensPerSkin;
static (RigSkeleton Skeleton, int[] Order) BuildOrdered( Vector3[] joints, int[] parents )
{
var order = new int[joints.Length];
var skeleton = new RigSkeleton();
var root = Array.IndexOf( parents, -1 );
var queue = new Queue<int>();
queue.Enqueue( root );
var emitted = 0;
while ( queue.Count > 0 )
{
var j = queue.Dequeue();
order[j] = emitted++;
skeleton.Joints.Add( new RigJoint
{
Name = j == root ? "root" : $"bone_{j}",
Parent = j == root ? -1 : order[parents[j]],
Position = joints[j],
} );
for ( var c = 0; c < parents.Length; c++ )
if ( c != root && parents[c] == j )
queue.Enqueue( c );
}
return (skeleton, order);
}
}