Code/AutoRig/Dl/TorchCheckpoint.cs

Parser for PyTorch checkpoint files. It reads both zip-format and legacy pickle stream checkpoints into named float32 Tensor objects, using a restricted unpickler that only accepts tensor state dict opcodes and a small set of safe globals.

File AccessObfuscated Code
using System.IO.Compression;
using System.Text;

namespace AutoRig.Dl;

/// <summary>
/// Reads PyTorch checkpoints (both the zip format and the pre-1.6 LEGACY stream
/// format) into named float32 tensors. The embedded pickles run on a RESTRICTED
/// stack machine: only the opcodes a tensor state dict uses, and only three
/// resolvable globals (torch._utils._rebuild_tensor_v2, torch.FloatStorage,
/// collections.OrderedDict). Anything else throws FormatException - arbitrary
/// pickle code can never run.
/// </summary>
public static class TorchCheckpoint
{
    /// <exception cref="FormatException">Unsupported container, pickle content, or
    /// inconsistent tensor data.</exception>
    public static IReadOnlyDictionary<string, Tensor> Parse( byte[] data )
    {
        ArgumentNullException.ThrowIfNull( data );
        if ( data.Length > 2 && data[0] == 0x50 && data[1] == 0x4b )
            return ParseZip( data );
        if ( data.Length > 2 && data[0] == 0x80 )
            return ParseLegacy( data );
        throw new FormatException( "torch checkpoint: neither a zip archive nor a legacy pickle stream." );
    }

    static IReadOnlyDictionary<string, Tensor> ParseZip( byte[] data )
    {
        var entries = ReadZip( data );

        var pickleEntry = entries.Keys.FirstOrDefault( k => k.EndsWith( "/data.pkl", StringComparison.Ordinal ) )
            ?? throw new FormatException( "torch checkpoint: no data.pkl entry found." );
        var prefix = pickleEntry[..^"data.pkl".Length];

        var unpickler = new Unpickler( entries[pickleEntry], 0 );
        var result = unpickler.Run( out _ );
        return Materialize( result, key =>
        {
            if ( !entries.TryGetValue( $"{prefix}data/{key}", out var bytes ) )
                throw new FormatException( $"torch checkpoint: storage '{key}' missing." );
            var type = unpickler.Storages.TryGetValue( key, out var info )
                ? info.Type : "FloatStorage";
            return ToFloat32Bytes( bytes, type );
        } );
    }

    /// <summary>
    /// Streams a zip-format checkpoint — required for files past the 2GB array
    /// limit (e.g. MagicArticulate's 4.4GB training checkpoint, which is also
    /// ZIP64). The stream must be seekable. <paramref name="keepTensor"/> filters
    /// by full dotted name BEFORE storage bytes are read, so skipping e.g.
    /// optimizer state costs no memory or IO.
    /// </summary>
    public static IReadOnlyDictionary<string, Tensor> Parse(
        Stream file, Func<string, bool> keepTensor = null )
    {
        ArgumentNullException.ThrowIfNull( file );
        var entries = ReadZipIndex( file );

        var pickleEntry = entries.Keys.FirstOrDefault( k => k.EndsWith( "/data.pkl", StringComparison.Ordinal ) )
            ?? throw new FormatException( "torch checkpoint: no data.pkl entry found." );
        var prefix = pickleEntry[..^"data.pkl".Length];

        var unpickler = new Unpickler( ReadZipEntry( file, entries[pickleEntry] ), 0 );
        var result = unpickler.Run( out _ );
        return Materialize( result, key =>
        {
            if ( !entries.TryGetValue( $"{prefix}data/{key}", out var entry ) )
                throw new FormatException( $"torch checkpoint: storage '{key}' missing." );
            var type = unpickler.Storages.TryGetValue( key, out var info )
                ? info.Type : "FloatStorage";
            return ToFloat32Bytes( ReadZipEntry( file, entry ), type );
        }, keepTensor );
    }

