Park/GridManager.cs
using HC3.Terrain;
using System;
using System.Collections;
using System.Collections.Immutable;

namespace HC3;

/// <summary>
/// Describes which <see cref="GridObject"/>s overlap a cell in the world.
/// </summary>
public sealed class GridCell : IEnumerable<GridObject>
{
	// TODO: can support objects at various heights

	private readonly HashSet<GridObject> _objects = new();

	/// <summary>
	/// Cached blocked direction masks per level, recomputed when blockers are added/removed/changed.
	/// </summary>
	private readonly Dictionary<int, BlockerCache> _blockerCaches = new();

	/// <summary>
	/// Precomputed navigation edges per level. Maps (level) → list of (destination Vector3Int) reachable from this cell.
	/// Rebuilt when neighbors change. Used by A* to avoid recomputing walkability during pathfinding.
	/// </summary>
	private readonly Dictionary<int, List<Vector3Int>> _navEdges = new();

	public Vector2Int Position { get; }

	public int X => Position.x;
	public int Y => Position.y;

	public GridObjectFlags Flags { get; private set; }

	/// <summary>
	/// Does this cell contain something that blocks construction?
	/// </summary>
	public bool IsConstructionBlocked => (Flags & GridObjectFlags.BlocksConstruction) != 0;

	/// <summary>
	/// Does this cell contain nothing at all?
	/// </summary>
	internal bool IsEmpty => _objects.Count == 0;

	internal GridCell( Vector2Int position )
	{
		Position = position;
	}

	internal bool AddInternal( GridObject gridObject, bool forceFlagUpdate )
	{
		if ( !_objects.Add( gridObject ) && !forceFlagUpdate )
		{
			return false;
		}

		UpdateFlags();
		RebuildBlockerCache();
		return true;
	}

	internal bool RemoveInternal( GridObject gridObject )
	{
		if ( !_objects.Remove( gridObject ) )
		{
			return false;
		}

		UpdateFlags();
		RebuildBlockerCache();
		return true;
	}

	private void UpdateFlags()
	{
		Flags = 0;

		foreach ( var obj in _objects )
		{
			Flags |= obj.GridFlags;
		}
	}

	public IEnumerable<T> GetComponents<T>() => _objects
		.Where( x => x.IsValid() )
		.SelectMany( x => x.GetComponents<T>() );

	/// <summary>
	/// Check if any object in this cell has a component of type T (non-allocating early-out)
	/// </summary>
	public bool HasComponent<T>()
	{
		foreach ( var obj in _objects )
		{
			if ( !obj.IsValid() ) continue;
			if ( obj.GetComponent<T>() is not null ) return true;
		}
		return false;
	}

	/// <summary>
	/// Check if any object in this cell has a component of type T matching predicate (non-allocating)
	/// </summary>
	public bool HasComponent<T>( Func<T, bool> predicate )
	{
		foreach ( var obj in _objects )
		{
			if ( !obj.IsValid() ) continue;
			foreach ( var comp in obj.GetComponents<T>() )
			{
				if ( predicate( comp ) ) return true;
			}
		}
		return false;
	}

	public IEnumerable<T> GetComponents<T>( int level ) => _objects
		.Where( x => x.IsValid() && x.Level == level )
		.SelectMany( x => x.GetComponents<T>() );
	public IEnumerable<T> GetComponents<T>( int level, int heightSpan ) => _objects
		.Where( x => x.IsValid() && (x.Level == level || Math.Abs( x.Level - level ) <= heightSpan) )
		.SelectMany( x => x.GetComponents<T>() );

	public T GetComponent<T>() => GetComponents<T>().FirstOrDefault();
	public T GetComponent<T>( int level ) => GetComponents<T>( level ).FirstOrDefault();

	public IEnumerator<GridObject> GetEnumerator() => _objects.GetEnumerator();
	IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

	public bool IsWalkable( int level ) => GetComponents<GridObject>( level ).Any( x => x.IsWalkable );

	/// <summary>
	/// Gets the cached blocked directions mask for a given level.
	/// Returns PathMask.None if no blockers exist at this level.
	/// </summary>
	public PathMask GetBlockedDirections( int level )
	{
		return _blockerCaches.TryGetValue( level, out var cache ) ? cache.BlockedDirections : 0;
	}

