Code/AutoRig/Solve/Organic/OrganicSkinner.cs

Organic skinning implementation for voxel-based meshes. Builds and caches per-voxel nearest-bone influence lists using geodesic BFS from bone segments, computes per-vertex skin weights from the cache, supports incremental re-skin when joints move, and applies one-pass smoothing over welded vertex neighborhoods.

File AccessNetworking
using AutoRig.Mesh;
using AutoRig.Rig;
using AutoRig.Voxel;

namespace AutoRig.Solve.Organic;

// s&box compat: the engine defines Vector2/Vector3 in the GLOBAL namespace, which
// shadows using-directive imports - alias explicitly to System.Numerics.
using Vector3 = System.Numerics.Vector3;

/// <summary>
/// Reusable per-voxel bone-influence lists (the expensive part of skinning), kept so
/// a moved joint only recomputes its own bones' geodesic fields (live re-skin).
/// Valid for one mesh/grid and a fixed joint COUNT; structure changes need a rebuild.
/// </summary>
public sealed class OrganicSkinCache
{
    internal RigMesh Mesh = null!;
    internal VoxelGrid Grid = null!;
    internal short[] InfluenceBones = [];
    internal int[] InfluenceDistances = [];
    internal int BoneCount;

    internal OrganicSkinCache()
    {
    }

    /// <summary>True when this cache can serve the given skeleton.</summary>
    public bool Matches( RigMesh mesh, RigSkeleton skeleton )
        => ReferenceEquals( Mesh, mesh ) && skeleton.Joints.Count == BoneCount;
}

/// <summary>
/// Smooth skin weights via geodesic voxel distances: each bone seeds the solid voxels
/// along its segment and BFS-floods the interior; a vertex weights the 4 geodesically
/// nearest bones with inverse-square falloff. Distances travel through the volume, so
/// weights never bleed across air gaps (opposite limbs stay independent).
/// </summary>
public static class OrganicSkinner
{
    const int InfluencesTracked = 8;

    public static SkinWeights Skin( RigMesh mesh, VoxelGrid grid, RigSkeleton skeleton )
        => SkinFromCache( BuildCache( mesh, grid, skeleton ), skeleton );

    /// <summary>Full influence-field build (every bone's geodesic BFS).</summary>
    public static OrganicSkinCache BuildCache( RigMesh mesh, VoxelGrid grid, RigSkeleton skeleton )
    {
        ArgumentNullException.ThrowIfNull( mesh );
        ArgumentNullException.ThrowIfNull( grid );
        ArgumentNullException.ThrowIfNull( skeleton );
        skeleton.Validate();

        var totalVoxels = grid.SizeX * grid.SizeY * grid.SizeZ;
        var cache = new OrganicSkinCache
        {
            Mesh = mesh,
            Grid = grid,
            InfluenceBones = new short[totalVoxels * InfluencesTracked],
            InfluenceDistances = new int[totalVoxels * InfluencesTracked],
            BoneCount = skeleton.Joints.Count,
        };
        Array.Fill( cache.InfluenceDistances, int.MaxValue );
        Array.Fill( cache.InfluenceBones, (short)-1 );

        for ( var bone = 0; bone < skeleton.Joints.Count; bone++ )
            InsertBoneField( cache, skeleton, bone );
        return cache;
    }

