163 lines
6.3 KiB
C#
163 lines
6.3 KiB
C#
using System;
|
|
using System.IO;
|
|
using System.IO.Compression;
|
|
|
|
namespace TH1_Logic.Steam
|
|
{
|
|
public static class NetworkPayloadCodec
|
|
{
|
|
private const int Magic = 0x31434854; // THC1
|
|
private const byte Version = 2;
|
|
private const byte AlgorithmDeflate = 2;
|
|
private const byte FlagCompressed = 1;
|
|
private const int HeaderSize = 20;
|
|
private const int CompressionThresholdBytes = 16 * 1024;
|
|
private const int MaxDecodedBytes = 64 * 1024 * 1024;
|
|
private const int MinSavedBytes = 256;
|
|
private const double MinSavedRatio = 0.05;
|
|
|
|
public static byte[] Encode(byte[] rawBytes)
|
|
{
|
|
return EncodeInternal(rawBytes, false);
|
|
}
|
|
|
|
public static byte[] EncodeForDiagnostics(byte[] rawBytes, bool forceCompression)
|
|
{
|
|
return EncodeInternal(rawBytes, forceCompression);
|
|
}
|
|
|
|
public static byte[] DecodeIfNeeded(byte[] wireBytes)
|
|
{
|
|
if (wireBytes == null || wireBytes.Length < HeaderSize) return wireBytes;
|
|
if (ReadInt32(wireBytes, 0) != Magic) return wireBytes;
|
|
|
|
var version = wireBytes[4];
|
|
var algorithm = wireBytes[5];
|
|
var flags = wireBytes[6];
|
|
var headerSize = wireBytes[7];
|
|
var rawLength = ReadInt32(wireBytes, 8);
|
|
var payloadLength = ReadInt32(wireBytes, 12);
|
|
var rawHash = ReadInt32(wireBytes, 16);
|
|
|
|
if (version != Version || headerSize != HeaderSize)
|
|
throw new InvalidDataException($"Unsupported network payload codec header: version={version}, header={headerSize}");
|
|
if (algorithm != AlgorithmDeflate || (flags & FlagCompressed) == 0)
|
|
throw new InvalidDataException($"Unsupported network payload algorithm: algorithm={algorithm}, flags={flags}");
|
|
if (rawLength < 0 || rawLength > MaxDecodedBytes)
|
|
throw new InvalidDataException($"Invalid network payload raw length: {rawLength}");
|
|
if (payloadLength < 0 || HeaderSize + payloadLength != wireBytes.Length)
|
|
throw new InvalidDataException($"Invalid network payload length: payload={payloadLength}, bytes={wireBytes.Length}");
|
|
|
|
var rawBytes = DecompressDeflate(wireBytes, HeaderSize, payloadLength, rawLength);
|
|
if (ComputeHash(rawBytes) != rawHash)
|
|
throw new InvalidDataException("Network payload hash mismatch");
|
|
|
|
return rawBytes;
|
|
}
|
|
|
|
private static byte[] EncodeInternal(byte[] rawBytes, bool forceCompression)
|
|
{
|
|
if (rawBytes == null || rawBytes.Length == 0) return rawBytes;
|
|
if (!forceCompression && rawBytes.Length < CompressionThresholdBytes) return rawBytes;
|
|
if (rawBytes.Length > MaxDecodedBytes) return rawBytes;
|
|
|
|
try
|
|
{
|
|
var compressed = CompressDeflate(rawBytes);
|
|
var wireLength = HeaderSize + compressed.Length;
|
|
if (!forceCompression && !ShouldUseCompressed(rawBytes.Length, wireLength)) return rawBytes;
|
|
|
|
var wireBytes = new byte[wireLength];
|
|
WriteInt32(wireBytes, 0, Magic);
|
|
wireBytes[4] = Version;
|
|
wireBytes[5] = AlgorithmDeflate;
|
|
wireBytes[6] = FlagCompressed;
|
|
wireBytes[7] = HeaderSize;
|
|
WriteInt32(wireBytes, 8, rawBytes.Length);
|
|
WriteInt32(wireBytes, 12, compressed.Length);
|
|
WriteInt32(wireBytes, 16, ComputeHash(rawBytes));
|
|
Buffer.BlockCopy(compressed, 0, wireBytes, HeaderSize, compressed.Length);
|
|
return wireBytes;
|
|
}
|
|
catch
|
|
{
|
|
return rawBytes;
|
|
}
|
|
}
|
|
|
|
private static byte[] CompressDeflate(byte[] rawBytes)
|
|
{
|
|
using (var output = new MemoryStream())
|
|
{
|
|
using (var stream = new DeflateStream(output, CompressionLevel.Fastest))
|
|
{
|
|
stream.Write(rawBytes, 0, rawBytes.Length);
|
|
}
|
|
|
|
return output.ToArray();
|
|
}
|
|
}
|
|
|
|
private static byte[] DecompressDeflate(byte[] wireBytes, int offset, int length, int rawLength)
|
|
{
|
|
var rawBytes = new byte[rawLength];
|
|
using (var input = new MemoryStream(wireBytes, offset, length))
|
|
using (var stream = new DeflateStream(input, CompressionMode.Decompress))
|
|
{
|
|
var readTotal = 0;
|
|
while (readTotal < rawBytes.Length)
|
|
{
|
|
var read = stream.Read(rawBytes, readTotal, rawBytes.Length - readTotal);
|
|
if (read == 0) break;
|
|
readTotal += read;
|
|
}
|
|
|
|
if (readTotal != rawBytes.Length)
|
|
throw new InvalidDataException($"Network payload decoded length mismatch: decoded={readTotal}, expected={rawBytes.Length}");
|
|
if (stream.ReadByte() != -1)
|
|
throw new InvalidDataException("Network payload decoded length exceeded header length");
|
|
}
|
|
|
|
return rawBytes;
|
|
}
|
|
|
|
private static bool ShouldUseCompressed(int rawLength, int wireLength)
|
|
{
|
|
var savedBytes = rawLength - wireLength;
|
|
if (savedBytes < MinSavedBytes) return false;
|
|
return savedBytes / (double)rawLength >= MinSavedRatio;
|
|
}
|
|
|
|
private static int ComputeHash(byte[] bytes)
|
|
{
|
|
unchecked
|
|
{
|
|
var hash = (uint)2166136261;
|
|
for (var i = 0; i < bytes.Length; i++)
|
|
{
|
|
hash ^= bytes[i];
|
|
hash *= 16777619;
|
|
}
|
|
|
|
return (int)hash;
|
|
}
|
|
}
|
|
|
|
private static int ReadInt32(byte[] bytes, int offset)
|
|
{
|
|
return bytes[offset]
|
|
| (bytes[offset + 1] << 8)
|
|
| (bytes[offset + 2] << 16)
|
|
| (bytes[offset + 3] << 24);
|
|
}
|
|
|
|
private static void WriteInt32(byte[] bytes, int offset, int value)
|
|
{
|
|
bytes[offset] = (byte)value;
|
|
bytes[offset + 1] = (byte)(value >> 8);
|
|
bytes[offset + 2] = (byte)(value >> 16);
|
|
bytes[offset + 3] = (byte)(value >> 24);
|
|
}
|
|
}
|
|
}
|