	/// <summary>
	/// Gets the cached blocked navigation mask for a given level.
	/// Returns PathMask.None if no blockers exist at this level.
	/// </summary>
	public PathMask GetBlockedNavigation( int level )
	{
		return _blockerCaches.TryGetValue( level, out var cache ) ? cache.BlockedNavigation : 0;
	}

	/// <summary>
	/// Rebuilds the blocker cache for all levels by iterating all IPathBlocker components in this cell.
	/// Called when objects are added/removed or blocker masks change.
	/// </summary>
	internal void RebuildBlockerCache()
	{
		_blockerCaches.Clear();

		foreach ( var blocker in GetComponents<IPathBlocker>() )
		{
			if ( blocker is not Component comp || !comp.IsValid() ) continue;

			var gridObj = comp.GetComponent<GridObject>();
			if ( gridObj is null || !gridObj.IsValid() ) continue;

			int level = gridObj.Level;

			if ( !_blockerCaches.TryGetValue( level, out var cache ) )
			{
				cache = new BlockerCache();
			}

			// Accumulate rotated masks from all blockers at this level
			foreach ( var edge in GridManager.AllEdges )
			{
				var mask = edge.ToPathMask();
				var rotatedMask = blocker.GetRotatedMask( mask );

				if ( (blocker.BlockedDirections & rotatedMask) != 0 )
					cache.BlockedDirections |= mask;

				if ( (blocker.BlockedNavigation & rotatedMask) != 0 )
					cache.BlockedNavigation |= mask;
			}

			_blockerCaches[level] = cache;
		}
	}

	/// <summary>
	/// Gets the precomputed navigable neighbors for a given level in this cell.
	/// Returns empty if no edges are cached for this level.
	/// </summary>
	public IReadOnlyList<Vector3Int> GetNavEdges( int level )
	{
		return _navEdges.TryGetValue( level, out var edges ) ? edges : (IReadOnlyList<Vector3Int>)Array.Empty<Vector3Int>();
	}

	/// <summary>
	/// Rebuilds the navigation edge cache for all levels in this cell.
	/// Called after neighbors change. Requires GridNavigation to be available.
	/// </summary>
	internal void RebuildNavEdges()
	{
		_navEdges.Clear();

		if ( GridNavigation.Instance is null ) return;

		// Find all levels that have walkable objects in this cell
		foreach ( var obj in _objects )
		{
			if ( !obj.IsValid() || !obj.IsWalkable ) continue;

			int level = obj.Level;
			if ( _navEdges.ContainsKey( level ) ) continue;

			var cellPos = new Vector3Int( Position.x, Position.y, level );
			var edges = new List<Vector3Int>();

			foreach ( var edge in GridManager.AllEdges )
			{
				if ( GridNavigation.Instance.IsWalkable( cellPos, edge, out var destPos ) &&
					GridNavigation.Instance.IsWalkable( destPos, edge.GetOpposite(), out _ ) )
				{
					edges.Add( destPos );
				}
			}

			_navEdges[level] = edges;
		}
	}

	public override int GetHashCode()
	{
		return Position.GetHashCode();
	}

	private static readonly List<GridObject> _objectsCopy = new();

	internal void DispatchNeighborsChanged()
	{
		_objectsCopy.Clear();
		_objectsCopy.AddRange( _objects );

		foreach ( var obj in _objectsCopy )
		{
			IGridObjectEvent.PostToGameObject( obj.GameObject, x => x.NeighborsChanged( this ) );
		}

		RebuildNavEdges();
	}

	public enum ConnectionFlags
	{
		Default,
		RequirePlaced = 1 << 0, // skip unplaced objects
	}

	public IPathConnector GetConnection( Vector3Int grid3d, TileEdge edge, ConnectionFlags flags = ConnectionFlags.Default )
	{
		var connectors = GetComponents<IPathConnector>();

		if ( flags.HasFlag( ConnectionFlags.RequirePlaced ) )
		{
			// shit but but i just dont want previews - they're not real connections
			connectors = connectors.Where( x => (x as IPlacementObject)?.IsPlaced ?? !(x as Path)?.GhostPreview ?? true );
		}

		return connectors.FirstOrDefault( x => x.CanConnectTile( grid3d, edge ) );
	}
}

/// <summary>
/// Cached blocker masks for a single level within a cell.
/// </summary>
public struct BlockerCache
{
	/// <summary>
	/// Accumulated blocked path connection directions from all IPathBlocker components at this level.
	/// </summary>
	public PathMask BlockedDirections;

	/// <summary>
	/// Accumulated blocked navigation directions from all IPathBlocker components at this level.
	/// </summary>
	public PathMask BlockedNavigation;
}

