Code/AutoRig/Formats/Fbx/FbxDonorImporter.cs

FBX donor rig importer. Parses FBX mesh and connection data to extract skin clusters, joint binds and control-point weights, builds a RigSkeleton and per-vertex skin weights (top-4 normalized) and returns a DonorRig.

File AccessNetworking
using System.Numerics;
using AutoRig.Rig;

namespace AutoRig.Formats.Fbx;

using Vector3 = System.Numerics.Vector3;

/// <summary>
/// Loads a rigged FBX as a <see cref="DonorRig"/>: skin clusters (Deformer "Skin" →
/// SubDeformer "Cluster" with control-point Indexes/Weights) mapped through the mesh
/// importer's corner-dedup, joint bind positions from each cluster's TransformLink
/// (world bind — the same scene space the mesh vertices are baked into), hierarchy
/// from the Model tree restricted to clustered joints.
/// </summary>
public static class FbxDonorImporter
{
    /// <exception cref="FormatException">Malformed file or no skin clusters.</exception>
    public static DonorRig Import( byte[] data, string sourceName )
    {
        var (mesh, vertexSource, root, scene) = FbxMeshImporter.ImportWithSource( data, sourceName );

        // Raw connection pairs (src child → dst parent).
        var pairs = new List<(long Src, long Dst)>();
        if ( root.Child( "Connections" ) is { } connections )
        {
            foreach ( var c in connections.ChildrenNamed( "C" ) )
            {
                if ( c.Properties.Count >= 3
                    && c.Properties[0] is string kind && kind == "OO"
                    && c.Properties[1] is (long or int)
                    && c.Properties[2] is (long or int) )
                    pairs.Add( (c.Prop<long>( 1 ), c.Prop<long>( 2 )) );
            }
        }

        // skin deformer → geometry; cluster → skin deformer; joint model → cluster.
        var skinToGeometry = new Dictionary<long, long>();
        var clusterToSkin = new Dictionary<long, long>();
        var clusterJoint = new Dictionary<long, FbxObject>();
        foreach ( var (src, dst) in pairs )
        {
            if ( !scene.ObjectsById.TryGetValue( src, out var srcObj )
                || !scene.ObjectsById.TryGetValue( dst, out var dstObj ) )
                continue;
            if ( srcObj.NodeType == "Deformer" && srcObj.SubClass == "Skin"
                && dstObj.NodeType == "Geometry" )
                skinToGeometry.TryAdd( src, dst );
            else if ( srcObj.NodeType == "Deformer" && srcObj.SubClass == "Cluster"
                && dstObj.NodeType == "Deformer" && dstObj.SubClass == "Skin" )
                clusterToSkin.TryAdd( src, dst );
            else if ( srcObj.NodeType == "Model"
                && dstObj.NodeType == "Deformer" && dstObj.SubClass == "Cluster" )
                clusterJoint.TryAdd( dst, srcObj );
        }

        // Cluster payloads: (geometry, joint model, control-point indexes, weights, bind).
        var jointModels = new List<FbxObject>();
        var jointIndexByModel = new Dictionary<long, int>();
        var bind = new List<Vector3?>();
        var controlPointWeights = new Dictionary<(long Geometry, int ControlPoint), List<(int Joint, float Weight)>>();

        foreach ( var (clusterId, skinId) in clusterToSkin )
        {
            if ( !skinToGeometry.TryGetValue( skinId, out var geometryId )
                || !clusterJoint.TryGetValue( clusterId, out var jointModel )
                || !scene.ObjectsById.TryGetValue( clusterId, out var cluster ) )
                continue;

            if ( !jointIndexByModel.TryGetValue( jointModel.Id, out var joint ) )
            {
                joint = jointModels.Count;
                jointIndexByModel[jointModel.Id] = joint;
                jointModels.Add( jointModel );
                bind.Add( null );
            }

            // Bind = TransformLink translation (falls back to the BindPose table).
            if ( bind[joint] is null )
            {
                var link = cluster.Node.Child( "TransformLink" );
                if ( link is not null && link.AsDoubleArray( 0 ) is { Length: >= 16 } m )
                    bind[joint] = new Vector3( (float)m[12], (float)m[13], (float)m[14] );
                else if ( scene.BindPose.TryGetValue( jointModel.Id, out var poseMatrix ) )
                    bind[joint] = poseMatrix.Translation;
            }

            var indexes = cluster.Node.Child( "Indexes" )?.AsIntArray( 0 );
            var weights = cluster.Node.Child( "Weights" )?.AsDoubleArray( 0 );
            if ( indexes is null || weights is null )
                continue;
            var count = Math.Min( indexes.Length, weights.Length );
            for ( var i = 0; i < count; i++ )
            {
                if ( weights[i] <= 0 )
                    continue;
                var key = (geometryId, indexes[i]);
                if ( !controlPointWeights.TryGetValue( key, out var list ) )
                    controlPointWeights[key] = list = new List<(int, float)>();
                list.Add( (joint, (float)weights[i]) );
            }
        }

        if ( jointModels.Count == 0 || controlPointWeights.Count == 0 )
            throw new FormatException( $"'{sourceName}' has no skin clusters - it is not a rigged model." );

        var (skeleton, order) = BuildSkeleton( scene, jointModels, jointIndexByModel, bind );

        // Per output vertex: control-point weight row → skeleton indices, top-4, normalized.
        var boneIndices = new int[mesh.Positions.Length * 4];
        var boneWeights = new float[mesh.Positions.Length * 4];
        for ( var v = 0; v < mesh.Positions.Length; v++ )
        {
            controlPointWeights.TryGetValue( vertexSource[v], out var row );
            if ( row is null || row.Count == 0 )
            {
                boneIndices[v * 4] = order[NearestJoint( skeleton, mesh.Positions[v] )];
                boneWeights[v * 4] = 1f;
                continue;
            }

            var top = row.OrderByDescending( e => e.Weight ).Take( 4 ).ToList();
            var total = top.Sum( e => e.Weight );
            for ( var k = 0; k < top.Count && total > 0f; k++ )
            {
                boneIndices[v * 4 + k] = order[top[k].Joint];
                boneWeights[v * 4 + k] = top[k].Weight / total;
            }
            if ( total <= 0f )
            {
                boneIndices[v * 4] = order[NearestJoint( skeleton, mesh.Positions[v] )];
                boneWeights[v * 4] = 1f;
            }
        }

        var donor = new DonorRig
        {
            Skeleton = skeleton,
            Weights = new SkinWeights { BoneIndices = boneIndices, Weights = boneWeights },
            Mesh = mesh,
        };
        donor.Skeleton.Validate();
        donor.Weights.Validate( donor.Mesh, donor.Skeleton );
        return donor;
    }

