PointNet primitives and set-abstraction / feature-propagation modules used by RigNet. Implements farthest point sampling, radius neighbor search, PointConv message passing, global max-pool, kNN interpolation, and three module classes (SaModule, GlobalSaModule, FpModule) that compose these ops with small MLPs.
namespace AutoRig.Dl.RigNet;
/// <summary>
/// PointNet++ set-abstraction / feature-propagation primitives as used by RigNet's
/// ROOTNET and BONENET joint encoders (torch_geometric 1.x semantics, transcribed
/// from models/ROOT_GCN.py + PairCls_GCN.py; see docs/superpowers/rignet-port-notes.md).
/// Farthest-point sampling starts deterministically at index 0 — the original uses a
/// random start, so results are one instance of the original's ensemble.
/// </summary>
public static class PointNet
{
/// <summary>
/// Farthest point sampling: ceil(ratio·N) indices in selection order, seeded at
/// index 0; each step adds the point with the largest distance to the selected
/// set (ties → lowest index, matching argmax).
/// </summary>
public static int[] FarthestPointSample( Tensor pos, float ratio )
{
var n = pos.Shape[0];
var count = (int)MathF.Ceiling( ratio * n );
count = Math.Clamp( count, 1, n );
var selected = new int[count];
var bestDistance = new float[n];
Array.Fill( bestDistance, float.MaxValue );
selected[0] = 0;
for ( var s = 1; s < count; s++ )
{
var last = selected[s - 1];
var farthest = -1;
var farthestDistance = float.MinValue;
for ( var i = 0; i < n; i++ )
{
var d = DistanceSquared( pos, i, pos, last );
if ( d < bestDistance[i] )
bestDistance[i] = d;
if ( bestDistance[i] > farthestDistance )
{
farthestDistance = bestDistance[i];
farthest = i;
}
}
selected[s] = farthest;
}
return selected;
}
/// <summary>
/// Radius neighborhood (torch_cluster.radius): for each query row, up to
/// maxNeighbors source indices within radius, in ascending source order.
/// Returns directed (source, query) pairs.
/// </summary>
public static (int Source, int Target)[] RadiusNeighbors(
Tensor source, Tensor query, float radius, int maxNeighbors = 64 )
{
var edges = new List<(int, int)>();
var radiusSq = radius * radius;
for ( var q = 0; q < query.Shape[0]; q++ )
{
var found = 0;
for ( var s = 0; s < source.Shape[0] && found < maxNeighbors; s++ )
{
if ( DistanceSquared( source, s, query, q ) <= radiusSq )
{
edges.Add( (s, q) );
found++;
}
}
}
return edges.ToArray();
}
/// <summary>
/// PointConv message passing over bipartite (source → query) edges:
/// message = localNn(concat(x_source, pos_source - pos_query)), max-aggregated
/// per query. Features may be null (message is just the position offset).
/// Queries with no message keep zeros (scatter_max fill) — cannot happen when
/// queries are a subset of sources (self distance 0).
/// </summary>
public static Tensor PointConv(
Tensor? x, Tensor sourcePos, Tensor queryPos, (int Source, int Target)[] edges, Mlp localNn )
{
var featureChannels = x?.Shape[1] ?? 0;
var inChannels = featureChannels + 3;
var messageInput = new float[edges.Length * inChannels];
for ( var e = 0; e < edges.Length; e++ )
{
var (s, q) = edges[e];
var baseIndex = e * inChannels;
for ( var c = 0; c < featureChannels; c++ )
messageInput[baseIndex + c] = x!.Data[s * featureChannels + c];
for ( var c = 0; c < 3; c++ )
messageInput[baseIndex + featureChannels + c] =
sourcePos.Data[s * 3 + c] - queryPos.Data[q * 3 + c];
}
var messages = localNn.Forward( Tensor.From( messageInput, edges.Length, inChannels ) );
var outChannels = messages.Shape[1];
var queryCount = queryPos.Shape[0];
var result = new float[queryCount * outChannels];
var hasMessage = new bool[queryCount];
for ( var e = 0; e < edges.Length; e++ )
{
var target = edges[e].Target;
for ( var c = 0; c < outChannels; c++ )
{
var value = messages.Data[e * outChannels + c];
var at = target * outChannels + c;
if ( !hasMessage[target] || value > result[at] )
result[at] = value;
}
hasMessage[target] = true;
}
return Tensor.From( result, queryCount, outChannels );
}
/// <summary>Column-wise max over all rows → [1, C] (global_max_pool, batch 1).</summary>
public static Tensor GlobalMaxPool( Tensor x )
{
var cols = x.Shape[1];
var pooled = new float[cols];
Array.Fill( pooled, float.MinValue );
for ( var r = 0; r < x.Shape[0]; r++ )
for ( var c = 0; c < cols; c++ )
pooled[c] = MathF.Max( pooled[c], x.Data[r * cols + c] );
return Tensor.From( pooled, 1, cols );
}
/// <summary>
/// knn_interpolate: inverse-squared-distance weighted average of the k nearest
/// source features per query (k capped at the source count; ties → lower index).
/// </summary>
public static Tensor KnnInterpolate( Tensor x, Tensor sourcePos, Tensor queryPos, int k )
{
var sourceCount = sourcePos.Shape[0];
var queryCount = queryPos.Shape[0];
var channels = x.Shape[1];
k = Math.Min( k, sourceCount );
var result = new float[queryCount * channels];
var distances = new float[sourceCount];
var order = new int[sourceCount];
for ( var q = 0; q < queryCount; q++ )
{
for ( var s = 0; s < sourceCount; s++ )
{
distances[s] = DistanceSquared( sourcePos, s, queryPos, q );
order[s] = s;
}
// k nearest by (distance, index) — ties resolve to the lower index.
for ( var pick = 0; pick < k; pick++ )
{
var best = pick;
for ( var s = pick + 1; s < sourceCount; s++ )
if ( distances[order[s]] < distances[order[best]]
|| (distances[order[s]] == distances[order[best]] && order[s] < order[best]) )
best = s;
(order[pick], order[best]) = (order[best], order[pick]);
}
float weightSum = 0;
for ( var pick = 0; pick < k; pick++ )
weightSum += 1f / MathF.Max( distances[order[pick]], 1e-16f );
for ( var pick = 0; pick < k; pick++ )
{
var weight = 1f / MathF.Max( distances[order[pick]], 1e-16f ) / weightSum;
for ( var c = 0; c < channels; c++ )
result[q * channels + c] += weight * x.Data[order[pick] * channels + c];
}
}
return Tensor.From( result, queryCount, channels );
}
static float DistanceSquared( Tensor a, int rowA, Tensor b, int rowB )
{
float sum = 0;
for ( var c = 0; c < 3; c++ )
{
var d = a.Data[rowA * 3 + c] - b.Data[rowB * 3 + c];
sum += d * d;
}
return sum;
}
}
/// <summary>Set-abstraction level: fps subsample, radius neighborhood, PointConv.</summary>
public sealed class SaModule
{
public required float Ratio;
public required float Radius;
public required Mlp LocalNn;
public (Tensor? X, Tensor Pos) Forward( Tensor? x, Tensor pos )
{
var idx = PointNet.FarthestPointSample( pos, Ratio );
var queryPos = SelectRows( pos, idx );
var edges = PointNet.RadiusNeighbors( pos, queryPos, Radius );
var features = PointNet.PointConv( x, pos, queryPos, edges, LocalNn );
return (features, queryPos);
}
internal static Tensor SelectRows( Tensor t, int[] rows )
{
var cols = t.Shape[1];
var data = new float[rows.Length * cols];
for ( var r = 0; r < rows.Length; r++ )
Array.Copy( t.Data, rows[r] * cols, data, r * cols, cols );
return Tensor.From( data, rows.Length, cols );
}
}
/// <summary>Global set abstraction: nn(concat(x, pos)) max-pooled to one row at origin.</summary>
public sealed class GlobalSaModule
{
public required Mlp Nn;
public (Tensor X, Tensor Pos) Forward( Tensor x, Tensor pos )
{
var features = PointNet.GlobalMaxPool( Nn.Forward( x.Concat( pos, 1 ) ) );
return (features, Tensor.Zeros( 1, 3 ));
}
}
/// <summary>Feature propagation: knn-interpolate to the skip level, concat skip features, nn.</summary>
public sealed class FpModule
{
public required int K;
public required Mlp Nn;
public Tensor Forward( Tensor x, Tensor pos, Tensor? skipX, Tensor skipPos )
{
var interpolated = PointNet.KnnInterpolate( x, pos, skipPos, K );
if ( skipX is not null )
interpolated = interpolated.Concat( skipX, 1 );
return Nn.Forward( interpolated );
}
}