public sealed partial class GridManager : Component
{
	public const int GridSize = 64;
	public const int HeightStep = 32;
	public static Vector3 CentreOffset = new( GridSize / 2f, GridSize / 2f, 0 );

	public static IReadOnlyList<TileEdge> AllEdges { get; } = new[]
	{
		TileEdge.Up, TileEdge.Right, TileEdge.Down, TileEdge.Left
	}.ToImmutableArray();

	public static GridManager Instance { get; private set; }

	[ConVar( "hc3.debug.grid", ConVarFlags.Cheat )]
	public static bool DebugDraw { get; set; } = false;

	[RequireComponent]
	public ParkTerrain Terrain { get; private set; }

	private readonly Dictionary<Vector2Int, GridCell> _cells = new();
	private readonly HashSet<Vector2Int> _changedNeighbors = new();

	private readonly List<Vector2Int> _changedNeighborCopy = new();

	/// <summary>
	/// Maintained set of all walkable GridObjects. Avoids Scene.GetAll&lt;GridObject&gt;() scans.
	/// </summary>
	private readonly HashSet<GridObject> _walkableObjects = new();

	/// <summary>
	/// Global navigation version. Incremented whenever the walkable graph changes
	/// (paths added/removed, decorations changed, regions recalculated).
	/// Used by agents to cheaply detect when their route may be invalid.
	/// </summary>
	public uint NavigationVersion { get; private set; }

	/// <summary>
	/// Whether region recalculation is pending. Debounced to avoid recalculating every frame.
	/// </summary>
	private bool _regionsDirty;
	private int _regionDirtyFrameCount;
	private const int RegionDebounceFrames = 3;

	protected override void OnAwake()
	{
		Instance = this;
	}

	protected override void OnUpdate()
	{
		if ( _changedNeighbors.Count > 0 )
		{
			_changedNeighborCopy.Clear();
			_changedNeighborCopy.AddRange( _changedNeighbors );

			_changedNeighbors.Clear();

			foreach ( var gridPos in _changedNeighborCopy )
			{
				GetCell( gridPos )?.DispatchNeighborsChanged();
			}

			_regionsDirty = true;
			_regionDirtyFrameCount = 0;
			NavigationVersion++;
		}

		// Debounce region recalculation: wait a few frames of stability before recalculating
		if ( _regionsDirty )
		{
			_regionDirtyFrameCount++;
			if ( _regionDirtyFrameCount >= RegionDebounceFrames )
			{
				DirtyRegions();
				_regionsDirty = false;
			}
		}

		if ( !DebugDraw ) return;

		using ( Gizmo.Scope( "gridmanager_debug" ) )
		{
			Gizmo.Transform = new();
			Gizmo.Draw.Color = Color.Red.WithAlpha( 0.1f );
			Gizmo.Draw.LineThickness = 2;

			foreach ( var cell in _cells.Keys )
			{
				var height = GetWorldHeight( cell );
				var min = GridToWorldPosition( cell ).WithZ( height - 64f );
				var max = min + new Vector3( GridSize, GridSize, 128f );
				Gizmo.Draw.LineBBox( new BBox( min, max ) );
			}
		}
	}

	public IEnumerable<GridCell> OccupiedCells => _cells.Values;

	/// <summary>
	/// Gets the world height of a cell using the terrain system
	/// </summary>
	/// <returns></returns>
	public float GetWorldHeight( Vector2Int pos )
	{
		var tile = Terrain[pos];
		var height = tile.MaxHeight * Terrain.TileHeight;

		return height;
	}

#nullable enable

	/// <summary>
	/// If the given grid position is occupied, returns information about the occupation.
	/// </summary>
	public GridCell? GetCell( Vector2Int gridpos ) =>
		_cells.GetValueOrDefault( gridpos );

	internal static ImmutableArray<Vector2Int> Directions { get; } =
		new[] { Vector2Int.Left, Vector2Int.Right, Vector2Int.Up, Vector2Int.Down }.ToImmutableArray();

	public IEnumerable<GridCell> GetNeighbors( Vector2Int gridPos ) =>
		Directions.Select( dir => GetCell( gridPos + dir ) )
			.OfType<GridCell>();

#nullable disable

	public bool IsConstructionBlocked( Vector2Int gridPos ) =>
		GetCell( gridPos )?.IsConstructionBlocked is true;