    /// <summary>Normalizes accepted storage payloads to raw float32 bytes (bf16/f16
    /// checkpoints — e.g. SkinTokens' bfloat16 model — widen at load).</summary>
    internal static byte[] ToFloat32Bytes( byte[] bytes, string storageType )
    {
        switch ( storageType )
        {
            case "BFloat16Storage":
            {
                var result = new byte[bytes.Length * 2];
                for ( var i = 0; i < bytes.Length / 2; i++ )
                {
                    // bf16 = the high 16 bits of an f32 (little-endian layout).
                    result[i * 4 + 2] = bytes[i * 2];
                    result[i * 4 + 3] = bytes[i * 2 + 1];
                }
                return result;
            }
            case "HalfStorage":
            {
                var result = new byte[bytes.Length * 2];
                for ( var i = 0; i < bytes.Length / 2; i++ )
                {
                    var value = (float)BitConverter.ToHalf( bytes, i * 2 );
                    BitConverter.TryWriteBytes( result.AsSpan( i * 4, 4 ), value );
                }
                return result;
            }
            default:
                return bytes;   // FloatStorage — already f32
        }
    }

    /// <summary>
    /// Legacy stream: four consecutive pickles (magic, protocol, sys_info, object),
    /// then a pickle listing storage keys, then per key an i64 element count followed
    /// by the raw float32 storage bytes.
    /// </summary>
    static IReadOnlyDictionary<string, Tensor> ParseLegacy( byte[] data )
    {
        var offset = 0;
        object result = null;
        Unpickler objectUnpickler = null;
        for ( var p = 0; p < 4; p++ )
        {
            objectUnpickler = new Unpickler( data, offset );
            result = objectUnpickler.Run( out offset );
        }

        var keysObject = new Unpickler( data, offset ).Run( out offset );
        if ( keysObject is not List<object> keyList )
            throw new FormatException( "torch checkpoint: legacy storage-key list missing." );

        var storages = new Dictionary<string, byte[]>( StringComparer.Ordinal );
        foreach ( var keyObject in keyList )
        {
            if ( keyObject is not string key )
                throw new FormatException( "torch checkpoint: legacy storage key is not a string." );
            if ( offset + 8 > data.Length )
                throw new FormatException( "torch checkpoint: legacy stream truncated at storage sizes." );
            var numel = BitConverter.ToInt64( data, offset );
            offset += 8;

            // Element width comes from the storage TYPE recorded at its persistent id
            // (optimizer state mixes Long/Int storages between the float tensors).
            var width = objectUnpickler.Storages.TryGetValue( key, out var info ) ? info.Width : 4;
            var byteCount = numel * width;
            if ( numel < 0 || offset + byteCount > data.Length )
                throw new FormatException( $"torch checkpoint: legacy storage '{key}' overruns the file." );

            if ( info.Type is "FloatStorage" or "BFloat16Storage" or "HalfStorage" )
            {
                var bytes = new byte[byteCount];
                Array.Copy( data, offset, bytes, 0, byteCount );
                storages[key] = ToFloat32Bytes( bytes, info.Type );
            }
            offset += (int)byteCount;
        }

        return Materialize( result, key =>
            storages.TryGetValue( key, out var bytes )
                ? bytes
                : throw new FormatException( $"torch checkpoint: legacy storage '{key}' missing." ) );
    }

