Code/Components/GaussianSplatRenderer.cs
namespace Sandbox;
/// <summary>
/// Renders a Gaussian splat point cloud loaded from a PLY file.
/// All GaussianSplatCloud instances in a scene share a single
/// <see cref="SceneGaussianSplatSystem"/> for globally-sorted rendering.
/// </summary>
[Title( "Gaussian Splat Renderer" )]
[Category( "Rendering" )]
[Icon( "blur_on" )]
public sealed class GaussianSplatRenderer : Component, Component.ExecuteInEditor
{
	SceneGaussianSplatObject _so;
	SceneGaussianSplatSystem _system;

	/// <summary>
	/// Relative path to the .ply or .sog file containing the Gaussian splat data.
	/// </summary>
	[Property, Order( -100 )]
	public string FilePath { get; set; }

	/// <summary>
	/// Size of each splat billboard in world units (inches)
	/// </summary>
	[Property, Group( "Appearance" ), Range( 0.1f, 20.0f )]
	public float SplatSize { get; set; } = 1.0f;

	/// <summary>
	/// Minimum opacity threshold, splats below this are discarded at load time.
	/// </summary>
	[Property, Group( "Appearance" ), Range( 0f, 1f )]
	public float MinOpacity { get; set; } = 1f / 255f;

	/// <summary>
	/// When enabled, splats receive lighting from scene lights using normals
	/// derived from the covariance ellipsoid.
	/// </summary>
	[Property, Group( "Appearance" )]
	public bool ReceiveLighting { get; set; } = false;

	/// <summary>
	/// When enabled, splats receive shadow darkening without full lighting re-coloring.
	/// This preserves the original baked look while adding dynamic shadows on top.
	/// Ignored if ReceiveLighting is enabled (full lighting already includes shadows).
	/// </summary>
	[Property, Group( "Appearance" ), HideIf( nameof( ReceiveLighting ), true )]
	public bool ReceiveShadows { get; set; } = false;

	/// <summary>
	/// Color used to tint shadowed regions. Allows matching the shadow color baked into
	/// the original splat data. Darker = stronger shadow, lighter = subtler effect.
	/// </summary>
	[Property, Group( "Appearance" ), HideIf( nameof( ReceiveLighting ), true ), ShowIf( nameof( ReceiveShadows ), true )]
	public Color ShadowTint { get; set; } = new Color( 0.1f, 0.1f, 0.15f, 1.0f );

	/// <summary>
	/// Color tint multiplied with each splat's color. White means no tinting.
	/// </summary>
	[Property, Group( "Appearance" )]
	public Color Tint { get; set; } = Color.White;

	/// <summary>
	/// When enabled, splats are progressively culled based on distance from the camera
	/// using the LOD curve to control the falloff.
	/// </summary>
	[Property, Group( "LOD" )]
	public bool EnableLOD { get; set; } = false;

	/// <summary>
	/// Maximum distance (in world units) for the LOD system.
	/// At this distance or beyond, the curve evaluates at its end (t=1).
	/// </summary>
	[Property, Group( "LOD" ), Range( 100f, 50000f )]
	public float LODMaxDistance { get; set; } = 10000f;

	/// <summary>
	/// Controls what fraction of splats are kept at each normalized distance (0=camera, 1=max distance).
	/// Y-axis: 1 = keep all splats, 0 = cull all splats.
	/// </summary>
	[Property, Group( "LOD" )]
	public Curve LODCurve { get; set; } = new Curve( new Curve.Frame( 0f, 1f ), new Curve.Frame( 1f, 0f ) );

	/// <summary>
	/// When enabled, the splat cloud is divided into spatial chunks and LOD is evaluated
	/// per-chunk rather than per-object. This gives much better results for large environment
	/// scans where different parts of the cloud are at different distances from the camera.
	/// </summary>
	[Property, Group( "LOD" )]
	public bool EnableChunkedLOD { get; set; } = false;

	/// <summary>
	/// Size of each spatial chunk in world units (inches). Splats are assigned to chunks
	/// based on their local position at load time. Smaller chunks = finer-grained LOD
	/// but more GPU overhead for the per-chunk LOD pass.
	/// </summary>
	[Property, Group( "LOD" ), Range( 50f, 5000f )]
	public float ChunkSize { get; set; } = 500f;

	/// <summary>
	/// Chunks whose center is above this local height (Y-axis) are exempt from LOD
	/// and always render at full density. Useful for scanned skies/ceilings.
	/// 0 = disabled (no height exemption).
	/// </summary>
	[Property, Group( "LOD" )]
	public float ChunkHeightExemption { get; set; } = 0f;

