Rides/TrackRides/TrackSection.cs
using System;
using System.Collections.Immutable;
using System.Text.Json.Serialization;
using HC3.Terrain;

namespace HC3.Rides;

#nullable enable

/// <summary>
/// Events related to <see cref="TrackSection"/>s.
/// </summary>
public interface ITrackEvent : ISceneEvent<ITrackEvent>
{
	/// <summary>
	/// Dispatched when a track section's elements change. Posted to everything in the scene.
	/// </summary>
	void Changed( TrackSection track );
}

/// <summary>
/// A continuous connected piece of track.
/// Needs a <see cref="TrackDefinition"/>, and a <see cref="TrackMesh"/> component on the same object to be visible.
/// </summary>
public sealed class TrackSection : Component, Component.ExecuteInEditor, Component.INetworkSnapshot, ITrackSection
{
	public sealed record Model( TrackNode Start, ImmutableArray<TrackElement> Elements );

	/// <summary>
	/// Track grid resolution. Half of the main grid resolution.
	/// </summary>
	public const int GridSize = GridManager.GridSize / 2;

	/// <summary>
	/// Smallest height increment for track.
	/// </summary>
	public const int HeightStep = GridSize / 2;

	private bool _splineInvalid = true;
	private readonly Spline _spline = new();
	private BBox _localBounds;
	private readonly List<float> _nodeDistances = new();

	private readonly record struct ClosestTrack( int ElementIndex, float DistanceSq );

	private readonly Dictionary<Vector3Int, ClosestTrack> _closestTrack = new();

	private TrackDefinition? _trackDef;

	/// <summary>
	/// Describes the position and orientation of track between each element.
	/// Must have at least two items, and always exactly 1 more than <see cref="_elements"/>.
	/// </summary>
	private readonly List<TrackNode> _nodes = new()
	{
		new TrackNode( new Vector3Int( -1, 0, 0 ), new TrackOrientation( TrackHeading.Forward ) ),
		new TrackNode( new Vector3Int( 1, 0, 0 ), new TrackOrientation( TrackHeading.Forward ) )
	};

	/// <summary>
	/// Describes what track is doing between each node.
	/// Must always have exactly 1 item less than <see cref="_nodes"/>.
	/// </summary>
	private readonly List<TrackElement> _elements = new()
	{
		TrackElementDefinition.Station.GetElement(
			0, TrackBanking.None, TrackBanking.None )
		?? throw new Exception( "Can't find station track element!" )
	};

	/// <inheritdoc/>
	public IReadOnlyList<TrackNode> Nodes => _nodes;

	/// <inheritdoc/>
	public IReadOnlyList<TrackElement> Elements => _elements;

	/// <summary>
	/// Snapshot of this track section to be serialized.
	/// </summary>
	[Property, JsonInclude, Hide]
	public Model Serialized
	{
		get => new( _nodes[0], _elements.ToImmutableArray() );
		set
		{
			_nodes.Clear();
			_elements.Clear();

			var prev = value.Start;

			_nodes.Add( prev );

			foreach ( var element in value.Elements )
			{
				if ( prev.Append( element ) is not { } next )
				{
					Log.Error( "Unable to deserialize track section!" );
					return;
				}

				_elements.Add( element );
				_nodes.Add( next );

				prev = next;
			}

			InvalidateSpline();
		}
	}

	/// <summary>
	/// Is this track a closed loop?
	/// </summary>
	public bool IsCycle => _nodes[0] == _nodes[^1];

	/// <summary>
	/// Describes how to generate a model for this track section.
	/// </summary>
	[Property]
	public TrackDefinition? TrackDefinition
	{
		get => _trackDef;
		set
		{
			_trackDef = value;
			InvalidateSpline();
		}
	}

	/// <summary>
	/// Total length of this track section.
	/// </summary>
	public float Length => Spline.Length;

	/// <summary>
	/// Describes the curve of this track. Will always be up to date when the track changes.
	/// </summary>
	private Spline Spline
	{
		get
		{
			UpdateSpline();
			return _spline;
		}
	}

	/// <summary>
	/// Bounds encompassing all the track, local to this game object.
	/// </summary>
	public BBox LocalBounds
	{
		get
		{
			UpdateSpline();
			return _localBounds;
		}
	}

	/// <summary>
	/// Bounds encompassing all the track, in world space.
	/// </summary>
	public BBox WorldBounds => LocalBounds.Transform( WorldTransform );

	/// <summary>
	/// Force the track geometry to rebuild next time it is accessed.
	/// </summary>
	[Button( "Rebuild" )]
	private void InvalidateSpline()
	{
		_splineInvalid = true;
	}

	protected override void OnValidate()
	{
		InvalidateSpline();
	}

	protected override void OnEnabled()
	{
		InvalidateSpline();

		Transform.OnTransformChanged += OnTransformChanged;
	}

