AutoRig/Dl/UniRig/SkinVae.cs

Neural model implementation and loader for the Skin VAE used by SkinTokens. Defines SkinVaeBlock (transformer-style cross/self attention block and feedforward) and SkinVae (encoding, FSQ dequantization, per-bone decoding and embedding utilities). SkinVaeLoader constructs a SkinVae from a dictionary of named Tensor weights.

Native Interop
using AutoRig.Dl.Nn;

namespace AutoRig.Dl.UniRig;

using AutoRig.Dl;
using Vector3 = System.Numerics.Vector3;

/// <summary>
/// SkinTokens' FSQ skin-VAE (vae.model.* inside the grpo checkpoint): the
/// network that turns the decoded skin TOKENS into real per-point weights.
/// Tripo2/DiT blocks (pre-LN, heads 8, separate q/k/v without bias, out proj
/// with bias, LayerNorm'd cross K/V, GELU-proj FFN ×4). Cond path: one cross
/// block (fps queries ← whole cloud) + self blocks + LN + cond_quant. Decode:
/// FSQ codes (base-8 digits, (d−4)/4, project_out) + cond through post_quant,
/// self blocks once per bone, then point queries cross-attend → sigmoid.
/// </summary>
public sealed class SkinVaeBlock
{
    public required Tensor NormGamma, NormBeta;         // norm1 (self) / norm2 (cross)
    public Tensor CrossNormGamma, CrossNormBeta;        // attn2.norm_cross (cross only)
    public required Tensor ToQ, ToK, ToV;               // no bias
    public required Tensor OutWeight, OutBias;          // to_out.0
    public required Tensor Norm3Gamma, Norm3Beta;
    public required Tensor FfProjWeight, FfProjBias;    // 768 → 3072, GELU
    public required Tensor FfOutWeight, FfOutBias;      // 3072 → 768
    public required bool IsCross;

    public const int Heads = 8;

    public Tensor Forward( Tensor x, Tensor encoderStates = null )
    {
        Tensor attended;
        var h = TransformerOps.LayerNorm( x, NormGamma, NormBeta );
        if ( IsCross )
        {
            var e = TransformerOps.LayerNorm( encoderStates, CrossNormGamma, CrossNormBeta );
            attended = TransformerOps.Attention(
                h.MatMul( ToQ.Transposed ),
                e.MatMul( ToK.Transposed ),
                e.MatMul( ToV.Transposed ),
                Heads, causal: false );
        }
        else
        {
            attended = TransformerOps.Attention(
                h.MatMul( ToQ.Transposed ),
                h.MatMul( ToK.Transposed ),
                h.MatMul( ToV.Transposed ),
                Heads, causal: false );
        }
        x = x.Add( attended.MatMul( OutWeight.Transposed ).Add( OutBias ) );

        h = TransformerOps.LayerNorm( x, Norm3Gamma, Norm3Beta );
        h = TransformerOps.Gelu( h.MatMul( FfProjWeight.Transposed ).Add( FfProjBias ) );
        return x.Add( h.MatMul( FfOutWeight.Transposed ).Add( FfOutBias ) );
    }
}

public sealed class SkinVae
{
    public const int FsqLevels = 8;
    public const int FsqDims = 5;

    public required Tensor CondProjInWeight, CondProjInBias;   // 54 → 768
    public required SkinVaeBlock[] CondBlocks;                 // [0] cross, rest self
    public required Tensor CondNormOutGamma, CondNormOutBeta;
    public required Tensor CondQuantWeight, CondQuantBias;     // 768 → 512

    public required Tensor FsqProjOutWeight, FsqProjOutBias;   // 5 → 512
    public required Tensor PostQuantWeight, PostQuantBias;     // 512 → 768

    public required SkinVaeBlock[] DecoderSelfBlocks;
    public required SkinVaeBlock DecoderCrossBlock;
    public required Tensor DecoderProjQueryWeight, DecoderProjQueryBias;   // 54 → 768
    public required Tensor DecoderNormOutGamma, DecoderNormOutBeta;
    public required Tensor DecoderProjOutWeight, DecoderProjOutBias;       // 768 → 1

