Skip to content

Commit

Permalink
refactor: proxy protocol parsers
Browse files Browse the repository at this point in the history
  • Loading branch information
LaoSparrow committed Aug 10, 2024
1 parent af70bc9 commit 65943b5
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 207 deletions.
6 changes: 3 additions & 3 deletions ProxyProtocolSocket/ProxyProtocolSocket.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
<TargetFramework>net6.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<Version>2.0</Version>
<Version>2.1</Version>
<Authors>LaoSparrow</Authors>
<AssemblyVersion>2.0</AssemblyVersion>
<FileVersion>2.0</FileVersion>
<AssemblyVersion>2.1</AssemblyVersion>
<FileVersion>2.1</FileVersion>
</PropertyGroup>

<ItemGroup>
Expand Down
68 changes: 35 additions & 33 deletions ProxyProtocolSocket/Utils/Net/ProxyProtocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,20 @@ public class ProxyProtocol
private static readonly byte[] V1_SIGNATURE = Encoding.ASCII.GetBytes("PROXY");
private static readonly byte[] V2_OR_ABOVE_SIGNATURE = { 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A };
private const int V2_OR_ABOVE_ID_LENGTH = 16;
private const int BUFFER_SIZE = 107;
private const int MAX_BUFFER_SIZE = 232;
#endregion

#region Members
private NetworkStream _stream;
private IPEndPoint _remoteEndpoint;

private byte[] _buffer = new byte[BUFFER_SIZE];
private int _bufferPosition = 0;
private byte[] _buffer = new byte[MAX_BUFFER_SIZE];
// i.e. the end of _buffer
// _buffer[_bufferSize - 1] is the last byte read from the stream
private int _bufferSize = 0;

private bool _isParserCached = false;
private IProxyProtocolParser? _cachedParser = null;
private IProxyProtocolParser? _cachedParser = null;
private ProxyProtocolVersion? _protocolVersion = null;
#endregion

Expand All @@ -62,7 +64,7 @@ public async Task Parse()
var parser = await GetParser();
if (parser == null)
return;
Logger.Log($"Calling parser {parser.GetType().Name}");
Logger.Log($"[{_remoteEndpoint}] calling {parser.GetType().Name}.Parse()");
await parser.Parse();
}

Expand Down Expand Up @@ -106,22 +108,14 @@ public async Task<ProxyProtocolCommand> GetCommand()
if (_isParserCached)
return _cachedParser;

Logger.Log("Getting parser");
Logger.Log($"[{_remoteEndpoint}] selecting parser...");
// Get parser corresponding to version
switch (await GetVersion())
_cachedParser = await GetVersion() switch
{
case ProxyProtocolVersion.V1:
_cachedParser = new ProxyProtocolParserV1(_stream, _remoteEndpoint, _buffer, ref _bufferPosition);
break;

case ProxyProtocolVersion.V2:
_cachedParser = new ProxyProtocolParserV2(_stream, _remoteEndpoint, _buffer, ref _bufferPosition);
break;

default:
_cachedParser = null;
break;
}
ProxyProtocolVersion.V1 => new ProxyProtocolParserV1(_stream, _remoteEndpoint, _buffer, _bufferSize),
ProxyProtocolVersion.V2 => new ProxyProtocolParserV2(_stream, _remoteEndpoint, _buffer, _bufferSize),
_ => null
};
_isParserCached = true;
return _cachedParser;
}
Expand All @@ -132,21 +126,21 @@ public async Task<ProxyProtocolVersion> GetVersion()
if (_protocolVersion != null)
return (ProxyProtocolVersion)_protocolVersion;

Logger.Log("Getting version info");
Logger.Log($"[{_remoteEndpoint}] interpreting protocol version...");

_protocolVersion = ProxyProtocolVersion.Unknown;
// Check if is version 1
await GetBytesToPosition(V1_SIGNATURE.Length);
await GetBytesTillBufferSize(V1_SIGNATURE.Length);
if (IsVersion1(_buffer))
_protocolVersion = ProxyProtocolVersion.V1;
else
{
// Check if is version 2 or above
await GetBytesToPosition(V2_OR_ABOVE_ID_LENGTH);
await GetBytesTillBufferSize(V2_OR_ABOVE_ID_LENGTH);
if (IsVersion2OrAbove(_buffer))
{
// Check versions
if (IsVersion2(_buffer, true))
if (IsVersion2(_buffer, false))
_protocolVersion = ProxyProtocolVersion.V2;
}
}
Expand All @@ -156,35 +150,43 @@ public async Task<ProxyProtocolVersion> GetVersion()
#endregion

#region Private methods
private async Task GetBytesToPosition(int position)
private async Task GetBytesTillBufferSize(int size)
{
if (position <= _bufferPosition)
if (size <= _bufferSize)
return;
await GetBytesFromStream(position - _bufferPosition);
await GetBytesFromStream(size - _bufferSize);
}

