From f8c12318d40ce7ee5ce811897a11108e8ab391ca Mon Sep 17 00:00:00 2001
From: Treeki <treeki@gmail.com>
Date: Fri, 17 Jan 2014 07:56:04 +0100
Subject: lots of code cleanup, add session resuming, add simple python client
 for testing

---
 core.cpp         | 262 +++++++++++++++++++++++++++++++++++--------------------
 core.h           |  99 +++++++++++++++++++++
 python_client.py |  23 +++++
 3 files changed, 289 insertions(+), 95 deletions(-)
 create mode 100644 core.h
 create mode 100644 python_client.py

diff --git a/core.cpp b/core.cpp
index a27510e..f969033 100644
--- a/core.cpp
+++ b/core.cpp
@@ -1,5 +1,6 @@
 #include <string.h>
 #include <stdint.h>
+#include <stdlib.h>
 #include <stdio.h>
 #include <errno.h>
 #include <fcntl.h>
@@ -10,14 +11,20 @@
 #include <sys/select.h>
 #include <netinet/in.h>
 #include <gnutls/gnutls.h>
+#include <list>
 
-#include "buffer.h"
 #include "dns.h"
+#include "core.h"
 
-#define CLIENT_LIMIT 100
-#define SERVER_LIMIT 20
 
-#define SESSION_KEEPALIVE 5
+Client *clients[CLIENT_LIMIT];
+Server *servers[SERVER_LIMIT];
+int clientCount, serverCount;
+bool quitFlag = false;
+
+static gnutls_dh_params_t dh_params;
+static gnutls_certificate_credentials_t serverCreds, clientCreds;
+
 
 
 static bool setSocketNonBlocking(int sock) {
@@ -35,95 +42,31 @@ static bool setSocketNonBlocking(int sock) {
 }
 
 
-struct SocketRWCommon {
-	Buffer inputBuf, outputBuf;
-
-	enum ConnState {
-		CS_DISCONNECTED = 0,
-		CS_WAITING_DNS = 1, // server only
-		CS_WAITING_CONNECT = 2, // server only
-		CS_TLS_HANDSHAKE = 3,
-		CS_CONNECTED = 4
-	};
-	ConnState state;
-
-	int sock;
-	gnutls_session_t tls;
-	bool tlsActive;
-
-	SocketRWCommon() {
-		sock = -1;
-		state = CS_DISCONNECTED;
-		tlsActive = false;
-	}
-
-	~SocketRWCommon() {
-		close();
-	}
-
-	void tryTLSHandshake();
-	virtual void close();
-
-	void readAction();
-	void writeAction();
-	bool hasTlsPendingData() {
-		// should un-inline this maybe?
-		if (tlsActive)
-			return (gnutls_record_check_pending(tls) > 0);
-		else
-			return false;
-	}
-private:
-	virtual void processReadBuffer() = 0;
-};
-
-struct Client : SocketRWCommon {
-	time_t deadTime;
-
-	void startService(int _sock, bool withTls);
-	void close();
-
-private:
-	void processReadBuffer();
-	void handleLine(char *line, int size);
-};
-
-struct Server : SocketRWCommon {
-	char ircHostname[256];
-	int ircPort;
-	int dnsQueryId;
-	bool ircUseTls;
-
-	Server() {
-		dnsQueryId = -1;
-		ircUseTls = false;
-	}
-
-	~Server() {
-		if (dnsQueryId != -1)
-			DNS::closeQuery(dnsQueryId);
-	}
-
-	void beginConnect();
-	void tryConnectPhase();
-	void connectionSuccessful();
-
-	void close();
+static Client *findClientWithKey(uint8_t *key) {
+	for (int i = 0; i < clientCount; i++)
+		if (!memcmp(clients[i]->sessionKey, key, SESSION_KEY_SIZE))
+			return clients[i];
 
-private:
-	void processReadBuffer();
-	void handleLine(char *line, int size);
-};
+	return 0;
+}
 
 
-Client *clients[CLIENT_LIMIT];
-Server *servers[SERVER_LIMIT];
-int clientCount, serverCount;
-bool quitFlag = false;
 
-static gnutls_dh_params_t dh_params;
-static gnutls_certificate_credentials_t serverCreds, clientCreds;
+SocketRWCommon::SocketRWCommon() {
+	sock = -1;
+	state = CS_DISCONNECTED;
+	tlsActive = false;
+}
+SocketRWCommon::~SocketRWCommon() {
+	close();
+}
 
+bool SocketRWCommon::hasTlsPendingData() const {
+	if (tlsActive)
+		return (gnutls_record_check_pending(tls) > 0);
+	else
+		return false;
+}
 
 void SocketRWCommon::tryTLSHandshake() {
 	int hsRet = gnutls_handshake(tls);
@@ -196,9 +139,10 @@ void SocketRWCommon::readAction() {
 		printf("[fd=%d] Read 0! Socket closing.\n", sock);
 		close();
 
-	} else if (amount < 0)
+	} else if (amount < 0) {
 		perror("Error while reading!");
-	// Close connection in that case, if a fatal error occurs?
+		close();
+	}
 }
 
 void SocketRWCommon::writeAction() {
@@ -220,9 +164,30 @@ void SocketRWCommon::writeAction() {
 		outputBuf.trimFromStart(amount);
 	} else if (amount == 0)
 		printf("Sent 0!\n");
-	else if (amount < 0)
+	else if (amount < 0) {
 		perror("Error while sending!");
-	// Close connection in that case, if a fatal error occurs?
+		close();
+	}
+}
+
+
+
+
+Client::Client() {
+	authState = AS_LOGIN_WAIT;
+	memset(sessionKey, 0, sizeof(sessionKey));
+	readBufPosition = 0;
+
+	nextPacketID = 1;
+	lastReceivedPacketID = 0;
+}
+Client::~Client() {
+	std::list<Packet *>::iterator
+		i = packetCache.begin(),
+		  e = packetCache.end();
+
+	for (; i != e; ++i)
+		delete *i;
 }
 
 