    /// <summary>Converts the unpickled graph's tensor stubs into tensors.</summary>
    static IReadOnlyDictionary<string, Tensor> Materialize(
        object root, Func<string, byte[]> storageBytes,
        Func<string, bool> keep = null, string prefix = "" )
    {
        if ( root is not Dictionary<object, object> dict )
            throw new FormatException( "torch checkpoint: top-level pickle value is not a dict." );

        var tensors = new Dictionary<string, Tensor>( StringComparer.Ordinal );
        foreach ( var (key, value) in dict )
        {
            if ( key is not string name )
                continue;
            var full = prefix + name;
            if ( value is TensorStub stub )
            {
                if ( keep is not null && !keep( full ) )
                    continue;
                var bytes = storageBytes( stub.StorageKey );
                long count = 1;
                foreach ( var d in stub.Shape )
                    count *= d;
                if ( (stub.StorageOffset + count) * 4L > bytes.Length )
                    throw new FormatException(
                        $"torch checkpoint: storage '{stub.StorageKey}' too small for shape "
                        + $"[{string.Join( ",", stub.Shape )}] at offset {stub.StorageOffset}." );
                var values = new float[count];
                Buffer.BlockCopy( bytes, stub.StorageOffset * 4, values, 0, (int)count * 4 );
                tensors[full] = stub.Shape.Length == 0
                    ? Tensor.From( values, 1 )
                    : Tensor.From( values, stub.Shape );
            }
            else if ( value is Dictionary<object, object> nested )
            {
                // Nested dicts ("state_dict"/"model" wrappers in training checkpoints).
                foreach ( var (innerKey, innerValue) in Materialize( nested, storageBytes, keep, $"{full}." ) )
                    tensors[innerKey] = innerValue;
            }
        }
        return tensors;
    }

    // ================================================================== zip

    /// <summary>Reads a plain zip archive (entry name → bytes). Public so model
    /// bundles (e.g. the RigNet checkpoints zip) can be read without extraction.</summary>
    public static IReadOnlyDictionary<string, byte[]> ReadArchive( byte[] data ) => ReadZip( data );

    // ---- streaming zip index (ZIP64-aware) — for checkpoints past 2GB ----

    readonly record struct ZipIndexEntry( long LocalHeaderOffset, long CompressedSize, ushort Method );

    static Dictionary<string, ZipIndexEntry> ReadZipIndex( Stream file )
    {
        // EOCD scan over the file tail (comment can pad up to 64KB).
        var tailLength = (int)Math.Min( file.Length, 66_000 );
        var tail = ReadAt( file, file.Length - tailLength, tailLength );
        var eocd = -1;
        for ( var i = tailLength - 22; i >= 0; i-- )
        {
            if ( tail[i] == 0x50 && tail[i + 1] == 0x4b && tail[i + 2] == 0x05 && tail[i + 3] == 0x06 )
            {
                eocd = i;
                break;
            }
        }
        if ( eocd < 0 )
            throw new FormatException( "torch checkpoint: not a zip archive (no end-of-central-directory)." );

        long count = BitConverter.ToUInt16( tail, eocd + 10 );
        long cdOffset = BitConverter.ToUInt32( tail, eocd + 16 );

        if ( count == 0xFFFF || cdOffset == 0xFFFFFFFF )
        {
            // ZIP64: locator sits 20 bytes before the EOCD.
            var locatorAt = eocd - 20;
            if ( locatorAt < 0 || BitConverter.ToUInt32( tail, locatorAt ) != 0x07064b50 )
                throw new FormatException( "torch checkpoint: ZIP64 locator missing." );
            var eocd64Offset = BitConverter.ToInt64( tail, locatorAt + 8 );
            var eocd64 = ReadAt( file, eocd64Offset, 56 );
            if ( BitConverter.ToUInt32( eocd64, 0 ) != 0x06064b50 )
                throw new FormatException( "torch checkpoint: corrupt ZIP64 end-of-central-directory." );
            count = BitConverter.ToInt64( eocd64, 32 );
            cdOffset = BitConverter.ToInt64( eocd64, 48 );
        }

        // Central directory can itself be large — read it in one buffer (it is
        // tiny relative to the payload: ~100 bytes per entry).
        var cdLength = file.Length - cdOffset;
        if ( cdLength > int.MaxValue )
            throw new FormatException( "torch checkpoint: central directory too large." );
        var cd = ReadAt( file, cdOffset, (int)cdLength );

        var entries = new Dictionary<string, ZipIndexEntry>( StringComparer.Ordinal );
        var at = 0;
        for ( long e = 0; e < count; e++ )
        {
            if ( at + 46 > cd.Length || BitConverter.ToUInt32( cd, at ) != 0x02014b50 )
                throw new FormatException( "torch checkpoint: corrupt central directory." );
            var method = BitConverter.ToUInt16( cd, at + 10 );
            long compressedSize = BitConverter.ToUInt32( cd, at + 20 );
            long uncompressedSize = BitConverter.ToUInt32( cd, at + 24 );
            var nameLength = BitConverter.ToUInt16( cd, at + 28 );
            var extraLength = BitConverter.ToUInt16( cd, at + 30 );
            var commentLength = BitConverter.ToUInt16( cd, at + 32 );
            long localOffset = BitConverter.ToUInt32( cd, at + 42 );
            var name = Encoding.UTF8.GetString( cd, at + 46, nameLength );

            // ZIP64 extra field 0x0001: u64 values, in order, only for the
            // fixed fields that overflowed to 0xFFFFFFFF.
            var extraAt = at + 46 + nameLength;
            var extraEnd = extraAt + extraLength;
            while ( extraAt + 4 <= extraEnd )
            {
                var id = BitConverter.ToUInt16( cd, extraAt );
                var size = BitConverter.ToUInt16( cd, extraAt + 2 );
                if ( id == 0x0001 )
                {
                    var v = extraAt + 4;
                    if ( uncompressedSize == 0xFFFFFFFF )
                    {
                        uncompressedSize = BitConverter.ToInt64( cd, v );
                        v += 8;
                    }
                    if ( compressedSize == 0xFFFFFFFF )
                    {
                        compressedSize = BitConverter.ToInt64( cd, v );
                        v += 8;
                    }
                    if ( localOffset == 0xFFFFFFFF )
                        localOffset = BitConverter.ToInt64( cd, v );
                }
                extraAt += 4 + size;
            }

            entries[name] = new ZipIndexEntry( localOffset, compressedSize, method );
            at += 46 + nameLength + extraLength + commentLength;
        }
        return entries;
    }

