Neural network implementation of the SAME GAT autoencoder in managed C#. It implements encoder and decoder stacks of explicit GATConv layers as float32 loops, per-node forward passes, per-frame max pooling and utilities to load layer tensors from a SameWeights container.
using System;
namespace HumanoidRetargeter.Dl;
/// <summary>
/// Pure-managed forward pass of the SAME GAT autoencoder (Lee et al., SIGGRAPH Asia 2023,
/// ckpt0) — the exact graph the spike's explicit-math ONNX export implements
/// (<c>dev/m10/scripts/export_onnx.py::gat_conv_explicit</c>, verified bit-equal to the
/// original PyG layers), re-expressed as plain float32 loops. No ONNX Runtime, no native
/// code; deterministic (fixed accumulation order).
/// </summary>
/// <remarks>
/// <para>Architecture (ckpt0 config): encoder = 4 × GATConv (32 → 16ch×16heads ×3 → 32,
/// ReLU between, last layer 1 head), then per-frame global <b>max</b>-pool over nodes →
/// z[32] per frame. Decoder = 4 × GATConv (38 = z32 ⊕ skel6 input, the 6 skeleton features
/// re-concatenated at every layer, → 11 outputs), ReLU between.</para>
/// <para>Edge convention (both graphs): bidirectional parent↔child pairs plus one self-loop
/// per node — equivalent to the original layer's internal remove/add_self_loops on the
/// unmasked inference path. One graph per frame; frames are batched as disjoint unions with
/// node indices offset per frame.</para>
/// <para>GAT layer math per <c>gat_conv_explicit</c>:
/// <c>h = W·x</c> per node, attention logit per edge
/// <c>e = leaky_relu(att_src·h[src] + att_dst·h[dst], 0.2)</c>, softmax over the incoming
/// edges of each destination node (per head), output = attention-weighted sum of source
/// <c>h</c> over incoming edges, heads concatenated, plus bias.</para>
/// </remarks>
public sealed class SameModel
{
/// <summary>Encoder input width (6 skeleton + 26 pose features).</summary>
public const int InputDim = 32;
/// <summary>Per-frame embedding width.</summary>
public const int ZDim = 32;
/// <summary>Target skeleton feature width (lo3 ⊕ go3).</summary>
public const int SkelDim = 6;
/// <summary>Decoder output width (r4 ⊕ q6 ⊕ c1, normalized).</summary>
public const int OutDim = 11;
private const float NegativeSlope = 0.2f;
private readonly GatLayer[] _encoder;
private readonly GatLayer[] _decoder;
/// <summary>Builds the model from a parsed weight blob.</summary>
/// <exception cref="ArgumentException">Thrown when the blob lacks expected tensors or
/// their shapes do not match the ckpt0 architecture.</exception>
public SameModel(SameWeights weights)
{
ArgumentNullException.ThrowIfNull(weights);
_encoder = LoadStack(weights, "encoder.0.convs", expectedLayers: 4, inputDim: InputDim);
_decoder = LoadStack(weights, "decoder.0.deconvs", expectedLayers: 4, inputDim: ZDim + SkelDim);
if (_encoder[^1].OutDim != ZDim)
throw new ArgumentException($"Encoder output width {_encoder[^1].OutDim}, expected {ZDim}.");
if (_decoder[^1].OutDim != OutDim)
throw new ArgumentException($"Decoder output width {_decoder[^1].OutDim}, expected {OutDim}.");
}
/// <summary>
/// Encoder: per-node features → per-frame embeddings.
/// </summary>
/// <param name="x">Normalized node features, flat [nodeCount × <see cref="InputDim"/>].</param>
/// <param name="edgeSrc">Edge source node per edge (bidirectional + self-loops).</param>
/// <param name="edgeDst">Edge destination node per edge.</param>
/// <param name="batch">Frame id per node (0-based, contiguous).</param>
/// <param name="frameCount">Number of frames.</param>
/// <returns>z, flat [frameCount × <see cref="ZDim"/>].</returns>
public float[] Encode(float[] x, int[] edgeSrc, int[] edgeDst, int[] batch, int frameCount)
{
var nodeZ = EncodeNodes(x, edgeSrc, edgeDst);
return MaxPool(nodeZ, ZDim, batch, frameCount);
}
/// <summary>Encoder without the final per-frame pool: per-node embeddings,
/// flat [nodeCount × <see cref="ZDim"/>] (exposed for parity tests).</summary>
public float[] EncodeNodes(float[] x, int[] edgeSrc, int[] edgeDst)
{
ArgumentNullException.ThrowIfNull(x);
var n = CheckNodes(x, InputDim, edgeSrc, edgeDst);
var h = x;
for (var i = 0; i < _encoder.Length; i++)
{
h = _encoder[i].Forward(h, n, edgeSrc, edgeDst);
if (i + 1 != _encoder.Length)
Relu(h);
}
return h;
}
/// <summary>
/// Decoder: per-frame embeddings conditioned on the target skeleton graph → per-node
/// outputs (normalized r4 ⊕ q6 ⊕ c1).
/// </summary>
/// <param name="z">Per-frame embeddings, flat [frameCount × <see cref="ZDim"/>].</param>
/// <param name="targetX">Normalized target skeleton features tiled per frame,
/// flat [nodeCount × <see cref="SkelDim"/>].</param>
/// <param name="edgeSrc">Edge source node per edge (bidirectional + self-loops, all frames).</param>
/// <param name="edgeDst">Edge destination node per edge.</param>
/// <param name="batch">Frame id per node.</param>
/// <returns>hatD, flat [nodeCount × <see cref="OutDim"/>].</returns>
public float[] Decode(float[] z, float[] targetX, int[] edgeSrc, int[] edgeDst, int[] batch)
{
ArgumentNullException.ThrowIfNull(z);
ArgumentNullException.ThrowIfNull(targetX);
ArgumentNullException.ThrowIfNull(batch);
var n = CheckNodes(targetX, SkelDim, edgeSrc, edgeDst);
if (batch.Length != n)
throw new ArgumentException($"batch length {batch.Length} != node count {n}.");
// dec_x starts as z gathered per node; skel features are concatenated before EVERY
// layer (tgt_all_lyr = true in ckpt0).
var h = new float[n * ZDim];
for (var i = 0; i < n; i++)
Array.Copy(z, batch[i] * ZDim, h, i * ZDim, ZDim);
var width = ZDim;
for (var i = 0; i < _decoder.Length; i++)
{
var concat = ConcatColumns(h, width, targetX, SkelDim, n);
h = _decoder[i].Forward(concat, n, edgeSrc, edgeDst);
width = _decoder[i].OutDim;
if (i + 1 != _decoder.Length)
Relu(h);
}
return h;
}
/// <summary>Per-frame max-pool over nodes (the encoder's host-side pool).</summary>
public static float[] MaxPool(float[] nodeValues, int width, int[] batch, int frameCount)
{
ArgumentNullException.ThrowIfNull(nodeValues);
ArgumentNullException.ThrowIfNull(batch);
var pooled = new float[frameCount * width];
Array.Fill(pooled, float.NegativeInfinity);
for (var i = 0; i < batch.Length; i++)
{
var f = batch[i];
for (var c = 0; c < width; c++)
{
var v = nodeValues[i * width + c];
if (v > pooled[f * width + c])
pooled[f * width + c] = v;
}
}
return pooled;
}
private static int CheckNodes(float[] x, int width, int[] edgeSrc, int[] edgeDst)
{
ArgumentNullException.ThrowIfNull(edgeSrc);
ArgumentNullException.ThrowIfNull(edgeDst);
if (x.Length % width != 0)
throw new ArgumentException($"Feature array length {x.Length} is not a multiple of {width}.");
if (edgeSrc.Length != edgeDst.Length)
throw new ArgumentException("edgeSrc/edgeDst lengths differ.");
return x.Length / width;
}
private static void Relu(float[] values)
{
for (var i = 0; i < values.Length; i++)
{
if (values[i] < 0f)
values[i] = 0f;
}
}
private static float[] ConcatColumns(float[] a, int aWidth, float[] b, int bWidth, int n)
{
var outWidth = aWidth + bWidth;
var result = new float[n * outWidth];
for (var i = 0; i < n; i++)
{
Array.Copy(a, i * aWidth, result, i * outWidth, aWidth);
Array.Copy(b, i * bWidth, result, i * outWidth + aWidth, bWidth);
}
return result;
}
private static GatLayer[] LoadStack(SameWeights weights, string prefix, int expectedLayers, int inputDim)
{
var layers = new GatLayer[expectedLayers];
var expectedIn = inputDim;
for (var i = 0; i < expectedLayers; i++)
{
var layer = new GatLayer(
weights.Get($"{prefix}.{i}.lin_src.weight"),
weights.Get($"{prefix}.{i}.att_src"),
weights.Get($"{prefix}.{i}.att_dst"),
weights.Get($"{prefix}.{i}.bias"),
$"{prefix}.{i}");
if (layer.InDim != expectedIn && i == 0)
throw new ArgumentException(
$"{prefix}.0 expects input width {layer.InDim}, model contract says {expectedIn}.");
layers[i] = layer;
// Decoder layers past the first see the previous output ⊕ skel features.
expectedIn = layer.OutDim + (prefix.Contains("deconvs", StringComparison.Ordinal) ? SkelDim : 0);
}
return layers;
}
/// <summary>One GATConv re-expressed as explicit tensor math (no masking path).</summary>
private sealed class GatLayer
{
public int InDim { get; }
public int Heads { get; }
public int Channels { get; }
public int OutDim => Heads * Channels;
private readonly float[] _w; // [Heads*Channels, InDim] row-major
private readonly float[] _attSrc; // [Heads, Channels]
private readonly float[] _attDst; // [Heads, Channels]
private readonly float[] _bias; // [Heads*Channels]
public GatLayer(WeightTensor w, WeightTensor attSrc, WeightTensor attDst, WeightTensor bias, string name)
{
if (w.Shape.Length != 2 || attSrc.Shape.Length != 3 || attDst.Shape.Length != 3 || bias.Shape.Length != 1)
throw new ArgumentException($"Unexpected tensor ranks in GAT layer '{name}'.");
Heads = attSrc.Shape[1];
Channels = attSrc.Shape[2];
InDim = w.Shape[1];
if (w.Shape[0] != OutDim || bias.Shape[0] != OutDim
|| attDst.Shape[1] != Heads || attDst.Shape[2] != Channels)
throw new ArgumentException($"Inconsistent tensor shapes in GAT layer '{name}'.");
_w = w.Data;
_attSrc = attSrc.Data;
_attDst = attDst.Data;
_bias = bias.Data;
}
public float[] Forward(float[] x, int n, int[] edgeSrc, int[] edgeDst)
{
if (x.Length != n * InDim)
throw new ArgumentException($"GAT layer expects {InDim} features, got {x.Length / n}.");
var hc = OutDim;
// h = x · Wᵀ → [n, Heads*Channels]
var h = new float[n * hc];
for (var i = 0; i < n; i++)
{
var xRow = i * InDim;
var hRow = i * hc;
for (var o = 0; o < hc; o++)
{
var wRow = o * InDim;
var sum = 0f;
for (var k = 0; k < InDim; k++)
sum += x[xRow + k] * _w[wRow + k];
h[hRow + o] = sum;
}
}
// Per-node attention halves: aSrc/aDst[i, head] = Σ_c h[i, head, c] · att[head, c].
var aSrc = new float[n * Heads];
var aDst = new float[n * Heads];
for (var i = 0; i < n; i++)
{
for (var head = 0; head < Heads; head++)
{
var hOff = i * hc + head * Channels;
var attOff = head * Channels;
var src = 0f;
var dst = 0f;
for (var c = 0; c < Channels; c++)
{
src += h[hOff + c] * _attSrc[attOff + c];
dst += h[hOff + c] * _attDst[attOff + c];
}
aSrc[i * Heads + head] = src;
aDst[i * Heads + head] = dst;
}
}
// Edge logits + segment softmax over the incoming edges of each destination.
var e = edgeSrc.Length;
var alpha = new float[e * Heads];
var segMax = new float[n * Heads];
Array.Fill(segMax, float.NegativeInfinity);
for (var ei = 0; ei < e; ei++)
{
var s = edgeSrc[ei];
var d = edgeDst[ei];
for (var head = 0; head < Heads; head++)
{
var logit = aSrc[s * Heads + head] + aDst[d * Heads + head];
if (logit < 0f)
logit *= NegativeSlope;
alpha[ei * Heads + head] = logit;
if (logit > segMax[d * Heads + head])
segMax[d * Heads + head] = logit;
}
}
var segSum = new float[n * Heads];
for (var ei = 0; ei < e; ei++)
{
var d = edgeDst[ei];
for (var head = 0; head < Heads; head++)
{
var ex = MathF.Exp(alpha[ei * Heads + head] - segMax[d * Heads + head]);
alpha[ei * Heads + head] = ex;
segSum[d * Heads + head] += ex;
}
}
// out[dst] = Σ_incoming attention · h[src], heads concatenated, + bias.
var result = new float[n * hc];
for (var ei = 0; ei < e; ei++)
{
var s = edgeSrc[ei];
var d = edgeDst[ei];
for (var head = 0; head < Heads; head++)
{
var att = alpha[ei * Heads + head] / segSum[d * Heads + head];
var srcOff = s * hc + head * Channels;
var dstOff = d * hc + head * Channels;
for (var c = 0; c < Channels; c++)
result[dstOff + c] += att * h[srcOff + c];
}
}
for (var i = 0; i < n; i++)
{
var row = i * hc;
for (var o = 0; o < hc; o++)
result[row + o] += _bias[o];
}
return result;
}
}
}