    /// <summary>
    /// Live re-skin after joints moved: only the changed bones' fields are
    /// recomputed (stale entries evicted first). Slightly coarser than a full
    /// rebuild — evicted bones cannot re-enter a voxel's list beyond the tracked
    /// depth — which is fine for interactive preview; the full re-skin stays exact.
    /// Mutates the cache to the new skeleton, so successive drags keep working.
    /// </summary>
    public static SkinWeights ReskinBones(
        OrganicSkinCache cache, RigSkeleton skeleton, IReadOnlyCollection<int> changedBones )
    {
        ArgumentNullException.ThrowIfNull( cache );
        ArgumentNullException.ThrowIfNull( skeleton );
        ArgumentNullException.ThrowIfNull( changedBones );
        if ( skeleton.Joints.Count != cache.BoneCount )
            throw new FormatException(
                $"Skin cache holds {cache.BoneCount} bones but the skeleton has "
                + $"{skeleton.Joints.Count} - rebuild the cache after structure changes." );

        if ( changedBones.Count > 0 )
        {
            var changed = new bool[cache.BoneCount];
            foreach ( var bone in changedBones )
                changed[bone] = true;

            // Evict stale entries (compact each voxel's sorted list).
            var totalVoxels = cache.Grid.SizeX * cache.Grid.SizeY * cache.Grid.SizeZ;
            for ( var i = 0; i < totalVoxels; i++ )
            {
                var baseIndex = i * InfluencesTracked;
                var write = baseIndex;
                for ( var k = 0; k < InfluencesTracked; k++ )
                {
                    var bone = cache.InfluenceBones[baseIndex + k];
                    if ( bone < 0 )
                        break;
                    if ( changed[bone] )
                        continue;
                    cache.InfluenceBones[write] = bone;
                    cache.InfluenceDistances[write] = cache.InfluenceDistances[baseIndex + k];
                    write++;
                }
                for ( ; write < baseIndex + InfluencesTracked; write++ )
                {
                    cache.InfluenceBones[write] = -1;
                    cache.InfluenceDistances[write] = int.MaxValue;
                }
            }

            foreach ( var bone in changedBones )
                InsertBoneField( cache, skeleton, bone );
        }
        return SkinFromCache( cache, skeleton );
    }

    /// <summary>Bones whose segments move when a joint moves: its own plus each child's.</summary>
    public static List<int> BonesAffectedByJoint( RigSkeleton skeleton, int joint )
    {
        var affected = new List<int> { joint };
        for ( var i = 0; i < skeleton.Joints.Count; i++ )
            if ( skeleton.Joints[i].Parent == joint )
                affected.Add( i );
        return affected;
    }

    /// <summary>BFS one bone's segment field into the cache's per-voxel best-N lists.</summary>
    static void InsertBoneField( OrganicSkinCache cache, RigSkeleton skeleton, int bone )
    {
        var joint = skeleton.Joints[bone];
        var start = joint.Parent >= 0 ? skeleton.Joints[joint.Parent].Position : joint.Position;
        var end = joint.Position;
        var grid = cache.Grid;
        var totalVoxels = grid.SizeX * grid.SizeY * grid.SizeZ;

        var field = GeodesicFromSegment( grid, start, end );
        for ( var i = 0; i < totalVoxels; i++ )
        {
            var d = field[i];
            if ( d == int.MaxValue )
                continue;
            // Insert into this voxel's best-N (sorted ascending by distance).
            var baseIndex = i * InfluencesTracked;
            for ( var k = 0; k < InfluencesTracked; k++ )
            {
                if ( d < cache.InfluenceDistances[baseIndex + k] )
                {
                    for ( var m = InfluencesTracked - 1; m > k; m-- )
                    {
                        cache.InfluenceDistances[baseIndex + m] = cache.InfluenceDistances[baseIndex + m - 1];
                        cache.InfluenceBones[baseIndex + m] = cache.InfluenceBones[baseIndex + m - 1];
                    }
                    cache.InfluenceDistances[baseIndex + k] = d;
                    cache.InfluenceBones[baseIndex + k] = (short)bone;
                    break;
                }
            }
        }
    }