    static byte[] ReadZipEntry( Stream file, ZipIndexEntry entry )
    {
        var header = ReadAt( file, entry.LocalHeaderOffset, 30 );
        if ( BitConverter.ToUInt32( header, 0 ) != 0x04034b50 )
            throw new FormatException( "torch checkpoint: corrupt local header." );
        var nameLength = BitConverter.ToUInt16( header, 26 );
        var extraLength = BitConverter.ToUInt16( header, 28 );
        var dataOffset = entry.LocalHeaderOffset + 30 + nameLength + extraLength;

        if ( entry.CompressedSize > int.MaxValue )
            throw new FormatException( "torch checkpoint: single zip entry past 2GB is not supported." );
        var raw = ReadAt( file, dataOffset, (int)entry.CompressedSize );
        if ( entry.Method == 0 )
            return raw;
        if ( entry.Method != 8 )
            throw new FormatException( $"torch checkpoint: unsupported compression method {entry.Method}." );

        using var inflater = new DeflateStream( new MemoryStream( raw ), CompressionMode.Decompress );
        using var output = new MemoryStream();
        inflater.CopyTo( output );
        return output.ToArray();
    }

    static byte[] ReadAt( Stream file, long offset, int count )
    {
        file.Seek( offset, SeekOrigin.Begin );
        var buffer = new byte[count];
        file.ReadExactly( buffer );
        return buffer;
    }

