RpcX.cs
using Sandbox;
using Sandbox.Utility;
using System;
using System.Collections.Generic;
namespace ExtendedNetworking;
[CodeGenerator(CodeGeneratorFlags.WrapMethod | CodeGeneratorFlags.Instance | CodeGeneratorFlags.Static, "ExtendedNetworking.RpcXInvoker.OnRpc")]
[AttributeUsage(AttributeTargets.Method)]
public class RpcXAttribute : Attribute
{
public RpcSender AllowedSenders { get; set; }
public RpcTarget Targets { get; set; }
public RpcXAttribute(RpcTarget target = RpcTarget.Broadcast, RpcSender allowedSenders = RpcSender.Anyone)
{
Targets = target;
AllowedSenders = allowedSenders;
}
}
public static class RpcX
{
internal static Connection.Filter? Filter;
[CodeGenerator(CodeGeneratorFlags.WrapMethod | CodeGeneratorFlags.Instance | CodeGeneratorFlags.Static, "ExtendedNetworking.RpcXInvoker.OnRpc")]
[AttributeUsage(AttributeTargets.Method)]
public class HostAttribute : RpcXAttribute
{
public HostAttribute(RpcSender allowedSender = RpcSender.Anyone) : base(RpcTarget.Host, allowedSender)
{
}
}
[CodeGenerator(CodeGeneratorFlags.WrapMethod | CodeGeneratorFlags.Instance | CodeGeneratorFlags.Static, "ExtendedNetworking.RpcXInvoker.OnRpc")]
[AttributeUsage(AttributeTargets.Method)]
public class OwnerAttribute : RpcXAttribute
{
public OwnerAttribute(RpcSender allowedSender = RpcSender.Anyone) : base(RpcTarget.Owner, allowedSender)
{
}
}
[CodeGenerator(CodeGeneratorFlags.WrapMethod | CodeGeneratorFlags.Instance | CodeGeneratorFlags.Static, "ExtendedNetworking.RpcXInvoker.OnRpc")]
[AttributeUsage(AttributeTargets.Method)]
public class BroadcastAttribute : RpcXAttribute
{
public BroadcastAttribute(RpcSender allowedSender = RpcSender.Anyone) : base(RpcTarget.Broadcast, allowedSender)
{
}
}
[CodeGenerator(CodeGeneratorFlags.WrapMethod | CodeGeneratorFlags.Instance | CodeGeneratorFlags.Static, "ExtendedNetworking.RpcXInvoker.OnRpc")]
[AttributeUsage(AttributeTargets.Method)]
public class AdminAttribute : RpcXAttribute
{
public AdminAttribute(RpcSender allowedSender = RpcSender.Anyone) : base(RpcTarget.Admin, allowedSender)
{
}
}
public static IDisposable FilterInclude([Description("Only send the RPC to these connections.")] IEnumerable<Connection> connections)
{
if(Filter.HasValue)
{
throw new InvalidOperationException("An RPC filter is already active");
}
Filter = new Connection.Filter(Connection.Filter.FilterType.Include, connections);
return DisposeAction.Create(delegate
{
Filter = null;
});
}
public static IDisposable FilterInclude([Description("Only send the RPC to connections that meet the criteria of the predicate.")] Predicate<Connection> predicate)
{
if(Filter.HasValue)
{
throw new InvalidOperationException("An RPC filter is already active");
}
Filter = new Connection.Filter(Connection.Filter.FilterType.Include, predicate);
return DisposeAction.Create(delegate
{
Filter = null;
});
}
public static IDisposable FilterInclude([Description("Only send the RPC to this connection.")] Connection connection)
{
if(Filter.HasValue)
{
throw new InvalidOperationException("An RPC filter is already active");
}
Filter = new Connection.Filter(Connection.Filter.FilterType.Include, (Connection c) => c == connection);
return DisposeAction.Create(delegate
{
Filter = null;
});
}
public static IDisposable FilterExclude([Description("Exclude connections that don't meet the criteria of the predicate from receiving the RPC.")] Predicate<Connection> predicate)
{
if(Filter.HasValue)
{
throw new InvalidOperationException("An RPC filter is already active");
}
Filter = new Connection.Filter(Connection.Filter.FilterType.Exclude, predicate);
return DisposeAction.Create(delegate
{
Filter = null;
});
}
public static IDisposable FilterExclude([Description("Exclude these connections from receiving the RPC.")] IEnumerable<Connection> connections)
{
if(Filter.HasValue)
{
throw new InvalidOperationException("An RPC filter is already active");
}
Filter = new Connection.Filter(Connection.Filter.FilterType.Exclude, connections);
return DisposeAction.Create(delegate
{
Filter = null;
});
}
public static IDisposable FilterExclude([Description("Exclude this connection from receiving the RPC.")] Connection connection)
{
if(Filter.HasValue)
{
throw new InvalidOperationException("An RPC filter is already active");
}
Filter = new Connection.Filter(Connection.Filter.FilterType.Exclude, (Connection c) => c == connection);
return DisposeAction.Create(delegate
{
Filter = null;
});
}
}
public static class RpcXInvoker
{
#if DEBUG
#else
public const bool KickInvalidCallers = true;
#endif
private static bool _invokingFromRemote;
public static void OnRpc(WrappedMethod m, params object[]? args)
{
if(_invokingFromRemote)
{
m.Resume();
return;
}
CallRpc(m, args);
}
private static void CallRpc(WrappedMethod m, object[]? args)
{
ThreadSafe.AssertIsMainThread();
var attribute = m.GetAttribute<RpcXAttribute>();
var rpcTarget = attribute.Targets;
var owner = GetOwner(m.Object);
var obj = m.Object;
var methodIdentity = m.MethodIdentity;
if(!IsValidRpcCaller(attribute.AllowedSenders, Connection.Local, owner))
return;
var filter = RpcX.Filter;
IDisposable? filterDisposable = null;
try
{
filterDisposable = Rpc.FilterInclude(connection => IsValidRpcTarget(rpcTarget, connection, owner) && (!filter.HasValue || filter.Value.IsRecipient(connection)));
}
catch
{
Log.Error($"Use {nameof(RpcX)} for filtering instead of {nameof(Rpc)}");
filterDisposable?.Dispose();
return;
}
try
{
OnRpcCalled(methodIdentity, obj, args);
}
finally
{
filterDisposable?.Dispose();
}
}
[Rpc.Broadcast]
private static void OnRpcCalled(int methodIdentity, object? obj, object[]? args)
{
var member = Game.TypeLibrary.GetMemberByIdent(methodIdentity);
if(member is not MethodDescription method)
{
Log.Warning($"{Rpc.Caller} calling method with identity {methodIdentity}, but method not found.");
KickCaller();
return;
}
if(obj is null != method.IsStatic)
{
Log.Warning($"{Rpc.Caller} calling method {method.Name}, but had invalid target.");
KickCaller();
return;
}
var parameters = method.Parameters;
bool argsValid;
if(parameters.Length == 0)
{
argsValid = args is null || args.Length == 0;
}
else
{
argsValid = args is not null && args.Length == parameters.Length;
if(argsValid)
{
for(int i = 0; i < parameters.Length; ++i)
{
var parameterType = parameters[i].ParameterType;
var arg = args![i];
if(arg is null)
{
if(parameterType.IsValueType)
{
argsValid = false;
break;
}
}
else
{
if(arg.GetType() != parameterType)
{
argsValid = false;
break;
}
}
}
}
}
if(!argsValid)
{
Log.Warning($"{Rpc.Caller} calling method {method.Name}, but had invalid arguments.");
KickCaller();
return;
}
if(method.GetCustomAttribute<RpcXAttribute>() is not { } attribute)
{
Log.Warning($"{Rpc.Caller} calling method {method.Name}, but rpc attribute found.");
KickCaller();
return;
}
var owner = GetOwner(obj);
if(!IsValidRpcCaller(attribute.AllowedSenders, Rpc.Caller, owner))
{
Log.Warning($"{Rpc.Caller} calling method {method.Name}, but had no permission.");
KickCaller();
return;
}
if(!IsValidRpcTarget(attribute.Targets, Connection.Local, owner))
{
Log.Warning($"{Rpc.Caller} calling method {method.Name}, but target is invalid.");
KickCaller();
return;
}
_invokingFromRemote = true;
try
{
method.Invoke(obj, args);
}
finally
{
_invokingFromRemote = false;
}
}
private static Connection? GetOwner(object? target)
{
if(target is not Component component)
return null;
GameObject.NetworkAccessor? network = component.Network;
if(network is not null && network.Active)
return network.Owner ?? Connection.Host;
return Connection.Local;
}
private static bool IsValidRpcTarget(RpcTarget rpcTarget, Connection target, Connection? owner)
{
return ((rpcTarget & RpcTarget.Broadcast) != 0) ||
((rpcTarget & RpcTarget.Host) != 0 && target.IsHost) ||
((rpcTarget & RpcTarget.Owner) != 0 && (owner is null || target == owner)) ||
((rpcTarget & RpcTarget.Admin) != 0 && AdminSystem.IsAdmin(target));
}
private static bool IsValidRpcCaller(RpcSender allowedSender, Connection caller, Connection? owner)
{
return ((allowedSender & RpcSender.Anyone) != 0) ||
((allowedSender & RpcSender.Host) != 0 && caller.IsHost) ||
((allowedSender & RpcSender.Owner) != 0 && (owner is null || caller == owner)) ||
((allowedSender & RpcSender.Admin) != 0 && AdminSystem.IsAdmin(caller));
}
private static void KickCaller()
{
#pragma warning disable CS0162 // Unreachable code detected
if(KickInvalidCallers)
Rpc.Caller.Kick("Invalid rpc.");
#pragma warning restore CS0162 // Unreachable code detected
}
}