	protected override void OnDisabled()
	{
		Transform.OnTransformChanged -= OnTransformChanged;
	}

	protected override void OnUpdate()
	{
		UpdateSpline();
	}

	protected override void OnDestroy()
	{
		Transform.OnTransformChanged -= OnTransformChanged;

		ITrackEvent.Post( x => x.Changed( this ) );
	}

	private void OnTransformChanged()
	{
		ITrackEvent.Post( x => x.Changed( this ) );
	}

	/// <summary>
	/// The elements of this track have changed, so we need to rebuild the spline,
	/// bounds, collision grid etc.
	/// </summary>
	private void UpdateSpline()
	{
		if ( !_splineInvalid ) return;

		_splineInvalid = false;

		_spline.Clear();
		_nodeDistances.Clear();

		_spline.AddTrack( _nodes[0], _elements, _nodeDistances );
		_localBounds = _spline.Bounds.Grow( 32f );

		UpdateClosestTrack();

		ITrackEvent.Post( x => x.Changed( this ) );
	}

	/// <summary>
	/// Update a 3D grid of distances to the nearest track piece, for quick collision detection.
	/// </summary>
	private void UpdateClosestTrack()
	{
		_closestTrack.Clear();

		var gridSize = GridSize;
		var gridScale = 1f / gridSize;

		// We'll measure distances from both the track level, and 12 units
		// above the track to make sure there's space for the train

		var railHeight = TrackDefinition?.RailHeight ?? 0f;
		var clearanceHeight = railHeight + 12f;

		const float maxDist = 16f;
		const float maxDistSq = maxDist * maxDist;

		for ( var i = 0; i < _elements.Count; ++i )
		{
			var prevDist = _nodeDistances[i];
			var nextDist = _nodeDistances[i + 1];

			for ( var t = prevDist + GridSize * 0.25f; t < nextDist; t += GridSize * 0.25f )
			{
				var sample = SampleAtDistance( t );
				var trackPos = sample.Position + sample.Up * railHeight;
				var clearancePos = sample.Position + sample.Up * clearanceHeight;

				// Limit to grid cells within maxDist of the track

				var minGridPos = (Vector3.Min( trackPos, clearancePos ) - maxDist) * gridScale;
				var maxGridPos = (Vector3.Max( trackPos, clearancePos ) + maxDist) * gridScale;

				var min = minGridPos.FloorToInt();
				var max = maxGridPos.CeilToInt();

				for ( var z = min.z; z <= max.z; ++z )
					for ( var y = min.y; y <= max.y; ++y )
						for ( var x = min.x; x <= max.x; ++x )
						{
							var gridPos = new Vector3Int( x, y, z );
							var cellMin = gridPos * GridSize;
							var cellMax = cellMin + GridSize;

							// Find the closest distance from this cell, either to the track or
							// train clearance position above the track

							var closestTrackPos = Vector3.Clamp( trackPos, cellMin, cellMax );
							var closestClearancePos = Vector3.Clamp( clearancePos, cellMin, cellMax );

							var distanceSq = Math.Min(
								(closestTrackPos - trackPos).LengthSquared,
								(closestClearancePos - clearancePos).LengthSquared );

							if ( distanceSq > maxDistSq ) continue;

							// We might already have a measurement for that cell from another part of
							// the track, let's use the closest

							if ( _closestTrack.TryGetValue( gridPos, out var closest ) && closest.DistanceSq <= distanceSq )
							{
								continue;
							}

							_closestTrack[gridPos] = new ClosestTrack( i, distanceSq );
						}
			}
		}
	}

	protected override void DrawGizmos()
	{
		if ( !Gizmo.IsSelected ) return;
		if ( Spline is not { PointCount: > 1 } ) return;

		Gizmo.Draw.Color = IsCycle ? Color.Green : Color.Red;

		DrawRailGizmo( -16f );
		DrawRailGizmo( 16f );

		Gizmo.Draw.Color = Color.White;

		foreach ( var node in _nodes )
		{
			var pos = node.Position * new Vector3( GridSize, GridSize, HeightStep );
			Gizmo.Draw.Text( $"{node.Position.z * 0.5f:F1}m", new Transform( pos ) );
		}

		foreach ( var (gridPos, closest) in _closestTrack )
		{
			var min = gridPos * GridSize;
			var max = (gridPos + 1) * GridSize;

			Gizmo.Draw.Color = Color.Lerp( Color.Red, Color.Green, MathF.Sqrt( closest.DistanceSq ) / 16f );
			Gizmo.Draw.LineBBox( new BBox( min, max ).Grow( -1f ) );
		}
	}

