Parser for PyTorch checkpoint files. It reads both zip-format and legacy pickle stream checkpoints into named float32 Tensor objects, using a restricted unpickler that only accepts tensor state dict opcodes and a small set of safe globals.
using System.IO.Compression;
using System.Text;
namespace AutoRig.Dl;
/// <summary>
/// Reads PyTorch checkpoints (both the zip format and the pre-1.6 LEGACY stream
/// format) into named float32 tensors. The embedded pickles run on a RESTRICTED
/// stack machine: only the opcodes a tensor state dict uses, and only three
/// resolvable globals (torch._utils._rebuild_tensor_v2, torch.FloatStorage,
/// collections.OrderedDict). Anything else throws FormatException - arbitrary
/// pickle code can never run.
/// </summary>
public static class TorchCheckpoint
{
/// <exception cref="FormatException">Unsupported container, pickle content, or
/// inconsistent tensor data.</exception>
public static IReadOnlyDictionary<string, Tensor> Parse( byte[] data )
{
ArgumentNullException.ThrowIfNull( data );
if ( data.Length > 2 && data[0] == 0x50 && data[1] == 0x4b )
return ParseZip( data );
if ( data.Length > 2 && data[0] == 0x80 )
return ParseLegacy( data );
throw new FormatException( "torch checkpoint: neither a zip archive nor a legacy pickle stream." );
}
static IReadOnlyDictionary<string, Tensor> ParseZip( byte[] data )
{
var entries = ReadZip( data );
var pickleEntry = entries.Keys.FirstOrDefault( k => k.EndsWith( "/data.pkl", StringComparison.Ordinal ) )
?? throw new FormatException( "torch checkpoint: no data.pkl entry found." );
var prefix = pickleEntry[..^"data.pkl".Length];
var unpickler = new Unpickler( entries[pickleEntry], 0 );
var result = unpickler.Run( out _ );
return Materialize( result, key =>
{
if ( !entries.TryGetValue( $"{prefix}data/{key}", out var bytes ) )
throw new FormatException( $"torch checkpoint: storage '{key}' missing." );
var type = unpickler.Storages.TryGetValue( key, out var info )
? info.Type : "FloatStorage";
return ToFloat32Bytes( bytes, type );
} );
}
/// <summary>
/// Streams a zip-format checkpoint — required for files past the 2GB array
/// limit (e.g. MagicArticulate's 4.4GB training checkpoint, which is also
/// ZIP64). The stream must be seekable. <paramref name="keepTensor"/> filters
/// by full dotted name BEFORE storage bytes are read, so skipping e.g.
/// optimizer state costs no memory or IO.
/// </summary>
public static IReadOnlyDictionary<string, Tensor> Parse(
Stream file, Func<string, bool> keepTensor = null )
{
ArgumentNullException.ThrowIfNull( file );
var entries = ReadZipIndex( file );
var pickleEntry = entries.Keys.FirstOrDefault( k => k.EndsWith( "/data.pkl", StringComparison.Ordinal ) )
?? throw new FormatException( "torch checkpoint: no data.pkl entry found." );
var prefix = pickleEntry[..^"data.pkl".Length];
var unpickler = new Unpickler( ReadZipEntry( file, entries[pickleEntry] ), 0 );
var result = unpickler.Run( out _ );
return Materialize( result, key =>
{
if ( !entries.TryGetValue( $"{prefix}data/{key}", out var entry ) )
throw new FormatException( $"torch checkpoint: storage '{key}' missing." );
var type = unpickler.Storages.TryGetValue( key, out var info )
? info.Type : "FloatStorage";
return ToFloat32Bytes( ReadZipEntry( file, entry ), type );
}, keepTensor );
}
/// <summary>Normalizes accepted storage payloads to raw float32 bytes (bf16/f16
/// checkpoints — e.g. SkinTokens' bfloat16 model — widen at load).</summary>
internal static byte[] ToFloat32Bytes( byte[] bytes, string storageType )
{
switch ( storageType )
{
case "BFloat16Storage":
{
var result = new byte[bytes.Length * 2];
for ( var i = 0; i < bytes.Length / 2; i++ )
{
// bf16 = the high 16 bits of an f32 (little-endian layout).
result[i * 4 + 2] = bytes[i * 2];
result[i * 4 + 3] = bytes[i * 2 + 1];
}
return result;
}
case "HalfStorage":
{
var result = new byte[bytes.Length * 2];
for ( var i = 0; i < bytes.Length / 2; i++ )
{
var value = (float)BitConverter.ToHalf( bytes, i * 2 );
BitConverter.TryWriteBytes( result.AsSpan( i * 4, 4 ), value );
}
return result;
}
default:
return bytes; // FloatStorage — already f32
}
}
/// <summary>
/// Legacy stream: four consecutive pickles (magic, protocol, sys_info, object),
/// then a pickle listing storage keys, then per key an i64 element count followed
/// by the raw float32 storage bytes.
/// </summary>
static IReadOnlyDictionary<string, Tensor> ParseLegacy( byte[] data )
{
var offset = 0;
object result = null;
Unpickler objectUnpickler = null;
for ( var p = 0; p < 4; p++ )
{
objectUnpickler = new Unpickler( data, offset );
result = objectUnpickler.Run( out offset );
}
var keysObject = new Unpickler( data, offset ).Run( out offset );
if ( keysObject is not List<object> keyList )
throw new FormatException( "torch checkpoint: legacy storage-key list missing." );
var storages = new Dictionary<string, byte[]>( StringComparer.Ordinal );
foreach ( var keyObject in keyList )
{
if ( keyObject is not string key )
throw new FormatException( "torch checkpoint: legacy storage key is not a string." );
if ( offset + 8 > data.Length )
throw new FormatException( "torch checkpoint: legacy stream truncated at storage sizes." );
var numel = BitConverter.ToInt64( data, offset );
offset += 8;
// Element width comes from the storage TYPE recorded at its persistent id
// (optimizer state mixes Long/Int storages between the float tensors).
var width = objectUnpickler.Storages.TryGetValue( key, out var info ) ? info.Width : 4;
var byteCount = numel * width;
if ( numel < 0 || offset + byteCount > data.Length )
throw new FormatException( $"torch checkpoint: legacy storage '{key}' overruns the file." );
if ( info.Type is "FloatStorage" or "BFloat16Storage" or "HalfStorage" )
{
var bytes = new byte[byteCount];
Array.Copy( data, offset, bytes, 0, byteCount );
storages[key] = ToFloat32Bytes( bytes, info.Type );
}
offset += (int)byteCount;
}
return Materialize( result, key =>
storages.TryGetValue( key, out var bytes )
? bytes
: throw new FormatException( $"torch checkpoint: legacy storage '{key}' missing." ) );
}
/// <summary>Converts the unpickled graph's tensor stubs into tensors.</summary>
static IReadOnlyDictionary<string, Tensor> Materialize(
object root, Func<string, byte[]> storageBytes,
Func<string, bool> keep = null, string prefix = "" )
{
if ( root is not Dictionary<object, object> dict )
throw new FormatException( "torch checkpoint: top-level pickle value is not a dict." );
var tensors = new Dictionary<string, Tensor>( StringComparer.Ordinal );
foreach ( var (key, value) in dict )
{
if ( key is not string name )
continue;
var full = prefix + name;
if ( value is TensorStub stub )
{
if ( keep is not null && !keep( full ) )
continue;
var bytes = storageBytes( stub.StorageKey );
long count = 1;
foreach ( var d in stub.Shape )
count *= d;
if ( (stub.StorageOffset + count) * 4L > bytes.Length )
throw new FormatException(
$"torch checkpoint: storage '{stub.StorageKey}' too small for shape "
+ $"[{string.Join( ",", stub.Shape )}] at offset {stub.StorageOffset}." );
var values = new float[count];
Buffer.BlockCopy( bytes, stub.StorageOffset * 4, values, 0, (int)count * 4 );
tensors[full] = stub.Shape.Length == 0
? Tensor.From( values, 1 )
: Tensor.From( values, stub.Shape );
}
else if ( value is Dictionary<object, object> nested )
{
// Nested dicts ("state_dict"/"model" wrappers in training checkpoints).
foreach ( var (innerKey, innerValue) in Materialize( nested, storageBytes, keep, $"{full}." ) )
tensors[innerKey] = innerValue;
}
}
return tensors;
}
// ================================================================== zip
/// <summary>Reads a plain zip archive (entry name → bytes). Public so model
/// bundles (e.g. the RigNet checkpoints zip) can be read without extraction.</summary>
public static IReadOnlyDictionary<string, byte[]> ReadArchive( byte[] data ) => ReadZip( data );
// ---- streaming zip index (ZIP64-aware) — for checkpoints past 2GB ----
readonly record struct ZipIndexEntry( long LocalHeaderOffset, long CompressedSize, ushort Method );
static Dictionary<string, ZipIndexEntry> ReadZipIndex( Stream file )
{
// EOCD scan over the file tail (comment can pad up to 64KB).
var tailLength = (int)Math.Min( file.Length, 66_000 );
var tail = ReadAt( file, file.Length - tailLength, tailLength );
var eocd = -1;
for ( var i = tailLength - 22; i >= 0; i-- )
{
if ( tail[i] == 0x50 && tail[i + 1] == 0x4b && tail[i + 2] == 0x05 && tail[i + 3] == 0x06 )
{
eocd = i;
break;
}
}
if ( eocd < 0 )
throw new FormatException( "torch checkpoint: not a zip archive (no end-of-central-directory)." );
long count = BitConverter.ToUInt16( tail, eocd + 10 );
long cdOffset = BitConverter.ToUInt32( tail, eocd + 16 );
if ( count == 0xFFFF || cdOffset == 0xFFFFFFFF )
{
// ZIP64: locator sits 20 bytes before the EOCD.
var locatorAt = eocd - 20;
if ( locatorAt < 0 || BitConverter.ToUInt32( tail, locatorAt ) != 0x07064b50 )
throw new FormatException( "torch checkpoint: ZIP64 locator missing." );
var eocd64Offset = BitConverter.ToInt64( tail, locatorAt + 8 );
var eocd64 = ReadAt( file, eocd64Offset, 56 );
if ( BitConverter.ToUInt32( eocd64, 0 ) != 0x06064b50 )
throw new FormatException( "torch checkpoint: corrupt ZIP64 end-of-central-directory." );
count = BitConverter.ToInt64( eocd64, 32 );
cdOffset = BitConverter.ToInt64( eocd64, 48 );
}
// Central directory can itself be large — read it in one buffer (it is
// tiny relative to the payload: ~100 bytes per entry).
var cdLength = file.Length - cdOffset;
if ( cdLength > int.MaxValue )
throw new FormatException( "torch checkpoint: central directory too large." );
var cd = ReadAt( file, cdOffset, (int)cdLength );
var entries = new Dictionary<string, ZipIndexEntry>( StringComparer.Ordinal );
var at = 0;
for ( long e = 0; e < count; e++ )
{
if ( at + 46 > cd.Length || BitConverter.ToUInt32( cd, at ) != 0x02014b50 )
throw new FormatException( "torch checkpoint: corrupt central directory." );
var method = BitConverter.ToUInt16( cd, at + 10 );
long compressedSize = BitConverter.ToUInt32( cd, at + 20 );
long uncompressedSize = BitConverter.ToUInt32( cd, at + 24 );
var nameLength = BitConverter.ToUInt16( cd, at + 28 );
var extraLength = BitConverter.ToUInt16( cd, at + 30 );
var commentLength = BitConverter.ToUInt16( cd, at + 32 );
long localOffset = BitConverter.ToUInt32( cd, at + 42 );
var name = Encoding.UTF8.GetString( cd, at + 46, nameLength );
// ZIP64 extra field 0x0001: u64 values, in order, only for the
// fixed fields that overflowed to 0xFFFFFFFF.
var extraAt = at + 46 + nameLength;
var extraEnd = extraAt + extraLength;
while ( extraAt + 4 <= extraEnd )
{
var id = BitConverter.ToUInt16( cd, extraAt );
var size = BitConverter.ToUInt16( cd, extraAt + 2 );
if ( id == 0x0001 )
{
var v = extraAt + 4;
if ( uncompressedSize == 0xFFFFFFFF )
{
uncompressedSize = BitConverter.ToInt64( cd, v );
v += 8;
}
if ( compressedSize == 0xFFFFFFFF )
{
compressedSize = BitConverter.ToInt64( cd, v );
v += 8;
}
if ( localOffset == 0xFFFFFFFF )
localOffset = BitConverter.ToInt64( cd, v );
}
extraAt += 4 + size;
}
entries[name] = new ZipIndexEntry( localOffset, compressedSize, method );
at += 46 + nameLength + extraLength + commentLength;
}
return entries;
}
static byte[] ReadZipEntry( Stream file, ZipIndexEntry entry )
{
var header = ReadAt( file, entry.LocalHeaderOffset, 30 );
if ( BitConverter.ToUInt32( header, 0 ) != 0x04034b50 )
throw new FormatException( "torch checkpoint: corrupt local header." );
var nameLength = BitConverter.ToUInt16( header, 26 );
var extraLength = BitConverter.ToUInt16( header, 28 );
var dataOffset = entry.LocalHeaderOffset + 30 + nameLength + extraLength;
if ( entry.CompressedSize > int.MaxValue )
throw new FormatException( "torch checkpoint: single zip entry past 2GB is not supported." );
var raw = ReadAt( file, dataOffset, (int)entry.CompressedSize );
if ( entry.Method == 0 )
return raw;
if ( entry.Method != 8 )
throw new FormatException( $"torch checkpoint: unsupported compression method {entry.Method}." );
using var inflater = new DeflateStream( new MemoryStream( raw ), CompressionMode.Decompress );
using var output = new MemoryStream();
inflater.CopyTo( output );
return output.ToArray();
}
static byte[] ReadAt( Stream file, long offset, int count )
{
file.Seek( offset, SeekOrigin.Begin );
var buffer = new byte[count];
file.ReadExactly( buffer );
return buffer;
}
/// <summary>Central-directory zip walk (stored + deflate).</summary>
static Dictionary<string, byte[]> ReadZip( byte[] data )
{
// EOCD scan from the end.
var eocd = -1;
for ( var i = data.Length - 22; i >= 0; i-- )
{
if ( data[i] == 0x50 && data[i + 1] == 0x4b && data[i + 2] == 0x05 && data[i + 3] == 0x06 )
{
eocd = i;
break;
}
}
if ( eocd < 0 )
throw new FormatException( "torch checkpoint: not a zip archive (no end-of-central-directory)." );
var count = BitConverter.ToUInt16( data, eocd + 10 );
var cdOffset = BitConverter.ToUInt32( data, eocd + 16 );
var entries = new Dictionary<string, byte[]>( StringComparer.Ordinal );
var at = (int)cdOffset;
for ( var e = 0; e < count; e++ )
{
if ( at + 46 > data.Length || BitConverter.ToUInt32( data, at ) != 0x02014b50 )
throw new FormatException( "torch checkpoint: corrupt central directory." );
var method = BitConverter.ToUInt16( data, at + 10 );
var compressedSize = BitConverter.ToUInt32( data, at + 20 );
var uncompressedSize = BitConverter.ToUInt32( data, at + 24 );
var nameLength = BitConverter.ToUInt16( data, at + 28 );
var extraLength = BitConverter.ToUInt16( data, at + 30 );
var commentLength = BitConverter.ToUInt16( data, at + 32 );
var localOffset = BitConverter.ToUInt32( data, at + 42 );
var name = Encoding.UTF8.GetString( data, at + 46, nameLength );
at += 46 + nameLength + extraLength + commentLength;
// Local header carries its own name/extra lengths.
var lh = (int)localOffset;
if ( lh + 30 > data.Length || BitConverter.ToUInt32( data, lh ) != 0x04034b50 )
throw new FormatException( $"torch checkpoint: corrupt local header for '{name}'." );
var localNameLength = BitConverter.ToUInt16( data, lh + 26 );
var localExtraLength = BitConverter.ToUInt16( data, lh + 28 );
var dataStart = lh + 30 + localNameLength + localExtraLength;
if ( dataStart + compressedSize > data.Length )
throw new FormatException( $"torch checkpoint: entry '{name}' overruns the file." );
byte[] content;
if ( method == 0 )
{
content = new byte[compressedSize];
Array.Copy( data, dataStart, content, 0, (int)compressedSize );
}
else if ( method == 8 )
{
content = new byte[uncompressedSize];
using var ms = new MemoryStream( data, dataStart, (int)compressedSize );
using var deflate = new DeflateStream( ms, CompressionMode.Decompress );
var read = 0;
while ( read < content.Length )
{
var n = deflate.Read( content, read, content.Length - read );
if ( n <= 0 )
throw new FormatException( $"torch checkpoint: truncated deflate data in '{name}'." );
read += n;
}
}
else
{
throw new FormatException( $"torch checkpoint: unsupported zip method {method} for '{name}'." );
}
entries[name] = content;
}
return entries;
}
// ================================================================== pickle
sealed class StorageRef
{
public required string Key;
public required long Numel;
}
sealed class GlobalRef
{
public required string Module;
public required string Name;
}
/// <summary>An unknown global resolved as INERT DATA - nothing is ever executed.</summary>
sealed class OpaqueRef
{
public required string What;
}
/// <summary>A tensor recorded during unpickling, materialized once storages are read.</summary>
sealed class TensorStub
{
public required string StorageKey;
public required int StorageOffset;
public required int[] Shape;
}
sealed class Unpickler
{
readonly byte[] _data;
readonly List<object> _stack = new();
readonly Dictionary<int, object> _memo = new();
int _pos;
/// <summary>Every storage seen via persistent ids: key → (type name, element width).</summary>
public Dictionary<string, (string Type, int Width)> Storages { get; } = new( StringComparer.Ordinal );
static readonly object MarkSentinel = new();
public Unpickler( byte[] pickle, int startOffset )
{
_data = pickle;
_pos = startOffset;
}
byte Next()
{
if ( _pos >= _data.Length )
throw new FormatException( "torch checkpoint: pickle stream truncated." );
return _data[_pos++];
}
byte[] NextBytes( int count )
{
if ( _pos + count > _data.Length )
throw new FormatException( "torch checkpoint: pickle stream truncated." );
var span = new byte[count];
Array.Copy( _data, _pos, span, 0, count );
_pos += count;
return span;
}
string ReadLine()
{
var start = _pos;
while ( _pos < _data.Length && _data[_pos] != (byte)'\n' )
_pos++;
if ( _pos >= _data.Length )
throw new FormatException( "torch checkpoint: unterminated pickle text line." );
var line = Encoding.ASCII.GetString( _data, start, _pos - start );
_pos++;
return line;
}
void Push( object value ) => _stack.Add( value );
object Pop()
{
if ( _stack.Count == 0 )
throw new FormatException( "torch checkpoint: pickle stack underflow." );
var value = _stack[^1];
_stack.RemoveAt( _stack.Count - 1 );
return value;
}
List<object> PopToMark()
{
var items = new List<object>();
while ( true )
{
var value = Pop();
if ( ReferenceEquals( value, MarkSentinel ) )
break;
items.Add( value );
}
items.Reverse();
return items;
}
public object Run( out int endOffset )
{
var result = RunInner();
endOffset = _pos;
return result;
}
object RunInner()
{
while ( true )
{
var op = Next();
switch ( op )
{
case 0x80: Next(); break; // PROTO n
case 0x95: NextBytes( 8 ); break; // FRAME len
case (byte)'}': Push( new Dictionary<object, object>() ); break; // EMPTY_DICT
case (byte)']': Push( new List<object>() ); break; // EMPTY_LIST
case (byte)'(': Push( MarkSentinel ); break; // MARK
case (byte)'N': Push( null ); break; // NONE
case 0x88: Push( true ); break; // NEWTRUE
case 0x89: Push( false ); break; // NEWFALSE
case (byte)'X': // BINUNICODE
{
var length = BitConverter.ToInt32( NextBytes( 4 ), 0 );
Push( Encoding.UTF8.GetString( NextBytes( length ) ) );
break;
}
case 0x8c: // SHORT_BINUNICODE
{
int length = Next();
Push( Encoding.UTF8.GetString( NextBytes( length ) ) );
break;
}
case (byte)'U': // SHORT_BINSTRING (legacy sys_info)
{
int length = Next();
Push( Encoding.Latin1.GetString( NextBytes( length ) ) );
break;
}
case (byte)'T': // BINSTRING
{
var length = BitConverter.ToInt32( NextBytes( 4 ), 0 );
Push( Encoding.Latin1.GetString( NextBytes( length ) ) );
break;
}
case (byte)'J': Push( BitConverter.ToInt32( NextBytes( 4 ), 0 ) ); break; // BININT
case (byte)'K': Push( (int)Next() ); break; // BININT1
case (byte)'M': Push( (int)BitConverter.ToUInt16( NextBytes( 2 ), 0 ) ); break; // BININT2
case 0x8a: // LONG1
{
int length = Next();
var bytes = NextBytes( length );
long value = 0;
for ( var i = length - 1; i >= 0; i-- )
value = (value << 8) | bytes[i];
// little-endian two's complement; small values only in practice
Push( (int)value );
break;
}
case (byte)'q': _memo[Next()] = Peek(); break; // BINPUT
case (byte)'r': _memo[BitConverter.ToInt32( NextBytes( 4 ), 0 )] = Peek(); break; // LONG_BINPUT
case 0x94: _memo[_memo.Count] = Peek(); break; // MEMOIZE
case (byte)'h': Push( Memo( Next() ) ); break; // BINGET
case (byte)'j': Push( Memo( BitConverter.ToInt32( NextBytes( 4 ), 0 ) ) ); break; // LONG_BINGET
case (byte)'t': Push( PopToMark().ToArray() ); break; // TUPLE
case (byte)')': Push( Array.Empty<object>() ); break; // EMPTY_TUPLE
case 0x85: Push( new[] { Pop() } ); break; // TUPLE1
case 0x86: // TUPLE2
{
var b = Pop(); var a = Pop();
Push( new[] { a, b } );
break;
}
case 0x87: // TUPLE3
{
var c = Pop(); var b = Pop(); var a = Pop();
Push( new[] { a, b, c } );
break;
}
case (byte)'c': // GLOBAL
{
var module = ReadLine();
var name = ReadLine();
Push( ResolveGlobal( module, name ) );
break;
}
case 0x93: // STACK_GLOBAL
{
var name = Pop() as string ?? throw Bad( "STACK_GLOBAL name" );
var module = Pop() as string ?? throw Bad( "STACK_GLOBAL module" );
Push( ResolveGlobal( module, name ) );
break;
}
case (byte)'Q': // BINPERSID
{
Push( ResolvePersistentId( Pop() ) );
break;
}
case (byte)'R': // REDUCE
{
var args = Pop() as object[] ?? throw Bad( "REDUCE args" );
var callable = Pop();
Push( Invoke( callable, args ) );
break;
}
case (byte)'s': // SETITEM
{
var value = Pop();
var key = Pop();
(Peek() as Dictionary<object, object> ?? throw Bad( "SETITEM target" ))[key] = value;
break;
}
case (byte)'u': // SETITEMS
{
var items = PopToMark();
var dict = Peek() as Dictionary<object, object> ?? throw Bad( "SETITEMS target" );
for ( var i = 0; i + 1 < items.Count; i += 2 )
dict[items[i]] = items[i + 1];
break;
}
case (byte)'a': // APPEND (single)
{
var item = Pop();
(Peek() as List<object> ?? throw Bad( "APPEND target" )).Add( item );
break;
}
case (byte)'0': Pop(); break; // POP
case (byte)'e': // APPENDS
{
var items = PopToMark();
var list = Peek() as List<object> ?? throw Bad( "APPENDS target" );
list.AddRange( items );
break;
}
case (byte)'G': // BINFLOAT (big-endian f64)
{
var raw = NextBytes( 8 );
var swapped = new byte[8];
for ( var i = 0; i < 8; i++ )
swapped[i] = raw[7 - i];
Push( BitConverter.ToDouble( swapped, 0 ) );
break;
}
case 0x81: // NEWOBJ: cls(*args) - inert
{
Pop(); // args
var cls = Pop();
Push( cls is GlobalRef g
? Invoke( g, Array.Empty<object>() )
: new OpaqueRef { What = "NEWOBJ instance" } );
break;
}
case (byte)'b': // BUILD: apply state - inert merge
{
var state = Pop();
if ( Peek() is Dictionary<object, object> targetDict
&& state is Dictionary<object, object> stateDict )
{
foreach ( var (k, v) in stateDict )
targetDict[k] = v;
}
// Opaque targets: state discarded, nothing executed.
break;
}
case (byte)'.': // STOP
return Pop();
default:
throw new FormatException(
$"torch checkpoint: unsupported pickle opcode 0x{op:X2} at {_pos - 1} "
+ "(the restricted reader accepts tensor state dicts only)." );
}
}
}
object Peek() => _stack.Count > 0 ? _stack[^1] : throw Bad( "empty stack" );
object Memo( int slot )
=> _memo.TryGetValue( slot, out var value ) ? value : throw Bad( $"memo slot {slot}" );
static FormatException Bad( string what )
=> new( $"torch checkpoint: malformed pickle ({what})." );
static object ResolveGlobal( string module, string name ) => (module, name) switch
{
("torch._utils", "_rebuild_tensor_v2") => new GlobalRef { Module = module, Name = name },
("torch", "FloatStorage") => new GlobalRef { Module = module, Name = name },
("collections", "OrderedDict") => new GlobalRef { Module = module, Name = name },
// Anything else becomes INERT DATA: nothing is looked up or executed, and
// values built from it are dropped at materialization. Training
// checkpoints carry optimizer state and framework metadata we must
// tolerate without running.
_ => new OpaqueRef { What = $"{module}.{name}" },
};
static int StorageWidth( string typeName ) => typeName switch
{
"DoubleStorage" or "LongStorage" => 8,
"FloatStorage" or "IntStorage" => 4,
"HalfStorage" or "ShortStorage" or "BFloat16Storage" => 2,
_ => 1, // Byte/Char/Bool
};
object ResolvePersistentId( object pid )
{
if ( pid is not object[] tuple || tuple.Length < 5
|| tuple[0] is not string kind || kind != "storage"
|| tuple[2] is not string key )
throw Bad( "persistent id" );
var typeName = tuple[1] switch
{
GlobalRef g => g.Name,
OpaqueRef o => o.What.Contains( '.' ) ? o.What[(o.What.LastIndexOf( '.' ) + 1)..] : o.What,
_ => throw Bad( "storage type" ),
};
var numel = tuple[4] switch { int i => (long)i, long l => l, _ => throw Bad( "storage numel" ) };
Storages[key] = (typeName, StorageWidth( typeName ));
if ( typeName is not ("FloatStorage" or "BFloat16Storage" or "HalfStorage") )
return new OpaqueRef { What = $"non-float storage {typeName}" };
return new StorageRef { Key = key, Numel = numel };
}
object Invoke( object callable, object[] args )
{
if ( callable is OpaqueRef opaque )
return opaque; // inert: nothing executed, value dropped later
if ( callable is not GlobalRef global )
throw Bad( "REDUCE callable" );
if ( global is { Module: "collections", Name: "OrderedDict" } )
return new Dictionary<object, object>();
if ( global is { Module: "torch._utils", Name: "_rebuild_tensor_v2" } )
{
if ( args.Length < 4
|| args[1] is not int storageOffset
|| args[2] is not object[] sizeTuple
|| args[3] is not object[] strideTuple )
throw Bad( "_rebuild_tensor_v2 args" );
if ( args[0] is OpaqueRef )
return new OpaqueRef { What = "tensor on unsupported storage" };
if ( args[0] is not StorageRef storage )
throw Bad( "_rebuild_tensor_v2 storage" );
var shape = sizeTuple
.Select( s => s is int v ? v : throw Bad( "size" ) ).ToArray();
var stride = strideTuple
.Select( s => s is int v ? v : throw Bad( "stride" ) ).ToArray();
// Contiguous row-major only.
var expected = 1;
for ( var d = shape.Length - 1; d >= 0; d-- )
{
if ( stride[d] != expected )
throw new FormatException(
"torch checkpoint: non-contiguous tensor storage is not supported." );
expected *= shape[d];
}
return new TensorStub
{
StorageKey = storage.Key,
StorageOffset = storageOffset,
Shape = shape,
};
}
throw Bad( $"REDUCE of {global.Module}.{global.Name}" );
}
}
}