ServerHandshake.cs 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. using System;
  2. using System.IO;
  3. using System.Security.Cryptography;
  4. using System.Text;
  5. namespace Mirror.SimpleWeb
  6. {
  7. /// <summary>
  8. /// Handles Handshakes from new clients on the server
  9. /// <para>The server handshake has buffers to reduce allocations when clients connect</para>
  10. /// </summary>
  11. internal class ServerHandshake
  12. {
  13. const int GetSize = 3;
  14. const int ResponseLength = 129;
  15. const int KeyLength = 24;
  16. const int MergedKeyLength = 60;
  17. const string KeyHeaderString = "Sec-WebSocket-Key: ";
  18. // this isn't an official max, just a reasonable size for a websocket handshake
  19. readonly int maxHttpHeaderSize = 3000;
  20. readonly SHA1 sha1 = SHA1.Create();
  21. readonly BufferPool bufferPool;
  22. public ServerHandshake(BufferPool bufferPool, int handshakeMaxSize)
  23. {
  24. this.bufferPool = bufferPool;
  25. this.maxHttpHeaderSize = handshakeMaxSize;
  26. }
  27. ~ServerHandshake()
  28. {
  29. sha1.Dispose();
  30. }
  31. public bool TryHandshake(Connection conn)
  32. {
  33. Stream stream = conn.stream;
  34. using (ArrayBuffer getHeader = bufferPool.Take(GetSize))
  35. {
  36. if (!ReadHelper.TryRead(stream, getHeader.array, 0, GetSize))
  37. return false;
  38. getHeader.count = GetSize;
  39. if (!IsGet(getHeader.array))
  40. {
  41. Log.Warn($"First bytes from client was not 'GET' for handshake, instead was {Log.BufferToString(getHeader.array, 0, GetSize)}");
  42. return false;
  43. }
  44. }
  45. string msg = ReadToEndForHandshake(stream);
  46. if (string.IsNullOrEmpty(msg))
  47. return false;
  48. try
  49. {
  50. AcceptHandshake(stream, msg);
  51. return true;
  52. }
  53. catch (ArgumentException e)
  54. {
  55. Log.InfoException(e);
  56. return false;
  57. }
  58. }
  59. string ReadToEndForHandshake(Stream stream)
  60. {
  61. using (ArrayBuffer readBuffer = bufferPool.Take(maxHttpHeaderSize))
  62. {
  63. int? readCountOrFail = ReadHelper.SafeReadTillMatch(stream, readBuffer.array, 0, maxHttpHeaderSize, Constants.endOfHandshake);
  64. if (!readCountOrFail.HasValue)
  65. return null;
  66. int readCount = readCountOrFail.Value;
  67. string msg = Encoding.ASCII.GetString(readBuffer.array, 0, readCount);
  68. Log.Verbose(msg);
  69. return msg;
  70. }
  71. }
  72. static bool IsGet(byte[] getHeader)
  73. {
  74. // just check bytes here instead of using Encoding.ASCII
  75. return getHeader[0] == 71 && // G
  76. getHeader[1] == 69 && // E
  77. getHeader[2] == 84; // T
  78. }
  79. void AcceptHandshake(Stream stream, string msg)
  80. {
  81. using (
  82. ArrayBuffer keyBuffer = bufferPool.Take(KeyLength),
  83. responseBuffer = bufferPool.Take(ResponseLength))
  84. {
  85. GetKey(msg, keyBuffer.array);
  86. AppendGuid(keyBuffer.array);
  87. byte[] keyHash = CreateHash(keyBuffer.array);
  88. CreateResponse(keyHash, responseBuffer.array);
  89. stream.Write(responseBuffer.array, 0, ResponseLength);
  90. }
  91. }
  92. static void GetKey(string msg, byte[] keyBuffer)
  93. {
  94. int start = msg.IndexOf(KeyHeaderString) + KeyHeaderString.Length;
  95. Log.Verbose($"Handshake Key: {msg.Substring(start, KeyLength)}");
  96. Encoding.ASCII.GetBytes(msg, start, KeyLength, keyBuffer, 0);
  97. }
  98. static void AppendGuid(byte[] keyBuffer)
  99. {
  100. Buffer.BlockCopy(Constants.HandshakeGUIDBytes, 0, keyBuffer, KeyLength, Constants.HandshakeGUID.Length);
  101. }
  102. byte[] CreateHash(byte[] keyBuffer)
  103. {
  104. Log.Verbose($"Handshake Hashing {Encoding.ASCII.GetString(keyBuffer, 0, MergedKeyLength)}");
  105. return sha1.ComputeHash(keyBuffer, 0, MergedKeyLength);
  106. }
  107. static void CreateResponse(byte[] keyHash, byte[] responseBuffer)
  108. {
  109. string keyHashString = Convert.ToBase64String(keyHash);
  110. // compiler should merge these strings into 1 string before format
  111. string message = string.Format(
  112. "HTTP/1.1 101 Switching Protocols\r\n" +
  113. "Connection: Upgrade\r\n" +
  114. "Upgrade: websocket\r\n" +
  115. "Sec-WebSocket-Accept: {0}\r\n\r\n",
  116. keyHashString);
  117. Log.Verbose($"Handshake Response length {message.Length}, IsExpected {message.Length == ResponseLength}");
  118. Encoding.ASCII.GetBytes(message, 0, ResponseLength, responseBuffer, 0);
  119. }
  120. }
  121. }