    /// <summary>Central-directory zip walk (stored + deflate).</summary>
    static Dictionary<string, byte[]> ReadZip( byte[] data )
    {
        // EOCD scan from the end.
        var eocd = -1;
        for ( var i = data.Length - 22; i >= 0; i-- )
        {
            if ( data[i] == 0x50 && data[i + 1] == 0x4b && data[i + 2] == 0x05 && data[i + 3] == 0x06 )
            {
                eocd = i;
                break;
            }
        }
        if ( eocd < 0 )
            throw new FormatException( "torch checkpoint: not a zip archive (no end-of-central-directory)." );

        var count = BitConverter.ToUInt16( data, eocd + 10 );
        var cdOffset = BitConverter.ToUInt32( data, eocd + 16 );

        var entries = new Dictionary<string, byte[]>( StringComparer.Ordinal );
        var at = (int)cdOffset;
        for ( var e = 0; e < count; e++ )
        {
            if ( at + 46 > data.Length || BitConverter.ToUInt32( data, at ) != 0x02014b50 )
                throw new FormatException( "torch checkpoint: corrupt central directory." );
            var method = BitConverter.ToUInt16( data, at + 10 );
            var compressedSize = BitConverter.ToUInt32( data, at + 20 );
            var uncompressedSize = BitConverter.ToUInt32( data, at + 24 );
            var nameLength = BitConverter.ToUInt16( data, at + 28 );
            var extraLength = BitConverter.ToUInt16( data, at + 30 );
            var commentLength = BitConverter.ToUInt16( data, at + 32 );
            var localOffset = BitConverter.ToUInt32( data, at + 42 );
            var name = Encoding.UTF8.GetString( data, at + 46, nameLength );
            at += 46 + nameLength + extraLength + commentLength;

            // Local header carries its own name/extra lengths.
            var lh = (int)localOffset;
            if ( lh + 30 > data.Length || BitConverter.ToUInt32( data, lh ) != 0x04034b50 )
                throw new FormatException( $"torch checkpoint: corrupt local header for '{name}'." );
            var localNameLength = BitConverter.ToUInt16( data, lh + 26 );
            var localExtraLength = BitConverter.ToUInt16( data, lh + 28 );
            var dataStart = lh + 30 + localNameLength + localExtraLength;
            if ( dataStart + compressedSize > data.Length )
                throw new FormatException( $"torch checkpoint: entry '{name}' overruns the file." );

            byte[] content;
            if ( method == 0 )
            {
                content = new byte[compressedSize];
                Array.Copy( data, dataStart, content, 0, (int)compressedSize );
            }
            else if ( method == 8 )
            {
                content = new byte[uncompressedSize];
                using var ms = new MemoryStream( data, dataStart, (int)compressedSize );
                using var deflate = new DeflateStream( ms, CompressionMode.Decompress );
                var read = 0;
                while ( read < content.Length )
                {
                    var n = deflate.Read( content, read, content.Length - read );
                    if ( n <= 0 )
                        throw new FormatException( $"torch checkpoint: truncated deflate data in '{name}'." );
                    read += n;
                }
            }
            else
            {
                throw new FormatException( $"torch checkpoint: unsupported zip method {method} for '{name}'." );
            }
            entries[name] = content;
        }
        return entries;
    }

    // ================================================================== pickle

    sealed class StorageRef
    {
        public required string Key;
        public required long Numel;
    }

    sealed class GlobalRef
    {
        public required string Module;
        public required string Name;
    }

    /// <summary>An unknown global resolved as INERT DATA - nothing is ever executed.</summary>
    sealed class OpaqueRef
    {
        public required string What;
    }

    /// <summary>A tensor recorded during unpickling, materialized once storages are read.</summary>
    sealed class TensorStub
    {
        public required string StorageKey;
        public required int StorageOffset;
        public required int[] Shape;
    }

    sealed class Unpickler
    {
        readonly byte[] _data;
        readonly List<object> _stack = new();
        readonly Dictionary<int, object> _memo = new();
        int _pos;

        /// <summary>Every storage seen via persistent ids: key → (type name, element width).</summary>
        public Dictionary<string, (string Type, int Width)> Storages { get; } = new( StringComparer.Ordinal );

        static readonly object MarkSentinel = new();

        public Unpickler( byte[] pickle, int startOffset )
        {
            _data = pickle;
            _pos = startOffset;
        }

        byte Next()
        {
            if ( _pos >= _data.Length )
                throw new FormatException( "torch checkpoint: pickle stream truncated." );
            return _data[_pos++];
        }

        byte[] NextBytes( int count )
        {
            if ( _pos + count > _data.Length )
                throw new FormatException( "torch checkpoint: pickle stream truncated." );
            var span = new byte[count];
            Array.Copy( _data, _pos, span, 0, count );
            _pos += count;
            return span;
        }