	private void DrawRailGizmo( float y )
	{
		const float step = 8f;

		var z = TrackDefinition?.RailHeight ?? 0f;
		var prev = GetRailGizmoPosition( SampleAtDistance( 0f ), y, z );

		for ( var x = step; x <= Length; x += step )
		{
			var next = GetRailGizmoPosition( SampleAtDistance( x ), y, z );

			Gizmo.Draw.Line( prev, next );

			prev = next;
		}

		Gizmo.Draw.Line( prev, GetRailGizmoPosition( SampleAtDistance( Length ), y, z ) );
	}

	private Vector3 GetRailGizmoPosition( in Transform transform, float y, float z )
	{
		return transform.Position + transform.Up * z + transform.Right * y;
	}

	/// <summary>
	/// Add an element to one end of this section. This performs any necessary checks,
	/// then broadcasts the actual build as an RPC.
	/// </summary>
	public bool Build( TrackElement element, TrackBuildKind buildKind = TrackBuildKind.Append )
	{
		if ( IsCycle ) return false;

		if ( buildKind == TrackBuildKind.Append )
		{
			var last = _nodes[^1];

			if ( last.Append( element ) is not { } node ) return false;

			BuildCore( node, element, buildKind );
		}
		else
		{
			return false;

			//if ( _nodes[0].Inverse.Append( element.Inverse ) is not { } node ) return false;

			//_nodes.Insert( 0, node.Inverse );
			//_elements.Insert( 0, element.Kind );
		}

		InvalidateSpline();
		return true;
	}

	/// <summary>
	/// RPC for actually building a track element after validating it in <see cref="Build"/>.
	/// </summary>
	[Rpc.Broadcast( NetFlags.Reliable | NetFlags.HostOnly )]
	private void BuildCore( TrackNode node, TrackElement.RpcSafe element, TrackBuildKind buildKind )
	{
		if ( buildKind != TrackBuildKind.Append ) throw new NotImplementedException();

		_nodes.Add( node );
		_elements.Add( element );

		InvalidateSpline();
	}

	/// <summary>
	/// Remove the last element from one end of this section. This performs any necessary checks,
	/// then broadcasts the actual demolition as an RPC.
	/// </summary>
	public bool Demolish( TrackBuildKind buildKind )
	{
		// TODO: remove whole section if this is the last element

		if ( _elements.Count <= 1 ) return false;

		DemolishCore( buildKind );
		return true;
	}

	/// <summary>
	/// RPC for actually demolishing a track element after validating it in <see cref="Demolish"/>.
	/// </summary>
	[Rpc.Broadcast( NetFlags.Reliable | NetFlags.HostOnly )]
	private void DemolishCore( TrackBuildKind buildKind )
	{
		if ( buildKind != TrackBuildKind.Append ) throw new NotImplementedException();

		_nodes.RemoveAt( _nodes.Count - 1 );
		_elements.RemoveAt( _elements.Count - 1 );

		InvalidateSpline();
	}

	/// <inheritdoc/>
	public float GetDistanceAtNode( int nodeIndex )
	{
		UpdateSpline();
		return _nodeDistances[nodeIndex];
	}

	/// <summary>
	/// Get the track / train car transform a given <paramref name="distance"/> along the track. 
	/// </summary>
	public Transform SampleAtDistance( float distance ) =>
		this.SampleAtDistance( TrackDefinition, Spline, distance );

	bool ITrackSection.HasObstruction( Vector3Int trackGridPos, IReadOnlySet<TrackNode>? ignoreNodes )
	{
		if ( !_closestTrack.TryGetValue( trackGridPos, out var closest ) ) return false;

		if ( ignoreNodes?.Contains( _nodes[closest.ElementIndex + 0] ) is true ) return false;
		if ( ignoreNodes?.Contains( _nodes[closest.ElementIndex + 1] ) is true ) return false;

		return true;
	}

	/// <summary>
	/// Replaces the contents of this track with a new list of <paramref name="elements"/> from a given <paramref name="start"/> node.
	/// </summary>
	public bool SetElements( TrackNode start, IEnumerable<TrackElement> elements )
	{
		_nodes.Clear();
		_elements.Clear();

		_nodes.Add( start );

		foreach ( var element in elements )
		{
			if ( _nodes[^1].Append( element ) is not { } next )
			{
				return false;
			}

			_nodes.Add( next );
			_elements.Add( element );
		}

		return true;
	}

	/// <summary>
	/// Convert a local track grid position to a world position.
	/// </summary>
	public Vector3 TrackGridToWorld( Vector3 trackGridPos )
	{
		var localPos = trackGridPos * new Vector3( GridSize, GridSize, HeightStep );
		return WorldTransform.PointToWorld( localPos );
	}

	/// <summary>
	/// Convert a world position to a local track grid position.
	/// </summary>
	public Vector3 WorldToTrackGrid( Vector3 worldPos )
	{
		var localPos = WorldTransform.PointToLocal( worldPos );
		return localPos / new Vector3( GridSize, GridSize, HeightStep );
	}

