AutoRig/Dl/UniRig/UniRigTokenizer.cs

Tokenizer and detokenizer for UniRig skeleton tokens. Provides token constants, discretize/undiscretize helpers, constrained decode mask (NextPossibleTokens), utilities to count completed bones, and Detokenize to convert a token sequence into a UniRigSkeleton (joints, parent indices, tails).

NetworkingFile Access
namespace AutoRig.Dl.UniRig;

using Vector3 = System.Numerics.Vector3;

/// <summary>Decoded skeleton from a UniRig token sequence.</summary>
public sealed class UniRigSkeleton
{
    public required Vector3[] Joints;
    public required int[] Parents;       // -1 = root
    public required Vector3[] Tails;     // per-joint tail (leaf/branch tails extruded)
}

/// <summary>
/// UniRig's skeleton tokenizer (transcribed from src/tokenizer/tokenizer_part.py +
/// spec.make_skeleton, config tokenizer_parts_articulationxl_256): 256 coordinate
/// bins in [-1,1], BOS/EOS/branch/pad/spring/part/class tokens, the constrained-
/// decode state machine, and detokenization to joints/parents/tails.
/// </summary>
public sealed class UniRigTokenizer
{
    public const int NumDiscrete = 256;
    public const float RangeLo = -1f, RangeHi = 1f;

    public const int TokenBranch = NumDiscrete + 0;   // 256
    public const int TokenBos = NumDiscrete + 1;      // 257
    public const int TokenEos = NumDiscrete + 2;      // 258
    public const int TokenPad = NumDiscrete + 3;      // 259
    public const int TokenSpring = NumDiscrete + 4;   // 260
    public const int TokenPartBody = NumDiscrete + 5; // 261 (parts: body 0, hand 1)
    public const int TokenPartHand = NumDiscrete + 6; // 262
    public const int TokenClsNone = NumDiscrete + 7;  // 263
    public const int TokenClsVroid = NumDiscrete + 8;         // 264
    public const int TokenClsMixamo = NumDiscrete + 9;        // 265
    public const int TokenClsArticulationXl = NumDiscrete + 10; // 266
    public const int VocabSize = NumDiscrete + 11;    // 267

    public static int Discretize( float value )
    {
        var t = (value - RangeLo) / (RangeHi - RangeLo) * NumDiscrete;
        return Math.Clamp( (int)MathF.Round( t ), 0, NumDiscrete - 1 );
    }

    public static float Undiscretize( int bin )
        => (bin + 0.5f) / NumDiscrete * (RangeHi - RangeLo) + RangeLo;

    enum State
    {
        ExpectBos, ExpectClsOrPartOrJoint, ExpectPartOrJoint,
        ExpectJoint, ExpectJoint2, ExpectJoint3, ExpectBranchOrPartOrJoint,
    }

    static bool IsCls( int id ) => id is >= TokenClsVroid and <= TokenClsArticulationXl;
    static bool IsPart( int id ) => id is TokenPartBody or TokenPartHand;

    static State Advance( State state, int id ) => state switch
    {
        State.ExpectBos => State.ExpectClsOrPartOrJoint,
        State.ExpectClsOrPartOrJoint => id < NumDiscrete
            ? State.ExpectJoint2
            : id == TokenClsNone || IsCls( id )
                ? State.ExpectPartOrJoint
                : State.ExpectJoint,
        State.ExpectPartOrJoint => id < NumDiscrete ? State.ExpectJoint2 : State.ExpectPartOrJoint,
        State.ExpectJoint2 => State.ExpectJoint3,
        State.ExpectJoint3 => State.ExpectBranchOrPartOrJoint,
        State.ExpectBranchOrPartOrJoint => id == TokenBranch
            ? State.ExpectJoint
            : id < NumDiscrete
                ? State.ExpectJoint2
                : State.ExpectJoint,
        State.ExpectJoint => State.ExpectJoint2,
        _ => throw new FormatException( $"UniRig tokenizer: bad state {state}" ),
    };

    /// <summary>
    /// Allowed next tokens given the sequence so far — the constrained-decode mask
    /// (next_posible_token). Greedy argmax over logits restricted to this set
    /// reproduces UniRig's LogitsProcessor deterministically.
    /// </summary>
    public static List<int> NextPossibleTokens( IReadOnlyList<int> ids )
    {
        if ( ids.Count == 0 )
            return new List<int> { TokenBos };

        var state = State.ExpectBos;
        foreach ( var id in ids )
        {
            if ( state == State.ExpectBos && id != TokenBos )
                throw new FormatException( "UniRig tokenizer: sequence does not start with BOS." );
            state = Advance( state, id );
        }

        var allowed = new List<int>();
        void AddCls()
        {
            allowed.Add( TokenClsNone );
            allowed.Add( TokenClsVroid );
            allowed.Add( TokenClsMixamo );
            allowed.Add( TokenClsArticulationXl );
        }
        void AddPart()
        {
            allowed.Add( TokenSpring );
            allowed.Add( TokenPartBody );
            allowed.Add( TokenPartHand );
        }
        void AddJoint()
        {
            for ( var i = 0; i < NumDiscrete; i++ )
                allowed.Add( i );
        }

        switch ( state )
        {
            case State.ExpectClsOrPartOrJoint:
                AddCls(); AddPart(); AddJoint();
                break;
            case State.ExpectPartOrJoint:
                AddPart(); AddJoint(); allowed.Add( TokenEos );
                break;
            case State.ExpectJoint or State.ExpectJoint2 or State.ExpectJoint3:
                AddJoint();
                break;
            case State.ExpectBranchOrPartOrJoint:
                AddJoint(); AddPart(); allowed.Add( TokenBranch ); allowed.Add( TokenEos );
                break;
            default:
                throw new FormatException( $"UniRig tokenizer: bad decode state {state}" );
        }
        return allowed;
    }

