Code/StateMachine.cs
using System;
using System.Collections.Generic;
using System.Linq;
using Sandbox.Diagnostics;
namespace Sandbox.States;
[Title( "State Machine" ), Icon( "smart_toy" ), Category( "State Machines" )]
public sealed class StateMachineComponent : Component
{
/// <summary>
/// How many instant state transitions in a row until we throw an error?
/// </summary>
public const int MaxInstantTransitions = 16;
private readonly Dictionary<int, State> _states = new();
private readonly Dictionary<int, Transition> _transitions = new();
private int _nextId = 0;
/// <summary>
/// All states in this machine.
/// </summary>
public IEnumerable<State> States => _states.Values;
/// <summary>
/// All transitions between states in this machine.
/// </summary>
public IEnumerable<Transition> Transitions => _transitions.Values;
/// <summary>
/// Which state becomes active when the machine starts?
/// </summary>
public State? InitialState { get; set; }
/// <summary>
/// Which state is currently active?
/// </summary>
public State? CurrentState
{
get => CurrentStateId is {} id ? _states!.GetValueOrDefault( id ) : null;
private set => CurrentStateId = value?.Id;
}
/// <summary>
/// How long have we been in the current state?
/// </summary>
public float StateTime { get; private set; }
[Property] private int? CurrentStateId { get; set; }
private bool _firstUpdate = true;
protected override void OnStart()
{
if ( !Network.IsProxy && InitialState is { } initial )
{
CurrentState = initial;
}
}
/// <summary>
/// Send a message to trigger a transition with a matching <see cref="Transition.Message"/>.
/// </summary>
/// <param name="message">Message name.</param>
public void SendMessage( string message )
{
if ( IsProxy )
{
Log.Warning( $"Can't call {nameof(SendMessage)} on a StateMachine owned by another connection." );
return;
}
if ( CurrentState?.GetNextTransition( message ) is { } transition )
{
DoTransition( transition );
}
}
private static void InvokeSafe( Action? action )
{
try
{
action?.Invoke();
}
catch ( Exception ex )
{
Log.Error( ex );
}
}
protected override void OnFixedUpdate()
{
if ( _firstUpdate )
{
_firstUpdate = false;
CurrentState?.Entered();
InvokeSafe( CurrentState?.OnEnterState );
}
if ( !Network.IsProxy )
{
var transitions = 0;
var prevTime = StateTime;
StateTime += Time.Delta;
while ( transitions++ < MaxInstantTransitions && CurrentState?.GetNextTransition( prevTime, StateTime ) is { } transition )
{
DoTransition( transition );
prevTime = 0f;
}
}
InvokeSafe( CurrentState?.OnUpdateState );
}
[Rpc.Broadcast( NetFlags.Reliable | NetFlags.OwnerOnly )]
private void BroadcastTransition( int transitionId )
{
var transition = _transitions!.GetValueOrDefault( transitionId )
?? throw new Exception( $"Unknown transition id: {transitionId}" );
DoTransitionInternal( transition );
}
private void DoTransition( Transition transition )
{
if ( !IsProxy && Network.Active )
{
BroadcastTransition( transition.Id );
}
else
{
DoTransitionInternal( transition );
}
}
private void DoTransitionInternal( Transition transition )
{
var current = CurrentState;
if ( current != null && current != transition.Source )
{
Log.Warning( $"Expected to transition from {transition.Source}, but we're in state {current}!" );
}
InvokeSafe( current?.OnLeaveState );
InvokeSafe( transition.OnTransition );
transition.LastTransitioned = 0f;
CurrentState = current = transition.Target;
StateTime = 0f;
current.Entered();
InvokeSafe( current.OnEnterState );
}
public State AddState()
{
var state = new State( this, _nextId++ );
_states.Add( state.Id, state );
state.IsValid = true;
InitialState ??= state;
return state;
}
internal void RemoveState( State state )
{
Assert.AreEqual( this, state.StateMachine );
Assert.AreEqual( state, _states[state.Id] );
if ( InitialState == state )
{
InitialState = null;
}
if ( CurrentState == state )
{
CurrentState = null;
}
var transitions = Transitions
.Where( x => x.Source == state || x.Target == state )
.ToArray();
foreach ( var transition in transitions )
{
transition.Remove();
}
_states.Remove( state.Id );
state.IsValid = false;
}
internal Transition AddTransition( State source, State target )
{
ArgumentNullException.ThrowIfNull( source, nameof( source ) );
ArgumentNullException.ThrowIfNull( target, nameof( target ) );
Assert.AreEqual( this, source.StateMachine );
Assert.AreEqual( this, target.StateMachine );
var transition = new Transition( _nextId++, source, target );
_transitions.Add( transition.Id, transition );
transition.IsValid = true;
source.InvalidateTransitions();
return transition;
}
internal void RemoveTransition( Transition transition )
{
Assert.AreEqual( this, transition.StateMachine );
Assert.AreEqual( transition, _transitions[transition.Id] );
_transitions.Remove( transition.Id );
transition.IsValid = false;
transition.Source.InvalidateTransitions();
}
internal void Clear()
{
_states.Clear();
_transitions.Clear();
InitialState = null;
StateTime = 0f;
_nextId = 0;
}
[Property]
private Model Serialized
{
get => SerializeInternal();
set => DeserializeInternal( value, true );
}
internal record Model(
IReadOnlyList<State.Model> States,
IReadOnlyList<Transition.Model> Transitions,
int? InitialStateId );
internal Model SerializeInternal()
{
return new Model(
States.Select( x => x.Serialize() ).OrderBy( x => x.Id ).ToArray(),
Transitions.Select( x => x.Serialize() ).OrderBy( x => x.Id ).ToArray(),
InitialState?.Id );
}
internal void DeserializeInternal( Model model, bool replace )
{
if ( replace )
{
Clear();
}
var idOffset = _nextId;
foreach ( var stateModel in model.States )
{
var state = new State( this, stateModel.Id + idOffset );
_states.Add( state.Id, state );
_nextId = Math.Max( _nextId, state.Id + 1 );
state.IsValid = true;
state.Deserialize( stateModel );
}
foreach ( var transitionModel in model.Transitions )
{
var transition = new Transition( transitionModel.Id + idOffset,
_states[transitionModel.SourceId + idOffset],
_states[transitionModel.TargetId + idOffset] );
_transitions.Add( transition.Id, transition );
_nextId = Math.Max( _nextId, transition.Id + 1 );
transition.IsValid = true;
transition.Deserialize( transitionModel );
}
if ( replace )
{
InitialState = model.InitialStateId is { } id ? _states[id] : null;
}
}
public string Serialize( IEnumerable<State> states, IEnumerable<Transition> transitions )
{
var stateSet = states
.Where( x => x.IsValid && x.StateMachine == this )
.ToHashSet();
var transitionSet = transitions
.Where( x => x.IsValid && stateSet.Contains( x.Source ) && stateSet.Contains( x.Target ) )
.ToHashSet();
var model = new Model(
stateSet.Select( x => x.Serialize() ).OrderBy( x => x.Id ).ToArray(),
transitionSet.Select( x => x.Serialize() ).OrderBy( x => x.Id ).ToArray(),
null );
return Json.Serialize( model );
}
public string SerializeAll()
{
return Json.Serialize( Serialized );
}
public void DeserializeAll( string json )
{
Serialized = Json.Deserialize<Model>( json );
}
public (IReadOnlyList<State> States, IReadOnlyList<Transition> Transitions) DeserializeInsert( string json )
{
var model = Json.Deserialize<Model>( json );
var baseId = _nextId;
DeserializeInternal( model, false );
return (
States.Where( x => x.Id >= baseId ).OrderBy( x => x.Id ).ToArray(),
Transitions.Where( x => x.Id >= baseId ).OrderBy( x => x.Id ).ToArray() );
}
}