private async Task GetBytesFromStream(int length)
{
if ((_bufferPosition + length) > _buffer.Length)
if ((_bufferSize + length) > _buffer.Length)
throw new InternalBufferOverflowException();

while (length > 0)
{
if (!_stream.DataAvailable)
throw new EndOfStreamException();

int count = await _stream.ReadAsync(_buffer, _bufferPosition, length);
var count = await _stream.ReadAsync(_buffer.AsMemory(_bufferSize, length));
length -= count;
_bufferPosition += count;
_bufferSize += count;
}
}
#endregion

#region Public static methods
public static bool IsVersion1(byte[] header) => header.Take(V1_SIGNATURE.Length).SequenceEqual(V1_SIGNATURE);
public static bool IsVersion2OrAbove(byte[] header) => header.Take(V2_OR_ABOVE_SIGNATURE.Length).SequenceEqual(V2_OR_ABOVE_SIGNATURE);
public static bool IsVersion2(byte[] header, bool checkedSignature = false) =>
checkedSignature || IsVersion2OrAbove(header) &&

// ReSharper disable once MemberCanBePrivate.Global
public static bool IsVersion1(ReadOnlySpan<byte> header) =>
header[..V1_SIGNATURE.Length].SequenceEqual(V1_SIGNATURE);

// ReSharper disable once MemberCanBePrivate.Global
public static bool IsVersion2OrAbove(ReadOnlySpan<byte> header) =>
header[..V2_OR_ABOVE_SIGNATURE.Length].SequenceEqual(V2_OR_ABOVE_SIGNATURE);

// ReSharper disable once MemberCanBePrivate.Global
public static bool IsVersion2(ReadOnlySpan<byte> header, bool checkSignature = true) =>
(!checkSignature || IsVersion2OrAbove(header)) &&
(header[12] & 0xF0) == 0x20;
#endregion
}
Expand Down
137 changes: 63 additions & 74 deletions ProxyProtocolSocket/Utils/Net/ProxyProtocolParserV1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,113 +9,113 @@ public class ProxyProtocolParserV1 : IProxyProtocolParser
#region Constants
private const string DELIMITER = "\r\n";
private const char SEPARATOR = ' ';
private const int MAX_HEADER_SIZE = 107;
#endregion

#region Members
private NetworkStream _stream;
private IPEndPoint _remoteEndpoint;
private byte[] _buffer;
private int _bufferPosition;
private int _bufferSize;

private bool _isParsed;
private bool _hasParsed;
private AddressFamily _addressFamily = AddressFamily.Unknown;
private ProxyProtocolCommand _protocolCommand = ProxyProtocolCommand.Unknown;
private IPEndPoint? _sourceEndpoint;
private IPEndPoint? _destEndpoint;
#endregion

public ProxyProtocolParserV1(NetworkStream stream, IPEndPoint remoteEndpoint, byte[] buffer, ref int bufferPosition)
public ProxyProtocolParserV1(NetworkStream stream, IPEndPoint remoteEndpoint, byte[] buffer, int bufferSize)
{
#region Args checking
if (stream == null) throw new ArgumentNullException("argument 'stream' cannot be null");
if (stream.CanRead != true) throw new ArgumentException("argument 'stream' is unreadable");
if (remoteEndpoint == null) throw new ArgumentNullException("argument 'remoteEndpoint' cannot be null");
if (buffer == null) throw new ArgumentNullException("argument 'buffer' cannot be null");
if (bufferPosition > buffer.Length) throw new ArgumentException("argument 'bufferPosition' is larger than 'buffer.Length'");
if (stream == null) throw new ArgumentNullException(nameof(stream));
if (stream.CanRead != true) throw new ArgumentException($"argument 'stream' is unreadable");
if (remoteEndpoint == null) throw new ArgumentNullException(nameof(remoteEndpoint));
if (buffer == null) throw new ArgumentNullException(nameof(buffer));
if (bufferSize > buffer.Length) throw new ArgumentException($"argument '{nameof(bufferSize)}' is larger than '{nameof(buffer)}.Length'");
#endregion

#region Filling members
_stream = stream;
_remoteEndpoint = remoteEndpoint;
_buffer = buffer;
_bufferPosition = bufferPosition;
_bufferSize = bufferSize;
#endregion
}

#region Public methods
public async Task Parse()
{
if (_isParsed)
if (_hasParsed)
return;
_isParsed = true;
Logger.Log("Parsing header");
_hasParsed = true;
Logger.Log($"[{_remoteEndpoint}] parsing header...");

#region Getting full header and do first check

await GetFullHeader();
if (_bufferPosition < 2 || _buffer[_bufferPosition - 2] != '\r')
throw new Exception("Header must end with CRLF");

string[] tokens = Encoding.ASCII.GetString(_buffer.Take(_bufferPosition - 2).ToArray()).Split(SEPARATOR);
if (ProxyProtocolSocketPlugin.Config.Settings.LogLevel == LogLevel.Debug)
Logger.Log($"[{_remoteEndpoint}] header content: {Convert.ToHexString(_buffer[.._bufferSize])}");
var tokens = Encoding.ASCII.GetString(_buffer[..(_bufferSize - 2)]).Split(SEPARATOR);
if (tokens.Length < 2)
throw new Exception("Unable to read AddressFamily and protocol");

#endregion

#region Parse address family
AddressFamily addressFamily;
switch (tokens[1])

Logger.Log($"[{_remoteEndpoint}] parsing address family...");
var addressFamily = tokens[1] switch
{
case "TCP4":
addressFamily = AddressFamily.InterNetwork;
break;

case "TCP6":
addressFamily = AddressFamily.InterNetworkV6;
break;

case "UNKNOWN":
addressFamily = AddressFamily.Unspecified;
break;
"TCP4" => AddressFamily.InterNetwork,
"TCP6" => AddressFamily.InterNetworkV6,
"UNKNOWN" => AddressFamily.Unspecified,
_ => throw new Exception("Invalid address family")
};

default:
throw new Exception("Invalid address family");
}
#endregion

