Code/AutoRig/Dl/RigNet/PointNet.cs

PointNet primitives and set-abstraction / feature-propagation modules used by RigNet. Implements farthest point sampling, radius neighbor search, PointConv message passing, global max-pool, kNN interpolation, and three module classes (SaModule, GlobalSaModule, FpModule) that compose these ops with small MLPs.

Native Interop
namespace AutoRig.Dl.RigNet;

/// <summary>
/// PointNet++ set-abstraction / feature-propagation primitives as used by RigNet's
/// ROOTNET and BONENET joint encoders (torch_geometric 1.x semantics, transcribed
/// from models/ROOT_GCN.py + PairCls_GCN.py; see docs/superpowers/rignet-port-notes.md).
/// Farthest-point sampling starts deterministically at index 0 — the original uses a
/// random start, so results are one instance of the original's ensemble.
/// </summary>
public static class PointNet
{
    /// <summary>
    /// Farthest point sampling: ceil(ratio·N) indices in selection order, seeded at
    /// index 0; each step adds the point with the largest distance to the selected
    /// set (ties → lowest index, matching argmax).
    /// </summary>
    public static int[] FarthestPointSample( Tensor pos, float ratio )
    {
        var n = pos.Shape[0];
        var count = (int)MathF.Ceiling( ratio * n );
        count = Math.Clamp( count, 1, n );

        var selected = new int[count];
        var bestDistance = new float[n];
        Array.Fill( bestDistance, float.MaxValue );
        selected[0] = 0;
        for ( var s = 1; s < count; s++ )
        {
            var last = selected[s - 1];
            var farthest = -1;
            var farthestDistance = float.MinValue;
            for ( var i = 0; i < n; i++ )
            {
                var d = DistanceSquared( pos, i, pos, last );
                if ( d < bestDistance[i] )
                    bestDistance[i] = d;
                if ( bestDistance[i] > farthestDistance )
                {
                    farthestDistance = bestDistance[i];
                    farthest = i;
                }
            }
            selected[s] = farthest;
        }
        return selected;
    }

    /// <summary>
    /// Radius neighborhood (torch_cluster.radius): for each query row, up to
    /// maxNeighbors source indices within radius, in ascending source order.
    /// Returns directed (source, query) pairs.
    /// </summary>
    public static (int Source, int Target)[] RadiusNeighbors(
        Tensor source, Tensor query, float radius, int maxNeighbors = 64 )
    {
        var edges = new List<(int, int)>();
        var radiusSq = radius * radius;
        for ( var q = 0; q < query.Shape[0]; q++ )
        {
            var found = 0;
            for ( var s = 0; s < source.Shape[0] && found < maxNeighbors; s++ )
            {
                if ( DistanceSquared( source, s, query, q ) <= radiusSq )
                {
                    edges.Add( (s, q) );
                    found++;
                }
            }
        }
        return edges.ToArray();
    }

    /// <summary>
    /// PointConv message passing over bipartite (source → query) edges:
    /// message = localNn(concat(x_source, pos_source - pos_query)), max-aggregated
    /// per query. Features may be null (message is just the position offset).
    /// Queries with no message keep zeros (scatter_max fill) — cannot happen when
    /// queries are a subset of sources (self distance 0).
    /// </summary>
    public static Tensor PointConv(
        Tensor? x, Tensor sourcePos, Tensor queryPos, (int Source, int Target)[] edges, Mlp localNn )
    {
        var featureChannels = x?.Shape[1] ?? 0;
        var inChannels = featureChannels + 3;
        var messageInput = new float[edges.Length * inChannels];
        for ( var e = 0; e < edges.Length; e++ )
        {
            var (s, q) = edges[e];
            var baseIndex = e * inChannels;
            for ( var c = 0; c < featureChannels; c++ )
                messageInput[baseIndex + c] = x!.Data[s * featureChannels + c];
            for ( var c = 0; c < 3; c++ )
                messageInput[baseIndex + featureChannels + c] =
                    sourcePos.Data[s * 3 + c] - queryPos.Data[q * 3 + c];
        }

        var messages = localNn.Forward( Tensor.From( messageInput, edges.Length, inChannels ) );
        var outChannels = messages.Shape[1];
        var queryCount = queryPos.Shape[0];
        var result = new float[queryCount * outChannels];
        var hasMessage = new bool[queryCount];
        for ( var e = 0; e < edges.Length; e++ )
        {
            var target = edges[e].Target;
            for ( var c = 0; c < outChannels; c++ )
            {
                var value = messages.Data[e * outChannels + c];
                var at = target * outChannels + c;
                if ( !hasMessage[target] || value > result[at] )
                    result[at] = value;
            }
            hasMessage[target] = true;
        }
        return Tensor.From( result, queryCount, outChannels );
    }

