Code/AutoRig/Dl/RigNet/RigNetInput.cs

Utilities to build RigNet inputs from a RigMesh and to compute a sampled surface geodesic. Provides normalization, vertex-clustering decimation to a proxy mesh, topology one-ring edges, geodesic neighbor edges, and a SurfaceGeodesic builder that samples the mesh, builds a 5-NN filtered graph, runs all-pairs Dijkstra using a small binary heap, and maps vertices to nearest samples.

Native Interop
using AutoRig.Mesh;

namespace AutoRig.Dl.RigNet;

using Vector2 = System.Numerics.Vector2;
using Vector3 = System.Numerics.Vector3;

/// <summary>
/// Builds RigNet's network inputs from a RigMesh, following quick_start's
/// preparation: normalize to a unit-height space, reduce to a proxy mesh of a few
/// thousand vertices (vertex clustering instead of quadric decimation — weights
/// transfer back by nearest proxy vertex, exactly like the original's remesh flow),
/// one-ring topology edges and geodesic-ball edges over a sampled surface
/// geodesic. Randomized reference steps (poisson sampling, random ball picks) are
/// replaced with deterministic equivalents.
/// </summary>
public static class RigNetInput
{
    /// <summary>normalize_obj: scale 1/longest-dim; pivot (center x, min y, center z).</summary>
    public static (Vector3[] Positions, Vector3 Pivot, float Scale) Normalize( Vector3[] positions )
    {
        var bounds = Aabb3.FromPoints( positions );
        var size = bounds.Size;
        var scale = 1f / MathF.Max( size.X, MathF.Max( size.Y, size.Z ) );
        var pivot = new Vector3(
            (bounds.Min.X + bounds.Max.X) * 0.5f, bounds.Min.Y, (bounds.Min.Z + bounds.Max.Z) * 0.5f );
        var result = new Vector3[positions.Length];
        for ( var i = 0; i < positions.Length; i++ )
            result[i] = (positions[i] - pivot) * scale;
        return (result, pivot, scale);
    }

    /// <summary>
    /// Vertex-clustering decimation: bucket vertices on a uniform grid sized so the
    /// proxy stays at or under <paramref name="targetVertices"/>, average each
    /// cluster, drop collapsed triangles. Returns the proxy plus each original
    /// vertex's proxy index (for weight transfer back).
    /// </summary>
    public static (RigMesh Proxy, int[] VertexMap) Decimate( RigMesh mesh, int targetVertices )
    {
        ArgumentNullException.ThrowIfNull( mesh );
        if ( mesh.Positions.Length <= targetVertices )
            return (mesh, Enumerable.Range( 0, mesh.Positions.Length ).ToArray());

        var bounds = mesh.ComputeBounds();
        var size = bounds.Size;
        var longest = MathF.Max( size.X, MathF.Max( size.Y, size.Z ) );

        // Shrink the grid until the cluster count fits the budget.
        var resolution = (int)MathF.Ceiling( MathF.Cbrt( targetVertices ) ) * 2;
        int[] map;
        Vector3[] proxyPositions;
        while ( true )
        {
            var cell = longest / resolution;
            var clusters = new Dictionary<(int, int, int), int>();
            map = new int[mesh.Positions.Length];
            var sums = new List<Vector3>();
            var counts = new List<int>();
            for ( var v = 0; v < mesh.Positions.Length; v++ )
            {
                var p = (mesh.Positions[v] - bounds.Min) / cell;
                var key = ((int)p.X, (int)p.Y, (int)p.Z);
                if ( !clusters.TryGetValue( key, out var id ) )
                {
                    id = sums.Count;
                    clusters[key] = id;
                    sums.Add( Vector3.Zero );
                    counts.Add( 0 );
                }
                map[v] = id;
                sums[id] += mesh.Positions[v];
                counts[id]++;
            }
            if ( sums.Count <= targetVertices || resolution <= 4 )
            {
                proxyPositions = new Vector3[sums.Count];
                for ( var i = 0; i < sums.Count; i++ )
                    proxyPositions[i] = sums[i] / counts[i];
                break;
            }
            resolution = (int)(resolution * 0.8f);
        }

        var triangles = new List<int>();
        for ( var t = 0; t < mesh.TriangleCount; t++ )
        {
            int a = map[mesh.Triangles[t * 3]], b = map[mesh.Triangles[t * 3 + 1]],
                c = map[mesh.Triangles[t * 3 + 2]];
            if ( a != b && b != c && a != c )
            {
                triangles.Add( a );
                triangles.Add( b );
                triangles.Add( c );
            }
        }

        var proxy = new RigMesh
        {
            SourceName = mesh.SourceName,
            Positions = proxyPositions,
            Normals = new Vector3[proxyPositions.Length],
            Uvs = new Vector2[proxyPositions.Length],
            Triangles = triangles.ToArray(),
            TriangleTags = new int[triangles.Count / 3],
            Tags = mesh.Tags,
        };
        return (proxy, map);
    }

