AutoRig/Formats/Gltf/GltfDonorImporter.cs

Importer for glTF/GLB files that builds a DonorRig. It parses a GltfDocument, extracts the first skin, computes world transforms, builds a RigSkeleton and collects skinned mesh data (positions, normals, triangles, bone indices/weights) into a RigMesh and SkinWeights, validating the results.

File Access
using System.Numerics;
using AutoRig.Mesh;
using AutoRig.Rig;

namespace AutoRig.Formats.Gltf;

// s&box compat: the engine defines Vector2/Vector3 in the GLOBAL namespace, which
// shadows using-directive imports - alias explicitly to System.Numerics.
using Vector2 = System.Numerics.Vector2;
using Vector3 = System.Numerics.Vector3;

/// <summary>
/// Loads a rigged glTF/GLB as a <see cref="DonorRig"/>. Positions come from skin
/// space (vertices as stored, joints from inverted inverseBindMatrices — the pair
/// the spec guarantees consistent), then BOTH are baked by the skinned node's
/// world transform so the donor lines up with what the static importer produces
/// for the same file (e.g. CesiumMan's Z-up→Y-up root rotation).
/// Non-skinned geometry is ignored; only the first skin is used.
/// </summary>
public static class GltfDonorImporter
{
    /// <exception cref="FormatException">Malformed file or no skinned geometry.</exception>
    public static DonorRig Import( byte[] data, string sourceName )
    {
        var document = GltfDocument.Parse( data );
        if ( document.Skins.Count == 0 )
            throw new FormatException( $"'{sourceName}' has no skin - it is not a rigged model." );
        var skin = document.Skins[0];

        // World transform of the first node that uses this skin (shared by the
        // whole donor - multi-skin/multi-space files use the first's).
        var world = Matrix4x4.Identity;
        for ( var nodeIndex = 0; nodeIndex < document.Nodes.Count; nodeIndex++ )
            if ( document.Nodes[nodeIndex].Skin == 0 && document.Nodes[nodeIndex].Mesh >= 0 )
            {
                world = NodeWorld( document, nodeIndex );
                break;
            }
        Matrix4x4.Invert( world, out var worldInverse );
        var normalMatrix = Matrix4x4.Transpose( worldInverse );

        var (skeleton, jointRemap) = BuildSkeleton( document, skin, world );

        // Gather skinned primitives (nodes using skin 0).
        var positions = new List<Vector3>();
        var normals = new List<Vector3>();
        var triangles = new List<int>();
        var triangleTags = new List<int>();
        var tags = new List<string>();
        var boneIndices = new List<int>();
        var boneWeights = new List<float>();

        for ( var nodeIndex = 0; nodeIndex < document.Nodes.Count; nodeIndex++ )
        {
            var node = document.Nodes[nodeIndex];
            if ( node.Skin != 0 || node.Mesh < 0 || node.Mesh >= document.Meshes.Count )
                continue;

            var mesh = document.Meshes[node.Mesh];
            var tagName = node.Name ?? mesh.Name ?? $"node{nodeIndex}";
            var tag = tags.IndexOf( tagName );
            if ( tag < 0 ) { tag = tags.Count; tags.Add( tagName ); }

            foreach ( var primitive in mesh.Primitives )
            {
                var baseVertex = positions.Count;
                var vertexCount = primitive.Positions.Length / 3;
                for ( var v = 0; v < vertexCount; v++ )
                {
                    positions.Add( Vector3.Transform( new Vector3(
                        primitive.Positions[v * 3],
                        primitive.Positions[v * 3 + 1],
                        primitive.Positions[v * 3 + 2] ), world ) );
                    if ( primitive.Normals is { } n )
                    {
                        var transformed = Vector3.TransformNormal(
                            new Vector3( n[v * 3], n[v * 3 + 1], n[v * 3 + 2] ), normalMatrix );
                        var length = transformed.Length();
                        normals.Add( length > 1e-8f ? transformed / length : Vector3.UnitZ );
                    }
                    else
                    {
                        normals.Add( Vector3.UnitZ );
                    }

                    if ( primitive.Joints is { } joints && primitive.Weights is { } weights )
                    {
                        float sum = 0;
                        for ( var k = 0; k < 4; k++ )
                            sum += weights[v * 4 + k];
                        for ( var k = 0; k < 4; k++ )
                        {
                            var skinLocal = joints[v * 4 + k];
                            if ( skinLocal < 0 || skinLocal >= jointRemap.Length )
                                throw new FormatException(
                                    $"'{sourceName}': JOINTS_0 index {skinLocal} exceeds the skin's "
                                    + $"{jointRemap.Length} joints." );
                            boneIndices.Add( jointRemap[skinLocal] );
                            boneWeights.Add( sum > 1e-6f ? weights[v * 4 + k] / sum : (k == 0 ? 1f : 0f) );
                        }
                    }
                    else
                    {
                        // Skinned node but this primitive lacks weights: rigid-bind
                        // to the nearest joint so the row still sums to 1.
                        var nearest = NearestJoint( skeleton, positions[^1] );
                        boneIndices.AddRange( new[] { nearest, 0, 0, 0 } );
                        boneWeights.AddRange( new[] { 1f, 0f, 0f, 0f } );
                    }
                }

                for ( var i = 0; i < primitive.Indices.Length; i += 3 )
                {
                    triangles.Add( baseVertex + primitive.Indices[i] );
                    triangles.Add( baseVertex + primitive.Indices[i + 1] );
                    triangles.Add( baseVertex + primitive.Indices[i + 2] );
                    triangleTags.Add( tag );
                }
            }
        }

        if ( triangles.Count == 0 )
            throw new FormatException(
                $"'{sourceName}' has no skin-bound triangle geometry - it is not a rigged model." );

        var rigMesh = new RigMesh
        {
            SourceName = sourceName,
            Positions = positions.ToArray(),
            Normals = normals.ToArray(),
            Uvs = new Vector2[positions.Count],
            Triangles = triangles.ToArray(),
            TriangleTags = triangleTags.ToArray(),
            Tags = tags.ToArray(),
        };
        rigMesh.Validate();

        var donor = new DonorRig
        {
            Skeleton = skeleton,
            Weights = new SkinWeights
            {
                BoneIndices = boneIndices.ToArray(),
                Weights = boneWeights.ToArray(),
            },
            Mesh = rigMesh,
        };
        donor.Skeleton.Validate();
        donor.Weights.Validate( donor.Mesh, donor.Skeleton );
        return donor;
    }

