AutoRig/Dl/Safetensors.cs

Parser for safetensors files. Reads a u64 header length, parses the JSON header, and extracts tensors (F32, F16, BF16) into float arrays with shapes, returning a dictionary of Tensor objects.

File Access
using System.Text;
using System.Text.Json;

namespace AutoRig.Dl;

/// <summary>
/// Parses the safetensors container (u64 header length + JSON header + raw payload).
/// F32 tensors only — the DL runtime is float32; other dtypes throw with the dtype
/// named so the catalog can explain the incompatibility.
/// </summary>
public static class Safetensors
{
    /// <exception cref="FormatException">Malformed, truncated or non-F32 content.</exception>
    public static IReadOnlyDictionary<string, Tensor> Parse( byte[] data )
    {
        ArgumentNullException.ThrowIfNull( data );
        if ( data.Length < 9 )
            throw new FormatException( "safetensors: file too small for a header." );

        var headerLength = BitConverter.ToUInt64( data, 0 );
        if ( headerLength == 0 || headerLength > (ulong)(data.Length - 8) )
            throw new FormatException(
                $"safetensors: header length {headerLength} does not fit the file ({data.Length} bytes)." );

        JsonElement header;
        try
        {
            var text = Encoding.UTF8.GetString( data, 8, (int)headerLength );
            using var doc = JsonDocument.Parse( text );
            header = doc.RootElement.Clone();
        }
        catch ( JsonException e )
        {
            throw new FormatException( $"safetensors: invalid JSON header ({e.Message})." );
        }

        var payloadStart = 8 + (int)headerLength;
        var payloadLength = data.Length - payloadStart;
        var tensors = new Dictionary<string, Tensor>( StringComparer.Ordinal );

        foreach ( var property in header.EnumerateObject() )
        {
            if ( property.Name == "__metadata__" )
                continue;
            var entry = property.Value;

            var dtype = entry.TryGetProperty( "dtype", out var dtypeProp ) ? dtypeProp.GetString() : null;
            var bytesPerValue = dtype switch
            {
                "F32" => 4,
                "F16" or "BF16" => 2,   // transformer checkpoints ship half precision
                _ => throw new FormatException(
                    $"safetensors: tensor '{property.Name}' has dtype {dtype ?? "?"} - "
                    + "only F32/F16/BF16 are supported." ),
            };

            if ( !entry.TryGetProperty( "shape", out var shapeProp )
                || shapeProp.ValueKind != JsonValueKind.Array )
                throw new FormatException( $"safetensors: tensor '{property.Name}' has no shape." );
            var shape = shapeProp.EnumerateArray().Select( v => v.GetInt32() ).ToArray();

            if ( !entry.TryGetProperty( "data_offsets", out var offsetsProp )
                || offsetsProp.GetArrayLength() != 2 )
                throw new FormatException( $"safetensors: tensor '{property.Name}' has no data_offsets." );
            var begin = offsetsProp[0].GetInt64();
            var end = offsetsProp[1].GetInt64();

            long expected = bytesPerValue;
            foreach ( var d in shape )
                expected *= d;
            if ( begin < 0 || end < begin || end > payloadLength || end - begin != expected )
                throw new FormatException(
                    $"safetensors: tensor '{property.Name}' offsets [{begin},{end}) are inconsistent "
                    + $"with shape [{string.Join( ",", shape )}] and payload size {payloadLength}." );

            var count = (int)(expected / bytesPerValue);
            var values = new float[count];
            var start = payloadStart + (int)begin;
            if ( bytesPerValue == 4 )
            {
                Buffer.BlockCopy( data, start, values, 0, (int)expected );
            }
            else if ( dtype == "F16" )
            {
                for ( var i = 0; i < count; i++ )
                    values[i] = (float)BitConverter.ToHalf( data, start + i * 2 );
            }
            else // BF16: the high 16 bits of an f32
            {
                for ( var i = 0; i < count; i++ )
                {
                    var bits = (uint)BitConverter.ToUInt16( data, start + i * 2 ) << 16;
                    values[i] = BitConverter.UInt32BitsToSingle( bits );
                }
            }
            tensors[property.Name] = Tensor.From( values, shape );
        }
        return tensors;
    }
}