Editor/ShaderGraphPlus/Nodes/Utility/CustomFunction.cs
using System.Text;
using static ShaderGraphPlus.ShaderGraphPlusGlobals;
namespace ShaderGraphPlus.Nodes;
[AttributeUsage( AttributeTargets.Property )]
internal sealed class HLSLAssetPathAttribute : Attribute
{
internal HLSLAssetPathAttribute()
{
}
}
public enum CustomCodeNodeMode
{
/// <summary>
/// Generate a function and place it within the generated shader
/// in whichever stage the node is used in
/// </summary>
Generate,
/// <summary>
/// Retrive the function from an external hlsl include file
/// </summary>
File,
}
public enum CompatableShaderStage
{
/// <summary>
/// Both shader stages
/// </summary>
Both,
/// <summary>
/// Just the Vetex shader stage
/// </summary>
Vertex,
/// <summary>
/// Just the Pixel shader stage
/// </summary>
Pixel,
}
/// <summary>
/// Inject your own custom HLSL code into ShaderGraphPlus
/// </summary>
[Title( "Custom Function" ), Category( "Utility" ), Icon( "code" )]
public sealed class CustomFunctionNode : ShaderNodePlus, IErroringNode, IWarningNode, BaseNodePlus.IInitializeNode
{
[JsonIgnore, Hide, Browsable( false )]
public override Color NodeTitleColor => ShaderGraphPlusTheme.NodeHeaderColors.FunctionNode;
[JsonIgnore, Hide, Browsable( false )]
public override bool AutoSize => true;
[Hide]
public override string Title => string.IsNullOrEmpty( Name ) ?
$"{DisplayInfo.For( this ).Name}" :
$"{DisplayInfo.For( this ).Name} ({Name})";
[Hide, JsonIgnore]
public override bool CanPreview => false;
public string Name { get; set; } = "CustomFunction0";
public CustomCodeNodeMode Mode { get; set; } = CustomCodeNodeMode.Generate;
[TextArea]
[HideIf( nameof( Mode ), CustomCodeNodeMode.File )]
public string Code { get; set; } = $"\nOutFloat3_0 = float3( 1.0f, 0.0f, 1.0f );";
[HideIf( nameof( Mode ), CustomCodeNodeMode.Generate )]
[HLSLAssetPath]
public string Source { get; set; }
[Hide]
public string CodeComment { get; set; } = "";
[Title( "Inputs" )]
public List<CustomFunctionNodePort> FunctionInputs { get; set; } = new List<CustomFunctionNodePort>()
{
{
new CustomFunctionNodePort
{
Name = "InFloat3_0",
TypeName = "Vector3"
}
},
};
[Hide]
private List<IPlugIn> InternalInputs = new();
[Hide]
public override IEnumerable<IPlugIn> Inputs => InternalInputs;
[Title( "Outputs" )]
public List<CustomFunctionNodePort> FunctionOutputs { get; set; } = new List<CustomFunctionNodePort>()
{
{
new CustomFunctionNodePort
{
Name = "OutFloat3_0",
TypeName = "Vector3"
}
},
};
[Hide]
private List<IPlugOut> InternalOutputs = new();
[Hide]
public override IEnumerable<IPlugOut> Outputs => InternalOutputs;
[Hide, JsonIgnore]
int _lastHashCodeInputs = 0;
[Hide, JsonIgnore]
int _lastHashCodeOutputs = 0;
public override void OnFrame()
{
var hashCodeInput = 0;
var hashCodeOutput = 0;
foreach ( var input in FunctionInputs )
{
hashCodeInput += input.GetHashCode();
}
foreach ( var output in FunctionOutputs )
{
hashCodeOutput += output.GetHashCode();
}
if ( hashCodeInput != _lastHashCodeInputs )
{
_lastHashCodeInputs = hashCodeInput;
CreateInputs();
Update();
}
if ( hashCodeOutput != _lastHashCodeOutputs )
{
_lastHashCodeOutputs = hashCodeOutput;
CreateOutputs();
Update();
}
}
public void InitializeNode()
{
OnNodeCreated();
}
private void OnNodeCreated()
{
CreateInputs();
CreateOutputs();
Update();
}
public void CreateInputs()
{
var plugs = new List<IPlugIn>();
if ( FunctionInputs == null )
{
InternalInputs = new();
}
else
{
foreach ( var input in FunctionInputs.OrderBy( x => x.Priority ) )
{
if ( !input.IsValid ) continue;
var info = new PlugInfo()
{
Id = input.Id,
Name = input.Name,
Type = input.Type,
DisplayInfo = new()
{
Name = input.Name,
Fullname = input.Type.FullName
}
};
var plug = new BasePlugIn( this, info, info.Type );
var oldPlug = InternalInputs.FirstOrDefault( x => x is BasePlugIn plugIn && plugIn.Info.Id == info.Id ) as BasePlugIn;
if ( oldPlug is not null )
{
oldPlug.Info.Name = info.Name;
oldPlug.Info.Type = info.Type;
oldPlug.Info.DisplayInfo = info.DisplayInfo;
if ( oldPlug.Type != plug.Type )
{
plugs.Add( plug );
}
else
{
plugs.Add( oldPlug );
}
}
else
{
plugs.Add( plug );
}
}
InternalInputs = plugs;
}
}
public void CreateOutputs()
{
var outPlugs = new List<IPlugOut>();
if ( FunctionOutputs == null )
{
InternalOutputs = new();
}
else
{
foreach ( var output in FunctionOutputs.OrderBy( x => x.Priority ) )
{
if ( !output.IsValid ) continue;
var info = new PlugInfo()
{
Id = output.Id,
Name = output.Name,
Type = output.Type,
DisplayInfo = new()
{
Name = output.Name,
Fullname = output.Type.FullName
}
};
var plug = new BasePlugOut( this, info, info.Type );
var oldPlug = InternalOutputs.FirstOrDefault( x => x is BasePlugOut plugOut && plugOut.Info.Id == info.Id ) as BasePlugOut;
if ( oldPlug is not null )
{
oldPlug.Info.Name = info.Name;
oldPlug.Info.Type = info.Type;
oldPlug.Info.DisplayInfo = info.DisplayInfo;
if ( oldPlug.Type != plug.Type )
{
outPlugs.Add( plug );
}
else
{
outPlugs.Add( oldPlug );
}
}
else
{
outPlugs.Add( plug );
}
}
InternalOutputs = outPlugs;
}
}
public NodeResult GetResult( GraphCompiler compiler )
{
if ( Mode is CustomCodeNodeMode.File )
{
var functionInputs = GetFunctionInputs( compiler, out var inputErrors );
if ( inputErrors.Any() )
{
return default;
}
compiler.RegisterInclude( Source );
var outputResults = GetFunctionOutputs( out var outputErrors );
if ( outputErrors.Any() )
{
return default;
}
/*
// Update shader function with any input and output changes.
if ( !string.IsNullOrWhiteSpace( Source ) && Editor.FileSystem.Content.FileExists( $"shaders/{Source}" ) )
{
string newFunctionHeader = $" void {Name}({ConstructArguments( ExpressionInputs, false )}{(ExpressionInputs.Any() ? "," : "")}{ConstructArguments( ExpressionOutputs, true )})";
var shaderPath = Editor.FileSystem.Content.GetFullPath( $"shaders/{Source}" );
var lines = File.ReadAllLines( shaderPath );
for ( int i = 0; i < lines.Length; i++ )
{
if ( lines[i].TrimStart().StartsWith( $"void {Name}" ) )
{
if ( lines[i] != newFunctionHeader )
{
SGPLog.Info( $"Updating function" );
lines[i] = "";
lines[i] = newFunctionHeader;
}
else
{
SGPLog.Info( $"No need to update function" );
}
break;
}
}
File.WriteAllLines( shaderPath, lines );
}
*/
string funcCall = compiler.RegisterCustomCode( Identifier, Name, functionInputs, outputResults, Mode );
return new NodeResult( ResultType.VoidFunction, funcCall, true );
}
else if ( Mode is CustomCodeNodeMode.Generate )
{
var sb = new StringBuilder();
sb.AppendLine( $"void {Name}({ConstructArguments( FunctionInputs, false )}{(FunctionInputs.Any() ? "," : "")}{ConstructArguments( FunctionOutputs, true )})" );
sb.AppendLine( "{" );
sb.AppendLine( GraphCompiler.IndentString( Code, 1 ) );
sb.AppendLine( "}" );
var funcName = compiler.RegisterHLSLFunction( sb.ToString(), Name, true );
var funcInputs = GetFunctionInputs( compiler, out var inputErrors );
if ( inputErrors.Any() )
{
return NodeResult.Error( inputErrors.ToArray() );
}
var funcOutputs = GetFunctionOutputs( out var outputErrors );
if ( outputErrors.Any() )
{
return NodeResult.Error( outputErrors.ToArray() );
}
var funcCall = compiler.RegisterCustomCode( Identifier, funcName, funcInputs, funcOutputs, Mode );
return new NodeResult( ResultType.VoidFunction, funcCall, true );
}
return NodeResult.Error( $"Unknown Mode '{Mode}'" );
}
/// <summary>
/// Fetches the results from the user defined node inputs.
/// </summary>
private string GetFunctionInputs( GraphCompiler compiler, out List<string> errors )
{
StringBuilder sb = new StringBuilder();
int index = 0;
errors = new List<string>();
foreach ( IPlugIn input in Inputs )
{
var result = "";
if ( input.ConnectedOutput is null )
{
switch ( input.Type )
{
case Type t when t == typeof( bool ):
result = $"false";
break;
case Type t when t == typeof( int ):
result = $"0";
break;
case Type t when t == typeof( float ):
result = $"0.0f";
break;
case Type t when t == typeof( Vector2 ):
result = $"float2( 0.0f, 0.0f )";
break;
case Type t when t == typeof( Vector3 ):
result = $"float3( 0.0f, 0.0f, 0.0f )";
break;
case Type t when t == typeof( Vector4 ) || t == typeof( Color ):
result = $"float4( 0.0f, 0.0f, 0.0f, 0.0f )";
break;
case Type t when t == typeof( Float2x2 ):
result = $"float2x2( 0.0f, 0.0f, 0.0f, 0.0f )";
break;
case Type t when t == typeof( Float3x3 ):
result = $"float3x3( 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f )";
break;
case Type t when t == typeof( Float4x4 ):
result = $"float4x4( 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f )";
break;
default:
errors.Add( $"Required Input \"{input.DisplayInfo.Name}\" of type \"{ShaderGraphPlusTheme.NodeHandleConfigs[input.Type].Name}\" is missing" );
continue;
}
}
else
{
NodeInput nodeInput = new NodeInput { Identifier = input.ConnectedOutput.Node.Identifier, Output = input.ConnectedOutput.Identifier };
var nodeResult = compiler.Result( nodeInput );
result = nodeResult.ToString();
if ( (nodeResult.ResultType == ResultType.Texture2D || nodeResult.ResultType == ResultType.TextureCube) && compiler.TryGetPreviewImage( result, out var imagePath ) )
{
var texturePath = compiler.CompileTexture( new TextureInput
{
DefaultTexture = imagePath,
Type = nodeResult.ResultType == ResultType.Texture2D ? TextureType.Tex2D : TextureType.TexCube,
ImageFormat = TextureFormat.DXT5,
SrgbRead = true,
DefaultColor = Color.White,
} );
var attributeName = result.TrimStart( "g_t" ).ToString();
compiler.SetAttribute( attributeName, Texture.Load( texturePath ) );
}
}
if ( index < Inputs.Count() - 1 )
{
sb.Append( $"{result}, " );
}
else
{
sb.Append( $"{result}" );
}
index++;
}
if ( errors.Any() )
return null;
return sb.ToString();
}
internal string ConstructArguments( List<CustomFunctionNodePort> ports, bool isOutputs )
{
var sb = new StringBuilder();
for ( int index = 0; index < ports.Count; index++ )
{
var argument = ports[index];
var keyword = isOutputs ? "out" : "";
if ( index == ports.Count - 1 )
{
sb.Append( $"{(isOutputs ? " " : "")}{keyword} {argument.HLSLDataType} {argument.Name}{(isOutputs ? " " : "")}" );
}
else
{
sb.Append( $"{(isOutputs ? " " : "")}{keyword} {argument.HLSLDataType} {argument.Name},{(isOutputs ? " " : "")}" );
}
}
return sb.ToString();
}
private Dictionary<string, string> GetFunctionOutputs( out List<string> errors )
{
Dictionary<string, string> result = new();
errors = new List<string>();
foreach ( CustomFunctionNodePort output in FunctionOutputs )
{
if ( string.IsNullOrWhiteSpace( output.Name ) || result.ContainsKey( output.Name ) )
{
continue;
}
result.Add( output.Name, output.HLSLDataType );
}
return result;
}
public List<string> GetErrors()
{
ClearError();
var errors = new List<string>();
var inputTypeInputError = $"Input{{0}}must have a Type.";
var missingInputNameError = $"Input of type '{{0}}' must have a name.";
var duplicateInputNameError = $"Duplicate input '{{0}}'.";
var conflictingNameInputError = $"Input '{{0}}' already exists as an output.";
var noOutputsOutputError = $"must have atleast 1 output.";
var outputTypeOutputError = $"Output{{0}}must have a Type.";
var missingOutputNameError = $"Output of type '{{0}}' must have a name.";
var duplicateOutputNameError = $"Duplicate output '{{0}}'.";
var conflictingNameOutputError = $"Output '{{0}}' already exists as an input.";
var missingFunctionNameError = $"Must have a function name.";
var missingSourceFilePathError = $"Source path is empty.";
var missingSourceFileError = $"Include file `shaders/{Source}` does not exist.";
if ( !FunctionOutputs.Any() )
{
errors.Add( noOutputsOutputError );
}
foreach ( var input in FunctionInputs )
{
var inputTypeInputErrorFm = string.Format( inputTypeInputError, $"{(!string.IsNullOrWhiteSpace( input.Name ) ? $" '{input.Name}' " : " ")}" );
var missingInputNameErrorFm = string.Format( missingInputNameError, input.TypeName );
var duplicateInputNameErrorFm = string.Format( duplicateInputNameError, input.Name );
var conflictingNameInputErrorFm = string.Format( conflictingNameInputError, input.Name );
if ( input.TypeName == null && !errors.Contains( inputTypeInputErrorFm ) )
{
errors.Add( inputTypeInputErrorFm );
}
else if ( string.IsNullOrWhiteSpace( input.Name ) && !errors.Contains( missingInputNameErrorFm ) )
{
errors.Add( missingInputNameErrorFm );
}
if ( !string.IsNullOrWhiteSpace( input.Name ) )
{
if ( FunctionInputs.Any( x => x != input && string.Equals( x.Name, input.Name, StringComparison.CurrentCultureIgnoreCase ) ) )
{
errors.Add( duplicateInputNameErrorFm );
}
if ( FunctionOutputs.Any( x => string.Equals( x.Name, input.Name, StringComparison.CurrentCultureIgnoreCase ) ) && !errors.Contains( conflictingNameInputErrorFm ) )
{
errors.Add( conflictingNameInputErrorFm );
}
}
}
foreach ( var output in FunctionOutputs )
{
var outputTypeOutputErrorFm = string.Format( outputTypeOutputError, $"{(!string.IsNullOrWhiteSpace( output.Name ) ? $" '{output.Name}' " : " ")}" );
var missingOutputNameErrorFm = string.Format( missingOutputNameError, output.TypeName );
var duplicateOutputNameErrorFm = string.Format( duplicateOutputNameError, output.Name );
var conflictingNameOutputErrorFm = string.Format( conflictingNameOutputError, output.Name );
if ( output.TypeName == null && !errors.Contains( outputTypeOutputErrorFm ) )
{
errors.Add( outputTypeOutputErrorFm );
}
else if ( string.IsNullOrWhiteSpace( output.Name ) && !errors.Contains( missingOutputNameErrorFm ) )
{
errors.Add( missingOutputNameErrorFm );
}
if ( !string.IsNullOrWhiteSpace( output.Name ) )
{
if ( FunctionOutputs.Any( x => x != output && string.Equals( x.Name, output.Name, StringComparison.InvariantCultureIgnoreCase ) ) )
{
errors.Add( duplicateOutputNameErrorFm );
}
if ( FunctionInputs.Any( x => string.Equals( x.Name, output.Name, StringComparison.InvariantCultureIgnoreCase ) ) && !errors.Contains( conflictingNameOutputErrorFm ) )
{
errors.Add( conflictingNameOutputErrorFm );
}
}
}
if ( string.IsNullOrWhiteSpace( Name ) )
{
errors.Add( missingFunctionNameError );
}
if ( Mode is CustomCodeNodeMode.File )
{
if ( string.IsNullOrWhiteSpace( Source ) )
{
errors.Add( missingSourceFilePathError );
}
else
{
if ( !Editor.FileSystem.Content.FileExists( $"shaders/{Source}" ) )
{
errors.Add( missingSourceFileError );
}
}
}
return errors;
}
public List<string> GetWarnings()
{
var warnings = new List<string>();
return warnings;
}
}
public class CustomFunctionNodePort : IValid
{
[Hide]
public Guid Id { get; } = Guid.NewGuid();
[Hide, JsonIgnore]
public bool IsValid => !string.IsNullOrWhiteSpace( Name ) && Type != null;
[KeyProperty]
public string Name { get; set; }
[Hide, JsonIgnore]
public Type Type
{
get
{
if ( string.IsNullOrEmpty( TypeName ) ) return null;
var portTypeName = TypeName;
Type type = null;
if ( portTypeName == "bool" ) type = typeof( bool );
if ( portTypeName == "int" ) type = typeof( int );
if ( portTypeName == "float" ) type = typeof( float );
if ( portTypeName == "Vector2" ) type = typeof( Vector2 );
if ( portTypeName == "Vector3" ) type = typeof( Vector3 );
if ( portTypeName == "Vector4" ) type = typeof( Vector4 );
if ( portTypeName == "Color" ) type = typeof( Color );
if ( portTypeName == "float2x2" ) type = typeof( Float2x2 );
if ( portTypeName == "float3x3" ) type = typeof( Float3x3 );
if ( portTypeName == "float4x4" ) type = typeof( Float4x4 );
if ( portTypeName == "Texture2D" || portTypeName == "TextureCube" ) type = typeof( Texture );
if ( portTypeName == "SamplerState" ) type = typeof( Sampler );
// Try getting type from EditorTypeLibrary.
if ( type != null && GraphCompiler.ValueTypes.TryGetValue( type, out var isEditorType ) && isEditorType )
{
return EditorTypeLibrary.GetType( type.FullName ).TargetType;
}
if ( portTypeName == "bool" ) portTypeName = typeof( bool ).FullName;
if ( portTypeName == "int" ) portTypeName = typeof( int ).FullName;
if ( portTypeName == "float" ) portTypeName = typeof( float ).FullName;
type = TypeLibrary.GetType( portTypeName ).TargetType;
return type;
}
}
[KeyProperty, Editor( ControlWidgetCustomEditors.PortTypeChoiceEditor ), JsonPropertyName( "Type" )]
public string TypeName { get; set; }
public int Priority { get; set; }
[Hide, JsonIgnore]
public string HLSLDataType
{
get
{
switch ( TypeName )
{
case "bool":
return "bool";
case "int":
return $"int";
case "float":
return $"float";
case "Vector2":
return $"float2";
case "Vector3":
return $"float3";
case "Vector4":
return $"float4";
case "Color":
return $"float4";
case "float2x2":
return "float2x2";
case "float3x3":
return "float3x3";
case "float4x4":
return "float4x4";
case "Texture2D":
return "Texture2D";
case "TextureCube":
return "TextureCube";
case "SamplerState":
return "SamplerState";
default:
return "";
}
}
}
[Hide, JsonIgnore]
internal bool IsOutput { get; set; }
public override int GetHashCode()
{
return System.HashCode.Combine( Id, Name, TypeName, Priority );
}
}