MeanShift clustering utilities for RigNet. Implements a damped flat-kernel mean-shift (Cluster), flat-kernel density estimation (Density), greedy non-maximum suppression (Nms), and a higher-level ExtractJoints pipeline that filters by attention, mirrors points across the YZ plane, clusters, density-filters, and runs NMS.
namespace AutoRig.Dl.RigNet;
// s&box compat: the engine defines Vector2/Vector3 in the GLOBAL namespace, which
// shadows using-directive imports - alias explicitly to System.Numerics.
using Vector3 = System.Numerics.Vector3;
/// <summary>
/// RigNet's joint clustering, transcribed from utils/cluster_utils.py and
/// quick_start.predict_joints (verbatim references in
/// docs/superpowers/rignet-port-notes.md): damped flat-kernel meanshift with
/// per-point attention weights, density filtering, and greedy NMS.
/// </summary>
public static class MeanShift
{
/// <summary>
/// Damped meanshift (0.3 step) with flat kernel K = max(b² - d², 0), attention
/// weights applied per SOURCE point, until total movement < 1e-3 or maxIter.
/// </summary>
public static Vector3[] Cluster(
Vector3[] points, float bandwidth, float[]? weights, int maxIter = 40 )
{
ArgumentNullException.ThrowIfNull( points );
var n = points.Length;
if ( n == 0 )
return points;
var current = new Vector3[n];
Array.Copy( points, current, n );
var next = new Vector3[n];
var bandwidthSq = bandwidth * bandwidth;
for ( var iter = 1; iter < maxIter; iter++ )
{
// K[i,j] = flat kernel · weight[i] (numpy: K * weights with [n,1] weights
// scales the FIRST axis); P[i,j] = K[j,i] / Σ_k K[k,i].
// next_i = current_i + 0.3 * (Σ_j P[i,j]·current_j - current_i).
float totalMovementSq = 0;
for ( var i = 0; i < n; i++ )
{
var numerator = Vector3.Zero;
float denominator = 0;
for ( var j = 0; j < n; j++ )
{
var kernel = bandwidthSq - Vector3.DistanceSquared( current[i], current[j] );
if ( kernel <= 0f )
continue;
var weighted = kernel * (weights is null ? 1f : weights[j]);
numerator += current[j] * weighted;
denominator += weighted;
}
var mean = denominator > 1e-10f ? numerator / denominator : current[i];
next[i] = current[i] + 0.3f * (mean - current[i]);
totalMovementSq += Vector3.DistanceSquared( next[i], current[i] );
}
(current, next) = (next, current);
if ( MathF.Sqrt( totalMovementSq ) <= 1e-3f )
break;
}
return current;
}
/// <summary>Flat-kernel density at each point: Σ_j max(b² - d², 0).</summary>
public static float[] Density( Vector3[] points, float bandwidth )
{
var n = points.Length;
var density = new float[n];
var bandwidthSq = bandwidth * bandwidth;
for ( var i = 0; i < n; i++ )
for ( var j = 0; j < n; j++ )
{
var kernel = bandwidthSq - Vector3.DistanceSquared( points[i], points[j] );
if ( kernel > 0f )
density[i] += kernel;
}
return density;
}
/// <summary>
/// Greedy NMS: highest-density point wins, everything within the bandwidth is
/// suppressed; ties resolved by index order (numpy argsort is stable on the
/// reversed array — replicate its exact ordering).
/// </summary>
public static Vector3[] Nms( Vector3[] points, float[] density, float bandwidth )
{
var n = points.Length;
var order = Enumerable.Range( 0, n )
.OrderByDescending( i => density[i] )
.ThenByDescending( i => i ) // numpy [::-1] of a stable ascending sort
.ToArray();
var unique = new bool[n];
Array.Fill( unique, true );
foreach ( var i in order )
{
if ( !unique[i] )
continue;
for ( var j = 0; j < n; j++ )
{
if ( Vector3.Distance( points[j], points[i] ) <= bandwidth )
unique[j] = false;
}
unique[i] = true;
}
return Enumerable.Range( 0, n ).Where( i => unique[i] )
.Select( i => points[i] ).ToArray();
}
/// <summary>
/// The full predict_joints clustering stage (post-network): filter by attention
/// > 1e-3, mirror across the YZ plane, cluster, density-filter by
/// density/Σdensity > threshold, NMS.
/// </summary>
public static Vector3[] ExtractJoints(
Vector3[] candidates, float[] attention, float bandwidth, float densityThreshold )
{
// Attention filter.
var kept = new List<Vector3>();
var keptAttention = new List<float>();
for ( var i = 0; i < candidates.Length; i++ )
{
if ( attention[i] > 1e-3f )
{
kept.Add( candidates[i] );
keptAttention.Add( attention[i] );
}
}
// Symmetrize across the YZ plane (x → -x), attention tiled.
var symmetrized = new Vector3[kept.Count * 2];
var weights = new float[kept.Count * 2];
for ( var i = 0; i < kept.Count; i++ )
{
symmetrized[i] = kept[i];
symmetrized[kept.Count + i] = new Vector3( -kept[i].X, kept[i].Y, kept[i].Z );
weights[i] = keptAttention[i];
weights[kept.Count + i] = keptAttention[i];
}
var shifted = Cluster( symmetrized, bandwidth, weights );
var density = Density( shifted, bandwidth );
var densitySum = density.Sum();
if ( densitySum <= 0f )
return Array.Empty<Vector3>();
var filteredPoints = new List<Vector3>();
var filteredDensity = new List<float>();
for ( var i = 0; i < shifted.Length; i++ )
{
if ( density[i] / densitySum > densityThreshold )
{
filteredPoints.Add( shifted[i] );
filteredDensity.Add( density[i] );
}
}
return Nms( filteredPoints.ToArray(), filteredDensity.ToArray(), bandwidth );
}
}