	/// <summary>
	/// Chunks whose center is beyond this distance from the object origin are exempt
	/// from LOD and always render at full density. Useful for distant backdrop geometry.
	/// 0 = disabled (no distance exemption).
	/// </summary>
	[Property, Group( "LOD" )]
	public float ChunkDistanceExemption { get; set; } = 0f;

	/// <summary>
	/// Scale applied when importing positions.
	/// PLY files from 3DGS training are typically in meters, 39.3701 converts to inches.
	/// </summary>
	private const float ImportScale = 39.3701f;

	/// <summary>
	/// Cached processed splat data to avoid re-reading PLY files on play mode transitions.
	/// Keyed by (file path, import scale, min opacity) so identical loads skip disk IO entirely.
	/// Entries are evicted when no active component references them.
	/// </summary>
	private static readonly Dictionary<(string File, float MinOpacity), CachedSplatData> _splatCache = new();

	/// <summary>
	/// Tracks which cache keys are currently in use by live components.
	/// When a key's reference count drops to zero, the cache entry is evicted.
	/// </summary>
	private static readonly Dictionary<(string File, float MinOpacity), int> _cacheRefCounts = new();

	private record struct CachedSplatData(
		SceneGaussianSplatObject.SplatPosition[] Positions,
		SceneGaussianSplatObject.SplatData[] Data,
		float[] Opacities,
		int VisibleCount,
		int TotalCount,
		bool HasCovariance
	);

	/// <summary>
	/// The cache key this component instance currently holds a reference to.
	/// Used for ref-counting so we can evict unused entries.
	/// </summary>
	private (string File, float MinOpacity)? _activeCacheKey;

	private string _loadedFile;
	private float _loadedMinOpacity;

	/// <summary>
	/// Pre-load a splat file into the CPU cache without creating any GPU resources or scene objects.
	/// Call during loading screens so that future instantiations of prefabs using this file are instant.
	/// Returns the number of visible splats (after opacity filtering) so you can pass it to ReserveCapacity.
	/// </summary>
	public static int Preload( string filePath, float minOpacity = 1f / 255f )
	{
		if ( string.IsNullOrWhiteSpace( filePath ) )
			return 0;

		var cacheKey = (filePath, minOpacity);
		if ( _splatCache.TryGetValue( cacheKey, out var cached ) )
			return cached.VisibleCount;

		try
		{
			var data = LoadFileIntoCache( filePath, minOpacity );
			return data.VisibleCount;
		}
		catch ( Exception ex )
		{
			Log.Error( $"GaussianSplatRenderer.Preload: Failed to load '{filePath}': {ex.Message}" );
			return 0;
		}
	}

	/// <summary>
	/// Pre-load a splat file and reserve GPU buffer capacity so that instantiation is completely lag-free.
	/// Combines <see cref="Preload"/> with <see cref="SceneGaussianSplatSystem.ReserveCapacity"/>.
	/// </summary>
	public static void PreloadAndReserve( SceneWorld sceneWorld, string filePath, float minOpacity = 1f / 255f )
	{
		PreloadAndReserve( sceneWorld, filePath, 1, minOpacity );
	}

	/// <summary>
	/// Pre-load a splat file and reserve GPU buffer capacity for multiple concurrent instances.
	/// Use this when you spawn N copies of the same splat at runtime (e.g., blood particles).
	/// Reserves <paramref name="maxInstances"/> × splatCount capacity in one allocation.
	/// </summary>
	public static void PreloadAndReserve( SceneWorld sceneWorld, string filePath, int maxInstances, float minOpacity = 1f / 255f )
	{
		int splatCount = Preload( filePath, minOpacity );
		if ( splatCount > 0 && maxInstances > 0 )
		{
			var system = SceneGaussianSplatSystem.GetOrCreate( sceneWorld );
			system.ReserveCapacity( splatCount * maxInstances );
		}
	}

