Editor-side registry for deep-learning models. Tracks enabled/disabled state per project, persists it to enabled.json, exposes available choices including a user checkpoint, picks best model via HardwareAdvisor, loads and caches model bundles and runnable models from disk, and adapts loading per model architecture.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using AutoRig.Dl;
using AutoRig.Solve;
using Editor;
using Sandbox;
namespace AutoRig.Editor;
/// <summary>
/// The editor's view of usable deep-learning models: catalog entries (installed +
/// enabled) plus an optional user-loaded checkpoint. Enabled flags persist per
/// project (autorig_dl/enabled.json). Picks the best enabled model for a mesh via
/// the hardware advisor.
/// </summary>
public static class DlModelRegistry
{
/// <summary>One selectable model for the dropdown/manage list.</summary>
public sealed class ModelChoice
{
public required string Id; // catalog id or "user:<label>"
public required string Title;
public CatalogEntry Entry; // null for the user checkpoint
public RigNetBundle LoadedBundle; // pre-loaded for the user checkpoint
}
static DlModelRegistry()
{
// Code/ is whitelist-safe and therefore SEQUENTIAL by default -
// System.Threading.Tasks.Parallel is not on the s&box whitelist. The
// editor is unwhitelisted, so give the inference runtime real cores.
AutoRig.Dl.Concurrency.ForRunner =
( from, to, body ) => System.Threading.Tasks.Parallel.For( from, to, body );
}
static readonly Dictionary<string, bool> _enabled = new();
static bool _loaded;
static RigNetBundle _userBundle;
static string _userLabel;
static readonly Dictionary<string, RigNetBundle> _bundleCache = new();
/// <summary>Raised whenever enablement or the model set changes (UI refresh).</summary>
public static Action Changed { get; set; }
static string StatePath => Path.Combine(
Project.Current?.GetAssetsPath() ?? ".", "autorig_dl", "enabled.json" );
static void Load()
{
if ( _loaded )
return;
_loaded = true;
try
{
if ( File.Exists( StatePath ) )
{
var doc = System.Text.Json.JsonSerializer
.Deserialize<Dictionary<string, bool>>( File.ReadAllText( StatePath ) );
foreach ( var (key, value) in doc )
_enabled[key] = value;
}
}
catch ( Exception )
{
// unreadable state - defaults apply
}
}
static void Save()
{
try
{
Directory.CreateDirectory( Path.GetDirectoryName( StatePath ) );
File.WriteAllText( StatePath,
System.Text.Json.JsonSerializer.Serialize( _enabled ) );
}
catch ( Exception )
{
}
}
/// <summary>Installed models default to enabled until the user says otherwise.</summary>
public static bool IsEnabled( string id )
{
Load();
return !_enabled.TryGetValue( id, out var enabled ) || enabled;
}
public static void SetEnabled( string id, bool enabled )
{
Load();
_enabled[id] = enabled;
Save();
Changed?.Invoke();
}
/// <summary>Registers/replaces the user's own checkpoint (already loaded).</summary>
public static void SetUserCheckpoint( RigNetBundle bundle, string label )
{
_userBundle = bundle;
_userLabel = label;
Changed?.Invoke();
}
/// <summary>
/// This machine can run the model at all (baseline RAM check). Models failing
/// this are FORCE-disabled: locked off in the manage list, never auto-picked.
/// </summary>
public static bool MeetsHardware( CatalogEntry entry )
=> HardwareAdvisor.Assess(
entry, 1, GC.GetGCMemoryInfo().TotalAvailableMemoryBytes ).Verdict
!= Verdict.Insufficient;
/// <summary>Every ENABLED, usable model (installed + hardware-capable catalog
/// entries + the user checkpoint).</summary>
public static List<ModelChoice> EnabledChoices()
{
Load();
var choices = new List<ModelChoice>();
if ( _userBundle is not null && IsEnabled( UserId ) )
choices.Add( new ModelChoice
{
Id = UserId,
Title = $"Your checkpoint ({_userLabel})",
LoadedBundle = _userBundle,
} );
foreach ( var entry in ModelCatalog.Entries )
if ( entry.RuntimeSupported && CatalogDialog.IsInstalled( entry )
&& IsEnabled( entry.Id ) && MeetsHardware( entry ) )
choices.Add( new ModelChoice { Id = entry.Id, Title = entry.Title, Entry = entry } );
return choices;
}
public const string UserId = "user:checkpoint";
/// <summary>True when the user has loaded their own checkpoint this session.</summary>
public static bool HasUserCheckpoint => _userBundle is not null;
public static string UserCheckpointLabel => _userLabel;
/// <summary>
/// The best enabled model for a mesh: hardware-advisor verdicts rank
/// Recommended before Slow (Insufficient is skipped); the user checkpoint
/// counts as Recommended. Null when nothing usable is enabled.
/// </summary>
public static ModelChoice SelectBest( int vertexCount )
{
var available = GC.GetGCMemoryInfo().TotalAvailableMemoryBytes;
ModelChoice best = null;
var bestRank = int.MaxValue;
foreach ( var choice in EnabledChoices() )
{
var rank = choice.Entry is null
? 0
: HardwareAdvisor.Assess( choice.Entry, Math.Max( vertexCount, 1 ), available )
.Verdict switch
{
Verdict.Recommended => 0,
Verdict.Slow => 1,
_ => int.MaxValue, // Insufficient - never auto-pick
};
if ( rank < bestRank )
{
bestRank = rank;
best = choice;
}
}
return bestRank == int.MaxValue ? null : best;
}
/// <summary>Loads (and caches) the bundle behind a choice.</summary>
public static RigNetBundle LoadBundle( ModelChoice choice )
{
if ( choice.LoadedBundle is not null )
return choice.LoadedBundle;
var path = Path.Combine(
CatalogDialog.InstallDir( choice.Entry ), choice.Entry.Files[0].FileName );
if ( _bundleCache.TryGetValue( path, out var cached ) )
return cached;
var bundle = RigNetBundle.FromCheckpointsZip( File.ReadAllBytes( path ) );
_bundleCache[path] = bundle;
return bundle;
}
static readonly Dictionary<string, object> _modelCache = new();
/// <summary>
/// Loads (and caches) the runnable model behind a choice, dispatched by
/// architecture: RigNetBundle or UniRigSkeletonModel. The user checkpoint is
/// always RigNet-format today.
/// </summary>
public static object LoadRunnableModel( ModelChoice choice )
{
if ( choice.LoadedBundle is not null )
return choice.LoadedBundle;
var path = Path.Combine(
CatalogDialog.InstallDir( choice.Entry ), choice.Entry.Files[0].FileName );
if ( _modelCache.TryGetValue( path, out var cached ) )
return cached;
object model = choice.Entry.Architecture switch
{
"rignet" => RigNetBundle.FromCheckpointsZip( File.ReadAllBytes( path ) ),
"unirig-transformer" => AutoRig.Dl.UniRig.UniRigLoader.LoadSkeleton(
AutoRig.Dl.TorchCheckpoint.Parse( File.ReadAllBytes( path ) ) ),
"skintokens-transformer" => AutoRig.Dl.UniRig.SkinTokensLoader.LoadModel(
AutoRig.Dl.TorchCheckpoint.Parse( File.ReadAllBytes( path ) ) ),
"magicarticulate-transformer" => LoadMagicArticulate( path ),
"puppeteer-transformer" => LoadPuppeteer( path ),
"riganything-transformer" => LoadRigAnything( path ),
"anymate-transformer" => LoadAnymate( path ),
_ => throw new FormatException(
$"Architecture '{choice.Entry.Architecture}' has no runtime yet." ),
};
_modelCache[path] = model;
return model;
}
/// <summary>The 4.4GB training checkpoint streams from disk (ZIP64) and the
/// optimizer state is filtered out before any bytes are read.</summary>
static object LoadMagicArticulate( string path )
{
using var stream = File.OpenRead( path );
return AutoRig.Dl.MagicArticulate.MagicArticulateLoader.LoadModel(
AutoRig.Dl.TorchCheckpoint.Parse(
stream, k => k.StartsWith( "model.", StringComparison.Ordinal ) ) );
}
static object LoadPuppeteer( string path )
{
AutoRig.Dl.Puppeteer.PuppeteerModel model;
using ( var stream = File.OpenRead( path ) )
model = AutoRig.Dl.Puppeteer.PuppeteerLoader.LoadModel(
AutoRig.Dl.TorchCheckpoint.Parse(
stream, k => k.StartsWith( "model.", StringComparison.Ordinal ) ) );
// The skinning transformer is optional (second catalog file): neural
// weights when installed, geodesic fallback otherwise.
var skinPath = Path.Combine( Path.GetDirectoryName( path )!, "puppeteer_skin.pth" );
if ( File.Exists( skinPath ) )
{
using var skinStream = File.OpenRead( skinPath );
model.Skin = AutoRig.Dl.Puppeteer.PuppeteerSkinLoader.Load(
AutoRig.Dl.TorchCheckpoint.Parse(
skinStream, k => k.StartsWith( "model.", StringComparison.Ordinal ) ) );
}
return model;
}
static object LoadRigAnything( string path )
{
using var stream = File.OpenRead( path );
return AutoRig.Dl.RigAnything.RigAnythingLoader.LoadModel(
AutoRig.Dl.TorchCheckpoint.Parse(
stream, k => k.StartsWith( "model.", StringComparison.Ordinal ) ) );
}
static object LoadAnymate( string path )
{
var directory = Path.GetDirectoryName( path )!;
IReadOnlyDictionary<string, AutoRig.Dl.Tensor> Load( string file )
{
using var stream = File.OpenRead( Path.Combine( directory, file ) );
return AutoRig.Dl.TorchCheckpoint.Parse( stream )
.ToDictionary(
kv => kv.Key.StartsWith( "state_dict.", StringComparison.Ordinal )
? kv.Key["state_dict.".Length..]
: kv.Key,
kv => kv.Value );
}
return AutoRig.Dl.Anymate.AnymateLoader.Load(
Load( "anymate_joint.pth.tar" ),
Load( "anymate_conn.pth.tar" ),
Load( "anymate_skin.pth.tar" ) );
}
}