Editor/AutoRig/DlModelRegistry.cs

Editor-side registry for deep-learning models. Tracks enabled/disabled state per project, persists it to enabled.json, exposes available choices including a user checkpoint, picks best model via HardwareAdvisor, loads and caches model bundles and runnable models from disk, and adapts loading per model architecture.

File AccessNetworking
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using AutoRig.Dl;
using AutoRig.Solve;
using Editor;
using Sandbox;

namespace AutoRig.Editor;

/// <summary>
/// The editor's view of usable deep-learning models: catalog entries (installed +
/// enabled) plus an optional user-loaded checkpoint. Enabled flags persist per
/// project (autorig_dl/enabled.json). Picks the best enabled model for a mesh via
/// the hardware advisor.
/// </summary>
public static class DlModelRegistry
{
    /// <summary>One selectable model for the dropdown/manage list.</summary>
    public sealed class ModelChoice
    {
        public required string Id;          // catalog id or "user:<label>"
        public required string Title;
        public CatalogEntry Entry;          // null for the user checkpoint
        public RigNetBundle LoadedBundle;   // pre-loaded for the user checkpoint
    }

    static DlModelRegistry()
    {
        // Code/ is whitelist-safe and therefore SEQUENTIAL by default -
        // System.Threading.Tasks.Parallel is not on the s&box whitelist. The
        // editor is unwhitelisted, so give the inference runtime real cores.
        AutoRig.Dl.Concurrency.ForRunner =
            ( from, to, body ) => System.Threading.Tasks.Parallel.For( from, to, body );
    }

    static readonly Dictionary<string, bool> _enabled = new();
    static bool _loaded;
    static RigNetBundle _userBundle;
    static string _userLabel;

    static readonly Dictionary<string, RigNetBundle> _bundleCache = new();

    /// <summary>Raised whenever enablement or the model set changes (UI refresh).</summary>
    public static Action Changed { get; set; }

    static string StatePath => Path.Combine(
        Project.Current?.GetAssetsPath() ?? ".", "autorig_dl", "enabled.json" );

    static void Load()
    {
        if ( _loaded )
            return;
        _loaded = true;
        try
        {
            if ( File.Exists( StatePath ) )
            {
                var doc = System.Text.Json.JsonSerializer
                    .Deserialize<Dictionary<string, bool>>( File.ReadAllText( StatePath ) );
                foreach ( var (key, value) in doc )
                    _enabled[key] = value;
            }
        }
        catch ( Exception )
        {
            // unreadable state - defaults apply
        }
    }

    static void Save()
    {
        try
        {
            Directory.CreateDirectory( Path.GetDirectoryName( StatePath ) );
            File.WriteAllText( StatePath,
                System.Text.Json.JsonSerializer.Serialize( _enabled ) );
        }
        catch ( Exception )
        {
        }
    }

    /// <summary>Installed models default to enabled until the user says otherwise.</summary>
    public static bool IsEnabled( string id )
    {
        Load();
        return !_enabled.TryGetValue( id, out var enabled ) || enabled;
    }

    public static void SetEnabled( string id, bool enabled )
    {
        Load();
        _enabled[id] = enabled;
        Save();
        Changed?.Invoke();
    }

    /// <summary>Registers/replaces the user's own checkpoint (already loaded).</summary>
    public static void SetUserCheckpoint( RigNetBundle bundle, string label )
    {
        _userBundle = bundle;
        _userLabel = label;
        Changed?.Invoke();
    }

    /// <summary>
    /// This machine can run the model at all (baseline RAM check). Models failing
    /// this are FORCE-disabled: locked off in the manage list, never auto-picked.
    /// </summary>
    public static bool MeetsHardware( CatalogEntry entry )
        => HardwareAdvisor.Assess(
                entry, 1, GC.GetGCMemoryInfo().TotalAvailableMemoryBytes ).Verdict
            != Verdict.Insufficient;

    /// <summary>Every ENABLED, usable model (installed + hardware-capable catalog
    /// entries + the user checkpoint).</summary>
    public static List<ModelChoice> EnabledChoices()
    {
        Load();
        var choices = new List<ModelChoice>();
        if ( _userBundle is not null && IsEnabled( UserId ) )
            choices.Add( new ModelChoice
            {
                Id = UserId,
                Title = $"Your checkpoint ({_userLabel})",
                LoadedBundle = _userBundle,
            } );
        foreach ( var entry in ModelCatalog.Entries )
            if ( entry.RuntimeSupported && CatalogDialog.IsInstalled( entry )
                && IsEnabled( entry.Id ) && MeetsHardware( entry ) )
                choices.Add( new ModelChoice { Id = entry.Id, Title = entry.Title, Entry = entry } );
        return choices;
    }

    public const string UserId = "user:checkpoint";

    /// <summary>True when the user has loaded their own checkpoint this session.</summary>
    public static bool HasUserCheckpoint => _userBundle is not null;
    public static string UserCheckpointLabel => _userLabel;

    /// <summary>
    /// The best enabled model for a mesh: hardware-advisor verdicts rank
    /// Recommended before Slow (Insufficient is skipped); the user checkpoint
    /// counts as Recommended. Null when nothing usable is enabled.
    /// </summary>
    public static ModelChoice SelectBest( int vertexCount )
    {
        var available = GC.GetGCMemoryInfo().TotalAvailableMemoryBytes;
        ModelChoice best = null;
        var bestRank = int.MaxValue;
        foreach ( var choice in EnabledChoices() )
        {
            var rank = choice.Entry is null
                ? 0
                : HardwareAdvisor.Assess( choice.Entry, Math.Max( vertexCount, 1 ), available )
                    .Verdict switch
                {
                    Verdict.Recommended => 0,
                    Verdict.Slow => 1,
                    _ => int.MaxValue,   // Insufficient - never auto-pick
                };
            if ( rank < bestRank )
            {
                bestRank = rank;
                best = choice;
            }
        }
        return bestRank == int.MaxValue ? null : best;
    }