	/// <summary>
	/// Load a file from disk, parse it, and store in the static cache. Returns the cached data.
	/// Shared by both Preload() and the normal LoadFromPly() path.
	/// </summary>
	private static CachedSplatData LoadFileIntoCache( string filePath, float minOpacity )
	{
		var cacheKey = (filePath, minOpacity);

		// Double-check: another call may have raced us
		if ( _splatCache.TryGetValue( cacheKey, out var existing ) )
			return existing;

		using var stream = FileSystem.Mounted.OpenRead( filePath );

		var plyData = filePath.EndsWith( ".sog", StringComparison.OrdinalIgnoreCase )
			? SogReader.Read( stream )
			: PlyReader.Read( stream );

		if ( plyData.VertexCount == 0 )
		{
			var empty = new CachedSplatData( null, null, null, 0, 0, false );
			_splatCache[cacheKey] = empty;
			return empty;
		}

		var scale = ImportScale;
		var hasCovariance = plyData.HasScaleRotation;

		int visibleCount = 0;
		for ( int i = 0; i < plyData.VertexCount; i++ )
		{
			float alpha = plyData.HasOpacity ? plyData.Opacities[i] : 1f;
			if ( alpha >= minOpacity )
				visibleCount++;
		}

		var splatPositions = new SceneGaussianSplatObject.SplatPosition[visibleCount];
		var splatData = new SceneGaussianSplatObject.SplatData[visibleCount];
		var opacities = new float[visibleCount];
		int writeIdx = 0;

		for ( int i = 0; i < plyData.VertexCount; i++ )
		{
			float alpha = plyData.HasOpacity ? plyData.Opacities[i] : 1f;
			if ( alpha < minOpacity )
				continue;

			var p = plyData.Positions[i];
			var c = plyData.HasColors ? plyData.Colors[i] : Color.White;

			splatPositions[writeIdx] = new SceneGaussianSplatObject.SplatPosition
			{
				Position = new Vector3( -p.x * scale, -p.z * scale, -p.y * scale )
			};

			var dataEntry = new SceneGaussianSplatObject.SplatData
			{
				PackedColor = PackColorRGBA8( c.r, c.g, c.b, alpha )
			};

			if ( hasCovariance )
			{
				var (covA, covB) = ComputeSplatCovariance( plyData.Scales[i], plyData.Rotations[i], scale );

				// Normalize covariance values before float16 packing to maximize precision.
				// Float16 has 10-bit mantissa, so values near 1.0 get ~0.001 precision.
				// Without normalization, large covariance values (e.g. big splats after
				// import scale) can exceed float16 range (65504) and get clamped, or small
				// values lose significant digits. Dividing by the max magnitude maps everything
				// to [-1,1] where float16 precision is best.
				float maxCov = MathF.Max(
					MathF.Max( MathF.Abs( covA.x ), MathF.Abs( covA.y ) ),
					MathF.Max( MathF.Abs( covA.z ),
						MathF.Max( MathF.Abs( covB.x ),
							MathF.Max( MathF.Abs( covB.y ), MathF.Abs( covB.z ) ) ) )
				);

				float covScale = maxCov > 0f ? maxCov : 1f;
				float invScale = 1f / covScale;

				dataEntry.CovAB0 = PackHalf2( covA.x * invScale, covA.y * invScale );
				dataEntry.CovAB1 = PackHalf2( covA.z * invScale, covB.x * invScale );
				dataEntry.CovAB2 = PackHalf2( covB.y * invScale, covB.z * invScale );

				splatPositions[writeIdx].CovarianceScale = covScale;
			}

			splatData[writeIdx] = dataEntry;
			opacities[writeIdx] = alpha;
			writeIdx++;
		}

		var cached = new CachedSplatData( splatPositions, splatData, opacities, visibleCount, plyData.VertexCount, hasCovariance );
		_splatCache[cacheKey] = cached;

		Log.Info( $"GaussianSplatRenderer.Preload: Cached {visibleCount} splats from '{filePath}'" );
		return cached;
	}
	private int _splatCount;

	/// <summary>
	/// The number of splats currently loaded.
	/// </summary>
	[Property, Group( "Info" ), ReadOnly]
	public int SplatCount => _splatCount;

	protected override void OnEnabled()
	{
		if ( Application.IsHeadless )
			return;

		_system = SceneGaussianSplatSystem.GetOrCreate( Scene.SceneWorld );

		// Reuse existing scene object if we have one (fast re-enable without buffer re-pack).
		// Only create a new one on first enable or after the GameObject is destroyed.
		if ( _so is null || !_so.IsValid() )
		{
			_so = new SceneGaussianSplatObject( _system );
			_loadedFile = null;
		}

		_so.IsActive = true;
		_so.Transform = WorldTransform;
		_so.Tags.SetFrom( Tags );
	}

	protected override void OnDisabled()
	{
		// Mark inactive instead of destroying — the object stays in the GPU buffer
		// but the cull shader skips its splats. No layout re-pack triggered.
		if ( _so is not null )
			_so.IsActive = false;
	}

	protected override void OnDestroy()
	{
		// Permanent destruction — release the scene object and cache reference.
		ReleaseCacheRef();
		_so?.Destroy();
		_so = null;
		_system = null;
		_loadedFile = null;
		_splatCount = 0;
	}