    /// <summary>[fourier(p) (51) , normal (3)] rows — the same layout the
    /// Michelangelo perceiver uses (logspace 2^0..2^7, include_pi false).</summary>
    public static Tensor Embed( Vector3[] points, Vector3[] normals )
    {
        const int fourierDim = 51;
        var rows = points.Length;
        var input = new float[rows * (fourierDim + 3)];
        for ( var r = 0; r < rows; r++ )
        {
            var at = r * (fourierDim + 3);
            var p = new[] { points[r].X, points[r].Y, points[r].Z };
            input[at + 0] = p[0];
            input[at + 1] = p[1];
            input[at + 2] = p[2];
            for ( var d = 0; d < 3; d++ )
                for ( var f = 0; f < 8; f++ )
                {
                    var scaled = p[d] * (1 << f);
                    input[at + 3 + d * 8 + f] = MathF.Sin( scaled );
                    input[at + 3 + 24 + d * 8 + f] = MathF.Cos( scaled );
                }
            input[at + fourierDim + 0] = normals[r].X;
            input[at + fourierDim + 1] = normals[r].Y;
            input[at + fourierDim + 2] = normals[r].Z;
        }
        return Tensor.From( input, rows, fourierDim + 3 );
    }

    /// <summary>Cond latents (Q × 512) from the cloud: queries = the picked
    /// rows, K/V = every point (reference fps-samples queries unseeded; we
    /// pick deterministically).</summary>
    public Tensor EncodeCond( Vector3[] points, Vector3[] normals, int[] queryIndices )
    {
        var kvInput = Embed( points, normals );
        var kv = kvInput.MatMul( CondProjInWeight.Transposed ).Add( CondProjInBias );

        var queryRows = new float[queryIndices.Length * kvInput.Shape[1]];
        for ( var i = 0; i < queryIndices.Length; i++ )
            Array.Copy( kvInput.Data, queryIndices[i] * kvInput.Shape[1],
                queryRows, i * kvInput.Shape[1], kvInput.Shape[1] );
        var x = Tensor.From( queryRows, queryIndices.Length, kvInput.Shape[1] )
            .MatMul( CondProjInWeight.Transposed ).Add( CondProjInBias );

        x = CondBlocks[0].Forward( x, kv );
        for ( var i = 1; i < CondBlocks.Length; i++ )
            x = CondBlocks[i].Forward( x );
        x = TransformerOps.LayerNorm( x, CondNormOutGamma, CondNormOutBeta );
        return x.MatMul( CondQuantWeight.Transposed ).Add( CondQuantBias );
    }

    /// <summary>FSQ dequantize: token indices (already − skeleton vocab) →
    /// base-8 digits → (d − 4)/4 → project_out. Out-of-range indices (the
    /// unified-EOS slot the reference leaves in the last position) wrap via
    /// the digit decomposition exactly as torch does.</summary>
    public Tensor DequantizeCodes( IReadOnlyList<int> indices )
    {
        var codes = new float[indices.Count * FsqDims];
        for ( var i = 0; i < indices.Count; i++ )
        {
            var basis = 1;
            for ( var d = 0; d < FsqDims; d++ )
            {
                var digit = indices[i] / basis % FsqLevels;
                codes[i * FsqDims + d] = (digit - 4f) / 4f;
                basis *= FsqLevels;
            }
        }
        return Tensor.From( codes, indices.Count, FsqDims )
            .MatMul( FsqProjOutWeight.Transposed ).Add( FsqProjOutBias );
    }

    /// <summary>One bone: latents = post_quant(cat[codes, cond]) through the
    /// self stack once; then per-point sigmoid weights for the query rows.</summary>
    public float[] DecodeBone( Tensor codes, Tensor cond, Tensor queryEmbed )
    {
        var latents = codes.Concat( cond, 0 )
            .MatMul( PostQuantWeight.Transposed ).Add( PostQuantBias );
        foreach ( var block in DecoderSelfBlocks )
            latents = block.Forward( latents );

        var q = queryEmbed.MatMul( DecoderProjQueryWeight.Transposed )
            .Add( DecoderProjQueryBias );
        var output = DecoderCrossBlock.Forward( q, latents );
        output = TransformerOps.LayerNorm( output, DecoderNormOutGamma, DecoderNormOutBeta );
        var logits = output.MatMul( DecoderProjOutWeight.Transposed ).Add( DecoderProjOutBias );

        var weights = new float[logits.Shape[0]];
        for ( var i = 0; i < weights.Length; i++ )
            weights[i] = 1f / (1f + MathF.Exp( -logits.Data[i] ));
        return weights;
    }
}

