Neural model implementation and loader for the Skin VAE used by SkinTokens. Defines SkinVaeBlock (transformer-style cross/self attention block and feedforward) and SkinVae (encoding, FSQ dequantization, per-bone decoding and embedding utilities). SkinVaeLoader constructs a SkinVae from a dictionary of named Tensor weights.
using AutoRig.Dl.Nn;
namespace AutoRig.Dl.UniRig;
using AutoRig.Dl;
using Vector3 = System.Numerics.Vector3;
/// <summary>
/// SkinTokens' FSQ skin-VAE (vae.model.* inside the grpo checkpoint): the
/// network that turns the decoded skin TOKENS into real per-point weights.
/// Tripo2/DiT blocks (pre-LN, heads 8, separate q/k/v without bias, out proj
/// with bias, LayerNorm'd cross K/V, GELU-proj FFN ×4). Cond path: one cross
/// block (fps queries ← whole cloud) + self blocks + LN + cond_quant. Decode:
/// FSQ codes (base-8 digits, (d−4)/4, project_out) + cond through post_quant,
/// self blocks once per bone, then point queries cross-attend → sigmoid.
/// </summary>
public sealed class SkinVaeBlock
{
public required Tensor NormGamma, NormBeta; // norm1 (self) / norm2 (cross)
public Tensor CrossNormGamma, CrossNormBeta; // attn2.norm_cross (cross only)
public required Tensor ToQ, ToK, ToV; // no bias
public required Tensor OutWeight, OutBias; // to_out.0
public required Tensor Norm3Gamma, Norm3Beta;
public required Tensor FfProjWeight, FfProjBias; // 768 → 3072, GELU
public required Tensor FfOutWeight, FfOutBias; // 3072 → 768
public required bool IsCross;
public const int Heads = 8;
public Tensor Forward( Tensor x, Tensor encoderStates = null )
{
Tensor attended;
var h = TransformerOps.LayerNorm( x, NormGamma, NormBeta );
if ( IsCross )
{
var e = TransformerOps.LayerNorm( encoderStates, CrossNormGamma, CrossNormBeta );
attended = TransformerOps.Attention(
h.MatMul( ToQ.Transposed ),
e.MatMul( ToK.Transposed ),
e.MatMul( ToV.Transposed ),
Heads, causal: false );
}
else
{
attended = TransformerOps.Attention(
h.MatMul( ToQ.Transposed ),
h.MatMul( ToK.Transposed ),
h.MatMul( ToV.Transposed ),
Heads, causal: false );
}
x = x.Add( attended.MatMul( OutWeight.Transposed ).Add( OutBias ) );
h = TransformerOps.LayerNorm( x, Norm3Gamma, Norm3Beta );
h = TransformerOps.Gelu( h.MatMul( FfProjWeight.Transposed ).Add( FfProjBias ) );
return x.Add( h.MatMul( FfOutWeight.Transposed ).Add( FfOutBias ) );
}
}
public sealed class SkinVae
{
public const int FsqLevels = 8;
public const int FsqDims = 5;
public required Tensor CondProjInWeight, CondProjInBias; // 54 → 768
public required SkinVaeBlock[] CondBlocks; // [0] cross, rest self
public required Tensor CondNormOutGamma, CondNormOutBeta;
public required Tensor CondQuantWeight, CondQuantBias; // 768 → 512
public required Tensor FsqProjOutWeight, FsqProjOutBias; // 5 → 512
public required Tensor PostQuantWeight, PostQuantBias; // 512 → 768
public required SkinVaeBlock[] DecoderSelfBlocks;
public required SkinVaeBlock DecoderCrossBlock;
public required Tensor DecoderProjQueryWeight, DecoderProjQueryBias; // 54 → 768
public required Tensor DecoderNormOutGamma, DecoderNormOutBeta;
public required Tensor DecoderProjOutWeight, DecoderProjOutBias; // 768 → 1
/// <summary>[fourier(p) (51) , normal (3)] rows — the same layout the
/// Michelangelo perceiver uses (logspace 2^0..2^7, include_pi false).</summary>
public static Tensor Embed( Vector3[] points, Vector3[] normals )
{
const int fourierDim = 51;
var rows = points.Length;
var input = new float[rows * (fourierDim + 3)];
for ( var r = 0; r < rows; r++ )
{
var at = r * (fourierDim + 3);
var p = new[] { points[r].X, points[r].Y, points[r].Z };
input[at + 0] = p[0];
input[at + 1] = p[1];
input[at + 2] = p[2];
for ( var d = 0; d < 3; d++ )
for ( var f = 0; f < 8; f++ )
{
var scaled = p[d] * (1 << f);
input[at + 3 + d * 8 + f] = MathF.Sin( scaled );
input[at + 3 + 24 + d * 8 + f] = MathF.Cos( scaled );
}
input[at + fourierDim + 0] = normals[r].X;
input[at + fourierDim + 1] = normals[r].Y;
input[at + fourierDim + 2] = normals[r].Z;
}
return Tensor.From( input, rows, fourierDim + 3 );
}
/// <summary>Cond latents (Q × 512) from the cloud: queries = the picked
/// rows, K/V = every point (reference fps-samples queries unseeded; we
/// pick deterministically).</summary>
public Tensor EncodeCond( Vector3[] points, Vector3[] normals, int[] queryIndices )
{
var kvInput = Embed( points, normals );
var kv = kvInput.MatMul( CondProjInWeight.Transposed ).Add( CondProjInBias );
var queryRows = new float[queryIndices.Length * kvInput.Shape[1]];
for ( var i = 0; i < queryIndices.Length; i++ )
Array.Copy( kvInput.Data, queryIndices[i] * kvInput.Shape[1],
queryRows, i * kvInput.Shape[1], kvInput.Shape[1] );
var x = Tensor.From( queryRows, queryIndices.Length, kvInput.Shape[1] )
.MatMul( CondProjInWeight.Transposed ).Add( CondProjInBias );
x = CondBlocks[0].Forward( x, kv );
for ( var i = 1; i < CondBlocks.Length; i++ )
x = CondBlocks[i].Forward( x );
x = TransformerOps.LayerNorm( x, CondNormOutGamma, CondNormOutBeta );
return x.MatMul( CondQuantWeight.Transposed ).Add( CondQuantBias );
}
/// <summary>FSQ dequantize: token indices (already − skeleton vocab) →
/// base-8 digits → (d − 4)/4 → project_out. Out-of-range indices (the
/// unified-EOS slot the reference leaves in the last position) wrap via
/// the digit decomposition exactly as torch does.</summary>
public Tensor DequantizeCodes( IReadOnlyList<int> indices )
{
var codes = new float[indices.Count * FsqDims];
for ( var i = 0; i < indices.Count; i++ )
{
var basis = 1;
for ( var d = 0; d < FsqDims; d++ )
{
var digit = indices[i] / basis % FsqLevels;
codes[i * FsqDims + d] = (digit - 4f) / 4f;
basis *= FsqLevels;
}
}
return Tensor.From( codes, indices.Count, FsqDims )
.MatMul( FsqProjOutWeight.Transposed ).Add( FsqProjOutBias );
}
/// <summary>One bone: latents = post_quant(cat[codes, cond]) through the
/// self stack once; then per-point sigmoid weights for the query rows.</summary>
public float[] DecodeBone( Tensor codes, Tensor cond, Tensor queryEmbed )
{
var latents = codes.Concat( cond, 0 )
.MatMul( PostQuantWeight.Transposed ).Add( PostQuantBias );
foreach ( var block in DecoderSelfBlocks )
latents = block.Forward( latents );
var q = queryEmbed.MatMul( DecoderProjQueryWeight.Transposed )
.Add( DecoderProjQueryBias );
var output = DecoderCrossBlock.Forward( q, latents );
output = TransformerOps.LayerNorm( output, DecoderNormOutGamma, DecoderNormOutBeta );
var logits = output.MatMul( DecoderProjOutWeight.Transposed ).Add( DecoderProjOutBias );
var weights = new float[logits.Shape[0]];
for ( var i = 0; i < weights.Length; i++ )
weights[i] = 1f / (1f + MathF.Exp( -logits.Data[i] ));
return weights;
}
}
public static class SkinVaeLoader
{
public static SkinVae Load( IReadOnlyDictionary<string, Tensor> tensors )
{
var prefix = tensors.Keys.Any( k => k.StartsWith( "state_dict.", StringComparison.Ordinal ) )
? "state_dict.vae.model."
: "vae.model.";
var condBlocks = new List<SkinVaeBlock>();
while ( HasBlock( tensors, $"{prefix}cond_encoder.blocks.{condBlocks.Count}." ) )
condBlocks.Add( Block( tensors, $"{prefix}cond_encoder.blocks.{condBlocks.Count}." ) );
var decoderBlocks = new List<SkinVaeBlock>();
while ( HasBlock( tensors, $"{prefix}decoder.blocks.{decoderBlocks.Count}." ) )
decoderBlocks.Add( Block( tensors, $"{prefix}decoder.blocks.{decoderBlocks.Count}." ) );
if ( condBlocks.Count < 2 || decoderBlocks.Count < 2 )
throw new FormatException( "SkinTokens checkpoint: skin-VAE blocks missing." );
if ( !condBlocks[0].IsCross || !decoderBlocks[^1].IsCross )
throw new FormatException( "SkinTokens checkpoint: unexpected skin-VAE block layout." );
return new SkinVae
{
CondProjInWeight = Get( tensors, $"{prefix}cond_encoder.proj_in.weight" ),
CondProjInBias = Get( tensors, $"{prefix}cond_encoder.proj_in.bias" ),
CondBlocks = condBlocks.ToArray(),
CondNormOutGamma = Get( tensors, $"{prefix}cond_encoder.norm_out.weight" ),
CondNormOutBeta = Get( tensors, $"{prefix}cond_encoder.norm_out.bias" ),
CondQuantWeight = Get( tensors, $"{prefix}cond_quant.weight" ),
CondQuantBias = Get( tensors, $"{prefix}cond_quant.bias" ),
FsqProjOutWeight = Get( tensors, $"{prefix}FSQ.project_out.weight" ),
FsqProjOutBias = Get( tensors, $"{prefix}FSQ.project_out.bias" ),
PostQuantWeight = Get( tensors, $"{prefix}post_quant.weight" ),
PostQuantBias = Get( tensors, $"{prefix}post_quant.bias" ),
DecoderSelfBlocks = decoderBlocks.Take( decoderBlocks.Count - 1 ).ToArray(),
DecoderCrossBlock = decoderBlocks[^1],
DecoderProjQueryWeight = Get( tensors, $"{prefix}decoder.proj_query.weight" ),
DecoderProjQueryBias = Get( tensors, $"{prefix}decoder.proj_query.bias" ),
DecoderNormOutGamma = Get( tensors, $"{prefix}decoder.norm_out.weight" ),
DecoderNormOutBeta = Get( tensors, $"{prefix}decoder.norm_out.bias" ),
DecoderProjOutWeight = Get( tensors, $"{prefix}decoder.proj_out.weight" ),
DecoderProjOutBias = Get( tensors, $"{prefix}decoder.proj_out.bias" ),
};
}
static bool HasBlock( IReadOnlyDictionary<string, Tensor> tensors, string b )
=> tensors.ContainsKey( $"{b}norm1.weight" ) || tensors.ContainsKey( $"{b}norm2.weight" );
static SkinVaeBlock Block( IReadOnlyDictionary<string, Tensor> tensors, string b )
{
var isCross = tensors.ContainsKey( $"{b}norm2.weight" );
var attn = isCross ? "attn2" : "attn1";
var norm = isCross ? "norm2" : "norm1";
return new SkinVaeBlock
{
IsCross = isCross,
NormGamma = Get( tensors, $"{b}{norm}.weight" ),
NormBeta = Get( tensors, $"{b}{norm}.bias" ),
CrossNormGamma = isCross ? Get( tensors, $"{b}{attn}.norm_cross.weight" ) : null,
CrossNormBeta = isCross ? Get( tensors, $"{b}{attn}.norm_cross.bias" ) : null,
ToQ = Get( tensors, $"{b}{attn}.to_q.weight" ),
ToK = Get( tensors, $"{b}{attn}.to_k.weight" ),
ToV = Get( tensors, $"{b}{attn}.to_v.weight" ),
OutWeight = Get( tensors, $"{b}{attn}.to_out.0.weight" ),
OutBias = Get( tensors, $"{b}{attn}.to_out.0.bias" ),
Norm3Gamma = Get( tensors, $"{b}norm3.weight" ),
Norm3Beta = Get( tensors, $"{b}norm3.bias" ),
FfProjWeight = Get( tensors, $"{b}ff.net.0.proj.weight" ),
FfProjBias = Get( tensors, $"{b}ff.net.0.proj.bias" ),
FfOutWeight = Get( tensors, $"{b}ff.net.2.weight" ),
FfOutBias = Get( tensors, $"{b}ff.net.2.bias" ),
};
}
static Tensor Get( IReadOnlyDictionary<string, Tensor> tensors, string key )
=> tensors.TryGetValue( key, out var tensor )
? tensor
: throw new FormatException( $"SkinTokens checkpoint is missing tensor '{key}'." );
}