AutoRig/Solve/TransferSolver.cs

Solver that transfers a rig and skin weights from a donor mesh to a target mesh. It normalizes both meshes into a shared space, stretches the donor per-axis to match target proportions, snaps donor joints into the target solid (via a voxelization), computes per-vertex blended weights from nearest donor vertices, smooths weights by one-ring averaging, and assembles the resulting RigResult.

File AccessNetworking
using AutoRig.Analyze;
using AutoRig.Dl.RigNet;
using AutoRig.Rig;
using AutoRig.Voxel;

namespace AutoRig.Solve;

using Vector3 = System.Numerics.Vector3;

/// <summary>
/// The transfer solver (spec §5.4): fits a rigged donor's skeleton into the target
/// mesh and carries its skin weights across. Both models are normalized into the
/// same unit space, the donor is stretched per-axis to the target's proportions,
/// joints are snapped into the target's solid, and each target vertex blends the
/// weights of its nearest donor vertices (inverse-squared distance, then one-ring
/// smoothing).
/// </summary>
public static class TransferSolver
{
    const int NearestDonorVertices = 4;
    const float SmoothingBlend = 0.5f;

    public static RigResult Rig( AnalysisResult analysis, DonorRig donor )
    {
        ArgumentNullException.ThrowIfNull( analysis );
        ArgumentNullException.ThrowIfNull( donor );
        var target = analysis.Mesh;

        // ---- shared unit space + per-axis affine donor → target ----
        var (targetNormalized, targetPivot, targetScale) = RigNetInput.Normalize( target.Positions );
        var (donorNormalized, donorPivot, donorScale) = RigNetInput.Normalize( donor.Mesh.Positions );

        var targetBounds = Mesh.Aabb3.FromPoints( targetNormalized );
        var donorBounds = Mesh.Aabb3.FromPoints( donorNormalized );
        var stretch = new Vector3(
            SafeRatio( targetBounds.Size.X, donorBounds.Size.X ),
            SafeRatio( targetBounds.Size.Y, donorBounds.Size.Y ),
            SafeRatio( targetBounds.Size.Z, donorBounds.Size.Z ) );
        var offset = targetBounds.Min - donorBounds.Min * stretch;

        Vector3 Fit( Vector3 donorPoint ) => donorPoint * stretch + offset;

        var donorVertices = new Vector3[donorNormalized.Length];
        for ( var v = 0; v < donorNormalized.Length; v++ )
            donorVertices[v] = Fit( donorNormalized[v] );

        // Donor joints through the same normalize (mesh-derived pivot/scale) + fit.
        var joints = new Vector3[donor.Skeleton.Joints.Count];
        for ( var j = 0; j < joints.Length; j++ )
            joints[j] = Fit( (donor.Skeleton.Joints[j].Position - donorPivot) * donorScale );

        // ---- snap joints into the target's solid ----
        var voxels = BuildVoxels( target, targetNormalized );
        if ( voxels is not null )
            for ( var j = 0; j < joints.Length; j++ )
                joints[j] = SnapInside( joints[j], voxels );

        // ---- weights: inverse-squared-distance blend of nearest donor vertices ----
        var jointCount = donor.Skeleton.Joints.Count;
        var raw = new float[targetNormalized.Length][];
        var nearest = new int[NearestDonorVertices];
        var nearestDistance = new float[NearestDonorVertices];
        for ( var v = 0; v < targetNormalized.Length; v++ )
        {
            var count = 0;
            for ( var d = 0; d < donorVertices.Length; d++ )
            {
                var distance = Vector3.DistanceSquared( targetNormalized[v], donorVertices[d] );
                if ( count < NearestDonorVertices )
                {
                    nearest[count] = d;
                    nearestDistance[count] = distance;
                    count++;
                }
                else
                {
                    var worst = 0;
                    for ( var k = 1; k < NearestDonorVertices; k++ )
                        if ( nearestDistance[k] > nearestDistance[worst] )
                            worst = k;
                    if ( distance < nearestDistance[worst] )
                    {
                        nearest[worst] = d;
                        nearestDistance[worst] = distance;
                    }
                }
            }

            var row = new float[jointCount];
            float weightSum = 0;
            for ( var k = 0; k < count; k++ )
                weightSum += 1f / MathF.Max( nearestDistance[k], 1e-12f );
            for ( var k = 0; k < count; k++ )
            {
                var blend = 1f / MathF.Max( nearestDistance[k], 1e-12f ) / weightSum;
                var donorVertex = nearest[k];
                for ( var slot = 0; slot < 4; slot++ )
                {
                    var weight = donor.Weights.Weights[donorVertex * 4 + slot];
                    if ( weight > 0f )
                        row[donor.Weights.BoneIndices[donorVertex * 4 + slot]] += blend * weight;
                }
            }
            raw[v] = row;
        }

        SmoothOneRing( raw, target.Triangles );

        // ---- assemble ----
        var skeleton = new RigSkeleton();
        for ( var j = 0; j < jointCount; j++ )
        {
            skeleton.Joints.Add( new RigJoint
            {
                Name = donor.Skeleton.Joints[j].Name,
                Parent = donor.Skeleton.Joints[j].Parent,
                Position = joints[j] / targetScale + targetPivot,
            } );
        }

        var indices = new int[targetNormalized.Length * 4];
        var weights = new float[targetNormalized.Length * 4];
        var top = new int[4];
        for ( var v = 0; v < targetNormalized.Length; v++ )
        {
            var row = raw[v];
            var count = 0;
            for ( var j = 0; j < jointCount; j++ )
            {
                if ( row[j] <= 0f )
                    continue;
                if ( count < 4 )
                    top[count++] = j;
                else
                {
                    var weakest = 0;
                    for ( var k = 1; k < 4; k++ )
                        if ( row[top[k]] < row[top[weakest]] )
                            weakest = k;
                    if ( row[j] > row[top[weakest]] )
                        top[weakest] = j;
                }
            }
            float total = 0;
            for ( var k = 0; k < count; k++ )
                total += row[top[k]];
            if ( count == 0 || total <= 0f )
            {
                indices[v * 4] = NearestJoint( joints, targetNormalized[v] );
                weights[v * 4] = 1f;
                continue;
            }
            for ( var k = 0; k < count; k++ )
            {
                indices[v * 4 + k] = top[k];
                weights[v * 4 + k] = row[top[k]] / total;
            }
        }

        return new RigResult
        {
            Skeleton = skeleton,
            Weights = new SkinWeights { BoneIndices = indices, Weights = weights },
            SolverName = "transfer",
            Degraded = false,
            Explanation = $"Transferred {jointCount} joints and their skin weights "
                + $"from '{donor.Mesh.SourceName}'.",
        };
    }

