Editor/SboxMcpServer.cs
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using Editor;
using Sandbox;

namespace SboxMcpServer;

/// <summary>
/// HTTP/SSE transport layer for the MCP server.
/// Responsible for accepting connections, managing SSE sessions, and routing
/// raw HTTP requests. Tool logic lives in SceneToolHandlers / ConsoleToolHandlers,
/// dispatch in RpcDispatcher, schemas in SceneToolDefinitions / ConsoleToolDefinitions,
/// and scene helpers in SceneQueryHelpers.
/// </summary>
public static class McpServer
{
	[ConVar( "mcp_server_port", ConVarFlags.Saved )]
	public static int Port { get; set; } = 8098;

	[DllImport("winmm.dll", EntryPoint = "PlaySound", CharSet = CharSet.Auto)]
	private static extern bool PlaySystemSound(string pszSound, IntPtr hmod, uint fdwSound);

	// ── GUI events & state ─────────────────────────────────────────────────
	public static event Action OnServerStateChanged;
	public static event Action<string> OnLogMessage;
	public static bool IsRunning   => _listener != null && _listener.IsListening;
	public static int  SessionCount => _sessions.Count;

	/// <summary>Safely dispatches an action to the main thread, swallowing exceptions if the thread is gone (e.g. during cancel).</summary>
	private static async void SafeRunOnMainThread( Action action )
	{
		try
		{
			await GameTask.RunInThreadAsync( action );
		}
		catch { /* Main thread unavailable (cancelled / disposed) – safe to ignore */ }
	}

	private static void LogInfo( string msg )
	{
		SafeRunOnMainThread( () =>
		{
			Log.Info( msg );
			OnLogMessage?.Invoke( msg );
		} );
	}

	private static void LogError( string msg )
	{
		SafeRunOnMainThread( () =>
		{
			Log.Error( msg );
			OnLogMessage?.Invoke( $"[ERROR] {msg}" );
		} );
	}

	private static void LogWarning( string msg )
	{
		SafeRunOnMainThread( () =>
		{
			Log.Warning( msg );
			OnLogMessage?.Invoke( $"[WARNING] {msg}" );
		} );
	}

	private static void NotifyStateChanged()
	{
		SafeRunOnMainThread( () => OnServerStateChanged?.Invoke() );
	}

	// ── Internal state ─────────────────────────────────────────────────────
	private static HttpListener _listener;
	private static CancellationTokenSource _cts;
	private static readonly ConcurrentDictionary<string, McpSession> _sessions = new();

	// Tracks in-flight RPC tasks so StopServer can wait for them to finish
	private static readonly ConcurrentDictionary<Guid, Task> _inflightTasks = new();

	/// <summary>Resolves the library root folder from the source file path (works in S&box's sandbox where Assembly.Location is null).</summary>
	private static string GetLibraryRoot( [CallerFilePath] string sourceFile = "" )
	{
		// sourceFile = .../Libraries/sbox_mcp/Editor/SboxMcpServer.cs  →  go up two levels to get .../Libraries/sbox_mcp
		return Path.GetFullPath( Path.Combine( Path.GetDirectoryName( sourceFile ), ".." ) );
	}

	internal static readonly JsonSerializerOptions JsonOptions = new()
	{
		PropertyNamingPolicy   = JsonNamingPolicy.CamelCase,
		DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
		WriteIndented          = false
	};

	// ── Lifecycle ──────────────────────────────────────────────────────────

	public static void StartServer()
	{
		if ( _listener != null && _listener.IsListening )
		{
			LogInfo( "MCP Server is already running" );
			return;
		}

		try
		{
			_listener = new HttpListener();
			_listener.Prefixes.Add( $"http://localhost:{Port}/" );
			_listener.Prefixes.Add( $"http://127.0.0.1:{Port}/" );
			_listener.Start();

			_cts = new CancellationTokenSource();
			Task.Run( () => ListenLoop( _cts.Token ) );

			LogInfo( $"Started Model Context Protocol Server on port {Port}" );
			NotifyStateChanged();

			// Play the user's startup sound via native Windows API to bypass S&box asset tracking restrictions
			try
			{
				var libRoot = GetLibraryRoot();
				var soundPath = Path.Combine( libRoot, "Sounds", "Startup sound.wav" );
				PlaySystemSound( soundPath, IntPtr.Zero, 0x00020001 );
			}
			catch ( Exception soundEx )
			{
				LogWarning( $"Could not play startup sound: {soundEx.Message}" );
			}
		}
		catch ( Exception ex )
		{
			LogError( $"Failed to start MCP Server: {ex.Message}" );
		}
	}

	public static void StopServer()
	{
		try
		{
			_cts?.Cancel();

			try { _listener?.Stop(); } catch { }
			try { _listener?.Close(); } catch { }
			_listener = null;

			// Don't wait for in-flight tasks — they may be blocked waiting for the
			// main thread (GameTask.RunInThreadAsync) which would deadlock here.
			// Just clear the tracking dictionary; the tasks will see the cancelled token
			// and wind down on their own.
			_inflightTasks.Clear();

			foreach ( var session in _sessions.Values )
			{
				try { session.Tcs.TrySetResult( true ); } catch { }
				try { session.SseResponse?.Close(); } catch { }
			}
			_sessions.Clear();

			LogInfo( "Stopped Model Context Protocol Server" );
			NotifyStateChanged();
		}
		catch ( Exception ex )
		{
			try { LogError( $"Error stopping MCP Server: {ex.Message}" ); } catch { }
		}
	}

	// ── HTTP listen loop ───────────────────────────────────────────────────

