Editor/Mcp/ToolRegistry.cs
using Sandbox;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;

namespace SboxMcp.Mcp;

/// <summary>
/// Discovers <see cref="McpToolAttribute"/>-decorated static methods via reflection,
/// builds JSON Schemas for their parameters, and dispatches incoming MCP
/// <c>tools/call</c> requests by name.
/// </summary>
public static class ToolRegistry
{
	private static readonly Dictionary<string, ToolEntry> _tools = new();
	private static readonly List<ToolDescriptor> _descriptors = new();
	private static bool _initialised;

	private sealed class ToolEntry
	{
		public ToolDescriptor Descriptor { get; init; }
		public MethodInfo Method { get; init; }
		public ParameterInfo[] Parameters { get; init; }
	}

	public static IReadOnlyList<ToolDescriptor> List()
	{
		EnsureInitialised();
		return _descriptors;
	}

	public static async Task<ToolCallResult> InvokeAsync( string name, JsonElement? arguments )
	{
		EnsureInitialised();

		if ( !_tools.TryGetValue( name, out var entry ) )
		{
			return new ToolCallResult
			{
				IsError = true,
				Content = new List<ToolContent> { ToolContent.FromText( $"Unknown tool: {name}" ) },
			};
		}

		try
		{
			var args = BuildArguments( entry.Parameters, arguments );
			var result = entry.Method.Invoke( null, args );
			var resolved = await UnwrapAsync( result );
			return new ToolCallResult
			{
				Content = new List<ToolContent> { ToolContent.FromText( FormatResult( resolved ) ) },
			};
		}
		catch ( Exception ex )
		{
			var inner = ex is TargetInvocationException tie && tie.InnerException is not null ? tie.InnerException : ex;
			Log.Warning( $"[MCP] Tool '{name}' failed: {inner.Message}" );
			return new ToolCallResult
			{
				IsError = true,
				Content = new List<ToolContent> { ToolContent.FromText( $"Error: {inner.Message}" ) },
			};
		}
	}

	private static void EnsureInitialised()
	{
		if ( _initialised ) return;
		_initialised = true;

		// Look only at classes tagged [McpToolGroup] — avoids scanning every type in the
		// s&box editor's combined assembly, and makes discovery deterministic.
		var thisAssembly = typeof( ToolRegistry ).Assembly;
		var groupTypes = thisAssembly.GetTypes()
			.Where( t => t.GetCustomAttribute<McpToolGroupAttribute>() is not null );
		var methods = groupTypes
			.SelectMany( t => t.GetMethods( BindingFlags.Public | BindingFlags.Static ) )
			.Where( m => m.GetCustomAttribute<McpToolAttribute>() is not null );

		foreach ( var m in methods )
		{
			var attr = m.GetCustomAttribute<McpToolAttribute>();
			var description = attr.Description ?? GetAttributeDescription( m ) ?? "";
			var parameters = m.GetParameters();
			var schema = BuildInputSchema( parameters );

			var descriptor = new ToolDescriptor
			{
				Name = attr.Name,
				Description = description,
				InputSchema = schema,
			};

			_tools[attr.Name] = new ToolEntry
			{
				Descriptor = descriptor,
				Method = m,
				Parameters = parameters,
			};
			_descriptors.Add( descriptor );
		}

		_descriptors.Sort( ( a, b ) => string.Compare( a.Name, b.Name, StringComparison.Ordinal ) );
		Log.Info( $"[MCP] Tool registry initialised: {_descriptors.Count} tools." );
	}

	private static object BuildInputSchema( ParameterInfo[] parameters )
	{
		var properties = new Dictionary<string, object>();
		var required = new List<string>();

		foreach ( var p in parameters )
		{
			if ( p.ParameterType == typeof( CancellationToken ) ) continue;

			var description = p.GetCustomAttribute<ParamDescriptionAttribute>()?.Description
				?? GetAttributeDescription( p )
				?? "";

			var prop = new Dictionary<string, object>
			{
				["type"] = JsonTypeFor( p.ParameterType ),
			};
			if ( !string.IsNullOrEmpty( description ) ) prop["description"] = description;

			properties[p.Name] = prop;

			if ( !p.IsOptional && !IsNullableReferenceLike( p.ParameterType ) )
				required.Add( p.Name );
		}

