summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTreeki <treeki@gmail.com>2014-01-16 20:19:00 +0100
committerTreeki <treeki@gmail.com>2014-01-16 20:19:00 +0100
commita47ea91881c3cca4c8e3bdd3ea13d09886cafa94 (patch)
tree91e9e6ebcbe08044f9594a577b9ad83b186d7868
downloadbounce4-a47ea91881c3cca4c8e3bdd3ea13d09886cafa94.tar.gz
bounce4-a47ea91881c3cca4c8e3bdd3ea13d09886cafa94.zip
initial commit of stuff
-rw-r--r--.gitignore3
-rw-r--r--buffer.h103
-rwxr-xr-xbuild.sh4
-rw-r--r--core.cpp745
-rw-r--r--dns.cpp159
-rw-r--r--dns.h15
6 files changed, 1029 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..ca63371
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,3 @@
+binary
+*.crt
+*.key
diff --git a/buffer.h b/buffer.h
new file mode 100644
index 0000000..ef6d027
--- /dev/null
+++ b/buffer.h
@@ -0,0 +1,103 @@
+#ifndef BUFFER_H
+#define BUFFER_H
+
+#include <string.h>
+#include <stdint.h>
+
+class Buffer {
+private:
+ char *m_data;
+ bool m_freeBuffer;
+ int m_size;
+ int m_capacity;
+ char m_preAllocBuffer[0x200];
+
+public:
+ Buffer() {
+ m_data = m_preAllocBuffer;
+ m_freeBuffer = false;
+ m_size = 0;
+ m_capacity = sizeof(m_preAllocBuffer);
+ }
+
+ ~Buffer() {
+ if ((m_data != NULL) && m_freeBuffer) {
+ delete[] m_data;
+ m_data = NULL;
+ }
+ }
+
+ char *data() const { return m_data; }
+ int size() const { return m_size; }
+ int capacity() const { return m_capacity; }
+
+ void setCapacity(int capacity) {
+ if (capacity == m_capacity)
+ return;
+
+ // Trim the size down if it's too big to fit
+ if (m_size > capacity)
+ m_size = capacity;
+
+ char *newBuf = new char[capacity];
+
+ if (m_data != NULL) {
+ memcpy(newBuf, m_data, m_size);
+ if (m_freeBuffer)
+ delete[] m_data;
+ }
+
+ m_data = newBuf;
+ m_capacity = capacity;
+ m_freeBuffer = true;
+ }
+
+ void clear() {
+ m_size = 0;
+ }
+ void append(const char *data, int size) {
+ if (size <= 0)
+ return;
+
+ int requiredSize = m_size + size;
+ if (requiredSize > m_capacity)
+ setCapacity(requiredSize + 0x100);
+
+ memcpy(&m_data[m_size], data, size);
+ m_size += size;
+ }
+ void resize(int size) {
+ if (size > m_capacity)
+ setCapacity(size + 0x100);
+ m_size = size;
+ }
+
+ void trimFromStart(int amount) {
+ if (amount <= 0)
+ return;
+ if (amount >= m_size) {
+ clear();
+ return;
+ }
+
+ memmove(m_data, &m_data[amount], m_size - amount);
+ m_size -= amount;
+ }
+
+
+ void writeU32(uint32_t v) { append((const char *)&v, 4); }
+ void writeU16(uint32_t v) { append((const char *)&v, 2); }
+ void writeU8(uint32_t v) { append((const char *)&v, 1); }
+ void writeS32(uint32_t v) { append((const char *)&v, 4); }
+ void writeS16(uint32_t v) { append((const char *)&v, 2); }
+ void writeS8(uint32_t v) { append((const char *)&v, 1); }
+
+ void writeStr(const char *data, int size = -1) {
+ if (size == -1)
+ size = strlen(data);
+ writeU32(size);
+ append(data, size);
+ }
+};
+
+#endif /* BUFFER_H */
diff --git a/build.sh b/build.sh
new file mode 100755
index 0000000..58d6365
--- /dev/null
+++ b/build.sh
@@ -0,0 +1,4 @@
+#!/bin/sh
+mkdir binary
+g++ -o binary/nb4 core.cpp dns.cpp -lgnutls -pthread
+
diff --git a/core.cpp b/core.cpp
new file mode 100644
index 0000000..472cdee
--- /dev/null
+++ b/core.cpp
@@ -0,0 +1,745 @@
+#include <string.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <unistd.h>
+#include <time.h>
+#include <sys/time.h>
+#include <sys/socket.h>
+#include <sys/select.h>
+#include <netinet/in.h>
+#include <gnutls/gnutls.h>
+
+#include "buffer.h"
+#include "dns.h"
+
+#define CLIENT_LIMIT 100
+#define SERVER_LIMIT 20
+
+#define SESSION_KEEPALIVE 5
+
+struct SocketRWCommon {
+ Buffer inputBuf, outputBuf;
+
+ enum ConnState {
+ CS_DISCONNECTED = 0,
+ CS_WAITING_DNS = 1, // server only
+ CS_WAITING_CONNECT = 2, // server only
+ CS_TLS_HANDSHAKE = 3,
+ CS_CONNECTED = 4
+ };
+ ConnState state;
+
+ int sock;
+ gnutls_session_t tls;
+ bool isTls, gnutlsSessionInited;
+
+ SocketRWCommon() {
+ sock = -1;
+ state = CS_DISCONNECTED;
+ isTls = false;
+ gnutlsSessionInited = false;
+ }
+
+ ~SocketRWCommon() {
+ close();
+ }
+
+ void tryTLSHandshake();
+ virtual void close();
+};
+
+struct Client : SocketRWCommon {
+ time_t deadTime;
+
+ void startService(int _sock, bool withTls);
+ void processInput();
+ void close();
+
+private:
+ void handleLine(char *line, int size);
+};
+
+struct Server : SocketRWCommon {
+ char hostname[256];
+ int port;
+ int dnsQueryId;
+
+ Server() {
+ dnsQueryId = -1;
+ }
+
+ ~Server() {
+ if (dnsQueryId != -1)
+ DNS::closeQuery(dnsQueryId);
+ }
+
+ void beginConnect();
+ void tryConnectPhase();
+ void connectionSuccessful();
+
+ void close();
+
+ void processInput();
+private:
+ void handleLine(char *line, int size);
+};
+
+
+Client *clients[CLIENT_LIMIT];
+Server *servers[SERVER_LIMIT];
+int clientCount, serverCount;
+bool quitFlag = false;
+
+static gnutls_dh_params_t dh_params;
+static gnutls_certificate_credentials_t serverCreds, clientCreds;
+
+
+void SocketRWCommon::tryTLSHandshake() {
+ int hsRet = gnutls_handshake(tls);
+ if (gnutls_error_is_fatal(hsRet)) {
+ printf("[SocketRWCommon::tryTLSHandshake] gnutls_handshake borked\n");
+ gnutls_perror(hsRet);
+ close();
+ return;
+ }
+
+ if (hsRet == GNUTLS_E_SUCCESS) {
+ // We're in !!
+ state = CS_CONNECTED;
+
+ inputBuf.clear();
+ outputBuf.clear();
+
+ printf("[SocketRWCommon connected via SSL!]\n");
+ }
+}
+
+void SocketRWCommon::close() {
+ if (sock != -1) {
+ if (gnutlsSessionInited)
+ gnutls_bye(tls, GNUTLS_SHUT_RDWR);
+ shutdown(sock, SHUT_RDWR);
+ ::close(sock);
+ }
+
+ sock = -1;
+ inputBuf.clear();
+ outputBuf.clear();
+ state = CS_DISCONNECTED;
+
+ if (gnutlsSessionInited) {
+ gnutls_deinit(tls);
+ gnutlsSessionInited = false;
+ }
+}
+
+
+void Client::startService(int _sock, bool withTls) {
+ close();
+
+ sock = _sock;
+
+ if (withTls) {
+ int initRet = gnutls_init(&tls, GNUTLS_SERVER);
+ if (initRet != GNUTLS_E_SUCCESS) {
+ printf("[Client::startService] gnutls_init borked\n");
+ gnutls_perror(initRet);
+ close();
+ return;
+ }
+
+ // TODO: error check this
+ int ret;
+ const char *errPos;
+
+ ret = gnutls_priority_set_direct(tls, "PERFORMANCE:%SERVER_PRECEDENCE", &errPos);
+ if (ret != GNUTLS_E_SUCCESS) {
+ printf("gnutls_priority_set_direct failure: %s\n", gnutls_strerror(ret));
+ close();
+ return;
+ }
+
+ ret = gnutls_credentials_set(tls, GNUTLS_CRD_CERTIFICATE, clientCreds);
+ if (ret != GNUTLS_E_SUCCESS) {
+ printf("gnutls_credentials_set failure: %s\n", gnutls_strerror(ret));
+ close();
+ return;
+ }
+
+ gnutls_certificate_server_set_request(tls, GNUTLS_CERT_IGNORE);
+
+ gnutls_transport_set_int(tls, sock);
+
+ gnutlsSessionInited = true;
+
+ state = CS_TLS_HANDSHAKE;
+ } else {
+ state = CS_CONNECTED;
+ }
+}
+
+void Client::close() {
+ SocketRWCommon::close();
+ // TODO: add canSafelyKeepAlive var, check it here, to kill
+ // never-authed conns instantly
+ deadTime = time(NULL) + SESSION_KEEPALIVE;
+}
+
+void Client::handleLine(char *line, int size) {
+ if (strncmp(line, "all ", 4) == 0) {
+ for (int i = 0; i < clientCount; i++) {
+ clients[i]->outputBuf.append(&line[4], size - 4);
+ clients[i]->outputBuf.append("\n", 1);
+ }
+ } else if (strcmp(line, "quit") == 0) {
+ quitFlag = true;
+ } else if (strncmp(line, "resolve ", 8) == 0) {
+ DNS::makeQuery(&line[8]);
+ } else if (strncmp(&line[1], "ddsrv ", 6) == 0) {
+ servers[serverCount] = new Server;
+ strcpy(servers[serverCount]->hostname, &line[7]);
+ servers[serverCount]->port = 1191;
+ servers[serverCount]->isTls = (line[0] == 's');
+ serverCount++;
+ outputBuf.append("Your wish is my command!\n", 25);
+ } else if (strncmp(line, "connsrv", 7) == 0) {
+ int sid = line[7] - '0';
+ servers[sid]->beginConnect();
+ } else if (line[0] >= '0' && line[0] <= '9') {
+ int sid = line[0] - '0';
+ servers[sid]->outputBuf.append(&line[1], size - 1);
+ servers[sid]->outputBuf.append("\r\n", 2);
+ }
+}
+
+void Client::processInput() {
+ // Try to process as many lines as we can
+ // This function will be changed to custom protocol eventually
+ char *buf = inputBuf.data();
+ int bufSize = inputBuf.size();
+ int lineBegin = 0, pos = 0;
+
+ while (pos < bufSize) {
+ if (buf[pos] == '\r' || buf[pos] == '\n') {
+ if (pos > lineBegin) {
+ buf[pos] = 0;
+ handleLine(&buf[lineBegin], pos - lineBegin);
+ }
+
+ lineBegin = pos + 1;
+ }
+
+ pos++;
+ }
+
+ // If we managed to handle anything, lop it off the buffer
+ inputBuf.trimFromStart(pos);
+}
+
+
+void Server::handleLine(char *line, int size) {
+ for (int i = 0; i < clientCount; i++) {
+ clients[i]->outputBuf.append(line, size);
+ clients[i]->outputBuf.append("\n", 1);
+ }
+}
+void Server::processInput() {
+ // Try to process as many lines as we can
+ char *buf = inputBuf.data();
+ int bufSize = inputBuf.size();
+ int lineBegin = 0, pos = 0;
+
+ while (pos < bufSize) {
+ if (buf[pos] == '\r' || buf[pos] == '\n') {
+ if (pos > lineBegin) {
+ buf[pos] = 0;
+ handleLine(&buf[lineBegin], pos - lineBegin);
+ }
+
+ lineBegin = pos + 1;
+ }
+
+ pos++;
+ }
+
+ // If we managed to handle anything, lop it off the buffer
+ inputBuf.trimFromStart(pos);
+}
+
+
+
+void Server::beginConnect() {
+ if (state == CS_DISCONNECTED) {
+ DNS::closeQuery(dnsQueryId); // just in case
+ dnsQueryId = DNS::makeQuery(hostname);
+
+ if (dnsQueryId == -1) {
+ // TODO: better error reporting
+ printf("DNS query failed!\n");
+ } else {
+ state = CS_WAITING_DNS;
+ }
+ }
+}
+
+void Server::tryConnectPhase() {
+ if (state == CS_WAITING_DNS) {
+ in_addr result;
+ bool isError;
+
+ if (DNS::checkQuery(dnsQueryId, &result, &isError)) {
+ DNS::closeQuery(dnsQueryId);
+ dnsQueryId = -1;
+
+ if (isError) {
+ printf("DNS query failed at phase 2!\n");
+ state = CS_DISCONNECTED;
+ } else {
+ // OK, if there was no error, we can go ahead and do this...
+
+ sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ if (sock == -1) {
+ perror("[Server] Failed to socket()");
+ close();
+ return;
+ }
+
+ int opts = fcntl(sock, F_GETFL);
+ if (opts < 0) {
+ perror("[Server] Could not get fcntl options");
+ close();
+ return;
+ }
+ opts |= O_NONBLOCK;
+ if (fcntl(sock, F_SETFL, opts) == -1) {
+ perror("[Server] Could not set fcntl options");
+ close();
+ return;
+ }
+
+ // We have our non-blocking socket, let's try connecting!
+ sockaddr_in outAddr;
+ outAddr.sin_family = AF_INET;
+ outAddr.sin_port = htons(port);
+ outAddr.sin_addr.s_addr = result.s_addr;
+
+ if (connect(sock, (sockaddr *)&outAddr, sizeof(outAddr)) == -1) {
+ if (errno == EINPROGRESS) {
+ state = CS_WAITING_CONNECT;
+ } else {
+ perror("[Server] Could not connect");
+ }
+ } else {
+ // Whoa, we're connected? Neat.
+ connectionSuccessful();
+ }
+ }
+ }
+ }
+}
+
+void Server::connectionSuccessful() {
+ state = CS_CONNECTED;
+
+ inputBuf.clear();
+ outputBuf.clear();
+
+ // Do we need to do any TLS junk?
+ if (isTls) {
+ int initRet = gnutls_init(&tls, GNUTLS_CLIENT);
+ if (initRet != GNUTLS_E_SUCCESS) {
+ printf("[Server::connectionSuccessful] gnutls_init borked\n");
+ gnutls_perror(initRet);
+ close();
+ return;
+ }
+
+ // TODO: error check this
+ const char *errPos;
+ gnutls_priority_set_direct(tls, "NORMAL", &errPos);
+
+ gnutls_credentials_set(tls, GNUTLS_CRD_CERTIFICATE, serverCreds);
+
+ gnutls_transport_set_int(tls, sock);
+
+ gnutlsSessionInited = true;
+ state = CS_TLS_HANDSHAKE;
+ }
+}
+
+void Server::close() {
+ SocketRWCommon::close();
+
+ if (dnsQueryId != -1) {
+ DNS::closeQuery(dnsQueryId);
+ dnsQueryId = -1;
+ }
+}
+
+
+int main(int argc, char **argv) {
+ clientCount = 0;
+ for (int i = 0; i < CLIENT_LIMIT; i++)
+ clients[i] = NULL;
+
+ int ret;
+ ret = gnutls_global_init();
+ if (ret != GNUTLS_E_SUCCESS) {
+ printf("gnutls_global_init failure: %s\n", gnutls_strerror(ret));
+ return 1;
+ }
+
+ unsigned int bits = gnutls_sec_param_to_pk_bits(GNUTLS_PK_DH, GNUTLS_SEC_PARAM_LEGACY);
+
+ ret = gnutls_dh_params_init(&dh_params);
+ if (ret != GNUTLS_E_SUCCESS) {
+ printf("dh_params_init failure: %s\n", gnutls_strerror(ret));
+ return 1;
+ }
+
+ ret = gnutls_dh_params_generate2(dh_params, bits);
+ if (ret != GNUTLS_E_SUCCESS) {
+ printf("dh_params_generate2 failure: %s\n", gnutls_strerror(ret));
+ return 1;
+ }
+
+ gnutls_certificate_allocate_credentials(&clientCreds);
+ ret = gnutls_certificate_set_x509_key_file(clientCreds, "ssl_test.crt", "ssl_test.key", GNUTLS_X509_FMT_PEM);
+ if (ret != GNUTLS_E_SUCCESS) {
+ printf("set_x509_key_file failure: %s\n", gnutls_strerror(ret));
+ return 1;
+ }
+ gnutls_certificate_set_dh_params(clientCreds, dh_params);
+
+ gnutls_certificate_allocate_credentials(&serverCreds);
+
+ DNS::start();
+
+
+ // prepare the listen socket
+ int listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ if (listener == -1) {
+ perror("Could not create the listener socket");
+ return 1;
+ }
+
+ int v = 1;
+ if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR, &v, sizeof(v)) == -1) {
+ perror("Could not set SO_REUSEADDR");
+ return 1;
+ }
+
+ sockaddr_in listenAddr;
+ listenAddr.sin_family = AF_INET;
+ listenAddr.sin_port = htons(5454);
+ listenAddr.sin_addr.s_addr = htonl(INADDR_ANY);
+
+ if (bind(listener, (sockaddr *)&listenAddr, sizeof(listenAddr)) == -1) {
+ perror("Could not bind to the listener socket");
+ return 1;
+ }
+
+ int opts = fcntl(listener, F_GETFL);
+ if (opts < 0) {
+ perror("Could not get fcntl options\n");
+ return 1;
+ }
+ opts |= O_NONBLOCK;
+ if (fcntl(listener, F_SETFL, opts) == -1) {
+ perror("Could not set fcntl options\n");
+ return 1;
+ }
+
+ if (listen(listener, 10) == -1) {
+ perror("Could not listen()");
+ return 1;
+ }
+
+ printf("Listening!\n");
+
+
+ // do stuff!
+ while (!quitFlag) {
+ fd_set readSet, writeSet;
+ FD_ZERO(&readSet);
+ FD_ZERO(&writeSet);
+
+ int maxFD = listener;
+ FD_SET(listener, &readSet);
+
+ time_t now = time(NULL);
+
+ for (int i = 0; i < clientCount; i++) {
+ if (clients[i]->state == Client::CS_TLS_HANDSHAKE)
+ clients[i]->tryTLSHandshake();
+
+ if (clients[i]->sock != -1) {
+ if (clients[i]->sock > maxFD)
+ maxFD = clients[i]->sock;
+
+ if (clients[i]->state == Client::CS_CONNECTED)
+ FD_SET(clients[i]->sock, &readSet);
+ if (clients[i]->outputBuf.size() > 0)
+ FD_SET(clients[i]->sock, &writeSet);
+
+ } else {
+ // Outdated session, can we kill it?
+ if (now >= clients[i]->deadTime) {
+ printf("[%d] Session expired, deleting\n", now);
+
+ // Yep.
+ Client *client = clients[i];
+
+ // If this is the last socket in the list, we can just
+ // decrement clientCount and all will be fine.
+ clientCount--;
+
+ // Otherwise, we move that pointer into this slot, and
+ // we subtract one from i so that we'll process that slot
+ // on the next loop iteration.
+ if (i != clientCount) {
+ clients[i] = clients[clientCount];
+ i--;
+ }
+ }
+ }
+ }
+
+ for (int i = 0; i < serverCount; i++) {
+ if (servers[i]->state == Server::CS_WAITING_DNS)
+ servers[i]->tryConnectPhase();
+ else if (servers[i]->state == Server::CS_TLS_HANDSHAKE)
+ servers[i]->tryTLSHandshake();
+
+ if (servers[i]->sock != -1) {
+ if (servers[i]->sock > maxFD)
+ maxFD = servers[i]->sock;
+
+ if (servers[i]->state == Server::CS_CONNECTED)
+ FD_SET(servers[i]->sock, &readSet);
+ if (servers[i]->outputBuf.size() > 0 || servers[i]->state == Server::CS_WAITING_CONNECT)
+ FD_SET(servers[i]->sock, &writeSet);
+ }
+ }
+
+ timeval timeout;
+ timeout.tv_sec = 1;
+ timeout.tv_usec = 0;
+ int numFDs = select(maxFD+1, &readSet, &writeSet, NULL, &timeout);
+
+ now = time(NULL);
+ printf("[%lu select:%d]\n", now, numFDs);
+
+
+ // This is really not very DRY.
+ // Once I implement SSL properly, I'll look at the common bits between these
+ // two blocks and make it cleaner...
+
+
+ for (int i = 0; i < clientCount; i++) {
+ if (clients[i]->sock != -1) {
+ if (FD_ISSET(clients[i]->sock, &writeSet)) {
+ // What can we get rid of...?
+ Client *client = clients[i];
+ ssize_t amount;
+ if (client->gnutlsSessionInited) {
+ amount = gnutls_record_send(client->tls,
+ client->outputBuf.data(),
+ client->outputBuf.size());
+ } else {
+ amount = send(client->sock,
+ client->outputBuf.data(),
+ client->outputBuf.size(),
+ 0);
+ }
+
+ if (amount > 0) {
+ printf("[%d] Wrote %d bytes\n", i, amount);
+ client->outputBuf.trimFromStart(amount);
+ } else if (amount == 0)
+ printf("Sent 0!\n");
+ else if (amount < 0)
+ perror("Error while sending!");
+ // Close connection in that case, if a fatal error occurs?
+ }
+
+
+ if (FD_ISSET(clients[i]->sock, &readSet) || (clients[i]->gnutlsSessionInited && gnutls_record_check_pending(clients[i]->tls) > 0)) {
+ Client *client = clients[i];
+
+ // Ensure we have at least 0x200 bytes space free
+ // (Up this, maybe?)
+ int bufSize = client->inputBuf.size();
+ int requiredSize = bufSize + 0x200;
+ if (requiredSize < client->inputBuf.capacity())
+ client->inputBuf.setCapacity(requiredSize);
+
+ ssize_t amount;
+ if (client->gnutlsSessionInited) {
+ amount = gnutls_record_recv(client->tls,
+ &client->inputBuf.data()[bufSize],
+ 0x200);
+ } else {
+ amount = recv(client->sock,
+ &client->inputBuf.data()[bufSize],
+ 0x200,
+ 0);
+ }
+
+
+ if (amount > 0) {
+ // Yep, we have data
+ printf("[%d] Read %d bytes\n", i, amount);
+ client->inputBuf.resize(bufSize + amount);
+
+ client->processInput();
+
+ } else if (amount == 0) {
+ printf("[%d] Read 0! Client closing.\n", i);
+ client->close();
+
+ } else if (amount < 0)
+ perror("Error while reading!");
+ // Close connection in that case, if a fatal error occurs?
+ }
+ }
+ }
+
+
+
+ for (int i = 0; i < serverCount; i++) {
+ if (servers[i]->sock != -1) {
+ if (FD_ISSET(servers[i]->sock, &writeSet)) {
+ Server *server = servers[i];
+
+ if (server->state == Server::CS_WAITING_CONNECT) {
+ // Welp, this means we're connected!
+ // Maybe.
+ // We might have an error condition, in which case,
+ // we're screwed.
+ bool didSucceed = false;
+ int sockErr;
+ socklen_t sockErrSize = sizeof(sockErr);
+
+ if (getsockopt(server->sock, SOL_SOCKET, SO_ERROR, &sockErr, &sockErrSize) == 0) {
+ if (sockErr == 0)
+ didSucceed = true;
+ }
+
+ if (didSucceed) {
+ // WE'RE IN fuck yeah
+ printf("[%d] Connection succeeded!\n", i);
+ server->connectionSuccessful();
+ } else {
+ // Nope. Nuke it.
+ printf("[%d] Connection failed: %d\n", i, sockErr);
+ server->close();
+ }
+
+ } else {
+ // What can we get rid of...?
+
+ ssize_t amount;
+ if (server->gnutlsSessionInited) {
+ amount = gnutls_record_send(server->tls,
+ server->outputBuf.data(),
+ server->outputBuf.size());
+ } else {
+ amount = send(server->sock,
+ server->outputBuf.data(),
+ server->outputBuf.size(),
+ 0);
+ }
+
+ if (amount > 0) {
+ printf("[%d] Wrote %d bytes\n", i, amount);
+ server->outputBuf.trimFromStart(amount);
+ } else if (amount == 0)
+ printf("Sent 0!\n");
+ else if (amount < 0)
+ perror("Error while sending!");
+ // Close connection in that case, if a fatal error occurs?
+ }
+ }
+
+
+ if (FD_ISSET(servers[i]->sock, &readSet) || (servers[i]->gnutlsSessionInited && gnutls_record_check_pending(servers[i]->tls) > 0)) {
+ Server *server = servers[i];
+
+ // Ensure we have at least 0x200 bytes space free
+ // (Up this, maybe?)
+ int bufSize = server->inputBuf.size();
+ int requiredSize = bufSize + 0x200;
+ if (requiredSize < server->inputBuf.capacity())
+ server->inputBuf.setCapacity(requiredSize);
+
+ ssize_t amount;
+ if (server->gnutlsSessionInited) {
+ amount = gnutls_record_recv(server->tls,
+ &server->inputBuf.data()[bufSize],
+ 0x200);
+ } else {
+ amount = recv(server->sock,
+ &server->inputBuf.data()[bufSize],
+ 0x200,
+ 0);
+ }
+
+
+ if (amount > 0) {
+ // Yep, we have data
+ printf("[%d] Read %d bytes\n", i, amount);
+ server->inputBuf.resize(bufSize + amount);
+
+ server->processInput();
+
+ } else if (amount == 0) {
+ printf("[%d] Read 0! Server closing.\n", i);
+ server->close();
+
+ } else if (amount < 0)
+ perror("Error while reading!");
+ // Close connection in that case, if a fatal error occurs?
+ }
+ }
+ }
+
+
+
+ if (FD_ISSET(listener, &readSet)) {
+ // Yay, we have a new connection
+ int sock = accept(listener, NULL, NULL);
+
+ if (clientCount >= CLIENT_LIMIT) {
+ // We can't accept it.
+ printf("Too many connections, we can't accept this one. THIS SHOULD NEVER HAPPEN.\n");
+ shutdown(sock, SHUT_RDWR);
+ close(sock);
+ } else {
+ // Create a new connection
+ printf("[%d] New connection, fd=%d\n", clientCount, sock);
+
+ Client *client = new Client;
+
+ clients[clientCount] = client;
+ ++clientCount;
+
+ client->startService(sock, true);
+ }
+ }
+ }
+
+ // Need to shut down all sockets here
+ for (int i = 0; i < serverCount; i++)
+ servers[i]->close();
+
+ for (int i = 0; i < clientCount; i++)
+ clients[i]->close();
+
+ shutdown(listener, SHUT_RDWR);
+ close(listener);
+}
+
+
diff --git a/dns.cpp b/dns.cpp
new file mode 100644
index 0000000..056bb09
--- /dev/null
+++ b/dns.cpp
@@ -0,0 +1,159 @@
+#include "dns.h"
+#include <pthread.h>
+#include <errno.h>
+#include <string.h>
+#include <stdio.h>
+
+#define DNS_QUERY_COUNT 20
+#define DNS_QUERY_NAME_SIZE 256
+
+enum DNSQueryState {
+ DQS_FREE = 0,
+ DQS_WAITING = 1,
+ DQS_COMPLETE = 2,
+ DQS_ERROR = 3
+};
+
+struct DNSQuery {
+ char name[DNS_QUERY_NAME_SIZE];
+ in_addr result;
+ DNSQueryState status;
+ int version;
+};
+
+static DNSQuery dnsQueue[DNS_QUERY_COUNT];
+static pthread_t dnsThread;
+static pthread_mutex_t dnsQueueMutex;
+static pthread_cond_t dnsQueueCond;
+
+static void *dnsThreadProc(void *);
+
+
+void DNS::start() {
+ pthread_mutex_init(&dnsQueueMutex, NULL);
+ pthread_cond_init(&dnsQueueCond, NULL);
+
+ pthread_create(&dnsThread, NULL, &dnsThreadProc, NULL);
+
+ for (int i = 0; i < DNS_QUERY_COUNT; i++) {
+ dnsQueue[i].status = DQS_FREE;
+ dnsQueue[i].version = 0;
+ }
+}
+
+int DNS::makeQuery(const char *name) {
+ int id = -1;
+
+ pthread_mutex_lock(&dnsQueueMutex);
+
+ for (int i = 0; i < DNS_QUERY_COUNT; i++) {
+ if (dnsQueue[i].status == DQS_FREE) {
+ id = i;
+ break;
+ }
+ }
+
+ if (id != -1) {
+ strncpy(dnsQueue[id].name, name, sizeof(dnsQueue[id].name));
+ dnsQueue[id].name[sizeof(dnsQueue[id].name) - 1] = 0;
+ dnsQueue[id].status = DQS_WAITING;
+ dnsQueue[id].version++;
+ printf("[DNS::%d] New query: %s\n", id, dnsQueue[id].name);
+ }
+
+ pthread_mutex_unlock(&dnsQueueMutex);
+ pthread_cond_signal(&dnsQueueCond);
+
+ return id;
+}
+
+void DNS::closeQuery(int id) {
+ if (id < 0 || id >= DNS_QUERY_COUNT)
+ return;
+
+ pthread_mutex_lock(&dnsQueueMutex);
+ printf("[DNS::%d] Closing query\n", id);
+ dnsQueue[id].status = DQS_FREE;
+ pthread_mutex_unlock(&dnsQueueMutex);
+}
+
+bool DNS::checkQuery(int id, in_addr *pResult, bool *pIsError) {
+ if (id < 0 || id >= DNS_QUERY_COUNT)
+ return false;
+
+ pthread_mutex_lock(&dnsQueueMutex);
+
+ bool finalResult = false;
+ if (dnsQueue[id].status == DQS_COMPLETE) {
+ finalResult = true;
+ *pIsError = false;
+ memcpy(pResult, &dnsQueue[id].result, sizeof(dnsQueue[id].result));
+ } else if (dnsQueue[id].status == DQS_ERROR) {
+ finalResult = true;
+ *pIsError = true;
+ }
+
+ pthread_mutex_unlock(&dnsQueueMutex);
+
+ return finalResult;
+}
+
+
+void *dnsThreadProc(void *) {
+ pthread_mutex_lock(&dnsQueueMutex);
+
+ for (;;) {
+ for (int i = 0; i < DNS_QUERY_COUNT; i++) {
+ if (dnsQueue[i].status == DQS_WAITING) {
+ char nameCopy[DNS_QUERY_NAME_SIZE];
+ memcpy(nameCopy, dnsQueue[i].name, DNS_QUERY_NAME_SIZE);
+
+ int versionCopy = dnsQueue[i].version;
+
+ printf("[DNS::%d] Trying %s...\n", i, nameCopy);
+
+ pthread_mutex_unlock(&dnsQueueMutex);
+
+ addrinfo hints, *res;
+
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = AF_INET;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_protocol = IPPROTO_TCP;
+ hints.ai_flags = AI_ADDRCONFIG | AI_V4MAPPED;
+
+ int s = getaddrinfo(nameCopy, NULL, &hints, &res);
+
+ pthread_mutex_lock(&dnsQueueMutex);
+
+ // Before we write to the request, check that it hasn't been
+ // closed (and possibly replaced...!) by another thread
+
+ if (dnsQueue[i].status == DQS_WAITING && dnsQueue[i].version == versionCopy) {
+ if (s == 0) {
+ // Only try the first one for now...
+ // Is this safe? Not sure.
+ dnsQueue[i].status = DQS_COMPLETE;
+ memcpy(&dnsQueue[i].result, &((sockaddr_in*)res->ai_addr)->sin_addr, sizeof(dnsQueue[i].result));
+
+ printf("[DNS::%d] Resolved %s to %x\n", i, dnsQueue[i].name, dnsQueue[i].result.s_addr);
+ } else {
+ dnsQueue[i].status = DQS_ERROR;
+ printf("[DNS::%d] Error condition: %d\n", i, s);
+ }
+ } else {
+ printf("[DNS::%d] Request was cancelled before getaddrinfo completed\n", i);
+ }
+
+ if (s == 0)
+ freeaddrinfo(res);
+ }
+ }
+
+ pthread_cond_wait(&dnsQueueCond, &dnsQueueMutex);
+ }
+
+ pthread_mutex_unlock(&dnsQueueMutex);
+ return NULL;
+}
+
diff --git a/dns.h b/dns.h
new file mode 100644
index 0000000..78ec75f
--- /dev/null
+++ b/dns.h
@@ -0,0 +1,15 @@
+#ifndef DNS_H
+#define DNS_H
+
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netdb.h>
+
+namespace DNS {
+ void start();
+ int makeQuery(const char *name);
+ void closeQuery(int id);
+ bool checkQuery(int id, in_addr *pResult, bool *pIsError);
+}
+
+#endif /* DNS_H */