    /// <summary>get_tpl_edges: one-ring mesh edges, both directions, no self-loops.</summary>
    public static (int From, int To)[] TopologyEdges( int vertexCount, int[] triangles )
    {
        var neighbors = new HashSet<int>[vertexCount];
        for ( var v = 0; v < vertexCount; v++ )
            neighbors[v] = new HashSet<int>();
        for ( var t = 0; t < triangles.Length; t += 3 )
        {
            int a = triangles[t], b = triangles[t + 1], c = triangles[t + 2];
            neighbors[a].Add( b ); neighbors[a].Add( c );
            neighbors[b].Add( a ); neighbors[b].Add( c );
            neighbors[c].Add( a ); neighbors[c].Add( b );
        }
        var edges = new List<(int, int)>();
        for ( var v = 0; v < vertexCount; v++ )
            foreach ( var n in neighbors[v].OrderBy( n => n ) )
                edges.Add( (v, n) );
        return edges.ToArray();
    }

    /// <summary>
    /// get_geo_edges: for each vertex, edges to vertices within surface-geodesic
    /// distance 0.06 (up to 10, evenly strided instead of the original's random
    /// sample), both stored as (from vertex, to neighbor).
    /// </summary>
    public static (int From, int To)[] GeodesicEdges( SurfaceGeodesic geodesic, int vertexCount )
    {
        var edges = new List<(int, int)>();
        var ball = new List<int>();
        for ( var v = 0; v < vertexCount; v++ )
        {
            ball.Clear();
            for ( var u = 0; u < vertexCount; u++ )
                if ( u != v && geodesic.Distance( v, u ) <= 0.06f )
                    ball.Add( u );
            if ( ball.Count > 10 )
            {
                var strided = new List<int>( 10 );
                for ( var i = 0; i < 10; i++ )
                    strided.Add( ball[i * ball.Count / 10] );
                ball = strided;
            }
            foreach ( var u in ball )
                edges.Add( (v, u) );
        }
        return edges.ToArray();
    }
}

/// <summary>
/// Sampled surface geodesic distances (calc_surface_geodesic): deterministic
/// area-weighted surface samples, a 5-NN graph filtered by normal agreement
/// (cos &gt; -0.5), all-pairs Dijkstra, vertices mapped to their nearest sample.
/// Disconnected pairs fall back to 8 + euclidean, like the reference.
/// </summary>
public sealed class SurfaceGeodesic
{
    readonly float[] _sampleDistances;   // [S*S]
    readonly int[] _vertexSample;        // vertex → sample id
    readonly Vector3[] _samplePositions;
    readonly int _sampleCount;

    public float Distance( int v0, int v1 )
    {
        int s0 = _vertexSample[v0], s1 = _vertexSample[v1];
        var d = _sampleDistances[s0 * _sampleCount + s1];
        return float.IsInfinity( d )
            ? 8f + Vector3.Distance( _samplePositions[s0], _samplePositions[s1] )
            : d;
    }

    SurfaceGeodesic( float[] sampleDistances, int[] vertexSample, Vector3[] samplePositions )
    {
        _sampleDistances = sampleDistances;
        _vertexSample = vertexSample;
        _samplePositions = samplePositions;
        _sampleCount = samplePositions.Length;
    }

    public static SurfaceGeodesic Build( RigMesh mesh, int sampleCount = 1024 )
    {
        ArgumentNullException.ThrowIfNull( mesh );
        var (samples, normals) = SampleSurface( mesh, sampleCount );
        var s = samples.Length;

        // 5-NN graph, keeping neighbors whose normals do not oppose (cos > -0.5).
        var adjacency = new List<(int To, float Weight)>[s];
        for ( var i = 0; i < s; i++ )
            adjacency[i] = new List<(int, float)>();
        var order = new int[s];
        var distances = new float[s];
        for ( var i = 0; i < s; i++ )
        {
            for ( var j = 0; j < s; j++ )
            {
                order[j] = j;
                distances[j] = Vector3.DistanceSquared( samples[i], samples[j] );
            }
            Array.Sort( distances, order );
            for ( var pick = 1; pick <= 5 && pick < s; pick++ )
            {
                var j = order[pick];
                var cos = Vector3.Dot( normals[i], normals[j] )
                    / (normals[i].Length() * normals[j].Length() + 1e-10f);
                if ( cos > -0.5f )
                {
                    var weight = MathF.Sqrt( distances[pick] );
                    adjacency[i].Add( (j, weight) );
                    adjacency[j].Add( (i, weight) );
                }
            }
        }

        // All-pairs Dijkstra over the sample graph (inline binary heap — no
        // PriorityQueue, it is not proven against the s&box whitelist).
        var all = new float[s * s];
        var heap = new MinHeap( s * 8 );
        for ( var source = 0; source < s; source++ )
        {
            var dist = new float[s];
            Array.Fill( dist, float.PositiveInfinity );
            dist[source] = 0f;
            heap.Count = 0;
            heap.Push( source, 0f );
            while ( heap.TryPop( out var u, out var du ) )
            {
                if ( du > dist[u] )
                    continue;
                foreach ( var (to, weight) in adjacency[u] )
                {
                    var candidate = du + weight;
                    if ( candidate < dist[to] )
                    {
                        dist[to] = candidate;
                        heap.Push( to, candidate );
                    }
                }
            }
            Array.Copy( dist, 0, all, source * s, s );
        }

        // Vertex → nearest sample.
        var vertexSample = new int[mesh.Positions.Length];
        for ( var v = 0; v < mesh.Positions.Length; v++ )
        {
            var best = 0;
            var bestDistance = float.MaxValue;
            for ( var i = 0; i < s; i++ )
            {
                var d = Vector3.DistanceSquared( mesh.Positions[v], samples[i] );
                if ( d < bestDistance )
                {
                    bestDistance = d;
                    best = i;
                }
            }
            vertexSample[v] = best;
        }
        return new SurfaceGeodesic( all, vertexSample, samples );
    }

