summaryrefslogtreecommitdiff
path: root/core.cpp
diff options
context:
space:
mode:
authorTreeki <treeki@gmail.com>2014-01-16 20:19:00 +0100
committerTreeki <treeki@gmail.com>2014-01-16 20:19:00 +0100
commita47ea91881c3cca4c8e3bdd3ea13d09886cafa94 (patch)
tree91e9e6ebcbe08044f9594a577b9ad83b186d7868 /core.cpp
downloadbounce4-a47ea91881c3cca4c8e3bdd3ea13d09886cafa94.tar.gz
bounce4-a47ea91881c3cca4c8e3bdd3ea13d09886cafa94.zip
initial commit of stuff
Diffstat (limited to '')
-rw-r--r--core.cpp745
1 files changed, 745 insertions, 0 deletions
diff --git a/core.cpp b/core.cpp
new file mode 100644
index 0000000..472cdee
--- /dev/null
+++ b/core.cpp
@@ -0,0 +1,745 @@
+#include <string.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <unistd.h>
+#include <time.h>
+#include <sys/time.h>
+#include <sys/socket.h>
+#include <sys/select.h>
+#include <netinet/in.h>
+#include <gnutls/gnutls.h>
+
+#include "buffer.h"
+#include "dns.h"
+
+#define CLIENT_LIMIT 100
+#define SERVER_LIMIT 20
+
+#define SESSION_KEEPALIVE 5
+
+struct SocketRWCommon {
+ Buffer inputBuf, outputBuf;
+
+ enum ConnState {
+ CS_DISCONNECTED = 0,
+ CS_WAITING_DNS = 1, // server only
+ CS_WAITING_CONNECT = 2, // server only
+ CS_TLS_HANDSHAKE = 3,
+ CS_CONNECTED = 4
+ };
+ ConnState state;
+
+ int sock;
+ gnutls_session_t tls;
+ bool isTls, gnutlsSessionInited;
+
+ SocketRWCommon() {
+ sock = -1;
+ state = CS_DISCONNECTED;
+ isTls = false;
+ gnutlsSessionInited = false;
+ }
+
+ ~SocketRWCommon() {
+ close();
+ }
+
+ void tryTLSHandshake();
+ virtual void close();
+};
+
+struct Client : SocketRWCommon {
+ time_t deadTime;
+
+ void startService(int _sock, bool withTls);
+ void processInput();
+ void close();
+
+private:
+ void handleLine(char *line, int size);
+};
+
+struct Server : SocketRWCommon {
+ char hostname[256];
+ int port;
+ int dnsQueryId;
+
+ Server() {
+ dnsQueryId = -1;
+ }
+
+ ~Server() {
+ if (dnsQueryId != -1)
+ DNS::closeQuery(dnsQueryId);
+ }
+
+ void beginConnect();
+ void tryConnectPhase();
+ void connectionSuccessful();
+
+ void close();
+
+ void processInput();
+private:
+ void handleLine(char *line, int size);
+};
+
+
+Client *clients[CLIENT_LIMIT];
+Server *servers[SERVER_LIMIT];
+int clientCount, serverCount;
+bool quitFlag = false;
+
+static gnutls_dh_params_t dh_params;
+static gnutls_certificate_credentials_t serverCreds, clientCreds;
+
+
+void SocketRWCommon::tryTLSHandshake() {
+ int hsRet = gnutls_handshake(tls);
+ if (gnutls_error_is_fatal(hsRet)) {
+ printf("[SocketRWCommon::tryTLSHandshake] gnutls_handshake borked\n");
+ gnutls_perror(hsRet);
+ close();
+ return;
+ }
+
+ if (hsRet == GNUTLS_E_SUCCESS) {
+ // We're in !!
+ state = CS_CONNECTED;
+
+ inputBuf.clear();
+ outputBuf.clear();
+
+ printf("[SocketRWCommon connected via SSL!]\n");
+ }
+}
+
+void SocketRWCommon::close() {
+ if (sock != -1) {
+ if (gnutlsSessionInited)
+ gnutls_bye(tls, GNUTLS_SHUT_RDWR);
+ shutdown(sock, SHUT_RDWR);
+ ::close(sock);
+ }
+
+ sock = -1;
+ inputBuf.clear();
+ outputBuf.clear();
+ state = CS_DISCONNECTED;
+
+ if (gnutlsSessionInited) {
+ gnutls_deinit(tls);
+ gnutlsSessionInited = false;
+ }
+}
+
+
+void Client::startService(int _sock, bool withTls) {
+ close();
+
+ sock = _sock;
+
+ if (withTls) {
+ int initRet = gnutls_init(&tls, GNUTLS_SERVER);
+ if (initRet != GNUTLS_E_SUCCESS) {
+ printf("[Client::startService] gnutls_init borked\n");
+ gnutls_perror(initRet);
+ close();
+ return;
+ }
+
+ // TODO: error check this
+ int ret;
+ const char *errPos;
+
+ ret = gnutls_priority_set_direct(tls, "PERFORMANCE:%SERVER_PRECEDENCE", &errPos);
+ if (ret != GNUTLS_E_SUCCESS) {
+ printf("gnutls_priority_set_direct failure: %s\n", gnutls_strerror(ret));
+ close();
+ return;
+ }
+
+ ret = gnutls_credentials_set(tls, GNUTLS_CRD_CERTIFICATE, clientCreds);
+ if (ret != GNUTLS_E_SUCCESS) {
+ printf("gnutls_credentials_set failure: %s\n", gnutls_strerror(ret));
+ close();
+ return;
+ }
+
+ gnutls_certificate_server_set_request(tls, GNUTLS_CERT_IGNORE);
+
+ gnutls_transport_set_int(tls, sock);
+
+ gnutlsSessionInited = true;
+
+ state = CS_TLS_HANDSHAKE;
+ } else {
+ state = CS_CONNECTED;
+ }
+}
+
+void Client::close() {
+ SocketRWCommon::close();
+ // TODO: add canSafelyKeepAlive var, check it here, to kill
+ // never-authed conns instantly
+ deadTime = time(NULL) + SESSION_KEEPALIVE;
+}
+
+void Client::handleLine(char *line, int size) {
+ if (strncmp(line, "all ", 4) == 0) {
+ for (int i = 0; i < clientCount; i++) {
+ clients[i]->outputBuf.append(&line[4], size - 4);
+ clients[i]->outputBuf.append("\n", 1);
+ }
+ } else if (strcmp(line, "quit") == 0) {
+ quitFlag = true;
+ } else if (strncmp(line, "resolve ", 8) == 0) {
+ DNS::makeQuery(&line[8]);
+ } else if (strncmp(&line[1], "ddsrv ", 6) == 0) {
+ servers[serverCount] = new Server;
+ strcpy(servers[serverCount]->hostname, &line[7]);
+ servers[serverCount]->port = 1191;
+ servers[serverCount]->isTls = (line[0] == 's');
+ serverCount++;
+ outputBuf.append("Your wish is my command!\n", 25);
+ } else if (strncmp(line, "connsrv", 7) == 0) {
+ int sid = line[7] - '0';
+ servers[sid]->beginConnect();
+ } else if (line[0] >= '0' && line[0] <= '9') {
+ int sid = line[0] - '0';
+ servers[sid]->outputBuf.append(&line[1], size - 1);
+ servers[sid]->outputBuf.append("\r\n", 2);
+ }
+}
+
+void Client::processInput() {
+ // Try to process as many lines as we can
+ // This function will be changed to custom protocol eventually
+ char *buf = inputBuf.data();
+ int bufSize = inputBuf.size();
+ int lineBegin = 0, pos = 0;
+
+ while (pos < bufSize) {
+ if (buf[pos] == '\r' || buf[pos] == '\n') {
+ if (pos > lineBegin) {
+ buf[pos] = 0;
+ handleLine(&buf[lineBegin], pos - lineBegin);
+ }
+
+ lineBegin = pos + 1;
+ }
+
+ pos++;
+ }
+
+ // If we managed to handle anything, lop it off the buffer
+ inputBuf.trimFromStart(pos);
+}
+
+
+void Server::handleLine(char *line, int size) {
+ for (int i = 0; i < clientCount; i++) {
+ clients[i]->outputBuf.append(line, size);
+ clients[i]->outputBuf.append("\n", 1);
+ }
+}
+void Server::processInput() {
+ // Try to process as many lines as we can
+ char *buf = inputBuf.data();
+ int bufSize = inputBuf.size();
+ int lineBegin = 0, pos = 0;
+
+ while (pos < bufSize) {
+ if (buf[pos] == '\r' || buf[pos] == '\n') {
+ if (pos > lineBegin) {
+ buf[pos] = 0;
+ handleLine(&buf[lineBegin], pos - lineBegin);
+ }
+
+ lineBegin = pos + 1;
+ }
+
+ pos++;
+ }
+
+ // If we managed to handle anything, lop it off the buffer
+ inputBuf.trimFromStart(pos);
+}
+
+
+
+void Server::beginConnect() {
+ if (state == CS_DISCONNECTED) {
+ DNS::closeQuery(dnsQueryId); // just in case
+ dnsQueryId = DNS::makeQuery(hostname);
+
+ if (dnsQueryId == -1) {
+ // TODO: better error reporting
+ printf("DNS query failed!\n");
+ } else {
+ state = CS_WAITING_DNS;
+ }
+ }
+}
+
+void Server::tryConnectPhase() {
+ if (state == CS_WAITING_DNS) {
+ in_addr result;
+ bool isError;
+
+ if (DNS::checkQuery(dnsQueryId, &result, &isError)) {
+ DNS::closeQuery(dnsQueryId);
+ dnsQueryId = -1;
+
+ if (isError) {
+ printf("DNS query failed at phase 2!\n");
+ state = CS_DISCONNECTED;
+ } else {
+ // OK, if there was no error, we can go ahead and do this...
+
+ sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ if (sock == -1) {
+ perror("[Server] Failed to socket()");
+ close();
+ return;
+ }
+
+ int opts = fcntl(sock, F_GETFL);
+ if (opts < 0) {
+ perror("[Server] Could not get fcntl options");
+ close();
+ return;
+ }
+ opts |= O_NONBLOCK;
+ if (fcntl(sock, F_SETFL, opts) == -1) {
+ perror("[Server] Could not set fcntl options");
+ close();
+ return;
+ }
+
+ // We have our non-blocking socket, let's try connecting!
+ sockaddr_in outAddr;
+ outAddr.sin_family = AF_INET;
+ outAddr.sin_port = htons(port);
+ outAddr.sin_addr.s_addr = result.s_addr;
+
+ if (connect(sock, (sockaddr *)&outAddr, sizeof(outAddr)) == -1) {
+ if (errno == EINPROGRESS) {
+ state = CS_WAITING_CONNECT;
+ } else {
+ perror("[Server] Could not connect");
+ }
+ } else {
+ // Whoa, we're connected? Neat.
+ connectionSuccessful();
+ }
+ }
+ }
+ }
+}
+
+void Server::connectionSuccessful() {
+ state = CS_CONNECTED;
+
+ inputBuf.clear();
+ outputBuf.clear();
+
+ // Do we need to do any TLS junk?
+ if (isTls) {
+ int initRet = gnutls_init(&tls, GNUTLS_CLIENT);
+ if (initRet != GNUTLS_E_SUCCESS) {
+ printf("[Server::connectionSuccessful] gnutls_init borked\n");
+ gnutls_perror(initRet);
+ close();
+ return;
+ }
+
+ // TODO: error check this
+ const char *errPos;
+ gnutls_priority_set_direct(tls, "NORMAL", &errPos);
+
+ gnutls_credentials_set(tls, GNUTLS_CRD_CERTIFICATE, serverCreds);
+
+ gnutls_transport_set_int(tls, sock);
+
+ gnutlsSessionInited = true;
+ state = CS_TLS_HANDSHAKE;
+ }
+}
+
+void Server::close() {
+ SocketRWCommon::close();
+
+ if (dnsQueryId != -1) {
+ DNS::closeQuery(dnsQueryId);
+ dnsQueryId = -1;
+ }
+}
+
+
+int main(int argc, char **argv) {
+ clientCount = 0;
+ for (int i = 0; i < CLIENT_LIMIT; i++)
+ clients[i] = NULL;
+
+ int ret;
+ ret = gnutls_global_init();
+ if (ret != GNUTLS_E_SUCCESS) {
+ printf("gnutls_global_init failure: %s\n", gnutls_strerror(ret));
+ return 1;
+ }
+
+ unsigned int bits = gnutls_sec_param_to_pk_bits(GNUTLS_PK_DH, GNUTLS_SEC_PARAM_LEGACY);
+
+ ret = gnutls_dh_params_init(&dh_params);
+ if (ret != GNUTLS_E_SUCCESS) {
+ printf("dh_params_init failure: %s\n", gnutls_strerror(ret));
+ return 1;
+ }
+
+ ret = gnutls_dh_params_generate2(dh_params, bits);
+ if (ret != GNUTLS_E_SUCCESS) {
+ printf("dh_params_generate2 failure: %s\n", gnutls_strerror(ret));
+ return 1;
+ }
+
+ gnutls_certificate_allocate_credentials(&clientCreds);
+ ret = gnutls_certificate_set_x509_key_file(clientCreds, "ssl_test.crt", "ssl_test.key", GNUTLS_X509_FMT_PEM);
+ if (ret != GNUTLS_E_SUCCESS) {
+ printf("set_x509_key_file failure: %s\n", gnutls_strerror(ret));
+ return 1;
+ }
+ gnutls_certificate_set_dh_params(clientCreds, dh_params);
+
+ gnutls_certificate_allocate_credentials(&serverCreds);
+
+ DNS::start();
+
+
+ // prepare the listen socket
+ int listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ if (listener == -1) {
+ perror("Could not create the listener socket");
+ return 1;
+ }
+
+ int v = 1;
+ if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR, &v, sizeof(v)) == -1) {
+ perror("Could not set SO_REUSEADDR");
+ return 1;
+ }
+
+ sockaddr_in listenAddr;
+ listenAddr.sin_family = AF_INET;
+ listenAddr.sin_port = htons(5454);
+ listenAddr.sin_addr.s_addr = htonl(INADDR_ANY);
+
+ if (bind(listener, (sockaddr *)&listenAddr, sizeof(listenAddr)) == -1) {
+ perror("Could not bind to the listener socket");
+ return 1;
+ }
+
+ int opts = fcntl(listener, F_GETFL);
+ if (opts < 0) {
+ perror("Could not get fcntl options\n");
+ return 1;
+ }
+ opts |= O_NONBLOCK;
+ if (fcntl(listener, F_SETFL, opts) == -1) {
+ perror("Could not set fcntl options\n");
+ return 1;
+ }
+
+ if (listen(listener, 10) == -1) {
+ perror("Could not listen()");
+ return 1;
+ }
+
+ printf("Listening!\n");
+
+
+ // do stuff!
+ while (!quitFlag) {
+ fd_set readSet, writeSet;
+ FD_ZERO(&readSet);
+ FD_ZERO(&writeSet);
+
+ int maxFD = listener;
+ FD_SET(listener, &readSet);
+
+ time_t now = time(NULL);
+
+ for (int i = 0; i < clientCount; i++) {
+ if (clients[i]->state == Client::CS_TLS_HANDSHAKE)
+ clients[i]->tryTLSHandshake();
+
+ if (clients[i]->sock != -1) {
+ if (clients[i]->sock > maxFD)
+ maxFD = clients[i]->sock;
+
+ if (clients[i]->state == Client::CS_CONNECTED)
+ FD_SET(clients[i]->sock, &readSet);
+ if (clients[i]->outputBuf.size() > 0)
+ FD_SET(clients[i]->sock, &writeSet);
+
+ } else {
+ // Outdated session, can we kill it?
+ if (now >= clients[i]->deadTime) {
+ printf("[%d] Session expired, deleting\n", now);
+
+ // Yep.
+ Client *client = clients[i];
+
+ // If this is the last socket in the list, we can just
+ // decrement clientCount and all will be fine.
+ clientCount--;
+
+ // Otherwise, we move that pointer into this slot, and
+ // we subtract one from i so that we'll process that slot
+ // on the next loop iteration.
+ if (i != clientCount) {
+ clients[i] = clients[clientCount];
+ i--;
+ }
+ }
+ }
+ }
+
+ for (int i = 0; i < serverCount; i++) {
+ if (servers[i]->state == Server::CS_WAITING_DNS)
+ servers[i]->tryConnectPhase();
+ else if (servers[i]->state == Server::CS_TLS_HANDSHAKE)
+ servers[i]->tryTLSHandshake();
+
+ if (servers[i]->sock != -1) {
+ if (servers[i]->sock > maxFD)
+ maxFD = servers[i]->sock;
+
+ if (servers[i]->state == Server::CS_CONNECTED)
+ FD_SET(servers[i]->sock, &readSet);
+ if (servers[i]->outputBuf.size() > 0 || servers[i]->state == Server::CS_WAITING_CONNECT)
+ FD_SET(servers[i]->sock, &writeSet);
+ }
+ }
+
+ timeval timeout;
+ timeout.tv_sec = 1;
+ timeout.tv_usec = 0;
+ int numFDs = select(maxFD+1, &readSet, &writeSet, NULL, &timeout);
+
+ now = time(NULL);
+ printf("[%lu select:%d]\n", now, numFDs);
+
+
+ // This is really not very DRY.
+ // Once I implement SSL properly, I'll look at the common bits between these
+ // two blocks and make it cleaner...
+
+
+ for (int i = 0; i < clientCount; i++) {
+ if (clients[i]->sock != -1) {
+ if (FD_ISSET(clients[i]->sock, &writeSet)) {
+ // What can we get rid of...?
+ Client *client = clients[i];
+ ssize_t amount;
+ if (client->gnutlsSessionInited) {
+ amount = gnutls_record_send(client->tls,
+ client->outputBuf.data(),
+ client->outputBuf.size());
+ } else {
+ amount = send(client->sock,
+ client->outputBuf.data(),
+ client->outputBuf.size(),
+ 0);
+ }
+
+ if (amount > 0) {
+ printf("[%d] Wrote %d bytes\n", i, amount);
+ client->outputBuf.trimFromStart(amount);
+ } else if (amount == 0)
+ printf("Sent 0!\n");
+ else if (amount < 0)
+ perror("Error while sending!");
+ // Close connection in that case, if a fatal error occurs?
+ }
+
+
+ if (FD_ISSET(clients[i]->sock, &readSet) || (clients[i]->gnutlsSessionInited && gnutls_record_check_pending(clients[i]->tls) > 0)) {
+ Client *client = clients[i];
+
+ // Ensure we have at least 0x200 bytes space free
+ // (Up this, maybe?)
+ int bufSize = client->inputBuf.size();
+ int requiredSize = bufSize + 0x200;
+ if (requiredSize < client->inputBuf.capacity())
+ client->inputBuf.setCapacity(requiredSize);
+
+ ssize_t amount;
+ if (client->gnutlsSessionInited) {
+ amount = gnutls_record_recv(client->tls,
+ &client->inputBuf.data()[bufSize],
+ 0x200);
+ } else {
+ amount = recv(client->sock,
+ &client->inputBuf.data()[bufSize],
+ 0x200,
+ 0);
+ }
+
+
+ if (amount > 0) {
+ // Yep, we have data
+ printf("[%d] Read %d bytes\n", i, amount);
+ client->inputBuf.resize(bufSize + amount);
+
+ client->processInput();
+
+ } else if (amount == 0) {
+ printf("[%d] Read 0! Client closing.\n", i);
+ client->close();
+
+ } else if (amount < 0)
+ perror("Error while reading!");
+ // Close connection in that case, if a fatal error occurs?
+ }
+ }
+ }
+
+
+
+ for (int i = 0; i < serverCount; i++) {
+ if (servers[i]->sock != -1) {
+ if (FD_ISSET(servers[i]->sock, &writeSet)) {
+ Server *server = servers[i];
+
+ if (server->state == Server::CS_WAITING_CONNECT) {
+ // Welp, this means we're connected!
+ // Maybe.
+ // We might have an error condition, in which case,
+ // we're screwed.
+ bool didSucceed = false;
+ int sockErr;
+ socklen_t sockErrSize = sizeof(sockErr);
+
+ if (getsockopt(server->sock, SOL_SOCKET, SO_ERROR, &sockErr, &sockErrSize) == 0) {
+ if (sockErr == 0)
+ didSucceed = true;
+ }
+
+ if (didSucceed) {
+ // WE'RE IN fuck yeah
+ printf("[%d] Connection succeeded!\n", i);
+ server->connectionSuccessful();
+ } else {
+ // Nope. Nuke it.
+ printf("[%d] Connection failed: %d\n", i, sockErr);
+ server->close();
+ }
+
+ } else {
+ // What can we get rid of...?
+
+ ssize_t amount;
+ if (server->gnutlsSessionInited) {
+ amount = gnutls_record_send(server->tls,
+ server->outputBuf.data(),
+ server->outputBuf.size());
+ } else {
+ amount = send(server->sock,
+ server->outputBuf.data(),
+ server->outputBuf.size(),
+ 0);
+ }
+
+ if (amount > 0) {
+ printf("[%d] Wrote %d bytes\n", i, amount);
+ server->outputBuf.trimFromStart(amount);
+ } else if (amount == 0)
+ printf("Sent 0!\n");
+ else if (amount < 0)
+ perror("Error while sending!");
+ // Close connection in that case, if a fatal error occurs?
+ }
+ }
+
+
+ if (FD_ISSET(servers[i]->sock, &readSet) || (servers[i]->gnutlsSessionInited && gnutls_record_check_pending(servers[i]->tls) > 0)) {
+ Server *server = servers[i];
+
+ // Ensure we have at least 0x200 bytes space free
+ // (Up this, maybe?)
+ int bufSize = server->inputBuf.size();
+ int requiredSize = bufSize + 0x200;
+ if (requiredSize < server->inputBuf.capacity())
+ server->inputBuf.setCapacity(requiredSize);
+
+ ssize_t amount;
+ if (server->gnutlsSessionInited) {
+ amount = gnutls_record_recv(server->tls,
+ &server->inputBuf.data()[bufSize],
+ 0x200);
+ } else {
+ amount = recv(server->sock,
+ &server->inputBuf.data()[bufSize],
+ 0x200,
+ 0);
+ }
+
+
+ if (amount > 0) {
+ // Yep, we have data
+ printf("[%d] Read %d bytes\n", i, amount);
+ server->inputBuf.resize(bufSize + amount);
+
+ server->processInput();
+
+ } else if (amount == 0) {
+ printf("[%d] Read 0! Server closing.\n", i);
+ server->close();
+
+ } else if (amount < 0)
+ perror("Error while reading!");
+ // Close connection in that case, if a fatal error occurs?
+ }
+ }
+ }
+
+
+
+ if (FD_ISSET(listener, &readSet)) {
+ // Yay, we have a new connection
+ int sock = accept(listener, NULL, NULL);
+
+ if (clientCount >= CLIENT_LIMIT) {
+ // We can't accept it.
+ printf("Too many connections, we can't accept this one. THIS SHOULD NEVER HAPPEN.\n");
+ shutdown(sock, SHUT_RDWR);
+ close(sock);
+ } else {
+ // Create a new connection
+ printf("[%d] New connection, fd=%d\n", clientCount, sock);
+
+ Client *client = new Client;
+
+ clients[clientCount] = client;
+ ++clientCount;
+
+ client->startService(sock, true);
+ }
+ }
+ }
+
+ // Need to shut down all sockets here
+ for (int i = 0; i < serverCount; i++)
+ servers[i]->close();
+
+ for (int i = 0; i < clientCount; i++)
+ clients[i]->close();
+
+ shutdown(listener, SHUT_RDWR);
+ close(listener);
+}
+
+