Skip to content

Commit

Permalink
Length-based switch dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
jcouv committed Dec 21, 2022
1 parent 8c32d3a commit 486e8ba
Show file tree
Hide file tree
Showing 30 changed files with 15,518 additions and 108 deletions.
1 change: 1 addition & 0 deletions src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,7 @@
<Field Name="Expression" Type="BoundExpression" Null="disallow" />
<Field Name="Cases" Type="ImmutableArray&lt;(ConstantValue value, LabelSymbol label)&gt;" />
<Field Name="DefaultLabel" Type="LabelSymbol" Null="disallow" />
<Field Name="LengthBasedStringSwitchDataOpt" Type="LengthBasedStringSwitchData?" />
</Node>

<Node Name="BoundIfStatement" Base="BoundStatement">
Expand Down
313 changes: 313 additions & 0 deletions src/Compilers/CSharp/Portable/BoundTree/LengthBasedStringSwitchData.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis.CodeGen;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.CSharp
{
// The general idea is to stratify cases by
// first bucketing on Length
// then bucketing on a character position selected by heuristic
// and finally switching to exact string (this is a simple string comparison when only one possibility remains).
//
// The benefit of this approach is that it much reduces the need for computing
// the input string's hashcode.
//
// We emit something like:
//
// // null case:
// if (key is null)
// goto labelNull; OR goto labelDefault;
//
// switch (key.Length)
// {
// // empty string doesn't need a char or string test
// case 0: goto labelEmpty;
//
// // strings of length 1 don't need any further validation once we've checked one char
// case 1:
// switch (key[posM])
// {
// case '1': goto label1;
// case '2': goto label2;
// ...
// default: goto labelDefault;
// }
// ...
// // when a given length is sufficient to narrow down to one case we skip the char test:
// case N: if (key == "caseN") { goto labelN; } else { goto labelDefault; }
// ...
// case M:
// switch (key[posM])
// {
// // when a single character check narrows down to one possibility:
// case '1': if (key == "caseM1") { goto labelM1; } else { goto labelDefault; }
//
// // when a single character check leaves a few possibilities remaining (worst case scenario):
// case '2':
// switch (key)
// {
// case "caseM1_A": goto labelM1_A;
// case "caseM1_B": goto labelM1_B;
// ...
// default: goto labelDefault;
// }
// ...
// default: goto labelDefault;
// }
// ...
// default: goto labelDefault;
// }

internal class LengthBasedStringSwitchData
{
internal readonly LengthJumpTable _lengthJumpTable;
internal readonly ImmutableArray<CharJumpTable> _charJumpTables;
internal readonly ImmutableArray<StringJumpTable> _stringJumpTables;

internal LengthBasedStringSwitchData(LengthJumpTable lengthJumpTable,
ImmutableArray<CharJumpTable> charJumpTables, ImmutableArray<StringJumpTable> stringJumpTables)
{
_lengthJumpTable = lengthJumpTable;
_charJumpTables = charJumpTables;
_stringJumpTables = stringJumpTables;
}

internal struct LengthJumpTable
{
public readonly LabelSymbol? nullCaseLabel;
public ImmutableArray<(ConstantValue value, LabelSymbol label)> lengthCaseLabels;

public LengthJumpTable(LabelSymbol? nullCaseLabel, ImmutableArray<(ConstantValue value, LabelSymbol label)> lengthCaseLabels)
{
Debug.Assert(lengthCaseLabels.All(c => c.value.IsIntegral));

this.nullCaseLabel = nullCaseLabel;
this.lengthCaseLabels = lengthCaseLabels;
}
}

internal struct CharJumpTable
{
public readonly LabelSymbol label;
public readonly int selectedCharPosition;
public readonly ImmutableArray<(ConstantValue value, LabelSymbol label)> charCaseLabels;

internal CharJumpTable(LabelSymbol label, int selectedCharPosition, ImmutableArray<(ConstantValue value, LabelSymbol label)> charCaseLabels)
{
Debug.Assert(charCaseLabels.All(c => c.value.IsChar));

this.label = label;
this.selectedCharPosition = selectedCharPosition;
this.charCaseLabels = charCaseLabels;
}
}

internal struct StringJumpTable
{
public readonly LabelSymbol label;
public readonly ImmutableArray<(ConstantValue value, LabelSymbol label)> stringCaseLabels;

internal StringJumpTable(LabelSymbol label, ImmutableArray<(ConstantValue value, LabelSymbol label)> stringCaseLabels)
{
Debug.Assert(stringCaseLabels.All(c => c.value.IsString) && stringCaseLabels.Length > 0);

this.label = label;
this.stringCaseLabels = stringCaseLabels;
}
}

// Based on benchmarks, the previous hashcode-based approach arguably performs better
// when buckets have 6 candidates or more.
internal bool ShouldGenerateLengthBasedSwitch(int labelsCount)
{
return SwitchStringJumpTableEmitter.ShouldGenerateHashTableSwitch(labelsCount) &&
_stringJumpTables.All(t => t.stringCaseLabels!.Length <= 5);
}

internal static LengthBasedStringSwitchData Create(ImmutableArray<(ConstantValue value, LabelSymbol label)> inputCases)
{
Debug.Assert(inputCases.All(c => c.value.IsString && c.label is not null));

LabelSymbol? nullCaseLabel = null;
foreach (var inputCase in inputCases)
{
if (inputCase.value.IsNull)
{
Debug.Assert(nullCaseLabel is null, "At most one null case per string dispatch");
nullCaseLabel = inputCase.label;
}
}

var lengthCaseLabels = ArrayBuilder<(ConstantValue value, LabelSymbol label)>.GetInstance();
var charJumpTables = ArrayBuilder<CharJumpTable>.GetInstance();
var stringJumpTables = ArrayBuilder<StringJumpTable>.GetInstance();
foreach (var group in inputCases.Where(c => !c.value.IsNull).GroupBy(c => c.value.StringValue!.Length))
{
int stringLength = group.Key;
var labelForLength = CreateAndRegisterCharJumpTables(stringLength, group.Select(e => (e.value, e.label)).ToImmutableArray(), charJumpTables, stringJumpTables);
lengthCaseLabels.Add((ConstantValue.Create(stringLength), labelForLength));
}

var lengthJumpTable = new LengthJumpTable(nullCaseLabel, lengthCaseLabels.ToImmutableAndFree());
return new LengthBasedStringSwitchData(lengthJumpTable, charJumpTables.ToImmutableAndFree(), stringJumpTables.ToImmutableAndFree());
}

private static LabelSymbol CreateAndRegisterCharJumpTables(int stringLength, ImmutableArray<(ConstantValue value, LabelSymbol label)> casesWithGivenLength,
ArrayBuilder<CharJumpTable> charJumpTables, ArrayBuilder<StringJumpTable> stringJumpTables)
{
Debug.Assert(stringLength >= 0);
Debug.Assert(casesWithGivenLength.All(c => c.value.StringValue!.Length == stringLength));
Debug.Assert(casesWithGivenLength.Length > 0);

if (stringLength == 0)
{
// Only the empty string has zero Length, no need for further testing
return casesWithGivenLength.Single().label;
}

if (casesWithGivenLength.Length == 1)
{
// We only have one case for the given string length, we don't need to do a char test
// Instead we'll jump straight to the final string test
return CreateAndRegisterStringJumpTable(casesWithGivenLength, stringJumpTables);
}

var bestCharacterPosition = selectBestCharacterIndex(stringLength, casesWithGivenLength);
var charCaseLabels = ArrayBuilder<(ConstantValue value, LabelSymbol label)>.GetInstance();
foreach (var group in casesWithGivenLength.GroupBy(c => c.value.StringValue![bestCharacterPosition]))
{
char character = group.Key;
var label = CreateAndRegisterStringJumpTable(group.Select(c => c).ToImmutableArray(), stringJumpTables);
charCaseLabels.Add((ConstantValue.Create(character), label));
}

var charJumpTable = new CharJumpTable(label: new GeneratedLabelSymbol("char-dispatch"), bestCharacterPosition, charCaseLabels.ToImmutableAndFree());
charJumpTables.Add(charJumpTable);
return charJumpTable.label;

static int selectBestCharacterIndex(int stringLength, ImmutableArray<(ConstantValue value, LabelSymbol label)> caseLabels)
{
// We pick the position that maximizes number of buckets with a single entry.
// We break ties by preferring lower max bucket size.
Debug.Assert(stringLength > 0);
Debug.Assert(caseLabels.Length > 0);
int bestIndex = -1;
int bestIndexSingleEntryCount = -1;
int bestIndexLargestBucket = int.MaxValue;
for (int currentPosition = 0; currentPosition < stringLength; currentPosition++)
{
(int singleEntryCount, int largestBucket) = positionScore(currentPosition, caseLabels);

if (singleEntryCount > bestIndexSingleEntryCount ||
(singleEntryCount == bestIndexSingleEntryCount && largestBucket < bestIndexLargestBucket))
{
bestIndexSingleEntryCount = singleEntryCount;
bestIndexLargestBucket = largestBucket;
bestIndex = currentPosition;
}
}

return bestIndex;
}

// Given a position and a set of string cases of matching lengths, inspect the buckets created by inspecting
// those strings at that position. Return the count how many buckets have a single entry and the size of the largest bucket.
static (int singleEntryCount, int largestBucket) positionScore(int position, ImmutableArray<(ConstantValue value, LabelSymbol label)> caseLabels)
{
var countPerChar = PooledDictionary<char, int>.GetInstance();
foreach (var caseLabel in caseLabels)
{
Debug.Assert(caseLabel.value.StringValue is not null);
var currentChar = caseLabel.value.StringValue[position];
if (countPerChar.TryGetValue(currentChar, out var currentCount))
{
countPerChar[currentChar]++;
}
else
{
countPerChar[currentChar] = 1;
}
}

var singleEntryCount = countPerChar.Values.Count(c => c == 1);
var largestBucket = countPerChar.Values.Max();
countPerChar.Free();
return (singleEntryCount, largestBucket);
}
}

private static LabelSymbol CreateAndRegisterStringJumpTable(ImmutableArray<(ConstantValue value, LabelSymbol label)> cases, ArrayBuilder<StringJumpTable> stringJumpTables)
{
Debug.Assert(cases.Length > 0);

if (cases.Length == 1 && cases[0].value.StringValue!.Length == 1)
{
// If we have a single case that consists of a 1-length string then we can skip the string check
return cases[0].label;
}

var stringJumpTable = new StringJumpTable(label: new GeneratedLabelSymbol("string-dispatch"), cases);
stringJumpTables.Add(stringJumpTable);
return stringJumpTable.label;
}

#if DEBUG
public string Dump()
{
var builder = new StringBuilder();
builder.AppendLine("Length dispatch:");
builder.AppendLine($"Buckets: {string.Join(", ", _charJumpTables.Select(t => t.charCaseLabels.Length))}");
builder.AppendLine($" case null: {readable(_lengthJumpTable.nullCaseLabel)}");
dump(_lengthJumpTable.lengthCaseLabels);
builder.AppendLine();

builder.AppendLine("Char dispatches:");
foreach (var charJumpTable in _charJumpTables)
{
builder.AppendLine($"Label {readable(charJumpTable.label)}:");
builder.AppendLine($" Selected char position: {charJumpTable.selectedCharPosition}:");
dump(charJumpTable.charCaseLabels!);
}
builder.AppendLine();

builder.AppendLine("String dispatches:");
foreach (var stringJumpTable in _stringJumpTables)
{
builder.AppendLine($"Label {readable(stringJumpTable.label)}:");
dump(stringJumpTable.stringCaseLabels!);
}
builder.AppendLine();

return builder.ToString();

void dump(ImmutableArray<(ConstantValue value, LabelSymbol label)> cases)
{
foreach (var (constant, label) in cases)
{
builder.AppendLine($" case {constant}: {readable(label)}");
}
}

string readable(LabelSymbol? label)
{
if (label is null)
{
return "<null>";
}

return label.ToString();
}
}
#endif
}
}
3 changes: 3 additions & 0 deletions src/Compilers/CSharp/Portable/CSharpResources.resx
Original file line number Diff line number Diff line change
Expand Up @@ -7406,4 +7406,7 @@ To remove the warning, you can use /reference instead (set the Embed Interop Typ
<data name="WRN_ParamsArrayInLambdaOnly_Title" xml:space="preserve">
<value>Parameter has params modifier in lambda but not in target delegate type.</value>
</data>
<data name="IDS_DisableLengthBasedSwitch" xml:space="preserve">
<value>disable-length-based switch</value>
</data>
</root>
Loading

0 comments on commit 486e8ba

Please sign in to comment.