    /// <summary>A minimal (id, key) binary min-heap for Dijkstra.</summary>
    sealed class MinHeap
    {
        int[] _ids;
        float[] _keys;
        public int Count;

        public MinHeap( int capacity )
        {
            _ids = new int[Math.Max( capacity, 16 )];
            _keys = new float[Math.Max( capacity, 16 )];
        }

        public void Push( int id, float key )
        {
            if ( Count == _ids.Length )
            {
                Array.Resize( ref _ids, Count * 2 );
                Array.Resize( ref _keys, Count * 2 );
            }
            var i = Count++;
            while ( i > 0 )
            {
                var parent = (i - 1) / 2;
                if ( _keys[parent] <= key )
                    break;
                _ids[i] = _ids[parent];
                _keys[i] = _keys[parent];
                i = parent;
            }
            _ids[i] = id;
            _keys[i] = key;
        }

        public bool TryPop( out int id, out float key )
        {
            if ( Count == 0 )
            {
                id = 0;
                key = 0f;
                return false;
            }
            id = _ids[0];
            key = _keys[0];
            Count--;
            var lastId = _ids[Count];
            var lastKey = _keys[Count];
            var i = 0;
            while ( true )
            {
                var child = i * 2 + 1;
                if ( child >= Count )
                    break;
                if ( child + 1 < Count && _keys[child + 1] < _keys[child] )
                    child++;
                if ( _keys[child] >= lastKey )
                    break;
                _ids[i] = _ids[child];
                _keys[i] = _keys[child];
                i = child;
            }
            _ids[i] = lastId;
            _keys[i] = lastKey;
            return true;
        }
    }

    /// <summary>Deterministic area-weighted surface sampling with low-discrepancy
    /// barycentrics (stands in for poisson-disk sampling).</summary>
    static (Vector3[] Points, Vector3[] Normals) SampleSurface( RigMesh mesh, int sampleCount )
    {
        var triangleCount = mesh.TriangleCount;
        var cumulative = new float[triangleCount];
        float total = 0f;
        for ( var t = 0; t < triangleCount; t++ )
        {
            var a = mesh.Positions[mesh.Triangles[t * 3]];
            var b = mesh.Positions[mesh.Triangles[t * 3 + 1]];
            var c = mesh.Positions[mesh.Triangles[t * 3 + 2]];
            total += Vector3.Cross( b - a, c - a ).Length() * 0.5f;
            cumulative[t] = total;
        }
        if ( total <= 0f )
            throw new FormatException( "Cannot sample a mesh with zero surface area." );

        var points = new Vector3[sampleCount];
        var normals = new Vector3[sampleCount];
        for ( var i = 0; i < sampleCount; i++ )
        {
            var target = (i + 0.5f) / sampleCount * total;
            var t = Array.BinarySearch( cumulative, target );
            if ( t < 0 )
                t = ~t;
            t = Math.Min( t, triangleCount - 1 );

            var a = mesh.Positions[mesh.Triangles[t * 3]];
            var b = mesh.Positions[mesh.Triangles[t * 3 + 1]];
            var c = mesh.Positions[mesh.Triangles[t * 3 + 2]];

            // Low-discrepancy barycentric from the sample index (plastic constants).
            var r1 = (i * 0.7548776662466927f) % 1f;
            var r2 = (i * 0.5698402909980532f) % 1f;
            if ( r1 + r2 > 1f )
            {
                r1 = 1f - r1;
                r2 = 1f - r2;
            }
            points[i] = a + (b - a) * r1 + (c - a) * r2;
            var normal = Vector3.Cross( b - a, c - a );
            normals[i] = normal.Length() > 1e-12f ? Vector3.Normalize( normal ) : Vector3.UnitY;
        }
        return (points, normals);
    }
}