    /// <summary>
    /// Skeleton over the clustered joints: parent = nearest clustered ancestor in the
    /// Model tree, extra roots reparent under the first, parent-before-child order.
    /// Returns the cluster-joint → skeleton-index map.
    /// </summary>
    static (RigSkeleton Skeleton, int[] Order) BuildSkeleton(
        FbxScene scene, List<FbxObject> jointModels,
        Dictionary<long, int> jointIndexByModel, List<Vector3?> bind )
    {
        var count = jointModels.Count;
        var parentLocal = new int[count];
        for ( var i = 0; i < count; i++ )
        {
            parentLocal[i] = -1;
            var walk = jointModels[i].ModelParent;
            while ( walk is not null )
            {
                if ( jointIndexByModel.TryGetValue( walk.Id, out var local ) )
                {
                    parentLocal[i] = local;
                    break;
                }
                walk = walk.ModelParent;
            }
        }
        var firstRoot = Array.IndexOf( parentLocal, -1 );
        for ( var i = firstRoot + 1; i < count; i++ )
            if ( parentLocal[i] == -1 )
                parentLocal[i] = firstRoot;

        var order = new int[count];
        var skeleton = new RigSkeleton();
        var names = new HashSet<string>( StringComparer.Ordinal );
        var visited = new bool[count];
        var queue = new Queue<int>();
        queue.Enqueue( firstRoot );
        visited[firstRoot] = true;
        var emitted = 0;
        while ( queue.Count > 0 )
        {
            var local = queue.Dequeue();
            order[local] = emitted++;

            var baseName = jointModels[local].Name;
            if ( string.IsNullOrWhiteSpace( baseName ) )
                baseName = $"joint_{local}";
            var name = baseName;
            for ( var suffix = 2; !names.Add( name ); suffix++ )
                name = $"{baseName}_{suffix}";

            skeleton.Joints.Add( new RigJoint
            {
                Name = name,
                Parent = local == firstRoot ? -1 : order[parentLocal[local]],
                Position = bind[local] ?? Vector3.Zero,
            } );
            for ( var c = 0; c < count; c++ )
                if ( !visited[c] && parentLocal[c] == local )
                {
                    visited[c] = true;
                    queue.Enqueue( c );
                }
        }
        return (skeleton, order);
    }

    static int NearestJoint( RigSkeleton skeleton, Vector3 point )
    {
        var best = 0;
        for ( var j = 1; j < skeleton.Joints.Count; j++ )
            if ( Vector3.DistanceSquared( skeleton.Joints[j].Position, point )
                < Vector3.DistanceSquared( skeleton.Joints[best].Position, point ) )
                best = j;
        return best;
    }
}