    static float SafeRatio( float target, float donor )
        => donor > 1e-6f ? target / donor : 1f;

    static VoxelGrid BuildVoxels( Mesh.RigMesh target, Vector3[] normalizedPositions )
    {
        try
        {
            var proxy = new Mesh.RigMesh
            {
                SourceName = target.SourceName,
                Positions = normalizedPositions,
                Normals = target.Normals,
                Uvs = target.Uvs,
                Triangles = target.Triangles,
                TriangleTags = target.TriangleTags,
                Tags = target.Tags,
            };
            return VoxelGrid.Build( proxy, 64 );
        }
        catch ( FormatException )
        {
            return null;   // degenerate target — skip snapping
        }
    }

    /// <summary>Nearest solid voxel center by outward ring search (unchanged when
    /// the joint is already inside).</summary>
    static Vector3 SnapInside( Vector3 point, VoxelGrid voxels )
    {
        var (x, y, z) = voxels.VoxelOf( point );
        if ( voxels.IsSolid( x, y, z ) )
            return point;

        var maxRadius = Math.Max( voxels.SizeX, Math.Max( voxels.SizeY, voxels.SizeZ ) );
        for ( var radius = 1; radius <= maxRadius; radius++ )
        {
            var best = point;
            var bestDistance = float.MaxValue;
            for ( var dx = -radius; dx <= radius; dx++ )
                for ( var dy = -radius; dy <= radius; dy++ )
                    for ( var dz = -radius; dz <= radius; dz++ )
                    {
                        if ( Math.Max( Math.Abs( dx ), Math.Max( Math.Abs( dy ), Math.Abs( dz ) ) ) != radius )
                            continue;   // shell only
                        if ( !voxels.IsSolid( x + dx, y + dy, z + dz ) )
                            continue;
                        var center = voxels.CenterOf( x + dx, y + dy, z + dz );
                        var distance = Vector3.DistanceSquared( center, point );
                        if ( distance < bestDistance )
                        {
                            bestDistance = distance;
                            best = center;
                        }
                    }
            if ( bestDistance < float.MaxValue )
                return best;
        }
        return point;
    }

    /// <summary>Blend each vertex row toward its one-ring neighborhood mean.</summary>
    static void SmoothOneRing( float[][] rows, int[] triangles )
    {
        var vertexCount = rows.Length;
        var neighbors = new HashSet<int>[vertexCount];
        for ( var v = 0; v < vertexCount; v++ )
            neighbors[v] = new HashSet<int>();
        for ( var t = 0; t < triangles.Length; t += 3 )
        {
            int a = triangles[t], b = triangles[t + 1], c = triangles[t + 2];
            neighbors[a].Add( b ); neighbors[a].Add( c );
            neighbors[b].Add( a ); neighbors[b].Add( c );
            neighbors[c].Add( a ); neighbors[c].Add( b );
        }

        var jointCount = rows[0].Length;
        var smoothed = new float[vertexCount][];
        for ( var v = 0; v < vertexCount; v++ )
        {
            if ( neighbors[v].Count == 0 )
            {
                smoothed[v] = rows[v];
                continue;
            }
            var mean = new float[jointCount];
            foreach ( var n in neighbors[v] )
                for ( var j = 0; j < jointCount; j++ )
                    mean[j] += rows[n][j];
            var row = new float[jointCount];
            for ( var j = 0; j < jointCount; j++ )
                row[j] = rows[v][j] * (1f - SmoothingBlend)
                    + mean[j] / neighbors[v].Count * SmoothingBlend;
            smoothed[v] = row;
        }
        for ( var v = 0; v < vertexCount; v++ )
            rows[v] = smoothed[v];
    }

    static int NearestJoint( Vector3[] joints, Vector3 point )
    {
        var best = 0;
        for ( var j = 1; j < joints.Length; j++ )
            if ( Vector3.DistanceSquared( joints[j], point )
                < Vector3.DistanceSquared( joints[best], point ) )
                best = j;
        return best;
    }
}