@@ -280,12 +245,46 @@ void Client::startService(int _sock, bool withTls) {
 
 void Client::close() {
 	SocketRWCommon::close();
-	// TODO: add canSafelyKeepAlive var, check it here, to kill
-	// never-authed conns instantly
-	deadTime = time(NULL) + SESSION_KEEPALIVE;
+
+	if (authState == AS_AUTHED)
+		deadTime = time(NULL) + SESSION_KEEPALIVE;
+	else
+		deadTime = time(NULL) - 1; // kill instantly
+}
+
+
+void Client::generateSessionKey() {
+	time_t now = time(NULL);
+
+	while (true) {
+		for (int i = 0; i < SESSION_KEY_SIZE; i++) {
+			if (i < sizeof(time_t))
+				sessionKey[i] = ((uint8_t*)&now)[i];
+			else
+				sessionKey[i] = rand() & 255;
+		}
+
+		// Is any other client already using this key?
+		// It's ridiculously unlikely, but... probably best
+		// to check just in case!
+		bool foundMatch = false;
+
+		for (int i = 0; i < clientCount; i++) {
+			if (clients[i] != this) {
+				if (!memcmp(clients[i]->sessionKey, sessionKey, SESSION_KEY_SIZE))
+					foundMatch = true;
+			}
+		}
+
+		// If there's none, we can safely leave!
+		if (!foundMatch)
+			break;
+	}
 }
 
+
 void Client::handleLine(char *line, int size) {
+	// This is a terrible mess that will be replaced shortly
 	if (strncmp(line, "all ", 4) == 0) {
 		for (int i = 0; i < clientCount; i++) {
 			clients[i]->outputBuf.append(&line[4], size - 4);
@@ -309,6 +308,34 @@ void Client::handleLine(char *line, int size) {
 		int sid = line[0] - '0';
 		servers[sid]->outputBuf.append(&line[1], size - 1);
 		servers[sid]->outputBuf.append("\r\n", 2);
+	} else if (strncmp(line, "login", 5) == 0) {
+		if (line[5] == 0) {
+			// no session key
+			generateSessionKey();
+			authState = AS_AUTHED;
+			outputBuf.append("OK ", 3);
+			for (int i = 0; i < SESSION_KEY_SIZE; i++) {
+				char bits[4];
+				sprintf(bits, "%02x", sessionKey[i]);
+				outputBuf.append(bits, 2);
+			}
+			outputBuf.append("\n", 1);
+		} else {
+			// This is awful. Don't care about writing clean code
+			// for something I'm going to throw away shortly...
+			uint8_t pkey[SESSION_KEY_SIZE];
+			for (int i = 0; i < SESSION_KEY_SIZE; i++) {
+				char highc = line[6 + (i * 2)];
+				char lowc = line[7 + (i * 2)];
+				int high = ((highc >= '0') && (highc <= '9')) ? (highc - '0') : (highc - 'a' + 10);
+				int low = ((lowc >= '0') && (lowc <= '9')) ? (lowc - '0') : (lowc - 'a' + 10);
+				pkey[i] = (high << 4) | low;
+			}
+
+			Client *other = findClientWithKey(pkey);
+			if (other && other->authState == AS_AUTHED)
+				other->stealConnection(this);
+		}
 	}
 }
 
@@ -323,6 +350,7 @@ void Client::processReadBuffer() {
 		if (buf[pos] == '\r' || buf[pos] == '\n') {
 			if (pos > lineBegin) {
 				buf[pos] = 0;
+				readBufPosition = pos + 1;
 				handleLine(&buf[lineBegin], pos - lineBegin);
 			}
 
@@ -334,9 +362,48 @@ void Client::processReadBuffer() {
 
 	// If we managed to handle anything, lop it off the buffer
 	inputBuf.trimFromStart(pos);
+	readBufPosition = 0;
 }
 
 
+void Client::stealConnection(Client *other) {
+	close();
+
+	inputBuf.clear();
+	inputBuf.append(
+			&other->inputBuf.data()[other->readBufPosition],
+			other->inputBuf.size() - other->readBufPosition);
+
+	// Not sure if we need to copy the outputbuf but it can't hurt
+	outputBuf.clear();
+	outputBuf.append(other->outputBuf.data(), other->outputBuf.size());
+
+	sock = other->sock;
+	tls = other->tls;
+	tlsActive = other->tlsActive;
+	state = other->state;
+
+	other->sock = -1;
+	other->tls = 0;
+	other->tlsActive = false;
+	other->state = CS_DISCONNECTED;
+	other->close();
+}
+
+
+
+
+Server::Server() {
+	dnsQueryId = -1;
+	ircUseTls = false;
+}
+Server::~Server() {
+	if (dnsQueryId != -1)
+		DNS::closeQuery(dnsQueryId);
+}
+
+
+
 void Server::handleLine(char *line, int size) {
 	for (int i = 0; i < clientCount; i++) {
 		clients[i]->outputBuf.append(line, size);
@@ -421,6 +488,7 @@ void Server::tryConnectPhase() {
 						state = CS_WAITING_CONNECT;
 					} else {
 						perror("[Server] Could not connect");
+						close();
 					}
 				} else {
 					// Whoa, we're connected? Neat.
@@ -474,6 +542,10 @@ int main(int argc, char **argv) {
 	clientCount = 0;
 	for (int i = 0; i < CLIENT_LIMIT; i++)
 		clients[i] = NULL;
+	serverCount = 0;
+	for (int i = 0; i < SERVER_LIMIT; i++)
+		servers[i] = NULL;
+
 
 	int ret;
 	ret = gnutls_global_init();
diff --git a/core.h b/core.h
new file mode 100644
index 0000000..e38d0f6
--- /dev/null
+++ b/core.h
@@ -0,0 +1,99 @@
+#ifndef CORE_H
+#define CORE_H 
+
+#include "buffer.h"
+
+#define CLIENT_LIMIT 100
+#define SERVER_LIMIT 20
+
+#define SESSION_KEEPALIVE 30
+
+#define SESSION_KEY_SIZE 16
+
+struct SocketRWCommon {
+	Buffer inputBuf, outputBuf;
+
+	enum ConnState {
+		CS_DISCONNECTED = 0,
+		CS_WAITING_DNS = 1, // server only
+		CS_WAITING_CONNECT = 2, // server only
+		CS_TLS_HANDSHAKE = 3,
+		CS_CONNECTED = 4
+	};
+	ConnState state;
+
+	int sock;
+	gnutls_session_t tls;
+	bool tlsActive;
+
+	SocketRWCommon();
+	virtual ~SocketRWCommon();
+
+	void tryTLSHandshake();
+	virtual void close();
+
+	void readAction();
+	void writeAction();
+	bool hasTlsPendingData() const;
+private:
+	virtual void processReadBuffer() = 0;
+};
+
+
+struct Packet {
+	int type;
+	int id;
+	Buffer data;
+};
+
+struct Client : SocketRWCommon {
+	enum AuthState {
+		AS_LOGIN_WAIT = 0,
+		AS_AUTHED = 1
+	};
+
+	AuthState authState;
+	uint8_t sessionKey[SESSION_KEY_SIZE];
+	time_t deadTime;
+
+	std::list<Packet *> packetCache;
+	int nextPacketID, lastReceivedPacketID;
+
+	Client();
+	~Client();
+
+	void startService(int _sock, bool withTls);
+	void close();
+
+private:
+	int readBufPosition;
+	void processReadBuffer();
+	void handleLine(char *line, int size);
+
+	void generateSessionKey();
+
+	void stealConnection(Client *other);
+};
+
+struct Server : SocketRWCommon {
+	char ircHostname[256];
+	int ircPort;
+	int dnsQueryId;
+	bool ircUseTls;
+
+	Server();
+	~Server();
+
+	void beginConnect();
+	void tryConnectPhase();
+	void connectionSuccessful();
+
+	void close();
+
+private:
+	void processReadBuffer();
+	void handleLine(char *line, int size);
+};
+
+
+#endif /* CORE_H */
diff --git a/python_client.py b/python_client.py
new file mode 100644
index 0000000..561a1c3
--- /dev/null
+++ b/python_client.py
@@ -0,0 +1,23 @@
+import socket, ssl, threading
+
+basesock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+basesock.connect(('localhost', 5454))
+sock = ssl.wrap_socket(basesock)
+
+def reader():
+	print('(Connected)')
+	while True:
+		data = sock.read()
+		if data:
+			print(data)
+		else:
+			print('(Disconnected)')
+			break
+
+thd = threading.Thread(None, reader)
+thd.start()
+
+while True:
+	bit = input()
+	sock.write((bit + '\n').encode('utf-8'))
+
-- 
cgit v1.2.3