        string ReadLine()
        {
            var start = _pos;
            while ( _pos < _data.Length && _data[_pos] != (byte)'\n' )
                _pos++;
            if ( _pos >= _data.Length )
                throw new FormatException( "torch checkpoint: unterminated pickle text line." );
            var line = Encoding.ASCII.GetString( _data, start, _pos - start );
            _pos++;
            return line;
        }

        void Push( object value ) => _stack.Add( value );

        object Pop()
        {
            if ( _stack.Count == 0 )
                throw new FormatException( "torch checkpoint: pickle stack underflow." );
            var value = _stack[^1];
            _stack.RemoveAt( _stack.Count - 1 );
            return value;
        }

        List<object> PopToMark()
        {
            var items = new List<object>();
            while ( true )
            {
                var value = Pop();
                if ( ReferenceEquals( value, MarkSentinel ) )
                    break;
                items.Add( value );
            }
            items.Reverse();
            return items;
        }

        public object Run( out int endOffset )
        {
            var result = RunInner();
            endOffset = _pos;
            return result;
        }

        object RunInner()
        {
            while ( true )
            {
                var op = Next();
                switch ( op )
                {
                    case 0x80: Next(); break;                             // PROTO n
                    case 0x95: NextBytes( 8 ); break;                      // FRAME len
                    case (byte)'}': Push( new Dictionary<object, object>() ); break; // EMPTY_DICT
                    case (byte)']': Push( new List<object>() ); break;     // EMPTY_LIST
                    case (byte)'(': Push( MarkSentinel ); break;           // MARK
                    case (byte)'N': Push( null ); break;                   // NONE
                    case 0x88: Push( true ); break;                        // NEWTRUE
                    case 0x89: Push( false ); break;                       // NEWFALSE

                    case (byte)'X':                                        // BINUNICODE
                    {
                        var length = BitConverter.ToInt32( NextBytes( 4 ), 0 );
                        Push( Encoding.UTF8.GetString( NextBytes( length ) ) );
                        break;
                    }
                    case 0x8c:                                             // SHORT_BINUNICODE
                    {
                        int length = Next();
                        Push( Encoding.UTF8.GetString( NextBytes( length ) ) );
                        break;
                    }
                    case (byte)'U':                                        // SHORT_BINSTRING (legacy sys_info)
                    {
                        int length = Next();
                        Push( Encoding.Latin1.GetString( NextBytes( length ) ) );
                        break;
                    }
                    case (byte)'T':                                        // BINSTRING
                    {
                        var length = BitConverter.ToInt32( NextBytes( 4 ), 0 );
                        Push( Encoding.Latin1.GetString( NextBytes( length ) ) );
                        break;
                    }

                    case (byte)'J': Push( BitConverter.ToInt32( NextBytes( 4 ), 0 ) ); break; // BININT
                    case (byte)'K': Push( (int)Next() ); break;            // BININT1
                    case (byte)'M': Push( (int)BitConverter.ToUInt16( NextBytes( 2 ), 0 ) ); break; // BININT2
                    case 0x8a:                                             // LONG1
                    {
                        int length = Next();
                        var bytes = NextBytes( length );
                        long value = 0;
                        for ( var i = length - 1; i >= 0; i-- )
                            value = (value << 8) | bytes[i];
                        // little-endian two's complement; small values only in practice
                        Push( (int)value );
                        break;
                    }

                    case (byte)'q': _memo[Next()] = Peek(); break;         // BINPUT
                    case (byte)'r': _memo[BitConverter.ToInt32( NextBytes( 4 ), 0 )] = Peek(); break; // LONG_BINPUT
                    case 0x94: _memo[_memo.Count] = Peek(); break;         // MEMOIZE
                    case (byte)'h': Push( Memo( Next() ) ); break;         // BINGET
                    case (byte)'j': Push( Memo( BitConverter.ToInt32( NextBytes( 4 ), 0 ) ) ); break; // LONG_BINGET

                    case (byte)'t': Push( PopToMark().ToArray() ); break;  // TUPLE
                    case (byte)')': Push( Array.Empty<object>() ); break;  // EMPTY_TUPLE
                    case 0x85: Push( new[] { Pop() } ); break;             // TUPLE1
                    case 0x86:                                             // TUPLE2
                    {
                        var b = Pop(); var a = Pop();
                        Push( new[] { a, b } );
                        break;
                    }
                    case 0x87:                                             // TUPLE3
                    {
                        var c = Pop(); var b = Pop(); var a = Pop();
                        Push( new[] { a, b, c } );
                        break;
                    }

                    case (byte)'c':                                        // GLOBAL
                    {
                        var module = ReadLine();
                        var name = ReadLine();
                        Push( ResolveGlobal( module, name ) );
                        break;
                    }
                    case 0x93:                                             // STACK_GLOBAL
                    {
                        var name = Pop() as string ?? throw Bad( "STACK_GLOBAL name" );
                        var module = Pop() as string ?? throw Bad( "STACK_GLOBAL module" );
                        Push( ResolveGlobal( module, name ) );
                        break;
                    }

                    case (byte)'Q':                                        // BINPERSID
                    {
                        Push( ResolvePersistentId( Pop() ) );
                        break;
                    }

                    case (byte)'R':                                        // REDUCE
                    {
                        var args = Pop() as object[] ?? throw Bad( "REDUCE args" );
                        var callable = Pop();
                        Push( Invoke( callable, args ) );
                        break;
                    }

                    case (byte)'s':                                        // SETITEM
                    {
                        var value = Pop();
                        var key = Pop();
                        (Peek() as Dictionary<object, object> ?? throw Bad( "SETITEM target" ))[key] = value;
                        break;
                    }
                    case (byte)'u':                                        // SETITEMS
                    {
                        var items = PopToMark();
                        var dict = Peek() as Dictionary<object, object> ?? throw Bad( "SETITEMS target" );
                        for ( var i = 0; i + 1 < items.Count; i += 2 )
                            dict[items[i]] = items[i + 1];
                        break;
                    }
                    case (byte)'a':                                        // APPEND (single)
                    {
                        var item = Pop();
                        (Peek() as List<object> ?? throw Bad( "APPEND target" )).Add( item );
                        break;
                    }
                    case (byte)'0': Pop(); break;                          // POP
                    case (byte)'e':                                        // APPENDS
                    {
                        var items = PopToMark();
                        var list = Peek() as List<object> ?? throw Bad( "APPENDS target" );
                        list.AddRange( items );
                        break;
                    }

                    case (byte)'G':                                        // BINFLOAT (big-endian f64)
                    {
                        var raw = NextBytes( 8 );
                        var swapped = new byte[8];
                        for ( var i = 0; i < 8; i++ )
                            swapped[i] = raw[7 - i];
                        Push( BitConverter.ToDouble( swapped, 0 ) );
                        break;
                    }

                    case 0x81:                                             // NEWOBJ: cls(*args) - inert
                    {
                        Pop(); // args
                        var cls = Pop();
                        Push( cls is GlobalRef g
                            ? Invoke( g, Array.Empty<object>() )
                            : new OpaqueRef { What = "NEWOBJ instance" } );
                        break;
                    }

                    case (byte)'b':                                        // BUILD: apply state - inert merge
                    {
                        var state = Pop();
                        if ( Peek() is Dictionary<object, object> targetDict
                            && state is Dictionary<object, object> stateDict )
                        {
                            foreach ( var (k, v) in stateDict )
                                targetDict[k] = v;
                        }
                        // Opaque targets: state discarded, nothing executed.
                        break;
                    }

                    case (byte)'.':                                        // STOP
                        return Pop();

                    default:
                        throw new FormatException(
                            $"torch checkpoint: unsupported pickle opcode 0x{op:X2} at {_pos - 1} "
                            + "(the restricted reader accepts tensor state dicts only)." );
                }
            }
        }