    /// <summary>Column-wise max over all rows → [1, C] (global_max_pool, batch 1).</summary>
    public static Tensor GlobalMaxPool( Tensor x )
    {
        var cols = x.Shape[1];
        var pooled = new float[cols];
        Array.Fill( pooled, float.MinValue );
        for ( var r = 0; r < x.Shape[0]; r++ )
            for ( var c = 0; c < cols; c++ )
                pooled[c] = MathF.Max( pooled[c], x.Data[r * cols + c] );
        return Tensor.From( pooled, 1, cols );
    }

    /// <summary>
    /// knn_interpolate: inverse-squared-distance weighted average of the k nearest
    /// source features per query (k capped at the source count; ties → lower index).
    /// </summary>
    public static Tensor KnnInterpolate( Tensor x, Tensor sourcePos, Tensor queryPos, int k )
    {
        var sourceCount = sourcePos.Shape[0];
        var queryCount = queryPos.Shape[0];
        var channels = x.Shape[1];
        k = Math.Min( k, sourceCount );

        var result = new float[queryCount * channels];
        var distances = new float[sourceCount];
        var order = new int[sourceCount];
        for ( var q = 0; q < queryCount; q++ )
        {
            for ( var s = 0; s < sourceCount; s++ )
            {
                distances[s] = DistanceSquared( sourcePos, s, queryPos, q );
                order[s] = s;
            }
            // k nearest by (distance, index) — ties resolve to the lower index.
            for ( var pick = 0; pick < k; pick++ )
            {
                var best = pick;
                for ( var s = pick + 1; s < sourceCount; s++ )
                    if ( distances[order[s]] < distances[order[best]]
                        || (distances[order[s]] == distances[order[best]] && order[s] < order[best]) )
                        best = s;
                (order[pick], order[best]) = (order[best], order[pick]);
            }

            float weightSum = 0;
            for ( var pick = 0; pick < k; pick++ )
                weightSum += 1f / MathF.Max( distances[order[pick]], 1e-16f );
            for ( var pick = 0; pick < k; pick++ )
            {
                var weight = 1f / MathF.Max( distances[order[pick]], 1e-16f ) / weightSum;
                for ( var c = 0; c < channels; c++ )
                    result[q * channels + c] += weight * x.Data[order[pick] * channels + c];
            }
        }
        return Tensor.From( result, queryCount, channels );
    }

    static float DistanceSquared( Tensor a, int rowA, Tensor b, int rowB )
    {
        float sum = 0;
        for ( var c = 0; c < 3; c++ )
        {
            var d = a.Data[rowA * 3 + c] - b.Data[rowB * 3 + c];
            sum += d * d;
        }
        return sum;
    }
}

/// <summary>Set-abstraction level: fps subsample, radius neighborhood, PointConv.</summary>
public sealed class SaModule
{
    public required float Ratio;
    public required float Radius;
    public required Mlp LocalNn;

    public (Tensor? X, Tensor Pos) Forward( Tensor? x, Tensor pos )
    {
        var idx = PointNet.FarthestPointSample( pos, Ratio );
        var queryPos = SelectRows( pos, idx );
        var edges = PointNet.RadiusNeighbors( pos, queryPos, Radius );
        var features = PointNet.PointConv( x, pos, queryPos, edges, LocalNn );
        return (features, queryPos);
    }

    internal static Tensor SelectRows( Tensor t, int[] rows )
    {
        var cols = t.Shape[1];
        var data = new float[rows.Length * cols];
        for ( var r = 0; r < rows.Length; r++ )
            Array.Copy( t.Data, rows[r] * cols, data, r * cols, cols );
        return Tensor.From( data, rows.Length, cols );
    }
}

/// <summary>Global set abstraction: nn(concat(x, pos)) max-pooled to one row at origin.</summary>
public sealed class GlobalSaModule
{
    public required Mlp Nn;

    public (Tensor X, Tensor Pos) Forward( Tensor x, Tensor pos )
    {
        var features = PointNet.GlobalMaxPool( Nn.Forward( x.Concat( pos, 1 ) ) );
        return (features, Tensor.Zeros( 1, 3 ));
    }
}

/// <summary>Feature propagation: knn-interpolate to the skip level, concat skip features, nn.</summary>
public sealed class FpModule
{
    public required int K;
    public required Mlp Nn;

    public Tensor Forward( Tensor x, Tensor pos, Tensor? skipX, Tensor skipPos )
    {
        var interpolated = PointNet.KnnInterpolate( x, pos, skipPos, K );
        if ( skipX is not null )
            interpolated = interpolated.Concat( skipX, 1 );
        return Nn.Forward( interpolated );
    }
}