Parser for the SAME model weight blob format. It reads a byte array and extracts named float32 tensors (shape and flat data) into a dictionary, and provides accessors for tensors and ms.* stats.
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
namespace HumanoidRetargeter.Dl;
/// <summary>
/// One named float32 tensor of a <see cref="SameWeights"/> blob.
/// </summary>
/// <param name="Shape">Dimensions (row-major / C order).</param>
/// <param name="Data">Flat float32 data, length = product of <paramref name="Shape"/>.</param>
public readonly record struct WeightTensor(int[] Shape, float[] Data);
/// <summary>
/// Parses the committed SAME model weight blob
/// (<c>Assets/humanoid_retargeter/dl/same_v1.weights</c>, produced by
/// <c>dev/m10/scripts/export_weights.py</c>) into named float32 tensors: the GAT
/// encoder/decoder parameters under their original PyTorch state-dict keys plus the
/// <c>ms.*</c> feature normalization statistics. No file IO — callers pass bytes
/// (the Editor reads the asset file; tests read the repo file).
/// </summary>
/// <remarks>
/// Blob format (little endian): 8-byte magic <c>HRSAMEW1</c>, int32 tensor count, then per
/// tensor: int32 UTF-8 name length, name bytes, int32 rank, int32 dims, float32 data.
/// </remarks>
public sealed class SameWeights
{
private const string Magic = "HRSAMEW1";
private readonly Dictionary<string, WeightTensor> _tensors;
private SameWeights(Dictionary<string, WeightTensor> tensors)
{
_tensors = tensors;
}
/// <summary>All tensors by name.</summary>
public IReadOnlyDictionary<string, WeightTensor> Tensors => _tensors;
/// <summary>Returns the named tensor.</summary>
/// <exception cref="ArgumentException">Thrown when the blob has no such tensor.</exception>
public WeightTensor Get(string name)
=> _tensors.TryGetValue(name, out var tensor)
? tensor
: throw new ArgumentException($"SAME weights blob has no tensor '{name}'.");
/// <summary>Returns the flat data of the named <c>ms.*</c> normalization stat
/// (e.g. <c>Stat("q_m")</c> for the rotation-feature mean).</summary>
public float[] Stat(string key) => Get("ms." + key).Data;
/// <summary>Parses a weight blob.</summary>
/// <exception cref="FormatException">Thrown on a wrong magic or truncated/invalid data.</exception>
public static SameWeights Parse(byte[] data)
{
ArgumentNullException.ThrowIfNull(data);
try
{
using var reader = new BinaryReader(new MemoryStream(data, writable: false));
var magic = Encoding.ASCII.GetString(reader.ReadBytes(Magic.Length));
if (magic != Magic)
throw new FormatException(
$"Not a SAME weight blob (magic '{magic}', expected '{Magic}').");
var count = reader.ReadInt32();
if (count is < 0 or > 4096)
throw new FormatException($"Implausible tensor count {count}.");
var tensors = new Dictionary<string, WeightTensor>(count, StringComparer.Ordinal);
for (var t = 0; t < count; t++)
{
var nameLength = reader.ReadInt32();
if (nameLength is <= 0 or > 1024)
throw new FormatException($"Implausible tensor name length {nameLength}.");
var name = Encoding.UTF8.GetString(reader.ReadBytes(nameLength));
var rank = reader.ReadInt32();
if (rank is < 0 or > 8)
throw new FormatException($"Implausible tensor rank {rank} for '{name}'.");
var shape = new int[rank];
long total = 1;
for (var d = 0; d < rank; d++)
{
shape[d] = reader.ReadInt32();
if (shape[d] <= 0)
throw new FormatException($"Non-positive dimension in tensor '{name}'.");
total *= shape[d];
}
if (total > int.MaxValue / sizeof(float))
throw new FormatException($"Tensor '{name}' too large.");
var bytes = reader.ReadBytes((int)total * sizeof(float));
if (bytes.Length != total * sizeof(float))
throw new FormatException($"Truncated data for tensor '{name}'.");
var values = new float[total];
Buffer.BlockCopy(bytes, 0, values, 0, bytes.Length);
if (!tensors.TryAdd(name, new WeightTensor(shape, values)))
throw new FormatException($"Duplicate tensor '{name}'.");
}
return new SameWeights(tensors);
}
catch (EndOfStreamException)
{
throw new FormatException("Truncated SAME weight blob.");
}
}
}