    /// <summary>Per-vertex weights + smoothing from prepared influence lists.</summary>
    public static SkinWeights SkinFromCache( OrganicSkinCache cache, RigSkeleton skeleton )
    {
        var mesh = cache.Mesh;
        var grid = cache.Grid;
        var influenceBones = cache.InfluenceBones;
        var influenceDistances = cache.InfluenceDistances;

        // Per-vertex weights from its (nearest solid) voxel's influences.
        var vertexCount = mesh.Positions.Length;
        var boneIndices = new int[vertexCount * 4];
        var weights = new float[vertexCount * 4];
        var fallbackBone = skeleton.Root;

        Span<float> raw = stackalloc float[4];
        Span<int> bones = stackalloc int[4];
        for ( var v = 0; v < vertexCount; v++ )
        {
            var voxel = NearestSolidVoxel( grid, mesh.Positions[v] );
            if ( voxel is null )
            {
                boneIndices[v * 4] = fallbackBone;
                weights[v * 4] = 1f;
                continue;
            }
            var baseIndex = grid.Index( voxel.Value.X, voxel.Value.Y, voxel.Value.Z ) * InfluencesTracked;

            float total = 0;
            var used = 0;
            for ( var k = 0; k < InfluencesTracked && used < 4; k++ )
            {
                var bone = influenceBones[baseIndex + k];
                if ( bone < 0 )
                    break;
                var d = influenceDistances[baseIndex + k];
                var w = 1f / ((d + 1f) * (d + 1f));
                raw[used] = w;
                bones[used] = bone;
                total += w;
                used++;
            }

            if ( used == 0 || total <= 0f )
            {
                boneIndices[v * 4] = fallbackBone;
                weights[v * 4] = 1f;
                continue;
            }
            for ( var k = 0; k < used; k++ )
            {
                boneIndices[v * 4 + k] = bones[k];
                weights[v * 4 + k] = raw[k] / total;
            }
        }

        var result = new SkinWeights { BoneIndices = boneIndices, Weights = weights };
        Smooth( mesh, result, skeleton.Joints.Count );
        result.Validate( mesh, skeleton );
        return result;
    }

    /// <summary>Geodesic BFS seeded from all solid voxels within one cell of the segment.</summary>
    static int[] GeodesicFromSegment( VoxelGrid grid, Vector3 start, Vector3 end )
    {
        var totalVoxels = grid.SizeX * grid.SizeY * grid.SizeZ;
        var distance = new int[totalVoxels];
        Array.Fill( distance, int.MaxValue );
        var queue = new Queue<(int X, int Y, int Z)>();

        var length = Vector3.Distance( start, end );
        var steps = Math.Max( 1, (int)(length / (grid.CellSize * 0.5f)) );
        for ( var s = 0; s <= steps; s++ )
        {
            var p = Vector3.Lerp( start, end, (float)s / steps );
            var seed = NearestSolidVoxel( grid, p );
            if ( seed is null )
                continue;
            var i = grid.Index( seed.Value.X, seed.Value.Y, seed.Value.Z );
            if ( distance[i] != 0 )
            {
                distance[i] = 0;
                queue.Enqueue( seed.Value );
            }
        }

        while ( queue.Count > 0 )
        {
            var (x, y, z) = queue.Dequeue();
            var next = distance[grid.Index( x, y, z )] + 1;
            Visit( x - 1, y, z );
            Visit( x + 1, y, z );
            Visit( x, y - 1, z );
            Visit( x, y + 1, z );
            Visit( x, y, z - 1 );
            Visit( x, y, z + 1 );

            void Visit( int nx, int ny, int nz )
            {
                if ( nx < 0 || nx >= grid.SizeX || ny < 0 || ny >= grid.SizeY || nz < 0 || nz >= grid.SizeZ )
                    return;
                if ( !grid.IsSolid( nx, ny, nz ) )
                    return;
                var i = grid.Index( nx, ny, nz );
                if ( distance[i] <= next )
                    return;
                distance[i] = next;
                queue.Enqueue( (nx, ny, nz) );
            }
        }
        return distance;
    }

