Services/WebSocketService.cs
using System;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.Configuration;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.Options;
using System.Collections.Concurrent;
using System.Linq;
using System.Collections.Generic;
using Microsoft.Extensions.DependencyInjection;
using SandboxModelContextProtocol.Server.Services.Interfaces;
using SandboxModelContextProtocol.Server.Services.Models;
using System.IO;
namespace SandboxModelContextProtocol.Server.Services;
public class WebSocketService( ILogger<WebSocketService> logger, IConfiguration configuration, IOptions<Models.WebSocketOptions> options, IServiceProvider serviceProvider ) : IHostedService, IWebSocketService
{
private readonly ILogger<WebSocketService> _logger = logger;
private readonly IConfiguration _configuration = configuration;
private readonly Models.WebSocketOptions _options = options.Value;
private readonly IServiceProvider _serviceProvider = serviceProvider;
private WebApplication? _app;
private readonly ConcurrentDictionary<WebSocketConnection, string> _connections = new();
/// <summary>
/// Start the WebSocket server
/// </summary>
/// <param name="cancellationToken">The cancellation token</param>
/// <returns>A task</returns>
public async Task StartAsync( CancellationToken cancellationToken )
{
var builder = WebApplication.CreateBuilder();
// Configure the application with the same configuration
builder.Configuration.AddConfiguration( _configuration );
// Configure the URL from configuration
builder.WebHost.UseUrls( _options.Url );
_app = builder.Build();
_app.UseWebSockets();
_app.Map( _options.Path, async context =>
{
if ( context.WebSockets.IsWebSocketRequest )
{
using var webSocket = await context.WebSockets.AcceptWebSocketAsync();
await HandleWebSocketConnection( webSocket, cancellationToken );
}
else
{
context.Response.StatusCode = 400;
}
} );
await _app.StartAsync( cancellationToken );
_logger.LogInformation( "WebSocket Server started on {WebSocketPath}", _options.Path );
}
/// <summary>
/// Get all connected WebSocket connections
/// </summary>
/// <returns>An enumerable of WebSocket connections</returns>
public IEnumerable<WebSocketConnection> GetWebSocketConnections()
{
return _connections.Keys;
}
/// <summary>
/// Get a connection by its ID
/// </summary>
/// <param name="connectionId">The ID of the connection</param>
/// <returns>The connection</returns>
public WebSocketConnection? GetWebSocketConnection( string connectionId )
{
return _connections.FirstOrDefault( kvp => kvp.Value == connectionId ).Key;
}
/// <summary>
/// Send a message to all connected clients
/// </summary>
/// <param name="message">The message to send</param>
/// <returns>A task</returns>
public async Task SendToAll( string message )
{
foreach ( var connection in _connections.Keys )
{
await connection.SendAsync( message );
}
}
/// <summary>
/// Stop the WebSocket server
/// </summary>
/// <param name="cancellationToken">The cancellation token</param>
/// <returns>A task</returns>
public async Task StopAsync( CancellationToken cancellationToken )
{
if ( _app != null )
{
await _app.StopAsync( cancellationToken );
_logger.LogInformation( "WebSocket Server stopped" );
}
}
private void RegisterWebSocketConnection( WebSocketConnection connection )
{
var connectionId = Guid.NewGuid().ToString();
_connections[connection] = connectionId;
_logger.LogInformation( "WebSocket connection registered with ID: {ConnectionId}", connectionId );
}
private void UnregisterWebSocketConnection( WebSocketConnection connection )
{
_connections.TryRemove( connection, out var connectionId );
_logger.LogInformation( "WebSocket connection unregistered: {ConnectionId}", connectionId );
}
public async Task HandleWebSocketConnection( WebSocket webSocket, CancellationToken cancellationToken )
{
var connection = new WebSocketConnection( webSocket, _logger );
RegisterWebSocketConnection( connection );
var buffer = new byte[1024 * 4];
try
{
_logger.LogInformation( "WebSocket connection established" );
while ( webSocket.State == WebSocketState.Open && !cancellationToken.IsCancellationRequested )
{
using var ms = new MemoryStream();
WebSocketReceiveResult result;
do
{
result = await webSocket.ReceiveAsync( new ArraySegment<byte>( buffer ), cancellationToken );
ms.Write( buffer, 0, result.Count );
} while ( !result.EndOfMessage && !cancellationToken.IsCancellationRequested );
if ( result.MessageType == WebSocketMessageType.Text )
{
var message = Encoding.UTF8.GetString( ms.ToArray() );
_logger.LogInformation( "Received from s&box: {Message}", message );
// Handle responses from s&box
var commandService = _serviceProvider.GetRequiredService<IToolService>();
commandService.HandleResponse( message );
}
else if ( result.MessageType == WebSocketMessageType.Close )
{
_logger.LogInformation( "WebSocket close message received from client" );
await webSocket.CloseAsync( WebSocketCloseStatus.NormalClosure, "Connection closed by client", cancellationToken );
break;
}
}
}
catch ( OperationCanceledException )
{
_logger.LogInformation( "WebSocket connection cancelled" );
}
catch ( WebSocketException ex )
{
_logger.LogError( ex, "WebSocket error occurred" );
}
catch ( Exception ex )
{
_logger.LogError( ex, "Unexpected error in WebSocket connection" );
}
finally
{
UnregisterWebSocketConnection( connection );
_logger.LogInformation( "WebSocket connection closed and unregistered" );
}
}
}