        object Peek() => _stack.Count > 0 ? _stack[^1] : throw Bad( "empty stack" );

        object Memo( int slot )
            => _memo.TryGetValue( slot, out var value ) ? value : throw Bad( $"memo slot {slot}" );

        static FormatException Bad( string what )
            => new( $"torch checkpoint: malformed pickle ({what})." );

        static object ResolveGlobal( string module, string name ) => (module, name) switch
        {
            ("torch._utils", "_rebuild_tensor_v2") => new GlobalRef { Module = module, Name = name },
            ("torch", "FloatStorage") => new GlobalRef { Module = module, Name = name },
            ("collections", "OrderedDict") => new GlobalRef { Module = module, Name = name },
            // Anything else becomes INERT DATA: nothing is looked up or executed, and
            // values built from it are dropped at materialization. Training
            // checkpoints carry optimizer state and framework metadata we must
            // tolerate without running.
            _ => new OpaqueRef { What = $"{module}.{name}" },
        };

        static int StorageWidth( string typeName ) => typeName switch
        {
            "DoubleStorage" or "LongStorage" => 8,
            "FloatStorage" or "IntStorage" => 4,
            "HalfStorage" or "ShortStorage" or "BFloat16Storage" => 2,
            _ => 1, // Byte/Char/Bool
        };