    /// <summary>
    /// Skeleton from the skin's joint nodes, parent-before-child; bind position =
    /// translation of the inverted inverse-bind matrix (falls back to the node's
    /// world transform when the skin has none). Names deduped via NameUtil-style
    /// numeric suffixes.
    /// </summary>
    static (RigSkeleton Skeleton, int[] Remap) BuildSkeleton(
        GltfDocument document, GltfSkin skin, Matrix4x4 world )
    {
        var jointSet = new Dictionary<int, int>();   // node index → skin-local index
        for ( var i = 0; i < skin.Joints.Length; i++ )
            jointSet[skin.Joints[i]] = i;

        // Bind position per skin-local joint (skin space → baked by `world`).
        var bind = new Vector3[skin.Joints.Length];
        for ( var i = 0; i < skin.Joints.Length; i++ )
        {
            if ( skin.InverseBindMatrices is { } ibm && Matrix4x4.Invert( ibm[i], out var jointWorld ) )
                bind[i] = Vector3.Transform( jointWorld.Translation, world );
            else
                // No IBM ⇒ bind = the joint's own scene-global transform (already
                // scene space - do NOT bake `world` again).
                bind[i] = NodeWorld( document, skin.Joints[i] ).Translation;
        }

        // Parent = nearest ancestor node that is also a joint of this skin.
        var parentLocal = new int[skin.Joints.Length];
        for ( var i = 0; i < skin.Joints.Length; i++ )
        {
            parentLocal[i] = -1;
            var walk = document.Nodes[skin.Joints[i]].Parent;
            while ( walk >= 0 )
            {
                if ( jointSet.TryGetValue( walk, out var local ) )
                {
                    parentLocal[i] = local;
                    break;
                }
                walk = document.Nodes[walk].Parent;
            }
        }

        // Multiple roots are legal in glTF; ours requires one - reparent extras
        // under the first root.
        var firstRoot = Array.IndexOf( parentLocal, -1 );
        for ( var i = firstRoot + 1; i < parentLocal.Length; i++ )
            if ( parentLocal[i] == -1 )
                parentLocal[i] = firstRoot;

        // Emit parent-before-child.
        var order = new int[skin.Joints.Length];     // skin-local → skeleton index
        var skeleton = new RigSkeleton();
        var names = new HashSet<string>( StringComparer.Ordinal );
        var queue = new Queue<int>();
        queue.Enqueue( firstRoot );
        var emitted = 0;
        var visited = new bool[skin.Joints.Length];
        visited[firstRoot] = true;
        while ( queue.Count > 0 )
        {
            var local = queue.Dequeue();
            order[local] = emitted++;

            var baseName = document.Nodes[skin.Joints[local]].Name;
            if ( string.IsNullOrWhiteSpace( baseName ) )
                baseName = $"joint_{local}";
            var name = baseName;
            for ( var suffix = 2; !names.Add( name ); suffix++ )
                name = $"{baseName}_{suffix}";

            skeleton.Joints.Add( new RigJoint
            {
                Name = name,
                Parent = local == firstRoot ? -1 : order[parentLocal[local]],
                Position = bind[local],
            } );
            for ( var c = 0; c < parentLocal.Length; c++ )
                if ( !visited[c] && parentLocal[c] == local )
                {
                    visited[c] = true;
                    queue.Enqueue( c );
                }
        }

        return (skeleton, order);
    }

    static Matrix4x4 NodeWorld( GltfDocument document, int nodeIndex )
    {
        var world = Matrix4x4.Identity;
        var walk = nodeIndex;
        while ( walk >= 0 )
        {
            world *= document.Nodes[walk].LocalMatrix;
            walk = document.Nodes[walk].Parent;
        }
        return world;
    }

    static int NearestJoint( RigSkeleton skeleton, Vector3 point )
    {
        var best = 0;
        for ( var j = 1; j < skeleton.Joints.Count; j++ )
            if ( Vector3.DistanceSquared( skeleton.Joints[j].Position, point )
                < Vector3.DistanceSquared( skeleton.Joints[best].Position, point ) )
                best = j;
        return best;
    }
}