public static class SkinVaeLoader
{
    public static SkinVae Load( IReadOnlyDictionary<string, Tensor> tensors )
    {
        var prefix = tensors.Keys.Any( k => k.StartsWith( "state_dict.", StringComparison.Ordinal ) )
            ? "state_dict.vae.model."
            : "vae.model.";

        var condBlocks = new List<SkinVaeBlock>();
        while ( HasBlock( tensors, $"{prefix}cond_encoder.blocks.{condBlocks.Count}." ) )
            condBlocks.Add( Block( tensors, $"{prefix}cond_encoder.blocks.{condBlocks.Count}." ) );

        var decoderBlocks = new List<SkinVaeBlock>();
        while ( HasBlock( tensors, $"{prefix}decoder.blocks.{decoderBlocks.Count}." ) )
            decoderBlocks.Add( Block( tensors, $"{prefix}decoder.blocks.{decoderBlocks.Count}." ) );

        if ( condBlocks.Count < 2 || decoderBlocks.Count < 2 )
            throw new FormatException( "SkinTokens checkpoint: skin-VAE blocks missing." );
        if ( !condBlocks[0].IsCross || !decoderBlocks[^1].IsCross )
            throw new FormatException( "SkinTokens checkpoint: unexpected skin-VAE block layout." );

        return new SkinVae
        {
            CondProjInWeight = Get( tensors, $"{prefix}cond_encoder.proj_in.weight" ),
            CondProjInBias = Get( tensors, $"{prefix}cond_encoder.proj_in.bias" ),
            CondBlocks = condBlocks.ToArray(),
            CondNormOutGamma = Get( tensors, $"{prefix}cond_encoder.norm_out.weight" ),
            CondNormOutBeta = Get( tensors, $"{prefix}cond_encoder.norm_out.bias" ),
            CondQuantWeight = Get( tensors, $"{prefix}cond_quant.weight" ),
            CondQuantBias = Get( tensors, $"{prefix}cond_quant.bias" ),
            FsqProjOutWeight = Get( tensors, $"{prefix}FSQ.project_out.weight" ),
            FsqProjOutBias = Get( tensors, $"{prefix}FSQ.project_out.bias" ),
            PostQuantWeight = Get( tensors, $"{prefix}post_quant.weight" ),
            PostQuantBias = Get( tensors, $"{prefix}post_quant.bias" ),
            DecoderSelfBlocks = decoderBlocks.Take( decoderBlocks.Count - 1 ).ToArray(),
            DecoderCrossBlock = decoderBlocks[^1],
            DecoderProjQueryWeight = Get( tensors, $"{prefix}decoder.proj_query.weight" ),
            DecoderProjQueryBias = Get( tensors, $"{prefix}decoder.proj_query.bias" ),
            DecoderNormOutGamma = Get( tensors, $"{prefix}decoder.norm_out.weight" ),
            DecoderNormOutBeta = Get( tensors, $"{prefix}decoder.norm_out.bias" ),
            DecoderProjOutWeight = Get( tensors, $"{prefix}decoder.proj_out.weight" ),
            DecoderProjOutBias = Get( tensors, $"{prefix}decoder.proj_out.bias" ),
        };
    }

    static bool HasBlock( IReadOnlyDictionary<string, Tensor> tensors, string b )
        => tensors.ContainsKey( $"{b}norm1.weight" ) || tensors.ContainsKey( $"{b}norm2.weight" );

    static SkinVaeBlock Block( IReadOnlyDictionary<string, Tensor> tensors, string b )
    {
        var isCross = tensors.ContainsKey( $"{b}norm2.weight" );
        var attn = isCross ? "attn2" : "attn1";
        var norm = isCross ? "norm2" : "norm1";
        return new SkinVaeBlock
        {
            IsCross = isCross,
            NormGamma = Get( tensors, $"{b}{norm}.weight" ),
            NormBeta = Get( tensors, $"{b}{norm}.bias" ),
            CrossNormGamma = isCross ? Get( tensors, $"{b}{attn}.norm_cross.weight" ) : null,
            CrossNormBeta = isCross ? Get( tensors, $"{b}{attn}.norm_cross.bias" ) : null,
            ToQ = Get( tensors, $"{b}{attn}.to_q.weight" ),
            ToK = Get( tensors, $"{b}{attn}.to_k.weight" ),
            ToV = Get( tensors, $"{b}{attn}.to_v.weight" ),
            OutWeight = Get( tensors, $"{b}{attn}.to_out.0.weight" ),
            OutBias = Get( tensors, $"{b}{attn}.to_out.0.bias" ),
            Norm3Gamma = Get( tensors, $"{b}norm3.weight" ),
            Norm3Beta = Get( tensors, $"{b}norm3.bias" ),
            FfProjWeight = Get( tensors, $"{b}ff.net.0.proj.weight" ),
            FfProjBias = Get( tensors, $"{b}ff.net.0.proj.bias" ),
            FfOutWeight = Get( tensors, $"{b}ff.net.2.weight" ),
            FfOutBias = Get( tensors, $"{b}ff.net.2.bias" ),
        };
    }

    static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
        => tensors.TryGetValue( key, out var tensor )
            ? tensor
            : throw new FormatException( $"SkinTokens checkpoint is missing tensor '{key}'." );
}