	/// <summary>
	/// Release this component's reference to its current cache entry.
	/// Evicts the entry when no other component references it.
	/// </summary>
	private void ReleaseCacheRef()
	{
		if ( _activeCacheKey is not { } key )
			return;

		_activeCacheKey = null;

		if ( _cacheRefCounts.TryGetValue( key, out int count ) )
		{
			if ( count <= 1 )
			{
				_cacheRefCounts.Remove( key );
				_splatCache.Remove( key );
			}
			else
			{
				_cacheRefCounts[key] = count - 1;
			}
		}
	}

	/// <summary>
	/// Add a reference to a cache entry for this component.
	/// Releases any previous reference first.
	/// </summary>
	private void AcquireCacheRef( (string File, float MinOpacity) key )
	{
		ReleaseCacheRef();
		_activeCacheKey = key;
		_cacheRefCounts[key] = _cacheRefCounts.GetValueOrDefault( key ) + 1;
	}

	protected override void OnPreRender()
	{
		if ( _so is null || !_so.IsValid() )
			return;

		// Flush expired GPU buffer disposals on the main thread.
		// RenderSceneObject runs on the render thread where Dispose is forbidden,
		// so this is the main-thread hook that ensures deferred disposals complete.
		_system?.FlushPendingDisposals();

		_so.Transform = WorldTransform;
		_so.SplatSize = SplatSize;
		_so.ReceiveLighting = ReceiveLighting;
		_so.ReceiveShadows = ReceiveShadows;
		_so.ShadowTint = ShadowTint;
		_so.Tint = Tint;
		_so.LODMaxDistance = EnableLOD ? LODMaxDistance : 0f;
		_so.EnableChunkedLOD = EnableLOD && EnableChunkedLOD;
		_so.ChunkSize = ChunkSize;
		_so.ChunkHeightExemption = ChunkHeightExemption;
		_so.ChunkDistanceExemption = ChunkDistanceExemption;
		SampleLODCurve();
		UpdateChunkAssignment();
		_so.Tags.SetFrom( Tags );

		// Reload if the file path, import scale, or opacity threshold changed
		if ( _loadedFile != FilePath || _loadedMinOpacity != MinOpacity )
		{
			LoadPlyFile();
		}
	}

	/// <summary>
	/// Sample the LOD curve into 8 evenly spaced points and push to the scene object.
	/// </summary>
	private void SampleLODCurve()
	{
		if ( !EnableLOD )
			return;

		var samples = _so.LODCurveSamples;
		for ( int i = 0; i < 8; i++ )
		{
			float t = i / 7f;
			samples[i] = LODCurve.Evaluate( t ).Clamp( 0f, 1f );
		}
	}

	/// <summary>
	/// Recompute chunk assignments if chunked LOD settings have changed.
	/// </summary>
	private float _lastChunkSize;
	private float _lastChunkHeightExemption;
	private float _lastChunkDistanceExemption;
	private bool _lastChunkedLOD;

	private void UpdateChunkAssignment()
	{
		if ( _so is null ) return;

		bool needsChunks = EnableLOD && EnableChunkedLOD;

		// Check if chunk parameters changed (requires recomputation)
		bool paramsChanged = needsChunks && (
			_lastChunkSize != ChunkSize ||
			_lastChunkHeightExemption != ChunkHeightExemption ||
			_lastChunkDistanceExemption != ChunkDistanceExemption ||
			_lastChunkedLOD != needsChunks
		);

		if ( needsChunks && (paramsChanged || _so.ChunkIds is null) )
		{
			_so.ChunkSize = ChunkSize;
			_so.ChunkHeightExemption = ChunkHeightExemption;
			_so.ChunkDistanceExemption = ChunkDistanceExemption;
			_so.ComputeChunks();

			_lastChunkSize = ChunkSize;
			_lastChunkHeightExemption = ChunkHeightExemption;
			_lastChunkDistanceExemption = ChunkDistanceExemption;
			_lastChunkedLOD = true;
		}
		else if ( !needsChunks && _lastChunkedLOD )
		{
			_so.ClearChunks();
			_lastChunkedLOD = false;
		}
	}

	private void LoadPlyFile()
	{
		_loadedFile = FilePath;
		_loadedMinOpacity = MinOpacity;
		_splatCount = 0;

		if ( string.IsNullOrWhiteSpace( FilePath ) )
		{
			_so.SetSplatData( null, null, 0 );
			GetComponent<GaussianSplatCollider>()?.Clear();
			return;
		}

		try
		{
			LoadFromPly();
		}
		catch ( Exception ex )
		{
			Log.Error( $"GaussianSplatCloud: Failed to load '{FilePath}': {ex.Message}" );
			_so.SetSplatData( null, null, 0 );
		}
	}

