diff options
| author | Treeki <treeki@gmail.com> | 2014-01-17 07:56:04 +0100 | 
|---|---|---|
| committer | Treeki <treeki@gmail.com> | 2014-01-17 07:56:04 +0100 | 
| commit | f8c12318d40ce7ee5ce811897a11108e8ab391ca (patch) | |
| tree | ac0ce754397435e0e1b309bd3ae5d0fcf5e0e2ea | |
| parent | 3761bfa6d6924b0842331f9fe8b80dca4f2e2450 (diff) | |
| download | bounce4-f8c12318d40ce7ee5ce811897a11108e8ab391ca.tar.gz bounce4-f8c12318d40ce7ee5ce811897a11108e8ab391ca.zip | |
lots of code cleanup, add session resuming, add simple python client for testing
Diffstat (limited to '')
| -rw-r--r-- | core.cpp | 262 | ||||
| -rw-r--r-- | core.h | 99 | ||||
| -rw-r--r-- | python_client.py | 23 | 
3 files changed, 289 insertions, 95 deletions
| @@ -1,5 +1,6 @@  #include <string.h>  #include <stdint.h> +#include <stdlib.h>  #include <stdio.h>  #include <errno.h>  #include <fcntl.h> @@ -10,14 +11,20 @@  #include <sys/select.h>  #include <netinet/in.h>  #include <gnutls/gnutls.h> +#include <list> -#include "buffer.h"  #include "dns.h" +#include "core.h" -#define CLIENT_LIMIT 100 -#define SERVER_LIMIT 20 -#define SESSION_KEEPALIVE 5 +Client *clients[CLIENT_LIMIT]; +Server *servers[SERVER_LIMIT]; +int clientCount, serverCount; +bool quitFlag = false; + +static gnutls_dh_params_t dh_params; +static gnutls_certificate_credentials_t serverCreds, clientCreds; +  static bool setSocketNonBlocking(int sock) { @@ -35,95 +42,31 @@ static bool setSocketNonBlocking(int sock) {  } -struct SocketRWCommon { -	Buffer inputBuf, outputBuf; - -	enum ConnState { -		CS_DISCONNECTED = 0, -		CS_WAITING_DNS = 1, // server only -		CS_WAITING_CONNECT = 2, // server only -		CS_TLS_HANDSHAKE = 3, -		CS_CONNECTED = 4 -	}; -	ConnState state; - -	int sock; -	gnutls_session_t tls; -	bool tlsActive; - -	SocketRWCommon() { -		sock = -1; -		state = CS_DISCONNECTED; -		tlsActive = false; -	} - -	~SocketRWCommon() { -		close(); -	} - -	void tryTLSHandshake(); -	virtual void close(); - -	void readAction(); -	void writeAction(); -	bool hasTlsPendingData() { -		// should un-inline this maybe? -		if (tlsActive) -			return (gnutls_record_check_pending(tls) > 0); -		else -			return false; -	} -private: -	virtual void processReadBuffer() = 0; -}; - -struct Client : SocketRWCommon { -	time_t deadTime; - -	void startService(int _sock, bool withTls); -	void close(); - -private: -	void processReadBuffer(); -	void handleLine(char *line, int size); -}; - -struct Server : SocketRWCommon { -	char ircHostname[256]; -	int ircPort; -	int dnsQueryId; -	bool ircUseTls; - -	Server() { -		dnsQueryId = -1; -		ircUseTls = false; -	} - -	~Server() { -		if (dnsQueryId != -1) -			DNS::closeQuery(dnsQueryId); -	} - -	void beginConnect(); -	void tryConnectPhase(); -	void connectionSuccessful(); - -	void close(); +static Client *findClientWithKey(uint8_t *key) { +	for (int i = 0; i < clientCount; i++) +		if (!memcmp(clients[i]->sessionKey, key, SESSION_KEY_SIZE)) +			return clients[i]; -private: -	void processReadBuffer(); -	void handleLine(char *line, int size); -}; +	return 0; +} -Client *clients[CLIENT_LIMIT]; -Server *servers[SERVER_LIMIT]; -int clientCount, serverCount; -bool quitFlag = false; -static gnutls_dh_params_t dh_params; -static gnutls_certificate_credentials_t serverCreds, clientCreds; +SocketRWCommon::SocketRWCommon() { +	sock = -1; +	state = CS_DISCONNECTED; +	tlsActive = false; +} +SocketRWCommon::~SocketRWCommon() { +	close(); +} +bool SocketRWCommon::hasTlsPendingData() const { +	if (tlsActive) +		return (gnutls_record_check_pending(tls) > 0); +	else +		return false; +}  void SocketRWCommon::tryTLSHandshake() {  	int hsRet = gnutls_handshake(tls); @@ -196,9 +139,10 @@ void SocketRWCommon::readAction() {  		printf("[fd=%d] Read 0! Socket closing.\n", sock);  		close(); -	} else if (amount < 0) +	} else if (amount < 0) {  		perror("Error while reading!"); -	// Close connection in that case, if a fatal error occurs? +		close(); +	}  }  void SocketRWCommon::writeAction() { @@ -220,9 +164,30 @@ void SocketRWCommon::writeAction() {  		outputBuf.trimFromStart(amount);  	} else if (amount == 0)  		printf("Sent 0!\n"); -	else if (amount < 0) +	else if (amount < 0) {  		perror("Error while sending!"); -	// Close connection in that case, if a fatal error occurs? +		close(); +	} +} + + + + +Client::Client() { +	authState = AS_LOGIN_WAIT; +	memset(sessionKey, 0, sizeof(sessionKey)); +	readBufPosition = 0; + +	nextPacketID = 1; +	lastReceivedPacketID = 0; +} +Client::~Client() { +	std::list<Packet *>::iterator +		i = packetCache.begin(), +		  e = packetCache.end(); + +	for (; i != e; ++i) +		delete *i;  } @@ -280,12 +245,46 @@ void Client::startService(int _sock, bool withTls) {  void Client::close() {  	SocketRWCommon::close(); -	// TODO: add canSafelyKeepAlive var, check it here, to kill -	// never-authed conns instantly -	deadTime = time(NULL) + SESSION_KEEPALIVE; + +	if (authState == AS_AUTHED) +		deadTime = time(NULL) + SESSION_KEEPALIVE; +	else +		deadTime = time(NULL) - 1; // kill instantly +} + + +void Client::generateSessionKey() { +	time_t now = time(NULL); + +	while (true) { +		for (int i = 0; i < SESSION_KEY_SIZE; i++) { +			if (i < sizeof(time_t)) +				sessionKey[i] = ((uint8_t*)&now)[i]; +			else +				sessionKey[i] = rand() & 255; +		} + +		// Is any other client already using this key? +		// It's ridiculously unlikely, but... probably best +		// to check just in case! +		bool foundMatch = false; + +		for (int i = 0; i < clientCount; i++) { +			if (clients[i] != this) { +				if (!memcmp(clients[i]->sessionKey, sessionKey, SESSION_KEY_SIZE)) +					foundMatch = true; +			} +		} + +		// If there's none, we can safely leave! +		if (!foundMatch) +			break; +	}  } +  void Client::handleLine(char *line, int size) { +	// This is a terrible mess that will be replaced shortly  	if (strncmp(line, "all ", 4) == 0) {  		for (int i = 0; i < clientCount; i++) {  			clients[i]->outputBuf.append(&line[4], size - 4); @@ -309,6 +308,34 @@ void Client::handleLine(char *line, int size) {  		int sid = line[0] - '0';  		servers[sid]->outputBuf.append(&line[1], size - 1);  		servers[sid]->outputBuf.append("\r\n", 2); +	} else if (strncmp(line, "login", 5) == 0) { +		if (line[5] == 0) { +			// no session key +			generateSessionKey(); +			authState = AS_AUTHED; +			outputBuf.append("OK ", 3); +			for (int i = 0; i < SESSION_KEY_SIZE; i++) { +				char bits[4]; +				sprintf(bits, "%02x", sessionKey[i]); +				outputBuf.append(bits, 2); +			} +			outputBuf.append("\n", 1); +		} else { +			// This is awful. Don't care about writing clean code +			// for something I'm going to throw away shortly... +			uint8_t pkey[SESSION_KEY_SIZE]; +			for (int i = 0; i < SESSION_KEY_SIZE; i++) { +				char highc = line[6 + (i * 2)]; +				char lowc = line[7 + (i * 2)]; +				int high = ((highc >= '0') && (highc <= '9')) ? (highc - '0') : (highc - 'a' + 10); +				int low = ((lowc >= '0') && (lowc <= '9')) ? (lowc - '0') : (lowc - 'a' + 10); +				pkey[i] = (high << 4) | low; +			} + +			Client *other = findClientWithKey(pkey); +			if (other && other->authState == AS_AUTHED) +				other->stealConnection(this); +		}  	}  } @@ -323,6 +350,7 @@ void Client::processReadBuffer() {  		if (buf[pos] == '\r' || buf[pos] == '\n') {  			if (pos > lineBegin) {  				buf[pos] = 0; +				readBufPosition = pos + 1;  				handleLine(&buf[lineBegin], pos - lineBegin);  			} @@ -334,9 +362,48 @@ void Client::processReadBuffer() {  	// If we managed to handle anything, lop it off the buffer  	inputBuf.trimFromStart(pos); +	readBufPosition = 0;  } +void Client::stealConnection(Client *other) { +	close(); + +	inputBuf.clear(); +	inputBuf.append( +			&other->inputBuf.data()[other->readBufPosition], +			other->inputBuf.size() - other->readBufPosition); + +	// Not sure if we need to copy the outputbuf but it can't hurt +	outputBuf.clear(); +	outputBuf.append(other->outputBuf.data(), other->outputBuf.size()); + +	sock = other->sock; +	tls = other->tls; +	tlsActive = other->tlsActive; +	state = other->state; + +	other->sock = -1; +	other->tls = 0; +	other->tlsActive = false; +	other->state = CS_DISCONNECTED; +	other->close(); +} + + + + +Server::Server() { +	dnsQueryId = -1; +	ircUseTls = false; +} +Server::~Server() { +	if (dnsQueryId != -1) +		DNS::closeQuery(dnsQueryId); +} + + +  void Server::handleLine(char *line, int size) {  	for (int i = 0; i < clientCount; i++) {  		clients[i]->outputBuf.append(line, size); @@ -421,6 +488,7 @@ void Server::tryConnectPhase() {  						state = CS_WAITING_CONNECT;  					} else {  						perror("[Server] Could not connect"); +						close();  					}  				} else {  					// Whoa, we're connected? Neat. @@ -474,6 +542,10 @@ int main(int argc, char **argv) {  	clientCount = 0;  	for (int i = 0; i < CLIENT_LIMIT; i++)  		clients[i] = NULL; +	serverCount = 0; +	for (int i = 0; i < SERVER_LIMIT; i++) +		servers[i] = NULL; +  	int ret;  	ret = gnutls_global_init(); @@ -0,0 +1,99 @@ +#ifndef CORE_H +#define CORE_H  + +#include "buffer.h" + +#define CLIENT_LIMIT 100 +#define SERVER_LIMIT 20 + +#define SESSION_KEEPALIVE 30 + +#define SESSION_KEY_SIZE 16 + +struct SocketRWCommon { +	Buffer inputBuf, outputBuf; + +	enum ConnState { +		CS_DISCONNECTED = 0, +		CS_WAITING_DNS = 1, // server only +		CS_WAITING_CONNECT = 2, // server only +		CS_TLS_HANDSHAKE = 3, +		CS_CONNECTED = 4 +	}; +	ConnState state; + +	int sock; +	gnutls_session_t tls; +	bool tlsActive; + +	SocketRWCommon(); +	virtual ~SocketRWCommon(); + +	void tryTLSHandshake(); +	virtual void close(); + +	void readAction(); +	void writeAction(); +	bool hasTlsPendingData() const; +private: +	virtual void processReadBuffer() = 0; +}; + + +struct Packet { +	int type; +	int id; +	Buffer data; +}; + +struct Client : SocketRWCommon { +	enum AuthState { +		AS_LOGIN_WAIT = 0, +		AS_AUTHED = 1 +	}; + +	AuthState authState; +	uint8_t sessionKey[SESSION_KEY_SIZE]; +	time_t deadTime; + +	std::list<Packet *> packetCache; +	int nextPacketID, lastReceivedPacketID; + +	Client(); +	~Client(); + +	void startService(int _sock, bool withTls); +	void close(); + +private: +	int readBufPosition; +	void processReadBuffer(); +	void handleLine(char *line, int size); + +	void generateSessionKey(); + +	void stealConnection(Client *other); +}; + +struct Server : SocketRWCommon { +	char ircHostname[256]; +	int ircPort; +	int dnsQueryId; +	bool ircUseTls; + +	Server(); +	~Server(); + +	void beginConnect(); +	void tryConnectPhase(); +	void connectionSuccessful(); + +	void close(); + +private: +	void processReadBuffer(); +	void handleLine(char *line, int size); +}; + + +#endif /* CORE_H */ diff --git a/python_client.py b/python_client.py new file mode 100644 index 0000000..561a1c3 --- /dev/null +++ b/python_client.py @@ -0,0 +1,23 @@ +import socket, ssl, threading + +basesock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) +basesock.connect(('localhost', 5454)) +sock = ssl.wrap_socket(basesock) + +def reader(): +	print('(Connected)') +	while True: +		data = sock.read() +		if data: +			print(data) +		else: +			print('(Disconnected)') +			break + +thd = threading.Thread(None, reader) +thd.start() + +while True: +	bit = input() +	sock.write((bit + '\n').encode('utf-8')) + | 
