Tokenizer and detokenizer for UniRig skeleton token sequences. It encodes discretization helpers, a constrained decode state machine, allowed-next-token mask, bone counting, and conversion of token sequences into a UniRigSkeleton with joints, parents, and tails.
namespace AutoRig.Dl.UniRig;
using Vector3 = System.Numerics.Vector3;
/// <summary>Decoded skeleton from a UniRig token sequence.</summary>
public sealed class UniRigSkeleton
{
public required Vector3[] Joints;
public required int[] Parents; // -1 = root
public required Vector3[] Tails; // per-joint tail (leaf/branch tails extruded)
}
/// <summary>
/// UniRig's skeleton tokenizer (transcribed from src/tokenizer/tokenizer_part.py +
/// spec.make_skeleton, config tokenizer_parts_articulationxl_256): 256 coordinate
/// bins in [-1,1], BOS/EOS/branch/pad/spring/part/class tokens, the constrained-
/// decode state machine, and detokenization to joints/parents/tails.
/// </summary>
public sealed class UniRigTokenizer
{
public const int NumDiscrete = 256;
public const float RangeLo = -1f, RangeHi = 1f;
public const int TokenBranch = NumDiscrete + 0; // 256
public const int TokenBos = NumDiscrete + 1; // 257
public const int TokenEos = NumDiscrete + 2; // 258
public const int TokenPad = NumDiscrete + 3; // 259
public const int TokenSpring = NumDiscrete + 4; // 260
public const int TokenPartBody = NumDiscrete + 5; // 261 (parts: body 0, hand 1)
public const int TokenPartHand = NumDiscrete + 6; // 262
public const int TokenClsNone = NumDiscrete + 7; // 263
public const int TokenClsVroid = NumDiscrete + 8; // 264
public const int TokenClsMixamo = NumDiscrete + 9; // 265
public const int TokenClsArticulationXl = NumDiscrete + 10; // 266
public const int VocabSize = NumDiscrete + 11; // 267
public static int Discretize( float value )
{
var t = (value - RangeLo) / (RangeHi - RangeLo) * NumDiscrete;
return Math.Clamp( (int)MathF.Round( t ), 0, NumDiscrete - 1 );
}
public static float Undiscretize( int bin )
=> (bin + 0.5f) / NumDiscrete * (RangeHi - RangeLo) + RangeLo;
enum State
{
ExpectBos, ExpectClsOrPartOrJoint, ExpectPartOrJoint,
ExpectJoint, ExpectJoint2, ExpectJoint3, ExpectBranchOrPartOrJoint,
}
static bool IsCls( int id ) => id is >= TokenClsVroid and <= TokenClsArticulationXl;
static bool IsPart( int id ) => id is TokenPartBody or TokenPartHand;
static State Advance( State state, int id ) => state switch
{
State.ExpectBos => State.ExpectClsOrPartOrJoint,
State.ExpectClsOrPartOrJoint => id < NumDiscrete
? State.ExpectJoint2
: id == TokenClsNone || IsCls( id )
? State.ExpectPartOrJoint
: State.ExpectJoint,
State.ExpectPartOrJoint => id < NumDiscrete ? State.ExpectJoint2 : State.ExpectPartOrJoint,
State.ExpectJoint2 => State.ExpectJoint3,
State.ExpectJoint3 => State.ExpectBranchOrPartOrJoint,
State.ExpectBranchOrPartOrJoint => id == TokenBranch
? State.ExpectJoint
: id < NumDiscrete
? State.ExpectJoint2
: State.ExpectJoint,
State.ExpectJoint => State.ExpectJoint2,
_ => throw new FormatException( $"UniRig tokenizer: bad state {state}" ),
};
/// <summary>
/// Allowed next tokens given the sequence so far — the constrained-decode mask
/// (next_posible_token). Greedy argmax over logits restricted to this set
/// reproduces UniRig's LogitsProcessor deterministically.
/// </summary>
public static List<int> NextPossibleTokens( IReadOnlyList<int> ids )
{
if ( ids.Count == 0 )
return new List<int> { TokenBos };
var state = State.ExpectBos;
foreach ( var id in ids )
{
if ( state == State.ExpectBos && id != TokenBos )
throw new FormatException( "UniRig tokenizer: sequence does not start with BOS." );
state = Advance( state, id );
}
var allowed = new List<int>();
void AddCls()
{
allowed.Add( TokenClsNone );
allowed.Add( TokenClsVroid );
allowed.Add( TokenClsMixamo );
allowed.Add( TokenClsArticulationXl );
}
void AddPart()
{
allowed.Add( TokenSpring );
allowed.Add( TokenPartBody );
allowed.Add( TokenPartHand );
}
void AddJoint()
{
for ( var i = 0; i < NumDiscrete; i++ )
allowed.Add( i );
}
switch ( state )
{
case State.ExpectClsOrPartOrJoint:
AddCls(); AddPart(); AddJoint();
break;
case State.ExpectPartOrJoint:
AddPart(); AddJoint(); allowed.Add( TokenEos );
break;
case State.ExpectJoint or State.ExpectJoint2 or State.ExpectJoint3:
AddJoint();
break;
case State.ExpectBranchOrPartOrJoint:
AddJoint(); AddPart(); allowed.Add( TokenBranch ); allowed.Add( TokenEos );
break;
default:
throw new FormatException( $"UniRig tokenizer: bad decode state {state}" );
}
return allowed;
}
/// <summary>Completed bones in a (possibly partial) sequence — the reference's
/// bones_in_sequence: a bone completes on its 3rd coordinate token, but a
/// branch's FIRST triple (the re-stated parent position) does NOT count, so
/// the result equals the number of joints.</summary>
public static int BonesInSequence( IReadOnlyList<int> ids )
{
var bones = 0;
var isBranch = false;
var state = State.ExpectBos;
foreach ( var id in ids )
{
if ( state == State.ExpectBos && id != TokenBos )
throw new FormatException( "UniRig tokenizer: sequence does not start with BOS." );
if ( state == State.ExpectBranchOrPartOrJoint && id == TokenBranch )
isBranch = true;
var previous = state;
state = Advance( state, id );
if ( previous == State.ExpectJoint3 )
{
if ( !isBranch )
bones++;
isBranch = false;
}
if ( id == TokenEos )
break;
}
return bones;
}
/// <summary>detokenize + make_skeleton: token sequence → joints/parents/tails.</summary>
public static UniRigSkeleton Detokenize( IReadOnlyList<int> ids )
{
if ( ids.Count < 2 || ids[0] != TokenBos )
throw new FormatException( "UniRig tokenizer: sequence must start with BOS." );
var end = ids.Count;
while ( end > 0 && ids[end - 1] == TokenPad )
end--;
if ( ids[end - 1] != TokenEos )
throw new FormatException( "UniRig tokenizer: sequence must end with EOS." );
var joints = new List<Vector3>();
var parentPositions = new List<Vector3>();
var tailsByBone = new Dictionary<int, Vector3>();
var isBranch = false;
Vector3? lastJoint = null;
Vector3 Coords( int at ) => new(
Undiscretize( ids[at] ), Undiscretize( ids[at + 1] ), Undiscretize( ids[at + 2] ) );
var i = 1;
while ( i < end - 1 )
{
var id = ids[i];
if ( id < NumDiscrete )
{
Vector3 current;
if ( isBranch )
{
parentPositions.Add( Coords( i ) );
current = Coords( i + 3 );
joints.Add( current );
i += 6;
}
else
{
current = Coords( i );
joints.Add( current );
parentPositions.Add( lastJoint ?? current ); // root parents itself
i += 3;
}
if ( lastJoint is not null )
tailsByBone[joints.Count - 2] = current;
lastJoint = current;
isBranch = false;
}
else if ( id == TokenBranch )
{
isBranch = true;
lastJoint = null;
i++;
}
else if ( id == TokenSpring || IsPart( id ) || IsCls( id ) || id == TokenClsNone )
{
i++; // parts/class recorded upstream; geometry unaffected
}
else
{
throw new FormatException( $"UniRig tokenizer: unexpected token {id}." );
}
}
// make_skeleton: parent of joint i = earlier bone whose joint is nearest to
// the recorded parent position (scan i-1 → 0, strictly-smaller wins, so the
// LATER index wins ties — reversed scan order, matching the reference).
var count = joints.Count;
var parents = new int[count];
for ( var j = 0; j < count; j++ )
{
if ( j == 0 )
{
parents[0] = -1;
continue;
}
var best = -1;
var bestDistance = float.MaxValue;
for ( var k = j - 1; k >= 0; k-- )
{
var d = Vector3.DistanceSquared( joints[k], parentPositions[j] );
if ( d < bestDistance )
{
bestDistance = d;
best = k;
}
}
parents[j] = best;
}
// Tails: recorded child position, else extruded for leaves/branches
// (extrude_scale 0.5 along the bone direction; z-up fallback).
var childCount = new int[count];
foreach ( var p in parents )
if ( p >= 0 )
childCount[p]++;
var tails = new Vector3[count];
for ( var j = 0; j < count; j++ )
{
if ( childCount[j] == 0 ) // leaf
{
var direction = parents[j] >= 0 ? joints[j] - joints[parents[j]] : default;
if ( direction.LengthSquared() <= 1e-18f )
direction = new Vector3( 0f, 0f, 1f );
tails[j] = joints[j] + direction * 0.5f;
}
else if ( childCount[j] > 1 ) // branch
{
Vector3 direction;
if ( parents[j] < 0 )
{
float averageLength = 0;
for ( var c = 0; c < count; c++ )
if ( parents[c] == j )
averageLength += Vector3.Distance( joints[j], joints[c] );
averageLength /= childCount[j];
tails[j] = joints[j] + new Vector3( 0f, 0f, 0.5f * averageLength );
continue;
}
direction = joints[j] - joints[parents[j]];
if ( direction.LengthSquared() <= 1e-18f )
direction = new Vector3( 0f, 0f, 1f );
tails[j] = joints[j] + direction * 0.5f;
}
else
{
tails[j] = tailsByBone.TryGetValue( j, out var tail ) ? tail : joints[j];
}
}
return new UniRigSkeleton
{
Joints = joints.ToArray(),
Parents = parents,
Tails = tails,
};
}
}