#region Do second check

if (addressFamily == AddressFamily.Unspecified)
{
_protocolCommand = ProxyProtocolCommand.Local;
_sourceEndpoint = _remoteEndpoint;
_isParsed = true;
_hasParsed = true;
return;
}
else if (tokens.Length < 6)
throw new Exception("Unable to read ipaddresses and ports");

if (tokens.Length < 6)
throw new Exception("Impossible to read ip addresses and ports as the number of tokens is less than 6");

#endregion

#region Parse source and dest end point
IPEndPoint sourceEP;
IPEndPoint destEP;

Logger.Log($"[{_remoteEndpoint}] parsing endpoints...");
IPEndPoint sourceEp;
IPEndPoint destEp;
try
{
// TODO: IP format validation
IPAddress sourceAddr = IPAddress.Parse(tokens[2]);
IPAddress destAddr = IPAddress.Parse(tokens[3]);
int sourcePort = Convert.ToInt32(tokens[4]);
int destPort = Convert.ToInt32(tokens[5]);
sourceEP = new IPEndPoint(sourceAddr, sourcePort);
destEP = new IPEndPoint(destAddr, destPort);
var sourceAddr = IPAddress.Parse(tokens[2]);
var destAddr = IPAddress.Parse(tokens[3]);
var sourcePort = Convert.ToInt32(tokens[4]);
var destPort = Convert.ToInt32(tokens[5]);
sourceEp = new IPEndPoint(sourceAddr, sourcePort);
destEp = new IPEndPoint(destAddr, destPort);
}
catch (Exception ex)
{
throw new Exception("Unable to parse ip addresses and ports", ex);
}

#endregion

_addressFamily = addressFamily;
_protocolCommand = ProxyProtocolCommand.Proxy;
_sourceEndpoint = sourceEP;
_destEndpoint = destEP;
_sourceEndpoint = sourceEp;
_destEndpoint = destEp;
}

public async Task<IPEndPoint?> GetSourceEndpoint()
Expand Down Expand Up @@ -146,58 +146,47 @@ public async Task<ProxyProtocolCommand> GetCommand()
#region Private methods
private async Task GetFullHeader()
{
Logger.Log($"Getting full header");
for (int i = 1; ; i++)
Logger.Log($"[{_remoteEndpoint}] getting full header");
for (var i = 7; i < MAX_HEADER_SIZE; i++) // Search after "PROXY" signature
{
if (await GetOneByteOfPosition(i) == '\n')
break;
if (i >= _buffer.Length)
throw new Exception("Reaching the end of buffer without reaching the delimiter of version 1");
if (await GetOneByteAtPosition(i) != DELIMITER[1])
continue;
if (await GetOneByteAtPosition(i - 1) != DELIMITER[0])
throw new Exception("Header must end with CRLF");
return;
}
throw new Exception("Failed to find any delimiter within the maximum header size of version 1");
}

private async Task GetBytesToPosition(int position)
private async Task GetBytesTillBufferSize(int size)
{
if (position <= _bufferPosition)
if (size <= _bufferSize)
return;
await GetBytesFromStream(position - _bufferPosition);
await GetBytesFromStream(size - _bufferSize);
}

private async Task GetBytesFromStream(int length)
{
if ((_bufferPosition + length) > _buffer.Length)
if (_bufferSize + length > _buffer.Length)
throw new InternalBufferOverflowException();

while (length > 0)
{
if (!_stream.DataAvailable)
throw new EndOfStreamException();

int count = await _stream.ReadAsync(_buffer, _bufferPosition, length);
var count = await _stream.ReadAsync(_buffer.AsMemory(_bufferSize, length));
length -= count;
_bufferPosition += count;
_bufferSize += count;
}
}

private async Task<byte> GetOneByteOfPosition(int position)
private async Task<byte> GetOneByteAtPosition(int position)
{
await GetBytesToPosition(position);
return _buffer[position - 1];
}

private byte GetOneByteFromStream()
{
if ((_bufferPosition + 1) > _buffer.Length)
throw new InternalBufferOverflowException();

int readState = _stream.ReadByte();
if (readState < 0)
throw new EndOfStreamException();

_buffer[_bufferPosition] = (byte)readState;
_bufferPosition++;
return (byte)readState;
await GetBytesTillBufferSize(position + 1);
return _buffer[position];
}

#endregion
}
}
Loading

0 comments on commit 65943b5

Please sign in to comment.