	public bool IsConstructionBlocked( Vector2Int min, Vector2Int max )
	{
		for ( int x = min.x; x <= max.x; x++ )
		{
			for ( int y = min.y; y <= max.y; y++ )
			{
				if ( IsConstructionBlocked( new Vector2Int( x, y ) ) )
					return true;
			}
		}

		return false;
	}

	public bool IsConstructionBlocked( Vector3Int pos, int height )
	{
		if ( GetCell( new Vector2Int( pos ) ) is not { } cell )
			return false;

		foreach ( var obj in cell )
		{
			if ( !obj.BlocksConstruction ) continue;

			if ( obj.Level >= pos.z + height ) continue;
			if ( obj.Level + obj.Height <= pos.z ) continue;

			return true;
		}

		return false;
	}

	public void Add( Vector2Int gridPos )
	{
		if ( !_cells.TryGetValue( gridPos, out var cell ) )
		{
			_cells[gridPos] = cell = new GridCell( gridPos );
		}

		foreach ( var direction in Directions )
		{
			_changedNeighbors.Add( gridPos + direction );
		}
	}

	internal void AddInternal( Vector2Int gridPos, GridObject obj, bool forceFlagUpdate )
	{
		if ( !_cells.TryGetValue( gridPos, out var cell ) )
		{
			_cells[gridPos] = cell = new GridCell( gridPos );
		}

		if ( !cell.AddInternal( obj, forceFlagUpdate ) ) return;

		foreach ( var direction in Directions )
		{
			_changedNeighbors.Add( gridPos + direction );
		}
	}

	internal void RemoveInternal( Vector2Int gridPos, GridObject obj )
	{
		if ( !_cells.TryGetValue( gridPos, out var cell ) ) return;

		if ( !cell.RemoveInternal( obj ) ) return;

		if ( cell.IsEmpty )
		{
			_cells.Remove( gridPos );
		}

		foreach ( var direction in Directions )
		{
			_changedNeighbors.Add( gridPos + direction );
		}
	}

	public static Vector3 AlignWorldToGrid( Vector3 worldPosition )
	{
		return new Vector3(
			(int)MathF.Floor( worldPosition.x / GridSize ) * GridSize,
			(int)MathF.Floor( worldPosition.y / GridSize ) * GridSize,
			(int)MathF.Floor( worldPosition.z / HeightStep ) * HeightStep );
	}

	public static Vector2Int WorldToGridPosition( Vector3 worldPosition )
	{
		var gridSize = GridSize;
		return new Vector2Int( (int)MathF.Floor( worldPosition.x / gridSize ), (int)MathF.Floor( worldPosition.y / gridSize ) );
	}

	public static Vector3Int WorldToGridPosition3D( Vector3 worldPosition )
	{
		var gridSize = GridSize;
		return new Vector3Int(
			(int)MathF.Floor( worldPosition.x / gridSize ),
			(int)MathF.Floor( worldPosition.y / gridSize ),
			FloorWorldHeightToGrid( worldPosition.z ) );
	}

	public static Vector3 GridToWorldPosition( Vector2Int gridPosition )
	{
		var gridSize = GridSize;
		var pos = new Vector3( gridPosition.x * gridSize, gridPosition.y * gridSize, 0f );

		return pos;
	}
	public static Vector3 GridToWorldPosition( Vector3Int gridPosition )
	{
		var gridSize = GridSize;
		var pos = new Vector3( gridPosition.x * gridSize, gridPosition.y * gridSize, gridPosition.z * HeightStep );

		return pos;
	}

	public static int FloorWorldHeightToGrid( float z ) =>
		(int)MathF.Floor( z / HeightStep );

	public static int RoundWorldHeightToGrid( float z ) =>
		(int)MathF.Round( z / HeightStep );

	/// <summary>
	/// Register a GridObject as walkable. Called when IsWalkable becomes true.
	/// </summary>
	internal void RegisterWalkable( GridObject obj )
	{
		_walkableObjects.Add( obj );
	}

	/// <summary>
	/// Unregister a GridObject as walkable. Called when IsWalkable becomes false or object is destroyed.
	/// </summary>
	internal void UnregisterWalkable( GridObject obj )
	{
		_walkableObjects.Remove( obj );
	}

	/// <summary>
	/// All currently walkable grid objects. Maintained incrementally to avoid Scene.GetAll scans.
	/// </summary>
	public IReadOnlyCollection<GridObject> WalkableObjects => _walkableObjects;
}