		var schema = new Dictionary<string, object>
		{
			["type"] = "object",
			["properties"] = properties,
		};
		if ( required.Count > 0 ) schema["required"] = required;
		return schema;
	}

	/// <summary>
	/// Find the Description property on any attribute applied to <paramref name="provider"/>.
	/// Avoids the System.ComponentModel.DescriptionAttribute / Sandbox.DescriptionAttribute
	/// namespace tug-of-war by reading whichever was actually applied via reflection.
	/// </summary>
	private static string GetAttributeDescription( ICustomAttributeProvider provider )
	{
		foreach ( var attr in provider.GetCustomAttributes( inherit: false ) )
		{
			var prop = attr.GetType().GetProperty( "Description" );
			if ( prop?.GetValue( attr ) is string s && !string.IsNullOrEmpty( s ) ) return s;
		}
		return null;
	}

	private static string JsonTypeFor( Type t )
	{
		if ( t == typeof( string ) ) return "string";
		if ( t == typeof( bool ) || t == typeof( bool? ) ) return "boolean";
		if ( t == typeof( int ) || t == typeof( int? )
			|| t == typeof( long ) || t == typeof( long? )
			|| t == typeof( float ) || t == typeof( float? )
			|| t == typeof( double ) || t == typeof( double? ) ) return "number";
		if ( t.IsArray || (t.IsGenericType && typeof( IEnumerable<object> ).IsAssignableFrom( t )) ) return "array";
		return "string"; // safe fallback
	}

	private static bool IsNullableReferenceLike( Type t )
	{
		// Treat reference types as optional when the call site doesn't set [Required].
		// We do this conservatively to avoid forcing strings to be present everywhere.
		return !t.IsValueType || Nullable.GetUnderlyingType( t ) is not null;
	}

	private static object[] BuildArguments( ParameterInfo[] parameters, JsonElement? arguments )
	{
		var result = new object[parameters.Length];
		var hasArgs = arguments.HasValue && arguments.Value.ValueKind == JsonValueKind.Object;

		for ( var i = 0; i < parameters.Length; i++ )
		{
			var p = parameters[i];

			if ( p.ParameterType == typeof( CancellationToken ) )
			{
				result[i] = CancellationToken.None;
				continue;
			}

			if ( hasArgs && arguments.Value.TryGetProperty( p.Name, out var prop )
				&& prop.ValueKind != JsonValueKind.Null && prop.ValueKind != JsonValueKind.Undefined )
			{
				result[i] = ConvertJson( prop, p.ParameterType );
			}
			else if ( p.HasDefaultValue )
			{
				result[i] = p.DefaultValue;
			}
			else if ( IsNullableReferenceLike( p.ParameterType ) )
			{
				result[i] = null;
			}
			else
			{
				throw new ArgumentException( $"Missing required argument: {p.Name}" );
			}
		}

		return result;
	}

	private static object ConvertJson( JsonElement el, Type targetType )
	{
		var nullable = Nullable.GetUnderlyingType( targetType );
		var t = nullable ?? targetType;

		if ( t == typeof( string ) ) return el.ValueKind == JsonValueKind.String ? el.GetString() : el.ToString();
		if ( t == typeof( bool ) ) return el.GetBoolean();
		if ( t == typeof( int ) ) return el.GetInt32();
		if ( t == typeof( long ) ) return el.GetInt64();
		if ( t == typeof( double ) ) return el.GetDouble();
		if ( t == typeof( float ) ) return (float)el.GetDouble();
		if ( t == typeof( JsonElement ) ) return el.Clone();

		// Last resort: deserialise via JsonSerializer
		return JsonSerializer.Deserialize( el.GetRawText(), t );
	}

	private static async Task<object> UnwrapAsync( object value )
	{
		if ( value is Task task )
		{
			await task.ConfigureAwait( false );
			var resultProp = task.GetType().GetProperty( "Result" );
			return resultProp?.GetValue( task );
		}
		return value;
	}

	private static string FormatResult( object value )
	{
		if ( value is null ) return "(no output)";
		if ( value is string s ) return s;
		try
		{
			return JsonSerializer.Serialize( value, new JsonSerializerOptions { WriteIndented = true } );
		}
		catch
		{
			return value.ToString() ?? "(no output)";
		}
	}
}