    /// <summary>Loads (and caches) the bundle behind a choice.</summary>
    public static RigNetBundle LoadBundle( ModelChoice choice )
    {
        if ( choice.LoadedBundle is not null )
            return choice.LoadedBundle;
        var path = Path.Combine(
            CatalogDialog.InstallDir( choice.Entry ), choice.Entry.Files[0].FileName );
        if ( _bundleCache.TryGetValue( path, out var cached ) )
            return cached;
        var bundle = RigNetBundle.FromCheckpointsZip( File.ReadAllBytes( path ) );
        _bundleCache[path] = bundle;
        return bundle;
    }

    static readonly Dictionary<string, object> _modelCache = new();

    /// <summary>
    /// Loads (and caches) the runnable model behind a choice, dispatched by
    /// architecture: RigNetBundle or UniRigSkeletonModel. The user checkpoint is
    /// always RigNet-format today.
    /// </summary>
    public static object LoadRunnableModel( ModelChoice choice )
    {
        if ( choice.LoadedBundle is not null )
            return choice.LoadedBundle;
        var path = Path.Combine(
            CatalogDialog.InstallDir( choice.Entry ), choice.Entry.Files[0].FileName );
        if ( _modelCache.TryGetValue( path, out var cached ) )
            return cached;

        object model = choice.Entry.Architecture switch
        {
            "rignet" => RigNetBundle.FromCheckpointsZip( File.ReadAllBytes( path ) ),
            "unirig-transformer" => AutoRig.Dl.UniRig.UniRigLoader.LoadSkeleton(
                AutoRig.Dl.TorchCheckpoint.Parse( File.ReadAllBytes( path ) ) ),
            "skintokens-transformer" => AutoRig.Dl.UniRig.SkinTokensLoader.LoadModel(
                AutoRig.Dl.TorchCheckpoint.Parse( File.ReadAllBytes( path ) ) ),
            "magicarticulate-transformer" => LoadMagicArticulate( path ),
            "puppeteer-transformer" => LoadPuppeteer( path ),
            "riganything-transformer" => LoadRigAnything( path ),
            "anymate-transformer" => LoadAnymate( path ),
            _ => throw new FormatException(
                $"Architecture '{choice.Entry.Architecture}' has no runtime yet." ),
        };
        _modelCache[path] = model;
        return model;
    }

    /// <summary>The 4.4GB training checkpoint streams from disk (ZIP64) and the
    /// optimizer state is filtered out before any bytes are read.</summary>
    static object LoadMagicArticulate( string path )
    {
        using var stream = File.OpenRead( path );
        return AutoRig.Dl.MagicArticulate.MagicArticulateLoader.LoadModel(
            AutoRig.Dl.TorchCheckpoint.Parse(
                stream, k => k.StartsWith( "model.", StringComparison.Ordinal ) ) );
    }

    static object LoadPuppeteer( string path )
    {
        AutoRig.Dl.Puppeteer.PuppeteerModel model;
        using ( var stream = File.OpenRead( path ) )
            model = AutoRig.Dl.Puppeteer.PuppeteerLoader.LoadModel(
                AutoRig.Dl.TorchCheckpoint.Parse(
                    stream, k => k.StartsWith( "model.", StringComparison.Ordinal ) ) );

        // The skinning transformer is optional (second catalog file): neural
        // weights when installed, geodesic fallback otherwise.
        var skinPath = Path.Combine( Path.GetDirectoryName( path )!, "puppeteer_skin.pth" );
        if ( File.Exists( skinPath ) )
        {
            using var skinStream = File.OpenRead( skinPath );
            model.Skin = AutoRig.Dl.Puppeteer.PuppeteerSkinLoader.Load(
                AutoRig.Dl.TorchCheckpoint.Parse(
                    skinStream, k => k.StartsWith( "model.", StringComparison.Ordinal ) ) );
        }
        return model;
    }

    static object LoadRigAnything( string path )
    {
        using var stream = File.OpenRead( path );
        return AutoRig.Dl.RigAnything.RigAnythingLoader.LoadModel(
            AutoRig.Dl.TorchCheckpoint.Parse(
                stream, k => k.StartsWith( "model.", StringComparison.Ordinal ) ) );
    }

    static object LoadAnymate( string path )
    {
        var directory = Path.GetDirectoryName( path )!;
        IReadOnlyDictionary<string, AutoRig.Dl.Tensor> Load( string file )
        {
            using var stream = File.OpenRead( Path.Combine( directory, file ) );
            return AutoRig.Dl.TorchCheckpoint.Parse( stream )
                .ToDictionary(
                    kv => kv.Key.StartsWith( "state_dict.", StringComparison.Ordinal )
                        ? kv.Key["state_dict.".Length..]
                        : kv.Key,
                    kv => kv.Value );
        }
        return AutoRig.Dl.Anymate.AnymateLoader.Load(
            Load( "anymate_joint.pth.tar" ),
            Load( "anymate_conn.pth.tar" ),
            Load( "anymate_skin.pth.tar" ) );
    }
}