TH1/Unity/Assets/Scripts/TH1_Logic/Steam/NetworkPayloadCodec.cs
2026-05-23 21:03:20 +08:00

163 lines
6.3 KiB
C#

using System;
using System.IO;
using System.IO.Compression;
namespace TH1_Logic.Steam
{
public static class NetworkPayloadCodec
{
private const int Magic = 0x31434854; // THC1
private const byte Version = 2;
private const byte AlgorithmDeflate = 2;
private const byte FlagCompressed = 1;
private const int HeaderSize = 20;
private const int CompressionThresholdBytes = 16 * 1024;
private const int MaxDecodedBytes = 64 * 1024 * 1024;
private const int MinSavedBytes = 256;
private const double MinSavedRatio = 0.05;
public static byte[] Encode(byte[] rawBytes)
{
return EncodeInternal(rawBytes, false);
}
public static byte[] EncodeForDiagnostics(byte[] rawBytes, bool forceCompression)
{
return EncodeInternal(rawBytes, forceCompression);
}
public static byte[] DecodeIfNeeded(byte[] wireBytes)
{
if (wireBytes == null || wireBytes.Length < HeaderSize) return wireBytes;
if (ReadInt32(wireBytes, 0) != Magic) return wireBytes;
var version = wireBytes[4];
var algorithm = wireBytes[5];
var flags = wireBytes[6];
var headerSize = wireBytes[7];
var rawLength = ReadInt32(wireBytes, 8);
var payloadLength = ReadInt32(wireBytes, 12);
var rawHash = ReadInt32(wireBytes, 16);
if (version != Version || headerSize != HeaderSize)
throw new InvalidDataException($"Unsupported network payload codec header: version={version}, header={headerSize}");
if (algorithm != AlgorithmDeflate || (flags & FlagCompressed) == 0)
throw new InvalidDataException($"Unsupported network payload algorithm: algorithm={algorithm}, flags={flags}");
if (rawLength < 0 || rawLength > MaxDecodedBytes)
throw new InvalidDataException($"Invalid network payload raw length: {rawLength}");
if (payloadLength < 0 || HeaderSize + payloadLength != wireBytes.Length)
throw new InvalidDataException($"Invalid network payload length: payload={payloadLength}, bytes={wireBytes.Length}");
var rawBytes = DecompressDeflate(wireBytes, HeaderSize, payloadLength, rawLength);
if (ComputeHash(rawBytes) != rawHash)
throw new InvalidDataException("Network payload hash mismatch");
return rawBytes;
}
private static byte[] EncodeInternal(byte[] rawBytes, bool forceCompression)
{
if (rawBytes == null || rawBytes.Length == 0) return rawBytes;
if (!forceCompression && rawBytes.Length < CompressionThresholdBytes) return rawBytes;
if (rawBytes.Length > MaxDecodedBytes) return rawBytes;
try
{
var compressed = CompressDeflate(rawBytes);
var wireLength = HeaderSize + compressed.Length;
if (!forceCompression && !ShouldUseCompressed(rawBytes.Length, wireLength)) return rawBytes;
var wireBytes = new byte[wireLength];
WriteInt32(wireBytes, 0, Magic);
wireBytes[4] = Version;
wireBytes[5] = AlgorithmDeflate;
wireBytes[6] = FlagCompressed;
wireBytes[7] = HeaderSize;
WriteInt32(wireBytes, 8, rawBytes.Length);
WriteInt32(wireBytes, 12, compressed.Length);
WriteInt32(wireBytes, 16, ComputeHash(rawBytes));
Buffer.BlockCopy(compressed, 0, wireBytes, HeaderSize, compressed.Length);
return wireBytes;
}
catch
{
return rawBytes;
}
}
private static byte[] CompressDeflate(byte[] rawBytes)
{
using (var output = new MemoryStream())
{
using (var stream = new DeflateStream(output, CompressionLevel.Fastest))
{
stream.Write(rawBytes, 0, rawBytes.Length);
}
return output.ToArray();
}
}
private static byte[] DecompressDeflate(byte[] wireBytes, int offset, int length, int rawLength)
{
var rawBytes = new byte[rawLength];
using (var input = new MemoryStream(wireBytes, offset, length))
using (var stream = new DeflateStream(input, CompressionMode.Decompress))
{
var readTotal = 0;
while (readTotal < rawBytes.Length)
{
var read = stream.Read(rawBytes, readTotal, rawBytes.Length - readTotal);
if (read == 0) break;
readTotal += read;
}
if (readTotal != rawBytes.Length)
throw new InvalidDataException($"Network payload decoded length mismatch: decoded={readTotal}, expected={rawBytes.Length}");
if (stream.ReadByte() != -1)
throw new InvalidDataException("Network payload decoded length exceeded header length");
}
return rawBytes;
}
private static bool ShouldUseCompressed(int rawLength, int wireLength)
{
var savedBytes = rawLength - wireLength;
if (savedBytes < MinSavedBytes) return false;
return savedBytes / (double)rawLength >= MinSavedRatio;
}
private static int ComputeHash(byte[] bytes)
{
unchecked
{
var hash = (uint)2166136261;
for (var i = 0; i < bytes.Length; i++)
{
hash ^= bytes[i];
hash *= 16777619;
}
return (int)hash;
}
}
private static int ReadInt32(byte[] bytes, int offset)
{
return bytes[offset]
| (bytes[offset + 1] << 8)
| (bytes[offset + 2] << 16)
| (bytes[offset + 3] << 24);
}
private static void WriteInt32(byte[] bytes, int offset, int value)
{
bytes[offset] = (byte)value;
bytes[offset + 1] = (byte)(value >> 8);
bytes[offset + 2] = (byte)(value >> 16);
bytes[offset + 3] = (byte)(value >> 24);
}
}
}