	private bool _hasCovariance;

	private void LoadFromPly()
	{
		var cacheKey = (FilePath, MinOpacity);

		// LoadFileIntoCache handles both cache-hit and disk-read paths
		var cached = LoadFileIntoCache( FilePath, MinOpacity );
		AcquireCacheRef( cacheKey );

		_splatCount = cached.VisibleCount;
		_hasCovariance = cached.HasCovariance;

		_so.SetSplatData( cached.Positions, cached.Data, cached.VisibleCount, cached.HasCovariance );

		// Compute spatial chunks for per-chunk LOD if enabled
		UpdateChunkAssignment();

		GetComponent<GaussianSplatCollider>()?.BuildFromPositions( cached.Positions, cached.Opacities );

		Log.Info( $"GaussianSplatCloud: Loaded {cached.VisibleCount} splats for '{FilePath}'" );
	}

	/// <summary>
	/// Pack two float values into a single uint as a pair of float16 values.
	/// </summary>
	private static uint PackHalf2( float a, float b )
	{
		ushort ha = BitConverter.HalfToUInt16Bits( (Half)a );
		ushort hb = BitConverter.HalfToUInt16Bits( (Half)b );
		return ha | ((uint)hb << 16);
	}

	/// <summary>
	/// Pack RGBA color (0..1 float per channel) into a single uint as RGBA8.
	/// </summary>
	private static uint PackColorRGBA8( float r, float g, float b, float a )
	{
		uint rb = (uint)(Math.Clamp( r, 0f, 1f ) * 255f + 0.5f);
		uint gb = (uint)(Math.Clamp( g, 0f, 1f ) * 255f + 0.5f);
		uint bb = (uint)(Math.Clamp( b, 0f, 1f ) * 255f + 0.5f);
		uint ab = (uint)(Math.Clamp( a, 0f, 1f ) * 255f + 0.5f);
		return rb | (gb << 8) | (bb << 16) | (ab << 24);
	}

	/// <summary>
	/// Compute the 3D covariance matrix from Gaussian scale and rotation, transformed to engine coordinates.
	/// Returns the upper triangle: CovA = (Σ00, Σ01, Σ02), CovB = (Σ11, Σ12, Σ22).
	/// </summary>
	private static (Vector3 CovA, Vector3 CovB) ComputeSplatCovariance( Vector3 splatScale, Vector4 rotation, float importScale )
	{
		// Scale values are already exp'd by PlyReader. Apply import scale.
		float sx = splatScale.x * importScale;
		float sy = splatScale.y * importScale;
		float sz = splatScale.z * importScale;

		// Quaternion components: rotation stored as (w, x, y, z) in the Vector4
		float qw = rotation.x, qx = rotation.y, qy = rotation.z, qz = rotation.w;

		// Build rotation matrix from quaternion (R[row, col])
		float r00 = 1 - 2 * (qy * qy + qz * qz);
		float r10 = 2 * (qx * qy + qw * qz);
		float r20 = 2 * (qx * qz - qw * qy);
		float r01 = 2 * (qx * qy - qw * qz);
		float r11 = 1 - 2 * (qx * qx + qz * qz);
		float r21 = 2 * (qy * qz + qw * qx);
		float r02 = 2 * (qx * qz + qw * qy);
		float r12 = 2 * (qy * qz - qw * qx);
		float r22 = 1 - 2 * (qx * qx + qy * qy);

		// M = CoordTransform * R * diag(scale)
		// CoordTransform: [-1,0,0; 0,0,-1; 0,-1,0] (COLMAP X-right,Y-down,Z-fwd → engine X-fwd,Y-left,Z-up)
		float m00 = -r00 * sx, m01 = -r01 * sy, m02 = -r02 * sz;
		float m10 = -r20 * sx, m11 = -r21 * sy, m12 = -r22 * sz;
		float m20 = -r10 * sx, m21 = -r11 * sy, m22 = -r12 * sz;

		// Covariance = M * M^T (symmetric 3x3)
		return (
			new Vector3(
				m00 * m00 + m01 * m01 + m02 * m02,
				m00 * m10 + m01 * m11 + m02 * m12,
				m00 * m20 + m01 * m21 + m02 * m22
			),
			new Vector3(
				m10 * m10 + m11 * m11 + m12 * m12,
				m10 * m20 + m11 * m21 + m12 * m22,
				m20 * m20 + m21 * m21 + m22 * m22
			)
		);
	}
}