Code/AutoRig/Dl/RigNet/MeanShift.cs

MeanShift clustering utilities for RigNet. Implements damped flat-kernel mean-shift, density estimation, greedy NMS, and a post-processing pipeline ExtractJoints that filters by attention, mirrors points across the YZ plane, clusters, density-filters, and applies NMS.

namespace AutoRig.Dl.RigNet;

// s&box compat: the engine defines Vector2/Vector3 in the GLOBAL namespace, which
// shadows using-directive imports - alias explicitly to System.Numerics.
using Vector3 = System.Numerics.Vector3;

/// <summary>
/// RigNet's joint clustering, transcribed from utils/cluster_utils.py and
/// quick_start.predict_joints (verbatim references in
/// docs/superpowers/rignet-port-notes.md): damped flat-kernel meanshift with
/// per-point attention weights, density filtering, and greedy NMS.
/// </summary>
public static class MeanShift
{
    /// <summary>
    /// Damped meanshift (0.3 step) with flat kernel K = max(b² - d², 0), attention
    /// weights applied per SOURCE point, until total movement &lt; 1e-3 or maxIter.
    /// </summary>
    public static Vector3[] Cluster(
        Vector3[] points, float bandwidth, float[]? weights, int maxIter = 40 )
    {
        ArgumentNullException.ThrowIfNull( points );
        var n = points.Length;
        if ( n == 0 )
            return points;

        var current = new Vector3[n];
        Array.Copy( points, current, n );
        var next = new Vector3[n];
        var bandwidthSq = bandwidth * bandwidth;

        for ( var iter = 1; iter < maxIter; iter++ )
        {
            // K[i,j] = flat kernel · weight[i] (numpy: K * weights with [n,1] weights
            // scales the FIRST axis); P[i,j] = K[j,i] / Σ_k K[k,i].
            // next_i = current_i + 0.3 * (Σ_j P[i,j]·current_j - current_i).
            float totalMovementSq = 0;
            for ( var i = 0; i < n; i++ )
            {
                var numerator = Vector3.Zero;
                float denominator = 0;
                for ( var j = 0; j < n; j++ )
                {
                    var kernel = bandwidthSq - Vector3.DistanceSquared( current[i], current[j] );
                    if ( kernel <= 0f )
                        continue;
                    var weighted = kernel * (weights is null ? 1f : weights[j]);
                    numerator += current[j] * weighted;
                    denominator += weighted;
                }
                var mean = denominator > 1e-10f ? numerator / denominator : current[i];
                next[i] = current[i] + 0.3f * (mean - current[i]);
                totalMovementSq += Vector3.DistanceSquared( next[i], current[i] );
            }
            (current, next) = (next, current);
            if ( MathF.Sqrt( totalMovementSq ) <= 1e-3f )
                break;
        }
        return current;
    }

    /// <summary>Flat-kernel density at each point: Σ_j max(b² - d², 0).</summary>
    public static float[] Density( Vector3[] points, float bandwidth )
    {
        var n = points.Length;
        var density = new float[n];
        var bandwidthSq = bandwidth * bandwidth;
        for ( var i = 0; i < n; i++ )
            for ( var j = 0; j < n; j++ )
            {
                var kernel = bandwidthSq - Vector3.DistanceSquared( points[i], points[j] );
                if ( kernel > 0f )
                    density[i] += kernel;
            }
        return density;
    }

    /// <summary>
    /// Greedy NMS: highest-density point wins, everything within the bandwidth is
    /// suppressed; ties resolved by index order (numpy argsort is stable on the
    /// reversed array — replicate its exact ordering).
    /// </summary>
    public static Vector3[] Nms( Vector3[] points, float[] density, float bandwidth )
    {
        var n = points.Length;
        var order = Enumerable.Range( 0, n )
            .OrderByDescending( i => density[i] )
            .ThenByDescending( i => i )   // numpy [::-1] of a stable ascending sort
            .ToArray();

        var unique = new bool[n];
        Array.Fill( unique, true );
        foreach ( var i in order )
        {
            if ( !unique[i] )
                continue;
            for ( var j = 0; j < n; j++ )
            {
                if ( Vector3.Distance( points[j], points[i] ) <= bandwidth )
                    unique[j] = false;
            }
            unique[i] = true;
        }
        return Enumerable.Range( 0, n ).Where( i => unique[i] )
            .Select( i => points[i] ).ToArray();
    }

    /// <summary>
    /// The full predict_joints clustering stage (post-network): filter by attention
    /// &gt; 1e-3, mirror across the YZ plane, cluster, density-filter by
    /// density/Σdensity &gt; threshold, NMS.
    /// </summary>
    public static Vector3[] ExtractJoints(
        Vector3[] candidates, float[] attention, float bandwidth, float densityThreshold )
    {
        // Attention filter.
        var kept = new List<Vector3>();
        var keptAttention = new List<float>();
        for ( var i = 0; i < candidates.Length; i++ )
        {
            if ( attention[i] > 1e-3f )
            {
                kept.Add( candidates[i] );
                keptAttention.Add( attention[i] );
            }
        }

        // Symmetrize across the YZ plane (x → -x), attention tiled.
        var symmetrized = new Vector3[kept.Count * 2];
        var weights = new float[kept.Count * 2];
        for ( var i = 0; i < kept.Count; i++ )
        {
            symmetrized[i] = kept[i];
            symmetrized[kept.Count + i] = new Vector3( -kept[i].X, kept[i].Y, kept[i].Z );
            weights[i] = keptAttention[i];
            weights[kept.Count + i] = keptAttention[i];
        }

        var shifted = Cluster( symmetrized, bandwidth, weights );
        var density = Density( shifted, bandwidth );
        var densitySum = density.Sum();
        if ( densitySum <= 0f )
            return Array.Empty<Vector3>();

        var filteredPoints = new List<Vector3>();
        var filteredDensity = new List<float>();
        for ( var i = 0; i < shifted.Length; i++ )
        {
            if ( density[i] / densitySum > densityThreshold )
            {
                filteredPoints.Add( shifted[i] );
                filteredDensity.Add( density[i] );
            }
        }
        return Nms( filteredPoints.ToArray(), filteredDensity.ToArray(), bandwidth );
    }
}