Builds and uses a SAME decoder graph for a target rig. Constructs a per-rig graph of nodes (bones + synthesized end joints), computes aligned rest geometry and normalized features, tiles features for batched input, and converts decoder outputs (root trajectory + 6D rotations per node) into a Clip of target bone local transforms.
using System;
using System.Collections.Generic;
using System.Numerics;
using HumanoidRetargeter.Mapping;
using HumanoidRetargeter.Maths;
using HumanoidRetargeter.Skeleton;
using HumanoidRetargeter.Solve;
using HumanoidRetargeter.Target;
namespace HumanoidRetargeter.Dl;
using Vector3 = System.Numerics.Vector3; // s&box compat: shadow engine's global-namespace Vector3 (see Code/HumanoidRetargeter/Assembly.cs)
/// <summary>
/// The target-skeleton graph the SAME decoder is conditioned on, plus everything needed to
/// turn decoder output back into target-rig bone locals.
/// </summary>
public sealed class SameTargetGraph
{
/// <summary>The rig this graph was built from.</summary>
public required TargetRig Rig { get; init; }
/// <summary>Rig bone index of the graph root (hips).</summary>
public required int HipsBone { get; init; }
/// <summary>Rig bone index per node; -1 for synthesized end joints.</summary>
public required int[] NodeBone { get; init; }
/// <summary>Graph-parent node index; -1 for the root (node 0).</summary>
public required int[] NodeParent { get; init; }
/// <summary>Node names (bone names; end joints suffixed <c>_end</c>).</summary>
public required string[] NodeNames { get; init; }
/// <summary>Normalized skeleton features (lo3 ⊕ go3), flat [NodeCount × 6].</summary>
public required float[] X { get; init; }
/// <summary>Raw local offsets, flat [NodeCount × 3] (aligned space; FK uses these).</summary>
public required float[] LoRaw { get; init; }
/// <summary>Raw global rest offsets, flat [NodeCount × 3] (diagnostics/tests).</summary>
public required float[] GoRaw { get; init; }
/// <summary>World rotation taking rig space into the canonical SAME frame.</summary>
public required Quaternion Align { get; init; }
/// <summary>Aligned-space offset restoring the rig's rest root XZ and ground height
/// when converting decoded root trajectories back to rig space.</summary>
public required Vector3 RestOffset { get; init; }
/// <summary>Rig-space rest world rotation per node (the leaf's for end joints).</summary>
public required Quaternion[] NodeRestWorldRot { get; init; }
/// <summary>Nodes per frame.</summary>
public int NodeCount => NodeBone.Length;
}
/// <summary>
/// Target side of the SAME port: builds the decoder's skeleton graph from a
/// <see cref="TargetRig"/> and converts decoder output (normalized r ⊕ q ⊕ c per node)
/// into a <see cref="Clip"/> of target bone locals.
/// </summary>
/// <remarks>
/// <para><b>Graph.</b> Nodes are the rig's <see cref="BoneClass.Animated"/> bones in the
/// hips subtree, <b>excluding finger-role bones</b> (the fingerless "core" variant the
/// spike validated — the checkpoint was trained finger-less, so finger joints would only
/// receive smooth but meaningless rotations; target finger bones keep their rest locals
/// instead, matching the geometric solver's CopyPinky posture), plus one end joint per
/// leaf (from the rig's <c>tail_world</c> data when present, else a half-length
/// continuation of the parent segment). Node 0 is the hips, where the decoder's root
/// output row lives.</para>
/// <para><b>Decode.</b> Per frame: denormalize, take the root row's r = (dθ, dx, dz, h),
/// 6D→rotation for every node, FK over the normalized skeleton (identity rest rotations,
/// offsets only), accumulate the facing deltas over time, then conjugate the normalized
/// world rotations back onto the actual rig: <c>R_rig(j) = A⁻¹ · Ĝ(j) · A · R_rest(j)</c>
/// (the normalized skeleton's world rotations are world-space deltas from the rest pose —
/// only rest <i>world orientations</i> are needed, which sidesteps the Citizen
/// no-anatomical-bone-axes issue). Bones outside the graph (fingers, ConstraintDriven,
/// IkBaked, anything above the hips) keep rest locals; IK baking and cleanup passes run
/// later in the shared pipeline.</para>
/// </remarks>
public static class SameTarget
{
/// <summary>Builds the decoder conditioning graph for a target rig.</summary>
/// <param name="rig">Target rig.</param>
/// <param name="stats">Normalization statistics.</param>
/// <param name="align">Apply the rest-geometry world alignment (Y-up, +Z facing).
/// Disabled only by golden-parity tests (the Python reference applies none).</param>
/// <exception cref="ArgumentException">Thrown when the rig maps no usable hips/root.</exception>
public static SameTargetGraph Build(TargetRig rig, SameStats stats, bool align = true)
{
ArgumentNullException.ThrowIfNull(rig);
ArgumentNullException.ThrowIfNull(stats);
var skeleton = rig.Skeleton;
var map = rig.ToMappingResult();
var hips = rig.BoneForRole(BoneRole.Hips) ?? SameFeatures.FindHips(skeleton, null);
if (hips < 0 || hips >= skeleton.Count)
throw new ArgumentException("Target rig has no identifiable hips bone.", nameof(rig));
var alignRot = align ? ComputeAlignment(skeleton, map) : Quaternion.Identity;
// ---- node selection: Animated, hips subtree, no finger roles ---------------------
var inSubtree = new bool[skeleton.Count];
inSubtree[hips] = true;
for (var i = hips + 1; i < skeleton.Count; i++)
{
var parent = skeleton[i].ParentIndex;
if (parent >= 0 && inSubtree[parent])
inSubtree[i] = true;
}
var nodeBone = new List<int>();
var nodeParent = new List<int>();
var names = new List<string>();
var nodeOfBone = new Dictionary<int, int>();
for (var i = 0; i < skeleton.Count; i++)
{
if (!inSubtree[i] || rig.ClassOf(i) != BoneClass.Animated || IsFingerRole(rig.RoleOf(i)))
continue;
// Graph parent: nearest included ancestor; bones whose chain to the hips is
// interrupted by an excluded bone are skipped entirely.
var parentNode = -1;
if (i != hips)
{
var a = skeleton[i].ParentIndex;
while (a >= 0 && !nodeOfBone.ContainsKey(a))
a = skeleton[a].ParentIndex;
if (a < 0)
continue;
parentNode = nodeOfBone[a];
}
nodeOfBone[i] = nodeBone.Count;
nodeBone.Add(i);
nodeParent.Add(parentNode);
names.Add(skeleton[i].Name);
}
if (nodeBone.Count == 0 || nodeBone[0] != hips)
throw new ArgumentException("Target rig's hips bone is not an Animated bone.", nameof(rig));
// ---- end joints per leaf ----------------------------------------------------------
var isLeaf = new bool[nodeBone.Count];
Array.Fill(isLeaf, true);
for (var n = 0; n < nodeParent.Count; n++)
{
if (nodeParent[n] >= 0)
isLeaf[nodeParent[n]] = false;
}
var endTips = new Dictionary<int, Vector3>(); // node index → rig-space tip delta
var boneCount = nodeBone.Count;
for (var n = 0; n < boneCount; n++)
{
if (!isLeaf[n])
continue;
var b = nodeBone[n];
Vector3 tip;
if (rig.TailWorldOf(b) is { } tail)
{
tip = tail - skeleton.RestWorld[b].Pos;
}
else
{
var p = skeleton[b].ParentIndex;
tip = p >= 0 ? (skeleton.RestWorld[b].Pos - skeleton.RestWorld[p].Pos) * 0.5f : Vector3.Zero;
}
if (tip.Length() < 1e-6f)
tip = new Vector3(0f, 2f, 0f); // matches the spike's BVH generator fallback
nodeBone.Add(-1);
nodeParent.Add(n);
names.Add(skeleton[b].Name + "_end");
endTips[nodeBone.Count - 1] = tip;
}
// ---- aligned rest geometry --------------------------------------------------------
var count = nodeBone.Count;
var restPos = new Vector3[count];
var restRot = new Quaternion[count];
for (var n = 0; n < count; n++)
{
if (nodeBone[n] >= 0)
{
restPos[n] = Vector3.Transform(skeleton.RestWorld[nodeBone[n]].Pos, alignRot);
restRot[n] = skeleton.RestWorld[nodeBone[n]].Rot;
}
else
{
var leafBone = nodeBone[nodeParent[n]];
restPos[n] = Vector3.Transform(
skeleton.RestWorld[leafBone].Pos + endTips[n], alignRot);
restRot[n] = skeleton.RestWorld[leafBone].Rot;
}
}
// No ground shift on the target: the convention keeps the rig's authored ground
// plane (the decoded height channel is interpreted against it) — matching the
// validated spike setup, where the s&box rig was consumed as authored.
var rootXZ = new Vector3(restPos[0].X, 0f, restPos[0].Z);
var lo = new float[count * 3];
var go = new float[count * 3];
var x = new float[count * SameModel.SkelDim];
for (var n = 0; n < count; n++)
{
Vector3 loV, goV;
if (n == 0)
{
loV = new Vector3(0f, restPos[0].Y, 0f);
goV = loV;
}
else
{
loV = restPos[n] - restPos[nodeParent[n]];
goV = restPos[n] - rootXZ;
}
lo[n * 3] = loV.X;
lo[n * 3 + 1] = loV.Y;
lo[n * 3 + 2] = loV.Z;
go[n * 3] = goV.X;
go[n * 3 + 1] = goV.Y;
go[n * 3 + 2] = goV.Z;
x[n * 6] = (loV.X - stats.LoM[0]) / stats.LoS[0];
x[n * 6 + 1] = (loV.Y - stats.LoM[1]) / stats.LoS[1];
x[n * 6 + 2] = (loV.Z - stats.LoM[2]) / stats.LoS[2];
x[n * 6 + 3] = (goV.X - stats.GoM[0]) / stats.GoS[0];
x[n * 6 + 4] = (goV.Y - stats.GoM[1]) / stats.GoS[1];
x[n * 6 + 5] = (goV.Z - stats.GoM[2]) / stats.GoS[2];
}
return new SameTargetGraph
{
Rig = rig,
HipsBone = hips,
NodeBone = nodeBone.ToArray(),
NodeParent = nodeParent.ToArray(),
NodeNames = names.ToArray(),
X = x,
LoRaw = lo,
GoRaw = go,
Align = alignRot,
RestOffset = new Vector3(rootXZ.X, 0f, rootXZ.Z),
NodeRestWorldRot = restRot,
};
}
/// <summary>Tiles the single-frame graph for a batch of frames: features, edges, batch ids.</summary>
public static (float[] X, int[] EdgeSrc, int[] EdgeDst, int[] Batch) Tile(SameTargetGraph graph, int frames)
{
ArgumentNullException.ThrowIfNull(graph);
var j = graph.NodeCount;
var x = new float[frames * j * SameModel.SkelDim];
var batch = new int[frames * j];
for (var f = 0; f < frames; f++)
{
Array.Copy(graph.X, 0, x, f * j * SameModel.SkelDim, graph.X.Length);
for (var i = 0; i < j; i++)
batch[f * j + i] = f;
}
var (src, dst) = SameFeatures.BuildEdges(graph.NodeParent, frames);
return (x, src, dst, batch);
}
/// <summary>
/// Converts decoder output into a clip of target bone locals (one frame per decoded
/// frame; non-graph bones at rest).
/// </summary>
/// <param name="hatD">Decoder output, flat [frames·NodeCount × 11].</param>
/// <exception cref="InvalidOperationException">Thrown when the output is non-finite.</exception>
public static Clip DecodeClip(
float[] hatD, SameTargetGraph graph, int frames, SameStats stats,
string name, float fps, bool looping)
{
ArgumentNullException.ThrowIfNull(hatD);
ArgumentNullException.ThrowIfNull(graph);
ArgumentNullException.ThrowIfNull(stats);
var j = graph.NodeCount;
if (hatD.Length != frames * j * SameModel.OutDim)
throw new ArgumentException(
$"hatD length {hatD.Length} != frames {frames} × nodes {j} × {SameModel.OutDim}.");
var skeleton = graph.Rig.Skeleton;
var alignInv = Quaternion.Conjugate(graph.Align);
var clip = new Clip(name, fps, looping);
// Facing accumulator (yaw rotation + ground-plane translation).
var accumYaw = 0f;
var accumPos = Vector3.Zero;
var nodeOfBone = new int[skeleton.Count];
Array.Fill(nodeOfBone, -1);
for (var n = 0; n < j; n++)
{
if (graph.NodeBone[n] >= 0)
nodeOfBone[graph.NodeBone[n]] = n;
}
var normWorld = new Quaternion[j];
for (var f = 0; f < frames; f++)
{
var rootRow = (f * j) * SameModel.OutDim;
var dTheta = hatD[rootRow] * stats.RS[0] + stats.RM[0];
var dx = hatD[rootRow + 1] * stats.RS[1] + stats.RM[1];
var dz = hatD[rootRow + 2] * stats.RS[2] + stats.RM[2];
var height = hatD[rootRow + 3] * stats.RS[3] + stats.RM[3];
// accumulate: T_acc(f) = T_acc(f−1) ∘ T_f
accumPos += Vector3.Transform(new Vector3(dx, 0f, dz),
Quaternion.CreateFromAxisAngle(Vector3.UnitY, accumYaw));
accumYaw += dTheta;
var accumRot = Quaternion.CreateFromAxisAngle(Vector3.UnitY, accumYaw);
// Normalized-skeleton world rotations via graph FK.
for (var n = 0; n < j; n++)
{
var row = (f * j + n) * SameModel.OutDim + 4; // q starts after r
var local = SixDToQuaternion(
hatD[row] * stats.QS[0] + stats.QM[0],
hatD[row + 1] * stats.QS[1] + stats.QM[1],
hatD[row + 2] * stats.QS[2] + stats.QM[2],
hatD[row + 3] * stats.QS[3] + stats.QM[3],
hatD[row + 4] * stats.QS[4] + stats.QM[4],
hatD[row + 5] * stats.QS[5] + stats.QM[5]);
normWorld[n] = n == 0
? MathQ.Normalize(accumRot * local)
: MathQ.Normalize(normWorld[graph.NodeParent[n]] * local);
}
// Root world position in rig space.
var alignedRootPos = new Vector3(accumPos.X, height, accumPos.Z) + graph.RestOffset;
var rigRootPos = Vector3.Transform(alignedRootPos, alignInv);
// Bone locals: graph bones get conjugated world rotations (rest-local
// translations preserved), the hips additionally gets the decoded position.
var locals = new XForm[skeleton.Count];
var world = new XForm[skeleton.Count];
for (var b = 0; b < skeleton.Count; b++)
{
var parent = skeleton[b].ParentIndex;
var n = nodeOfBone[b];
if (n < 0)
{
locals[b] = skeleton[b].RestLocal;
}
else
{
var rigWorldRot = MathQ.Normalize(
alignInv * normWorld[n] * graph.Align * graph.NodeRestWorldRot[n]);
if (b == graph.HipsBone)
{
var desired = new XForm(rigRootPos, rigWorldRot);
locals[b] = parent < 0 ? desired : XForm.ToLocal(world[parent], desired);
}
else
{
var parentRot = parent < 0 ? Quaternion.Identity : world[parent].Rot;
locals[b] = new XForm(
skeleton[b].RestLocal.Pos,
MathQ.Normalize(Quaternion.Conjugate(parentRot) * rigWorldRot));
}
}
world[b] = parent < 0 ? locals[b] : XForm.Compose(world[parent], locals[b]);
}
AssertFiniteFrame(locals);
clip.Frames.Add(locals);
}
return clip;
}
/// <summary>Gram-Schmidt 6D → quaternion (columns e1, e2, e1×e2).</summary>
internal static Quaternion SixDToQuaternion(float v1x, float v1y, float v1z, float v2x, float v2y, float v2z)
{
var v1 = new Vector3(v1x, v1y, v1z);
var v2 = new Vector3(v2x, v2y, v2z);
var e1 = v1.LengthSquared() > 1e-12f ? Vector3.Normalize(v1) : Vector3.UnitX;
var u2 = v2 - e1 * Vector3.Dot(e1, v2);
var e2 = u2.LengthSquared() > 1e-12f ? Vector3.Normalize(u2) : Vector3.UnitY;
var e3 = Vector3.Cross(e1, e2);
// Rows of the row-vector-convention matrix are the columns of the rotation.
var m = new Matrix4x4(
e1.X, e1.Y, e1.Z, 0f,
e2.X, e2.Y, e2.Z, 0f,
e3.X, e3.Y, e3.Z, 0f,
0f, 0f, 0f, 1f);
return MathQ.Normalize(Quaternion.CreateFromRotationMatrix(m));
}
private static bool IsFingerRole(BoneRole? role)
{
if (role is not { } r)
return false;
var name = r.ToString();
return name.StartsWith("Thumb", StringComparison.Ordinal)
|| name.StartsWith("Index", StringComparison.Ordinal)
|| name.StartsWith("Middle", StringComparison.Ordinal)
|| name.StartsWith("Ring", StringComparison.Ordinal)
|| name.StartsWith("Pinky", StringComparison.Ordinal);
}
private static Quaternion ComputeAlignment(
HumanoidRetargeter.Skeleton.Skeleton skeleton, MappingResult map)
{
try
{
var frame = CharacterFrame.Compute(skeleton, map, skeleton.RestWorld);
return SameFeatures.AlignFromBasis(frame.Lateral, frame.Up, frame.Forward);
}
catch (ArgumentException)
{
return Quaternion.Identity; // assume the rig is already Y-up/+Z (the shipped one is)
}
}
private static void AssertFiniteFrame(XForm[] locals)
{
foreach (var xf in locals)
{
if (!float.IsFinite(xf.Pos.X) || !float.IsFinite(xf.Pos.Y) || !float.IsFinite(xf.Pos.Z)
|| !float.IsFinite(xf.Rot.X) || !float.IsFinite(xf.Rot.Y)
|| !float.IsFinite(xf.Rot.Z) || !float.IsFinite(xf.Rot.W))
throw new InvalidOperationException("DL solve produced non-finite transforms.");
}
}
}