Code/Wasm/Optimize/FunctionTypeOptimizations.cs
using System;
using System.Collections.Generic;
using System.Linq;

namespace WasmBox.Wasm.Optimize {
    /// <summary>
    /// Defines function type optimizations.
    /// </summary>
    public static class FunctionTypeOptimizations {
        /// <summary>
        /// Takes a sequence of function types as input and produces
        /// a list of equivalent, distinct function types and a map
        /// that maps type indices from the old type sequence to
        /// indices of equivalent types in the new function type list.
        /// </summary>
        /// <param name="types">The sequence of types to make distinct.</param>
        /// <param name="newTypes">The list of distinct function types.</param>
        /// <param name="typeMapping">
        /// A map from function types that occur in <c>Types</c> to their equivalents in <c>NewTypes</c>.
        /// </param>
        public static void MakeFunctionTypesDistinct(
            IEnumerable<FunctionType> types,
            out IReadOnlyList<FunctionType> newTypes,
            out IReadOnlyDictionary<uint, uint> typeMapping) {
            var newTypeList = new List<FunctionType>();
            var structuralOldToNewTypeMap = new Dictionary<FunctionType, uint>(
                ConstFunctionTypeComparer.Instance);
            var referentialOldToNewTypeMap = new Dictionary<uint, uint>();
            uint i = 0;
            foreach (var oldType in types) {
                uint newTypeIndex;
                if (structuralOldToNewTypeMap.TryGetValue(oldType, out newTypeIndex)) {
                    referentialOldToNewTypeMap[i] = newTypeIndex;
                }
                else {
                    newTypeIndex = (uint)newTypeList.Count;
                    structuralOldToNewTypeMap[oldType] = newTypeIndex;
                    referentialOldToNewTypeMap[i] = newTypeIndex;
                    newTypeList.Add(oldType);
                }
                i++;
            }
            newTypes = newTypeList;
            typeMapping = referentialOldToNewTypeMap;
        }

        /// <summary>
        /// Rewrites function type references in the given WebAssembly file
        /// by replacing keys from the rewrite map with their corresponding
        /// values.
        /// </summary>
        /// <param name="file">The WebAssembly file to rewrite.</param>
        /// <param name="rewriteMap">A mapping of original type indices to new type indices.</param>
        public static void RewriteFunctionTypeReferences(
            this WasmFile file,
            IReadOnlyDictionary<uint, uint> rewriteMap) {
            // Type references occur only in the import and function sections.
            var importSections = file.GetSections<ImportSection>();
            for (int i = 0; i < importSections.Count; i++) {
                var importSec = importSections[i];
                for (int j = 0; j < importSec.Imports.Count; j++) {
                    var importDecl = importSec.Imports[j] as ImportedFunction;
                    uint newIndex;
                    if (importDecl != null && rewriteMap.TryGetValue(importDecl.TypeIndex, out newIndex)) {
                        importDecl.TypeIndex = newIndex;
                    }
                }
            }

            var funcSections = file.GetSections<FunctionSection>();
            for (int i = 0; i < funcSections.Count; i++) {
                var funcSec = funcSections[i];
                for (int j = 0; j < funcSec.FunctionTypes.Count; j++) {
                    uint newIndex;
                    if (rewriteMap.TryGetValue(funcSec.FunctionTypes[j], out newIndex)) {
                        funcSec.FunctionTypes[j] = newIndex;
                    }
                }
            }
        }

        /// <summary>
        /// Compresses function types in the given WebAssembly file
        /// by including only unique function types.
        /// </summary>
        /// <param name="file">The WebAssembly file to modify.</param>
        public static void CompressFunctionTypes(
            this WasmFile file) {
            // Grab the first type section.
            var typeSection = file.GetFirstSectionOrNull<TypeSection>();
            if (typeSection == null) {
                return;
            }

            // Make all types from the first type section distinct.
            IReadOnlyList<FunctionType> newTypes;
            IReadOnlyDictionary<uint, uint> typeIndexMap;
            MakeFunctionTypesDistinct(typeSection.FunctionTypes, out newTypes, out typeIndexMap);

            // Rewrite the type section's function types.
            typeSection.FunctionTypes.Clear();
            typeSection.FunctionTypes.AddRange(newTypes);

            // Rewrite type indices.
            file.RewriteFunctionTypeReferences(typeIndexMap);
        }
    }

    /// <summary>
    /// An equality comparer for function types that assumes that the contents
    /// of the function types it is given remain constant over the course of
    /// its operation.
    /// </summary>
    public sealed class ConstFunctionTypeComparer : IEqualityComparer<FunctionType> {
        private ConstFunctionTypeComparer() { }

        /// <summary>
        /// Gets the one and only instance of the constant function type comparer.
        /// </summary>
        public static readonly ConstFunctionTypeComparer Instance = new ConstFunctionTypeComparer();

        /// <inheritdoc/>
        public bool Equals(FunctionType x, FunctionType y) {
            return Enumerable.SequenceEqual( x.ParameterTypes, y.ParameterTypes)
                && Enumerable.SequenceEqual( x.ReturnTypes, y.ReturnTypes);
        }

        private static int HashSequence(IEnumerable<WasmValueType> values, int seed) {
            // Based on Brendan's answer to this StackOverflow question:
            // https://stackoverflow.com/questions/16340/how-do-i-generate-a-hashcode-from-a-byte-array-in-c
            int result = seed;
            foreach (var item in values) {
                result = (result * 31) ^ (int)item;
            }
            return result;
        }

        /// <inheritdoc/>
        public int GetHashCode(FunctionType obj) {
            return HashSequence(obj.ReturnTypes, HashSequence(obj.ParameterTypes, 0));
        }
    }
}