	protected override void OnFixedUpdate()
	{
		base.OnFixedUpdate();

		if ( !Networking.IsHost ) return;

		SimulateTrains();
	}

	/// <summary>
	/// Temporary list to let us iterate trains on child objects during <see cref="SimulateTrains"/>.
	/// </summary>
	private readonly List<Train> _trains = new();

	/// <summary>
	/// Perform train physics simulation, including gravity and collisions.
	/// </summary>
	private void SimulateTrains()
	{
		_trains.Clear();
		_trains.AddRange( GetComponentsInChildren<Train>().Where( x => x.Mass > 0f ) );

		if ( _trains.Count == 0 ) return;

		_trains.Sort( ( a, b ) => a.TrackPosition.CompareTo( b.TrackPosition ) );

		var prev = _trains[^1];

		foreach ( var next in _trains )
		{
			prev.NextTrain = next == prev ? null : next;
			prev = next;
		}

		var dt = Time.Delta * (DayNightController.Instance?.CurrentTimeScale ?? 1f);

		foreach ( var train in _trains )
		{
			train.PrePhysicsUpdate( dt );
		}

		SimulateCollisions( _trains );

		foreach ( var train in _trains )
		{
			train.PostPhysicsUpdate( dt );
		}
	}

	/// <summary>
	/// Test and respond to collisions between neighbouring trains.
	/// Assumes <paramref name="trains"/> is sorted by track position.
	/// </summary>
	private void SimulateCollisions( IReadOnlyList<Train> trains )
	{
		if ( trains.Count == 0 ) return;

		var prev = IsCycle ? trains[^1] : null;

		ITrackSection section = this;

		// Only need to test for collisions between neighbouring trains in the list

		foreach ( var next in trains )
		{
			if ( prev is null || prev == next )
			{
				prev = next;
				continue;
			}

			// We normalize to deal with cyclic track

			var delta = section.NormalizeDelta( next.TrackPosition - next.Length - prev.TrackPosition );

			if ( delta <= 0f && delta > -prev.Length )
			{
				// Make sure trains aren't overlapping

				prev.TrackPosition += delta * 0.5f;
				next.TrackPosition -= delta * 0.5f;

				// Conservation of momentum

				var reaction = prev.Momentum - next.Momentum;

				prev.Velocity -= reaction * 0.5f / prev.Mass;
				next.Velocity += reaction * 0.5f / next.Mass;
			}

			// Update NextTrainDistance, used in station logic

			prev.NextTrainDistance = Math.Max( delta, 0f );

			prev = next;
		}

		// Don't need to handle derailing if track is a cycle

		if ( IsCycle ) return;

		// TODO: derail when going backwards off the start of the track too?

		var first = trains[0];

		if ( first.Velocity < 0f && first.TrackPosition < first.Length )
		{
			first.TrackPosition = first.Length;
			first.Velocity *= -0.5f;
		}

		// Derail when going off the end of the track

		var last = trains[^1];

		while ( last.Velocity > 0f && last.TrackPosition > Length && last.Consist.Count > 0 )
		{
			var head = last.Consist.First();

			head.GameObject.SetParent( null );

			last.InvalidateConsist();
			last.TrackPosition -= head.Length;

			head.AddComponent<TemporaryEffect>().DestroyAfterSeconds = 5f;

			var rb = head.AddComponent<Rigidbody>();

			rb.Velocity = head.WorldRotation.Forward * last.Velocity;

			if ( last.Consist.Count == 0 )
			{
				last.DestroyGameObject();
			}
		}
	}

	void INetworkSnapshot.WriteSnapshot( ref ByteStream writer )
	{
		// Snapshot only needs to contain the start node, and list of elements

		writer.Write( Elements.Count );
		writer.Write( Nodes[0] );

		foreach ( var element in Elements )
		{
			writer.Write( element.Definition.ResourcePath );
			writer.Write( element.Flags );
			writer.Write( element.EndBanking );
		}
	}

	void INetworkSnapshot.ReadSnapshot( ref ByteStream reader )
	{
		var elementCount = reader.Read<int>();
		var start = reader.Read<TrackNode>();
		var prev = start;

		var elements = new TrackElement[elementCount];

		for ( var i = 0; i < elementCount; ++i )
		{
			var def = ResourceLibrary.Get<TrackElementDefinition>( reader.Read<string>() )
				?? throw new Exception( "Unknown track element definition." );

			var flags = reader.Read<TrackElementFlags>();
			var banking = reader.Read<TrackBanking>();
			var element = def.GetElement( flags, prev.Orientation.Banking, banking )
				?? throw new Exception( "Invalid track element." );

			elements[i] = element;
			prev = prev.Append( element )
				?? throw new Exception( "Elements are incompatible." );
		}

		SetElements( start, elements );
	}
}