        object ResolvePersistentId( object pid )
        {
            if ( pid is not object[] tuple || tuple.Length < 5
                || tuple[0] is not string kind || kind != "storage"
                || tuple[2] is not string key )
                throw Bad( "persistent id" );

            var typeName = tuple[1] switch
            {
                GlobalRef g => g.Name,
                OpaqueRef o => o.What.Contains( '.' ) ? o.What[(o.What.LastIndexOf( '.' ) + 1)..] : o.What,
                _ => throw Bad( "storage type" ),
            };
            var numel = tuple[4] switch { int i => (long)i, long l => l, _ => throw Bad( "storage numel" ) };
            Storages[key] = (typeName, StorageWidth( typeName ));

            if ( typeName is not ("FloatStorage" or "BFloat16Storage" or "HalfStorage") )
                return new OpaqueRef { What = $"non-float storage {typeName}" };
            return new StorageRef { Key = key, Numel = numel };
        }

        object Invoke( object callable, object[] args )
        {
            if ( callable is OpaqueRef opaque )
                return opaque; // inert: nothing executed, value dropped later

            if ( callable is not GlobalRef global )
                throw Bad( "REDUCE callable" );

            if ( global is { Module: "collections", Name: "OrderedDict" } )
                return new Dictionary<object, object>();

            if ( global is { Module: "torch._utils", Name: "_rebuild_tensor_v2" } )
            {
                if ( args.Length < 4
                    || args[1] is not int storageOffset
                    || args[2] is not object[] sizeTuple
                    || args[3] is not object[] strideTuple )
                    throw Bad( "_rebuild_tensor_v2 args" );
                if ( args[0] is OpaqueRef )
                    return new OpaqueRef { What = "tensor on unsupported storage" };
                if ( args[0] is not StorageRef storage )
                    throw Bad( "_rebuild_tensor_v2 storage" );

                var shape = sizeTuple
                    .Select( s => s is int v ? v : throw Bad( "size" ) ).ToArray();
                var stride = strideTuple
                    .Select( s => s is int v ? v : throw Bad( "stride" ) ).ToArray();

                // Contiguous row-major only.
                var expected = 1;
                for ( var d = shape.Length - 1; d >= 0; d-- )
                {
                    if ( stride[d] != expected )
                        throw new FormatException(
                            "torch checkpoint: non-contiguous tensor storage is not supported." );
                    expected *= shape[d];
                }

                return new TensorStub
                {
                    StorageKey = storage.Key,
                    StorageOffset = storageOffset,
                    Shape = shape,
                };
            }

            throw Bad( $"REDUCE of {global.Module}.{global.Name}" );
        }
    }
}