	private static async Task ListenLoop( CancellationToken token )
	{
		while ( !token.IsCancellationRequested && _listener != null && _listener.IsListening )
		{
			try
			{
				var context = await _listener.GetContextAsync();
				_ = Task.Run( () => HandleContext( context ), token );
			}
			catch ( Exception ex ) when ( ex is not ObjectDisposedException )
			{
				LogError( $"Error in MCP listen loop: {ex.Message}" );
			}
		}
	}

	private static async Task HandleContext( HttpListenerContext context )
	{
		var req = context.Request;
		var res = context.Response;

		res.Headers.Add( "Access-Control-Allow-Origin",  "*" );
		res.Headers.Add( "Access-Control-Allow-Methods", "GET, POST, OPTIONS" );
		res.Headers.Add( "Access-Control-Allow-Headers", "*" );

		if ( req.HttpMethod == "OPTIONS" ) { res.StatusCode = 200; res.Close(); return; }

		try
		{
			if      ( req.Url.AbsolutePath == "/sse"     && req.HttpMethod == "GET"  ) await HandleSse( req, res );
			else if ( req.Url.AbsolutePath == "/message" && req.HttpMethod == "POST" ) await HandleMessage( req, res );
			else    { res.StatusCode = 404; res.Close(); }
		}
		catch ( Exception ex )
		{
			LogError( $"Error handling MCP request: {ex.Message}" );
			res.StatusCode = 500;
			res.Close();
		}
	}

	// ── SSE connection ─────────────────────────────────────────────────────

	private static async Task HandleSse( HttpListenerRequest req, HttpListenerResponse res )
	{
		var sessionId = Guid.NewGuid().ToString();
		var session   = new McpSession { SessionId = sessionId, SseResponse = res };
		_sessions[sessionId] = session;

		res.ContentType = "text/event-stream";
		res.Headers.Add( "Cache-Control", "no-cache" );
		res.Headers.Add( "Connection",    "keep-alive" );

		try
		{
			var msg    = $"event: endpoint\ndata: /message?sessionId={sessionId}\n\n";
			var buffer = Encoding.UTF8.GetBytes( msg );
			await res.OutputStream.WriteAsync( buffer, 0, buffer.Length );
			await res.OutputStream.FlushAsync();

			LogInfo( $"Created new MCP SSE session: {sessionId}" );
			NotifyStateChanged();

			await session.Tcs.Task; // keep alive until closed
		}
		catch ( Exception ex ) { LogError( $"SSE connection error: {ex.Message}" ); }
		finally
		{
			_sessions.TryRemove( sessionId, out _ );
			try { res.Close(); } catch { }
			LogInfo( $"Closed MCP SSE session: {sessionId}" );
			NotifyStateChanged();
		}
	}

	// ── Message (JSON-RPC) ─────────────────────────────────────────────────

	private static async Task HandleMessage( HttpListenerRequest req, HttpListenerResponse res )
	{
		var sessionId = req.QueryString["sessionId"];
		if ( string.IsNullOrEmpty( sessionId ) || !_sessions.TryGetValue( sessionId, out var session ) )
		{
			res.StatusCode = 400; res.Close(); return;
		}

		using var reader = new StreamReader( req.InputStream, Encoding.UTF8 );
		var body = await reader.ReadToEndAsync();

		try
		{
			using var doc = JsonDocument.Parse( body );
			var root      = doc.RootElement;
			string method = root.TryGetProperty( "method", out var m ) ? m.GetString() : null;
			object id     = null;

			if ( root.TryGetProperty( "id", out var idProp ) )
			{
				if      ( idProp.ValueKind == JsonValueKind.Number ) id = idProp.GetInt32();
				else if ( idProp.ValueKind == JsonValueKind.String ) id = idProp.GetString();
			}

			res.StatusCode = 202;
			res.Close();

			if ( id != null )
			{
				var bodyCopy   = body;
				var idCopy     = id;
				var methodCopy = method;
				var taskId     = Guid.NewGuid();
				var task       = GameTask.RunInThreadAsync( async () =>
				{
					try
					{
						await RpcDispatcher.ProcessRpcRequest(
							session, idCopy, methodCopy, bodyCopy,
							JsonOptions, LogInfo, LogError );
					}
					catch ( Exception ex )
					{
						LogError( $"ProcessRpcRequest unhandled fault: {ex.Message}" );
						var errResponse = new
						{
							jsonrpc = "2.0",
							id      = idCopy,
							result  = (object)null,
							error   = new { code = -32603, message = $"Internal error: {ex.Message}" }
						};
						var errJson = JsonSerializer.Serialize( errResponse, JsonOptions );
						await SendSseEvent( session, "message", errJson );
					}
					finally
					{
						_inflightTasks.TryRemove( taskId, out _ );
					}
				} );
				_inflightTasks[taskId] = task;
			}
			else if ( method == "notifications/initialized" )
			{
				session.Initialized = true;
				LogInfo( $"MCP Session {sessionId} initialized." );
				NotifyStateChanged();
			}
		}
		catch ( Exception ex )
		{
			LogError( $"Error parsing JSON-RPC: {ex.Message}" );
		}
	}

	// ── SSE write ──────────────────────────────────────────────────────────

	/// <summary>Writes a single SSE event to the given session's output stream.</summary>
	internal static async Task SendSseEvent( McpSession session, string eventName, string data )
	{
		if ( session.SseResponse == null || !session.SseResponse.OutputStream.CanWrite ) return;
		try
		{
			var msg    = $"event: {eventName}\ndata: {data}\n\n";
			var buffer = Encoding.UTF8.GetBytes( msg );
			await session.SseResponse.OutputStream.WriteAsync( buffer, 0, buffer.Length );
			await session.SseResponse.OutputStream.FlushAsync();
		}
		catch ( Exception ex )
		{
			LogWarning( $"Failed to send SSE event to session {session.SessionId}: {ex.Message}" );
		}
	}
}