    /// <summary>Completed bones in a (possibly partial) sequence — the reference's
    /// bones_in_sequence: a bone completes on its 3rd coordinate token, but a
    /// branch's FIRST triple (the re-stated parent position) does NOT count, so
    /// the result equals the number of joints.</summary>
    public static int BonesInSequence( IReadOnlyList<int> ids )
    {
        var bones = 0;
        var isBranch = false;
        var state = State.ExpectBos;
        foreach ( var id in ids )
        {
            if ( state == State.ExpectBos && id != TokenBos )
                throw new FormatException( "UniRig tokenizer: sequence does not start with BOS." );
            if ( state == State.ExpectBranchOrPartOrJoint && id == TokenBranch )
                isBranch = true;
            var previous = state;
            state = Advance( state, id );
            if ( previous == State.ExpectJoint3 )
            {
                if ( !isBranch )
                    bones++;
                isBranch = false;
            }
            if ( id == TokenEos )
                break;
        }
        return bones;
    }

    /// <summary>detokenize + make_skeleton: token sequence → joints/parents/tails.</summary>
    public static UniRigSkeleton Detokenize( IReadOnlyList<int> ids )
    {
        if ( ids.Count < 2 || ids[0] != TokenBos )
            throw new FormatException( "UniRig tokenizer: sequence must start with BOS." );
        var end = ids.Count;
        while ( end > 0 && ids[end - 1] == TokenPad )
            end--;
        if ( ids[end - 1] != TokenEos )
            throw new FormatException( "UniRig tokenizer: sequence must end with EOS." );

        var joints = new List<Vector3>();
        var parentPositions = new List<Vector3>();
        var tailsByBone = new Dictionary<int, Vector3>();
        var isBranch = false;
        Vector3? lastJoint = null;

        Vector3 Coords( int at ) => new(
            Undiscretize( ids[at] ), Undiscretize( ids[at + 1] ), Undiscretize( ids[at + 2] ) );

        var i = 1;
        while ( i < end - 1 )
        {
            var id = ids[i];
            if ( id < NumDiscrete )
            {
                Vector3 current;
                if ( isBranch )
                {
                    parentPositions.Add( Coords( i ) );
                    current = Coords( i + 3 );
                    joints.Add( current );
                    i += 6;
                }
                else
                {
                    current = Coords( i );
                    joints.Add( current );
                    parentPositions.Add( lastJoint ?? current );   // root parents itself
                    i += 3;
                }
                if ( lastJoint is not null )
                    tailsByBone[joints.Count - 2] = current;
                lastJoint = current;
                isBranch = false;
            }
            else if ( id == TokenBranch )
            {
                isBranch = true;
                lastJoint = null;
                i++;
            }
            else if ( id == TokenSpring || IsPart( id ) || IsCls( id ) || id == TokenClsNone )
            {
                i++;   // parts/class recorded upstream; geometry unaffected
            }
            else
            {
                throw new FormatException( $"UniRig tokenizer: unexpected token {id}." );
            }
        }

        // make_skeleton: parent of joint i = earlier bone whose joint is nearest to
        // the recorded parent position (scan i-1 → 0, strictly-smaller wins, so the
        // LATER index wins ties — reversed scan order, matching the reference).
        var count = joints.Count;
        var parents = new int[count];
        for ( var j = 0; j < count; j++ )
        {
            if ( j == 0 )
            {
                parents[0] = -1;
                continue;
            }
            var best = -1;
            var bestDistance = float.MaxValue;
            for ( var k = j - 1; k >= 0; k-- )
            {
                var d = Vector3.DistanceSquared( joints[k], parentPositions[j] );
                if ( d < bestDistance )
                {
                    bestDistance = d;
                    best = k;
                }
            }
            parents[j] = best;
        }

        // Tails: recorded child position, else extruded for leaves/branches
        // (extrude_scale 0.5 along the bone direction; z-up fallback).
        var childCount = new int[count];
        foreach ( var p in parents )
            if ( p >= 0 )
                childCount[p]++;
        var tails = new Vector3[count];
        for ( var j = 0; j < count; j++ )
        {
            if ( childCount[j] == 0 )   // leaf
            {
                var direction = parents[j] >= 0 ? joints[j] - joints[parents[j]] : default;
                if ( direction.LengthSquared() <= 1e-18f )
                    direction = new Vector3( 0f, 0f, 1f );
                tails[j] = joints[j] + direction * 0.5f;
            }
            else if ( childCount[j] > 1 )   // branch
            {
                Vector3 direction;
                if ( parents[j] < 0 )
                {
                    float averageLength = 0;
                    for ( var c = 0; c < count; c++ )
                        if ( parents[c] == j )
                            averageLength += Vector3.Distance( joints[j], joints[c] );
                    averageLength /= childCount[j];
                    tails[j] = joints[j] + new Vector3( 0f, 0f, 0.5f * averageLength );
                    continue;
                }
                direction = joints[j] - joints[parents[j]];
                if ( direction.LengthSquared() <= 1e-18f )
                    direction = new Vector3( 0f, 0f, 1f );
                tails[j] = joints[j] + direction * 0.5f;
            }
            else
            {
                tails[j] = tailsByBone.TryGetValue( j, out var tail ) ? tail : joints[j];
            }
        }

        return new UniRigSkeleton
        {
            Joints = joints.ToArray(),
            Parents = parents,
            Tails = tails,
        };
    }
}