diff options
-rw-r--r-- | core.cpp | 262 | ||||
-rw-r--r-- | core.h | 99 | ||||
-rw-r--r-- | python_client.py | 23 |
3 files changed, 289 insertions, 95 deletions
@@ -1,5 +1,6 @@ #include <string.h> #include <stdint.h> +#include <stdlib.h> #include <stdio.h> #include <errno.h> #include <fcntl.h> @@ -10,14 +11,20 @@ #include <sys/select.h> #include <netinet/in.h> #include <gnutls/gnutls.h> +#include <list> -#include "buffer.h" #include "dns.h" +#include "core.h" -#define CLIENT_LIMIT 100 -#define SERVER_LIMIT 20 -#define SESSION_KEEPALIVE 5 +Client *clients[CLIENT_LIMIT]; +Server *servers[SERVER_LIMIT]; +int clientCount, serverCount; +bool quitFlag = false; + +static gnutls_dh_params_t dh_params; +static gnutls_certificate_credentials_t serverCreds, clientCreds; + static bool setSocketNonBlocking(int sock) { @@ -35,95 +42,31 @@ static bool setSocketNonBlocking(int sock) { } -struct SocketRWCommon { - Buffer inputBuf, outputBuf; - - enum ConnState { - CS_DISCONNECTED = 0, - CS_WAITING_DNS = 1, // server only - CS_WAITING_CONNECT = 2, // server only - CS_TLS_HANDSHAKE = 3, - CS_CONNECTED = 4 - }; - ConnState state; - - int sock; - gnutls_session_t tls; - bool tlsActive; - - SocketRWCommon() { - sock = -1; - state = CS_DISCONNECTED; - tlsActive = false; - } - - ~SocketRWCommon() { - close(); - } - - void tryTLSHandshake(); - virtual void close(); - - void readAction(); - void writeAction(); - bool hasTlsPendingData() { - // should un-inline this maybe? - if (tlsActive) - return (gnutls_record_check_pending(tls) > 0); - else - return false; - } -private: - virtual void processReadBuffer() = 0; -}; - -struct Client : SocketRWCommon { - time_t deadTime; - - void startService(int _sock, bool withTls); - void close(); - -private: - void processReadBuffer(); - void handleLine(char *line, int size); -}; - -struct Server : SocketRWCommon { - char ircHostname[256]; - int ircPort; - int dnsQueryId; - bool ircUseTls; - - Server() { - dnsQueryId = -1; - ircUseTls = false; - } - - ~Server() { - if (dnsQueryId != -1) - DNS::closeQuery(dnsQueryId); - } - - void beginConnect(); - void tryConnectPhase(); - void connectionSuccessful(); - - void close(); +static Client *findClientWithKey(uint8_t *key) { + for (int i = 0; i < clientCount; i++) + if (!memcmp(clients[i]->sessionKey, key, SESSION_KEY_SIZE)) + return clients[i]; -private: - void processReadBuffer(); - void handleLine(char *line, int size); -}; + return 0; +} -Client *clients[CLIENT_LIMIT]; -Server *servers[SERVER_LIMIT]; -int clientCount, serverCount; -bool quitFlag = false; -static gnutls_dh_params_t dh_params; -static gnutls_certificate_credentials_t serverCreds, clientCreds; +SocketRWCommon::SocketRWCommon() { + sock = -1; + state = CS_DISCONNECTED; + tlsActive = false; +} +SocketRWCommon::~SocketRWCommon() { + close(); +} +bool SocketRWCommon::hasTlsPendingData() const { + if (tlsActive) + return (gnutls_record_check_pending(tls) > 0); + else + return false; +} void SocketRWCommon::tryTLSHandshake() { int hsRet = gnutls_handshake(tls); @@ -196,9 +139,10 @@ void SocketRWCommon::readAction() { printf("[fd=%d] Read 0! Socket closing.\n", sock); close(); - } else if (amount < 0) + } else if (amount < 0) { perror("Error while reading!"); - // Close connection in that case, if a fatal error occurs? + close(); + } } void SocketRWCommon::writeAction() { @@ -220,9 +164,30 @@ void SocketRWCommon::writeAction() { outputBuf.trimFromStart(amount); } else if (amount == 0) printf("Sent 0!\n"); - else if (amount < 0) + else if (amount < 0) { perror("Error while sending!"); - // Close connection in that case, if a fatal error occurs? + close(); + } +} + + + + +Client::Client() { + authState = AS_LOGIN_WAIT; + memset(sessionKey, 0, sizeof(sessionKey)); + readBufPosition = 0; + + nextPacketID = 1; + lastReceivedPacketID = 0; +} +Client::~Client() { + std::list<Packet *>::iterator + i = packetCache.begin(), + e = packetCache.end(); + + for (; i != e; ++i) + delete *i; } @@ -280,12 +245,46 @@ void Client::startService(int _sock, bool withTls) { void Client::close() { SocketRWCommon::close(); - // TODO: add canSafelyKeepAlive var, check it here, to kill - // never-authed conns instantly - deadTime = time(NULL) + SESSION_KEEPALIVE; + + if (authState == AS_AUTHED) + deadTime = time(NULL) + SESSION_KEEPALIVE; + else + deadTime = time(NULL) - 1; // kill instantly +} + + +void Client::generateSessionKey() { + time_t now = time(NULL); + + while (true) { + for (int i = 0; i < SESSION_KEY_SIZE; i++) { + if (i < sizeof(time_t)) + sessionKey[i] = ((uint8_t*)&now)[i]; + else + sessionKey[i] = rand() & 255; + } + + // Is any other client already using this key? + // It's ridiculously unlikely, but... probably best + // to check just in case! + bool foundMatch = false; + + for (int i = 0; i < clientCount; i++) { + if (clients[i] != this) { + if (!memcmp(clients[i]->sessionKey, sessionKey, SESSION_KEY_SIZE)) + foundMatch = true; + } + } + + // If there's none, we can safely leave! + if (!foundMatch) + break; + } } + void Client::handleLine(char *line, int size) { + // This is a terrible mess that will be replaced shortly if (strncmp(line, "all ", 4) == 0) { for (int i = 0; i < clientCount; i++) { clients[i]->outputBuf.append(&line[4], size - 4); @@ -309,6 +308,34 @@ void Client::handleLine(char *line, int size) { int sid = line[0] - '0'; servers[sid]->outputBuf.append(&line[1], size - 1); servers[sid]->outputBuf.append("\r\n", 2); + } else if (strncmp(line, "login", 5) == 0) { + if (line[5] == 0) { + // no session key + generateSessionKey(); + authState = AS_AUTHED; + outputBuf.append("OK ", 3); + for (int i = 0; i < SESSION_KEY_SIZE; i++) { + char bits[4]; + sprintf(bits, "%02x", sessionKey[i]); + outputBuf.append(bits, 2); + } + outputBuf.append("\n", 1); + } else { + // This is awful. Don't care about writing clean code + // for something I'm going to throw away shortly... + uint8_t pkey[SESSION_KEY_SIZE]; + for (int i = 0; i < SESSION_KEY_SIZE; i++) { + char highc = line[6 + (i * 2)]; + char lowc = line[7 + (i * 2)]; + int high = ((highc >= '0') && (highc <= '9')) ? (highc - '0') : (highc - 'a' + 10); + int low = ((lowc >= '0') && (lowc <= '9')) ? (lowc - '0') : (lowc - 'a' + 10); + pkey[i] = (high << 4) | low; + } + + Client *other = findClientWithKey(pkey); + if (other && other->authState == AS_AUTHED) + other->stealConnection(this); + } } } @@ -323,6 +350,7 @@ void Client::processReadBuffer() { if (buf[pos] == '\r' || buf[pos] == '\n') { if (pos > lineBegin) { buf[pos] = 0; + readBufPosition = pos + 1; handleLine(&buf[lineBegin], pos - lineBegin); } @@ -334,9 +362,48 @@ void Client::processReadBuffer() { // If we managed to handle anything, lop it off the buffer inputBuf.trimFromStart(pos); + readBufPosition = 0; } +void Client::stealConnection(Client *other) { + close(); + + inputBuf.clear(); + inputBuf.append( + &other->inputBuf.data()[other->readBufPosition], + other->inputBuf.size() - other->readBufPosition); + + // Not sure if we need to copy the outputbuf but it can't hurt + outputBuf.clear(); + outputBuf.append(other->outputBuf.data(), other->outputBuf.size()); + + sock = other->sock; + tls = other->tls; + tlsActive = other->tlsActive; + state = other->state; + + other->sock = -1; + other->tls = 0; + other->tlsActive = false; + other->state = CS_DISCONNECTED; + other->close(); +} + + + + +Server::Server() { + dnsQueryId = -1; + ircUseTls = false; +} +Server::~Server() { + if (dnsQueryId != -1) + DNS::closeQuery(dnsQueryId); +} + + + void Server::handleLine(char *line, int size) { for (int i = 0; i < clientCount; i++) { clients[i]->outputBuf.append(line, size); @@ -421,6 +488,7 @@ void Server::tryConnectPhase() { state = CS_WAITING_CONNECT; } else { perror("[Server] Could not connect"); + close(); } } else { // Whoa, we're connected? Neat. @@ -474,6 +542,10 @@ int main(int argc, char **argv) { clientCount = 0; for (int i = 0; i < CLIENT_LIMIT; i++) clients[i] = NULL; + serverCount = 0; + for (int i = 0; i < SERVER_LIMIT; i++) + servers[i] = NULL; + int ret; ret = gnutls_global_init(); @@ -0,0 +1,99 @@ +#ifndef CORE_H +#define CORE_H + +#include "buffer.h" + +#define CLIENT_LIMIT 100 +#define SERVER_LIMIT 20 + +#define SESSION_KEEPALIVE 30 + +#define SESSION_KEY_SIZE 16 + +struct SocketRWCommon { + Buffer inputBuf, outputBuf; + + enum ConnState { + CS_DISCONNECTED = 0, + CS_WAITING_DNS = 1, // server only + CS_WAITING_CONNECT = 2, // server only + CS_TLS_HANDSHAKE = 3, + CS_CONNECTED = 4 + }; + ConnState state; + + int sock; + gnutls_session_t tls; + bool tlsActive; + + SocketRWCommon(); + virtual ~SocketRWCommon(); + + void tryTLSHandshake(); + virtual void close(); + + void readAction(); + void writeAction(); + bool hasTlsPendingData() const; +private: + virtual void processReadBuffer() = 0; +}; + + +struct Packet { + int type; + int id; + Buffer data; +}; + +struct Client : SocketRWCommon { + enum AuthState { + AS_LOGIN_WAIT = 0, + AS_AUTHED = 1 + }; + + AuthState authState; + uint8_t sessionKey[SESSION_KEY_SIZE]; + time_t deadTime; + + std::list<Packet *> packetCache; + int nextPacketID, lastReceivedPacketID; + + Client(); + ~Client(); + + void startService(int _sock, bool withTls); + void close(); + +private: + int readBufPosition; + void processReadBuffer(); + void handleLine(char *line, int size); + + void generateSessionKey(); + + void stealConnection(Client *other); +}; + +struct Server : SocketRWCommon { + char ircHostname[256]; + int ircPort; + int dnsQueryId; + bool ircUseTls; + + Server(); + ~Server(); + + void beginConnect(); + void tryConnectPhase(); + void connectionSuccessful(); + + void close(); + +private: + void processReadBuffer(); + void handleLine(char *line, int size); +}; + + +#endif /* CORE_H */ diff --git a/python_client.py b/python_client.py new file mode 100644 index 0000000..561a1c3 --- /dev/null +++ b/python_client.py @@ -0,0 +1,23 @@ +import socket, ssl, threading + +basesock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) +basesock.connect(('localhost', 5454)) +sock = ssl.wrap_socket(basesock) + +def reader(): + print('(Connected)') + while True: + data = sock.read() + if data: + print(data) + else: + print('(Disconnected)') + break + +thd = threading.Thread(None, reader) +thd.start() + +while True: + bit = input() + sock.write((bit + '\n').encode('utf-8')) + |