123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230 |
- using System;
- using System.Collections.Concurrent;
- using System.Linq;
- using System.Net.Sockets;
- using System.Threading;
- namespace Mirror.SimpleWeb
- {
- public class WebSocketServer
- {
- public readonly ConcurrentQueue<Message> receiveQueue = new ConcurrentQueue<Message>();
- readonly TcpConfig tcpConfig;
- readonly int maxMessageSize;
- TcpListener listener;
- Thread acceptThread;
- bool serverStopped;
- readonly ServerHandshake handShake;
- readonly ServerSslHelper sslHelper;
- readonly BufferPool bufferPool;
- readonly ConcurrentDictionary<int, Connection> connections = new ConcurrentDictionary<int, Connection>();
- int _idCounter = 0;
- public WebSocketServer(TcpConfig tcpConfig, int maxMessageSize, int handshakeMaxSize, SslConfig sslConfig, BufferPool bufferPool)
- {
- this.tcpConfig = tcpConfig;
- this.maxMessageSize = maxMessageSize;
- sslHelper = new ServerSslHelper(sslConfig);
- this.bufferPool = bufferPool;
- handShake = new ServerHandshake(this.bufferPool, handshakeMaxSize);
- }
- public void Listen(int port)
- {
- listener = TcpListener.Create(port);
- listener.Start();
- Log.Info($"Server has started on port {port}");
- acceptThread = new Thread(acceptLoop);
- acceptThread.IsBackground = true;
- acceptThread.Start();
- }
- public void Stop()
- {
- serverStopped = true;
- // Interrupt then stop so that Exception is handled correctly
- acceptThread?.Interrupt();
- listener?.Stop();
- acceptThread = null;
- Log.Info("Server stopped, Closing all connections...");
- // make copy so that foreach doesn't break if values are removed
- Connection[] connectionsCopy = connections.Values.ToArray();
- foreach (Connection conn in connectionsCopy)
- {
- conn.Dispose();
- }
- connections.Clear();
- }
- void acceptLoop()
- {
- try
- {
- try
- {
- while (true)
- {
- TcpClient client = listener.AcceptTcpClient();
- tcpConfig.ApplyTo(client);
- // TODO keep track of connections before they are in connections dictionary
- // this might not be a problem as HandshakeAndReceiveLoop checks for stop
- // and returns/disposes before sending message to queue
- Connection conn = new Connection(client, AfterConnectionDisposed);
- Log.Info($"A client connected {conn}");
- // handshake needs its own thread as it needs to wait for message from client
- Thread receiveThread = new Thread(() => HandshakeAndReceiveLoop(conn));
- conn.receiveThread = receiveThread;
- receiveThread.IsBackground = true;
- receiveThread.Start();
- }
- }
- catch (SocketException)
- {
- // check for Interrupted/Abort
- Utils.CheckForInterupt();
- throw;
- }
- }
- catch (ThreadInterruptedException e) { Log.InfoException(e); }
- catch (ThreadAbortException e) { Log.InfoException(e); }
- catch (Exception e) { Log.Exception(e); }
- }
- void HandshakeAndReceiveLoop(Connection conn)
- {
- try
- {
- bool success = sslHelper.TryCreateStream(conn);
- if (!success)
- {
- Log.Error($"Failed to create SSL Stream {conn}");
- conn.Dispose();
- return;
- }
- success = handShake.TryHandshake(conn);
- if (success)
- {
- Log.Info($"Sent Handshake {conn}");
- }
- else
- {
- Log.Error($"Handshake Failed {conn}");
- conn.Dispose();
- return;
- }
- // check if Stop has been called since accepting this client
- if (serverStopped)
- {
- Log.Info("Server stops after successful handshake");
- return;
- }
- conn.connId = Interlocked.Increment(ref _idCounter);
- connections.TryAdd(conn.connId, conn);
- receiveQueue.Enqueue(new Message(conn.connId, EventType.Connected));
- Thread sendThread = new Thread(() =>
- {
- SendLoop.Config sendConfig = new SendLoop.Config(
- conn,
- bufferSize: Constants.HeaderSize + maxMessageSize,
- setMask: false);
- SendLoop.Loop(sendConfig);
- });
- conn.sendThread = sendThread;
- sendThread.IsBackground = true;
- sendThread.Name = $"SendLoop {conn.connId}";
- sendThread.Start();
- ReceiveLoop.Config receiveConfig = new ReceiveLoop.Config(
- conn,
- maxMessageSize,
- expectMask: true,
- receiveQueue,
- bufferPool);
- ReceiveLoop.Loop(receiveConfig);
- }
- catch (ThreadInterruptedException e) { Log.InfoException(e); }
- catch (ThreadAbortException e) { Log.InfoException(e); }
- catch (Exception e) { Log.Exception(e); }
- finally
- {
- // close here in case connect fails
- conn.Dispose();
- }
- }
- void AfterConnectionDisposed(Connection conn)
- {
- if (conn.connId != Connection.IdNotSet)
- {
- receiveQueue.Enqueue(new Message(conn.connId, EventType.Disconnected));
- connections.TryRemove(conn.connId, out Connection _);
- }
- }
- public void Send(int id, ArrayBuffer buffer)
- {
- if (connections.TryGetValue(id, out Connection conn))
- {
- conn.sendQueue.Enqueue(buffer);
- conn.sendPending.Set();
- }
- else
- {
- Log.Warn($"Cant send message to {id} because connection was not found in dictionary. Maybe it disconnected.");
- }
- }
- public bool CloseConnection(int id)
- {
- if (connections.TryGetValue(id, out Connection conn))
- {
- Log.Info($"Kicking connection {id}");
- conn.Dispose();
- return true;
- }
- else
- {
- Log.Warn($"Failed to kick {id} because id not found");
- return false;
- }
- }
- public string GetClientAddress(int id)
- {
- if (connections.TryGetValue(id, out Connection conn))
- {
- return conn.client.Client.RemoteEndPoint.ToString();
- }
- else
- {
- Log.Error($"Cant close connection to {id} because connection was not found in dictionary");
- return null;
- }
- }
- }
- }
|