    /// <summary>The vertex's voxel if solid, else the nearest solid voxel within an expanding
    /// chebyshev search (radius up to 4); null when none found.</summary>
    static (int X, int Y, int Z)? NearestSolidVoxel( VoxelGrid grid, Vector3 world )
    {
        var (x, y, z) = grid.VoxelOf( world );
        if ( grid.IsSolid( x, y, z ) )
            return (x, y, z);
        for ( var radius = 1; radius <= 4; radius++ )
        {
            (int X, int Y, int Z)? best = null;
            var bestDistance = float.MaxValue;
            for ( var dz = -radius; dz <= radius; dz++ )
                for ( var dy = -radius; dy <= radius; dy++ )
                    for ( var dx = -radius; dx <= radius; dx++ )
                    {
                        if ( Math.Max( Math.Abs( dx ), Math.Max( Math.Abs( dy ), Math.Abs( dz ) ) ) != radius )
                            continue; // shell only
                        var nx = x + dx; var ny = y + dy; var nz = z + dz;
                        if ( nx < 0 || nx >= grid.SizeX || ny < 0 || ny >= grid.SizeY
                            || nz < 0 || nz >= grid.SizeZ || !grid.IsSolid( nx, ny, nz ) )
                            continue;
                        var d = Vector3.DistanceSquared( world, grid.CenterOf( nx, ny, nz ) );
                        if ( d < bestDistance )
                        {
                            bestDistance = d;
                            best = (nx, ny, nz);
                        }
                    }
            if ( best is not null )
                return best;
        }
        return null;
    }

    /// <summary>One Laplacian pass: average each vertex's weights with its welded-edge
    /// neighbors and re-normalize (keeps the 4 strongest influences).</summary>
    static void Smooth( RigMesh mesh, SkinWeights skin, int boneCount )
    {
        var weld = Analyze.MeshSegmenter.WeldMap( mesh, 0f );

        // Adjacency over welded ids → representative vertex lists.
        var neighbors = new Dictionary<int, HashSet<int>>();
        void Link( int a, int b )
        {
            if ( !neighbors.TryGetValue( a, out var set ) )
                neighbors[a] = set = new HashSet<int>();
            set.Add( b );
        }
        for ( var t = 0; t < mesh.Triangles.Length; t += 3 )
        {
            var a = weld[mesh.Triangles[t]];
            var b = weld[mesh.Triangles[t + 1]];
            var c = weld[mesh.Triangles[t + 2]];
            Link( a, b ); Link( b, a ); Link( b, c ); Link( c, b ); Link( a, c ); Link( c, a );
        }

        // Accumulate smoothed per-bone weights per welded id (sparse).
        var smoothed = new Dictionary<int, Dictionary<int, float>>();
        foreach ( var (id, set) in neighbors )
        {
            var accumulator = new Dictionary<int, float>();
            void AddVertex( int vertex, float scale )
            {
                for ( var k = 0; k < 4; k++ )
                {
                    var w = skin.Weights[vertex * 4 + k];
                    if ( w <= 0f )
                        continue;
                    var bone = skin.BoneIndices[vertex * 4 + k];
                    accumulator[bone] = accumulator.GetValueOrDefault( bone ) + w * scale;
                }
            }
            AddVertex( id, 1f );
            foreach ( var n in set )
                AddVertex( n, 1f / set.Count );
            smoothed[id] = accumulator;
        }

        for ( var v = 0; v < mesh.Positions.Length; v++ )
        {
            if ( !smoothed.TryGetValue( weld[v], out var accumulator ) || accumulator.Count == 0 )
                continue;
            var top = accumulator.OrderByDescending( kv => kv.Value ).Take( 4 ).ToList();
            var total = top.Sum( kv => kv.Value );
            if ( total <= 0f )
                continue;
            for ( var k = 0; k < 4; k++ )
            {
                if ( k < top.Count )
                {
                    skin.BoneIndices[v * 4 + k] = top[k].Key;
                    skin.Weights[v * 4 + k] = top[k].Value / total;
                }
                else
                {
                    skin.BoneIndices[v * 4 + k] = 0;
                    skin.Weights[v * 4 + k] = 0f;
                }
            }
        }
    }
}