WebSocketServer.cs 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Linq;
  4. using System.Net.Sockets;
  5. using System.Threading;
  6. namespace Mirror.SimpleWeb
  7. {
  8. public class WebSocketServer
  9. {
  10. public readonly ConcurrentQueue<Message> receiveQueue = new ConcurrentQueue<Message>();
  11. readonly TcpConfig tcpConfig;
  12. readonly int maxMessageSize;
  13. TcpListener listener;
  14. Thread acceptThread;
  15. bool serverStopped;
  16. readonly ServerHandshake handShake;
  17. readonly ServerSslHelper sslHelper;
  18. readonly BufferPool bufferPool;
  19. readonly ConcurrentDictionary<int, Connection> connections = new ConcurrentDictionary<int, Connection>();
  20. int _idCounter = 0;
  21. public WebSocketServer(TcpConfig tcpConfig, int maxMessageSize, int handshakeMaxSize, SslConfig sslConfig, BufferPool bufferPool)
  22. {
  23. this.tcpConfig = tcpConfig;
  24. this.maxMessageSize = maxMessageSize;
  25. sslHelper = new ServerSslHelper(sslConfig);
  26. this.bufferPool = bufferPool;
  27. handShake = new ServerHandshake(this.bufferPool, handshakeMaxSize);
  28. }
  29. public void Listen(int port)
  30. {
  31. listener = TcpListener.Create(port);
  32. listener.Start();
  33. Log.Info($"Server has started on port {port}");
  34. acceptThread = new Thread(acceptLoop);
  35. acceptThread.IsBackground = true;
  36. acceptThread.Start();
  37. }
  38. public void Stop()
  39. {
  40. serverStopped = true;
  41. // Interrupt then stop so that Exception is handled correctly
  42. acceptThread?.Interrupt();
  43. listener?.Stop();
  44. acceptThread = null;
  45. Log.Info("Server stopped, Closing all connections...");
  46. // make copy so that foreach doesn't break if values are removed
  47. Connection[] connectionsCopy = connections.Values.ToArray();
  48. foreach (Connection conn in connectionsCopy)
  49. {
  50. conn.Dispose();
  51. }
  52. connections.Clear();
  53. }
  54. void acceptLoop()
  55. {
  56. try
  57. {
  58. try
  59. {
  60. while (true)
  61. {
  62. TcpClient client = listener.AcceptTcpClient();
  63. tcpConfig.ApplyTo(client);
  64. // TODO keep track of connections before they are in connections dictionary
  65. // this might not be a problem as HandshakeAndReceiveLoop checks for stop
  66. // and returns/disposes before sending message to queue
  67. Connection conn = new Connection(client, AfterConnectionDisposed);
  68. Log.Info($"A client connected {conn}");
  69. // handshake needs its own thread as it needs to wait for message from client
  70. Thread receiveThread = new Thread(() => HandshakeAndReceiveLoop(conn));
  71. conn.receiveThread = receiveThread;
  72. receiveThread.IsBackground = true;
  73. receiveThread.Start();
  74. }
  75. }
  76. catch (SocketException)
  77. {
  78. // check for Interrupted/Abort
  79. Utils.CheckForInterupt();
  80. throw;
  81. }
  82. }
  83. catch (ThreadInterruptedException e) { Log.InfoException(e); }
  84. catch (ThreadAbortException e) { Log.InfoException(e); }
  85. catch (Exception e) { Log.Exception(e); }
  86. }
  87. void HandshakeAndReceiveLoop(Connection conn)
  88. {
  89. try
  90. {
  91. bool success = sslHelper.TryCreateStream(conn);
  92. if (!success)
  93. {
  94. Log.Error($"Failed to create SSL Stream {conn}");
  95. conn.Dispose();
  96. return;
  97. }
  98. success = handShake.TryHandshake(conn);
  99. if (success)
  100. {
  101. Log.Info($"Sent Handshake {conn}");
  102. }
  103. else
  104. {
  105. Log.Error($"Handshake Failed {conn}");
  106. conn.Dispose();
  107. return;
  108. }
  109. // check if Stop has been called since accepting this client
  110. if (serverStopped)
  111. {
  112. Log.Info("Server stops after successful handshake");
  113. return;
  114. }
  115. conn.connId = Interlocked.Increment(ref _idCounter);
  116. connections.TryAdd(conn.connId, conn);
  117. receiveQueue.Enqueue(new Message(conn.connId, EventType.Connected));
  118. Thread sendThread = new Thread(() =>
  119. {
  120. SendLoop.Config sendConfig = new SendLoop.Config(
  121. conn,
  122. bufferSize: Constants.HeaderSize + maxMessageSize,
  123. setMask: false);
  124. SendLoop.Loop(sendConfig);
  125. });
  126. conn.sendThread = sendThread;
  127. sendThread.IsBackground = true;
  128. sendThread.Name = $"SendLoop {conn.connId}";
  129. sendThread.Start();
  130. ReceiveLoop.Config receiveConfig = new ReceiveLoop.Config(
  131. conn,
  132. maxMessageSize,
  133. expectMask: true,
  134. receiveQueue,
  135. bufferPool);
  136. ReceiveLoop.Loop(receiveConfig);
  137. }
  138. catch (ThreadInterruptedException e) { Log.InfoException(e); }
  139. catch (ThreadAbortException e) { Log.InfoException(e); }
  140. catch (Exception e) { Log.Exception(e); }
  141. finally
  142. {
  143. // close here in case connect fails
  144. conn.Dispose();
  145. }
  146. }
  147. void AfterConnectionDisposed(Connection conn)
  148. {
  149. if (conn.connId != Connection.IdNotSet)
  150. {
  151. receiveQueue.Enqueue(new Message(conn.connId, EventType.Disconnected));
  152. connections.TryRemove(conn.connId, out Connection _);
  153. }
  154. }
  155. public void Send(int id, ArrayBuffer buffer)
  156. {
  157. if (connections.TryGetValue(id, out Connection conn))
  158. {
  159. conn.sendQueue.Enqueue(buffer);
  160. conn.sendPending.Set();
  161. }
  162. else
  163. {
  164. Log.Warn($"Cant send message to {id} because connection was not found in dictionary. Maybe it disconnected.");
  165. }
  166. }
  167. public bool CloseConnection(int id)
  168. {
  169. if (connections.TryGetValue(id, out Connection conn))
  170. {
  171. Log.Info($"Kicking connection {id}");
  172. conn.Dispose();
  173. return true;
  174. }
  175. else
  176. {
  177. Log.Warn($"Failed to kick {id} because id not found");
  178. return false;
  179. }
  180. }
  181. public string GetClientAddress(int id)
  182. {
  183. if (connections.TryGetValue(id, out Connection conn))
  184. {
  185. return conn.client.Client.RemoteEndPoint.ToString();
  186. }
  187. else
  188. {
  189. Log.Error($"Cant close connection to {id} because connection was not found in dictionary");
  190. return null;
  191. }
  192. }
  193. }
  194. }