From a47ea91881c3cca4c8e3bdd3ea13d09886cafa94 Mon Sep 17 00:00:00 2001 From: Treeki Date: Thu, 16 Jan 2014 20:19:00 +0100 Subject: initial commit of stuff --- .gitignore | 3 + buffer.h | 103 +++++++++ build.sh | 4 + core.cpp | 745 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ dns.cpp | 159 +++++++++++++ dns.h | 15 ++ 6 files changed, 1029 insertions(+) create mode 100644 .gitignore create mode 100644 buffer.h create mode 100755 build.sh create mode 100644 core.cpp create mode 100644 dns.cpp create mode 100644 dns.h diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ca63371 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +binary +*.crt +*.key diff --git a/buffer.h b/buffer.h new file mode 100644 index 0000000..ef6d027 --- /dev/null +++ b/buffer.h @@ -0,0 +1,103 @@ +#ifndef BUFFER_H +#define BUFFER_H + +#include +#include + +class Buffer { +private: + char *m_data; + bool m_freeBuffer; + int m_size; + int m_capacity; + char m_preAllocBuffer[0x200]; + +public: + Buffer() { + m_data = m_preAllocBuffer; + m_freeBuffer = false; + m_size = 0; + m_capacity = sizeof(m_preAllocBuffer); + } + + ~Buffer() { + if ((m_data != NULL) && m_freeBuffer) { + delete[] m_data; + m_data = NULL; + } + } + + char *data() const { return m_data; } + int size() const { return m_size; } + int capacity() const { return m_capacity; } + + void setCapacity(int capacity) { + if (capacity == m_capacity) + return; + + // Trim the size down if it's too big to fit + if (m_size > capacity) + m_size = capacity; + + char *newBuf = new char[capacity]; + + if (m_data != NULL) { + memcpy(newBuf, m_data, m_size); + if (m_freeBuffer) + delete[] m_data; + } + + m_data = newBuf; + m_capacity = capacity; + m_freeBuffer = true; + } + + void clear() { + m_size = 0; + } + void append(const char *data, int size) { + if (size <= 0) + return; + + int requiredSize = m_size + size; + if (requiredSize > m_capacity) + setCapacity(requiredSize + 0x100); + + memcpy(&m_data[m_size], data, size); + m_size += size; + } + void resize(int size) { + if (size > m_capacity) + setCapacity(size + 0x100); + m_size = size; + } + + void trimFromStart(int amount) { + if (amount <= 0) + return; + if (amount >= m_size) { + clear(); + return; + } + + memmove(m_data, &m_data[amount], m_size - amount); + m_size -= amount; + } + + + void writeU32(uint32_t v) { append((const char *)&v, 4); } + void writeU16(uint32_t v) { append((const char *)&v, 2); } + void writeU8(uint32_t v) { append((const char *)&v, 1); } + void writeS32(uint32_t v) { append((const char *)&v, 4); } + void writeS16(uint32_t v) { append((const char *)&v, 2); } + void writeS8(uint32_t v) { append((const char *)&v, 1); } + + void writeStr(const char *data, int size = -1) { + if (size == -1) + size = strlen(data); + writeU32(size); + append(data, size); + } +}; + +#endif /* BUFFER_H */ diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..58d6365 --- /dev/null +++ b/build.sh @@ -0,0 +1,4 @@ +#!/bin/sh +mkdir binary +g++ -o binary/nb4 core.cpp dns.cpp -lgnutls -pthread + diff --git a/core.cpp b/core.cpp new file mode 100644 index 0000000..472cdee --- /dev/null +++ b/core.cpp @@ -0,0 +1,745 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "buffer.h" +#include "dns.h" + +#define CLIENT_LIMIT 100 +#define SERVER_LIMIT 20 + +#define SESSION_KEEPALIVE 5 + +struct SocketRWCommon { + Buffer inputBuf, outputBuf; + + enum ConnState { + CS_DISCONNECTED = 0, + CS_WAITING_DNS = 1, // server only + CS_WAITING_CONNECT = 2, // server only + CS_TLS_HANDSHAKE = 3, + CS_CONNECTED = 4 + }; + ConnState state; + + int sock; + gnutls_session_t tls; + bool isTls, gnutlsSessionInited; + + SocketRWCommon() { + sock = -1; + state = CS_DISCONNECTED; + isTls = false; + gnutlsSessionInited = false; + } + + ~SocketRWCommon() { + close(); + } + + void tryTLSHandshake(); + virtual void close(); +}; + +struct Client : SocketRWCommon { + time_t deadTime; + + void startService(int _sock, bool withTls); + void processInput(); + void close(); + +private: + void handleLine(char *line, int size); +}; + +struct Server : SocketRWCommon { + char hostname[256]; + int port; + int dnsQueryId; + + Server() { + dnsQueryId = -1; + } + + ~Server() { + if (dnsQueryId != -1) + DNS::closeQuery(dnsQueryId); + } + + void beginConnect(); + void tryConnectPhase(); + void connectionSuccessful(); + + void close(); + + void processInput(); +private: + void handleLine(char *line, int size); +}; + + +Client *clients[CLIENT_LIMIT]; +Server *servers[SERVER_LIMIT]; +int clientCount, serverCount; +bool quitFlag = false; + +static gnutls_dh_params_t dh_params; +static gnutls_certificate_credentials_t serverCreds, clientCreds; + + +void SocketRWCommon::tryTLSHandshake() { + int hsRet = gnutls_handshake(tls); + if (gnutls_error_is_fatal(hsRet)) { + printf("[SocketRWCommon::tryTLSHandshake] gnutls_handshake borked\n"); + gnutls_perror(hsRet); + close(); + return; + } + + if (hsRet == GNUTLS_E_SUCCESS) { + // We're in !! + state = CS_CONNECTED; + + inputBuf.clear(); + outputBuf.clear(); + + printf("[SocketRWCommon connected via SSL!]\n"); + } +} + +void SocketRWCommon::close() { + if (sock != -1) { + if (gnutlsSessionInited) + gnutls_bye(tls, GNUTLS_SHUT_RDWR); + shutdown(sock, SHUT_RDWR); + ::close(sock); + } + + sock = -1; + inputBuf.clear(); + outputBuf.clear(); + state = CS_DISCONNECTED; + + if (gnutlsSessionInited) { + gnutls_deinit(tls); + gnutlsSessionInited = false; + } +} + + +void Client::startService(int _sock, bool withTls) { + close(); + + sock = _sock; + + if (withTls) { + int initRet = gnutls_init(&tls, GNUTLS_SERVER); + if (initRet != GNUTLS_E_SUCCESS) { + printf("[Client::startService] gnutls_init borked\n"); + gnutls_perror(initRet); + close(); + return; + } + + // TODO: error check this + int ret; + const char *errPos; + + ret = gnutls_priority_set_direct(tls, "PERFORMANCE:%SERVER_PRECEDENCE", &errPos); + if (ret != GNUTLS_E_SUCCESS) { + printf("gnutls_priority_set_direct failure: %s\n", gnutls_strerror(ret)); + close(); + return; + } + + ret = gnutls_credentials_set(tls, GNUTLS_CRD_CERTIFICATE, clientCreds); + if (ret != GNUTLS_E_SUCCESS) { + printf("gnutls_credentials_set failure: %s\n", gnutls_strerror(ret)); + close(); + return; + } + + gnutls_certificate_server_set_request(tls, GNUTLS_CERT_IGNORE); + + gnutls_transport_set_int(tls, sock); + + gnutlsSessionInited = true; + + state = CS_TLS_HANDSHAKE; + } else { + state = CS_CONNECTED; + } +} + +void Client::close() { + SocketRWCommon::close(); + // TODO: add canSafelyKeepAlive var, check it here, to kill + // never-authed conns instantly + deadTime = time(NULL) + SESSION_KEEPALIVE; +} + +void Client::handleLine(char *line, int size) { + if (strncmp(line, "all ", 4) == 0) { + for (int i = 0; i < clientCount; i++) { + clients[i]->outputBuf.append(&line[4], size - 4); + clients[i]->outputBuf.append("\n", 1); + } + } else if (strcmp(line, "quit") == 0) { + quitFlag = true; + } else if (strncmp(line, "resolve ", 8) == 0) { + DNS::makeQuery(&line[8]); + } else if (strncmp(&line[1], "ddsrv ", 6) == 0) { + servers[serverCount] = new Server; + strcpy(servers[serverCount]->hostname, &line[7]); + servers[serverCount]->port = 1191; + servers[serverCount]->isTls = (line[0] == 's'); + serverCount++; + outputBuf.append("Your wish is my command!\n", 25); + } else if (strncmp(line, "connsrv", 7) == 0) { + int sid = line[7] - '0'; + servers[sid]->beginConnect(); + } else if (line[0] >= '0' && line[0] <= '9') { + int sid = line[0] - '0'; + servers[sid]->outputBuf.append(&line[1], size - 1); + servers[sid]->outputBuf.append("\r\n", 2); + } +} + +void Client::processInput() { + // Try to process as many lines as we can + // This function will be changed to custom protocol eventually + char *buf = inputBuf.data(); + int bufSize = inputBuf.size(); + int lineBegin = 0, pos = 0; + + while (pos < bufSize) { + if (buf[pos] == '\r' || buf[pos] == '\n') { + if (pos > lineBegin) { + buf[pos] = 0; + handleLine(&buf[lineBegin], pos - lineBegin); + } + + lineBegin = pos + 1; + } + + pos++; + } + + // If we managed to handle anything, lop it off the buffer + inputBuf.trimFromStart(pos); +} + + +void Server::handleLine(char *line, int size) { + for (int i = 0; i < clientCount; i++) { + clients[i]->outputBuf.append(line, size); + clients[i]->outputBuf.append("\n", 1); + } +} +void Server::processInput() { + // Try to process as many lines as we can + char *buf = inputBuf.data(); + int bufSize = inputBuf.size(); + int lineBegin = 0, pos = 0; + + while (pos < bufSize) { + if (buf[pos] == '\r' || buf[pos] == '\n') { + if (pos > lineBegin) { + buf[pos] = 0; + handleLine(&buf[lineBegin], pos - lineBegin); + } + + lineBegin = pos + 1; + } + + pos++; + } + + // If we managed to handle anything, lop it off the buffer + inputBuf.trimFromStart(pos); +} + + + +void Server::beginConnect() { + if (state == CS_DISCONNECTED) { + DNS::closeQuery(dnsQueryId); // just in case + dnsQueryId = DNS::makeQuery(hostname); + + if (dnsQueryId == -1) { + // TODO: better error reporting + printf("DNS query failed!\n"); + } else { + state = CS_WAITING_DNS; + } + } +} + +void Server::tryConnectPhase() { + if (state == CS_WAITING_DNS) { + in_addr result; + bool isError; + + if (DNS::checkQuery(dnsQueryId, &result, &isError)) { + DNS::closeQuery(dnsQueryId); + dnsQueryId = -1; + + if (isError) { + printf("DNS query failed at phase 2!\n"); + state = CS_DISCONNECTED; + } else { + // OK, if there was no error, we can go ahead and do this... + + sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (sock == -1) { + perror("[Server] Failed to socket()"); + close(); + return; + } + + int opts = fcntl(sock, F_GETFL); + if (opts < 0) { + perror("[Server] Could not get fcntl options"); + close(); + return; + } + opts |= O_NONBLOCK; + if (fcntl(sock, F_SETFL, opts) == -1) { + perror("[Server] Could not set fcntl options"); + close(); + return; + } + + // We have our non-blocking socket, let's try connecting! + sockaddr_in outAddr; + outAddr.sin_family = AF_INET; + outAddr.sin_port = htons(port); + outAddr.sin_addr.s_addr = result.s_addr; + + if (connect(sock, (sockaddr *)&outAddr, sizeof(outAddr)) == -1) { + if (errno == EINPROGRESS) { + state = CS_WAITING_CONNECT; + } else { + perror("[Server] Could not connect"); + } + } else { + // Whoa, we're connected? Neat. + connectionSuccessful(); + } + } + } + } +} + +void Server::connectionSuccessful() { + state = CS_CONNECTED; + + inputBuf.clear(); + outputBuf.clear(); + + // Do we need to do any TLS junk? + if (isTls) { + int initRet = gnutls_init(&tls, GNUTLS_CLIENT); + if (initRet != GNUTLS_E_SUCCESS) { + printf("[Server::connectionSuccessful] gnutls_init borked\n"); + gnutls_perror(initRet); + close(); + return; + } + + // TODO: error check this + const char *errPos; + gnutls_priority_set_direct(tls, "NORMAL", &errPos); + + gnutls_credentials_set(tls, GNUTLS_CRD_CERTIFICATE, serverCreds); + + gnutls_transport_set_int(tls, sock); + + gnutlsSessionInited = true; + state = CS_TLS_HANDSHAKE; + } +} + +void Server::close() { + SocketRWCommon::close(); + + if (dnsQueryId != -1) { + DNS::closeQuery(dnsQueryId); + dnsQueryId = -1; + } +} + + +int main(int argc, char **argv) { + clientCount = 0; + for (int i = 0; i < CLIENT_LIMIT; i++) + clients[i] = NULL; + + int ret; + ret = gnutls_global_init(); + if (ret != GNUTLS_E_SUCCESS) { + printf("gnutls_global_init failure: %s\n", gnutls_strerror(ret)); + return 1; + } + + unsigned int bits = gnutls_sec_param_to_pk_bits(GNUTLS_PK_DH, GNUTLS_SEC_PARAM_LEGACY); + + ret = gnutls_dh_params_init(&dh_params); + if (ret != GNUTLS_E_SUCCESS) { + printf("dh_params_init failure: %s\n", gnutls_strerror(ret)); + return 1; + } + + ret = gnutls_dh_params_generate2(dh_params, bits); + if (ret != GNUTLS_E_SUCCESS) { + printf("dh_params_generate2 failure: %s\n", gnutls_strerror(ret)); + return 1; + } + + gnutls_certificate_allocate_credentials(&clientCreds); + ret = gnutls_certificate_set_x509_key_file(clientCreds, "ssl_test.crt", "ssl_test.key", GNUTLS_X509_FMT_PEM); + if (ret != GNUTLS_E_SUCCESS) { + printf("set_x509_key_file failure: %s\n", gnutls_strerror(ret)); + return 1; + } + gnutls_certificate_set_dh_params(clientCreds, dh_params); + + gnutls_certificate_allocate_credentials(&serverCreds); + + DNS::start(); + + + // prepare the listen socket + int listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (listener == -1) { + perror("Could not create the listener socket"); + return 1; + } + + int v = 1; + if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR, &v, sizeof(v)) == -1) { + perror("Could not set SO_REUSEADDR"); + return 1; + } + + sockaddr_in listenAddr; + listenAddr.sin_family = AF_INET; + listenAddr.sin_port = htons(5454); + listenAddr.sin_addr.s_addr = htonl(INADDR_ANY); + + if (bind(listener, (sockaddr *)&listenAddr, sizeof(listenAddr)) == -1) { + perror("Could not bind to the listener socket"); + return 1; + } + + int opts = fcntl(listener, F_GETFL); + if (opts < 0) { + perror("Could not get fcntl options\n"); + return 1; + } + opts |= O_NONBLOCK; + if (fcntl(listener, F_SETFL, opts) == -1) { + perror("Could not set fcntl options\n"); + return 1; + } + + if (listen(listener, 10) == -1) { + perror("Could not listen()"); + return 1; + } + + printf("Listening!\n"); + + + // do stuff! + while (!quitFlag) { + fd_set readSet, writeSet; + FD_ZERO(&readSet); + FD_ZERO(&writeSet); + + int maxFD = listener; + FD_SET(listener, &readSet); + + time_t now = time(NULL); + + for (int i = 0; i < clientCount; i++) { + if (clients[i]->state == Client::CS_TLS_HANDSHAKE) + clients[i]->tryTLSHandshake(); + + if (clients[i]->sock != -1) { + if (clients[i]->sock > maxFD) + maxFD = clients[i]->sock; + + if (clients[i]->state == Client::CS_CONNECTED) + FD_SET(clients[i]->sock, &readSet); + if (clients[i]->outputBuf.size() > 0) + FD_SET(clients[i]->sock, &writeSet); + + } else { + // Outdated session, can we kill it? + if (now >= clients[i]->deadTime) { + printf("[%d] Session expired, deleting\n", now); + + // Yep. + Client *client = clients[i]; + + // If this is the last socket in the list, we can just + // decrement clientCount and all will be fine. + clientCount--; + + // Otherwise, we move that pointer into this slot, and + // we subtract one from i so that we'll process that slot + // on the next loop iteration. + if (i != clientCount) { + clients[i] = clients[clientCount]; + i--; + } + } + } + } + + for (int i = 0; i < serverCount; i++) { + if (servers[i]->state == Server::CS_WAITING_DNS) + servers[i]->tryConnectPhase(); + else if (servers[i]->state == Server::CS_TLS_HANDSHAKE) + servers[i]->tryTLSHandshake(); + + if (servers[i]->sock != -1) { + if (servers[i]->sock > maxFD) + maxFD = servers[i]->sock; + + if (servers[i]->state == Server::CS_CONNECTED) + FD_SET(servers[i]->sock, &readSet); + if (servers[i]->outputBuf.size() > 0 || servers[i]->state == Server::CS_WAITING_CONNECT) + FD_SET(servers[i]->sock, &writeSet); + } + } + + timeval timeout; + timeout.tv_sec = 1; + timeout.tv_usec = 0; + int numFDs = select(maxFD+1, &readSet, &writeSet, NULL, &timeout); + + now = time(NULL); + printf("[%lu select:%d]\n", now, numFDs); + + + // This is really not very DRY. + // Once I implement SSL properly, I'll look at the common bits between these + // two blocks and make it cleaner... + + + for (int i = 0; i < clientCount; i++) { + if (clients[i]->sock != -1) { + if (FD_ISSET(clients[i]->sock, &writeSet)) { + // What can we get rid of...? + Client *client = clients[i]; + ssize_t amount; + if (client->gnutlsSessionInited) { + amount = gnutls_record_send(client->tls, + client->outputBuf.data(), + client->outputBuf.size()); + } else { + amount = send(client->sock, + client->outputBuf.data(), + client->outputBuf.size(), + 0); + } + + if (amount > 0) { + printf("[%d] Wrote %d bytes\n", i, amount); + client->outputBuf.trimFromStart(amount); + } else if (amount == 0) + printf("Sent 0!\n"); + else if (amount < 0) + perror("Error while sending!"); + // Close connection in that case, if a fatal error occurs? + } + + + if (FD_ISSET(clients[i]->sock, &readSet) || (clients[i]->gnutlsSessionInited && gnutls_record_check_pending(clients[i]->tls) > 0)) { + Client *client = clients[i]; + + // Ensure we have at least 0x200 bytes space free + // (Up this, maybe?) + int bufSize = client->inputBuf.size(); + int requiredSize = bufSize + 0x200; + if (requiredSize < client->inputBuf.capacity()) + client->inputBuf.setCapacity(requiredSize); + + ssize_t amount; + if (client->gnutlsSessionInited) { + amount = gnutls_record_recv(client->tls, + &client->inputBuf.data()[bufSize], + 0x200); + } else { + amount = recv(client->sock, + &client->inputBuf.data()[bufSize], + 0x200, + 0); + } + + + if (amount > 0) { + // Yep, we have data + printf("[%d] Read %d bytes\n", i, amount); + client->inputBuf.resize(bufSize + amount); + + client->processInput(); + + } else if (amount == 0) { + printf("[%d] Read 0! Client closing.\n", i); + client->close(); + + } else if (amount < 0) + perror("Error while reading!"); + // Close connection in that case, if a fatal error occurs? + } + } + } + + + + for (int i = 0; i < serverCount; i++) { + if (servers[i]->sock != -1) { + if (FD_ISSET(servers[i]->sock, &writeSet)) { + Server *server = servers[i]; + + if (server->state == Server::CS_WAITING_CONNECT) { + // Welp, this means we're connected! + // Maybe. + // We might have an error condition, in which case, + // we're screwed. + bool didSucceed = false; + int sockErr; + socklen_t sockErrSize = sizeof(sockErr); + + if (getsockopt(server->sock, SOL_SOCKET, SO_ERROR, &sockErr, &sockErrSize) == 0) { + if (sockErr == 0) + didSucceed = true; + } + + if (didSucceed) { + // WE'RE IN fuck yeah + printf("[%d] Connection succeeded!\n", i); + server->connectionSuccessful(); + } else { + // Nope. Nuke it. + printf("[%d] Connection failed: %d\n", i, sockErr); + server->close(); + } + + } else { + // What can we get rid of...? + + ssize_t amount; + if (server->gnutlsSessionInited) { + amount = gnutls_record_send(server->tls, + server->outputBuf.data(), + server->outputBuf.size()); + } else { + amount = send(server->sock, + server->outputBuf.data(), + server->outputBuf.size(), + 0); + } + + if (amount > 0) { + printf("[%d] Wrote %d bytes\n", i, amount); + server->outputBuf.trimFromStart(amount); + } else if (amount == 0) + printf("Sent 0!\n"); + else if (amount < 0) + perror("Error while sending!"); + // Close connection in that case, if a fatal error occurs? + } + } + + + if (FD_ISSET(servers[i]->sock, &readSet) || (servers[i]->gnutlsSessionInited && gnutls_record_check_pending(servers[i]->tls) > 0)) { + Server *server = servers[i]; + + // Ensure we have at least 0x200 bytes space free + // (Up this, maybe?) + int bufSize = server->inputBuf.size(); + int requiredSize = bufSize + 0x200; + if (requiredSize < server->inputBuf.capacity()) + server->inputBuf.setCapacity(requiredSize); + + ssize_t amount; + if (server->gnutlsSessionInited) { + amount = gnutls_record_recv(server->tls, + &server->inputBuf.data()[bufSize], + 0x200); + } else { + amount = recv(server->sock, + &server->inputBuf.data()[bufSize], + 0x200, + 0); + } + + + if (amount > 0) { + // Yep, we have data + printf("[%d] Read %d bytes\n", i, amount); + server->inputBuf.resize(bufSize + amount); + + server->processInput(); + + } else if (amount == 0) { + printf("[%d] Read 0! Server closing.\n", i); + server->close(); + + } else if (amount < 0) + perror("Error while reading!"); + // Close connection in that case, if a fatal error occurs? + } + } + } + + + + if (FD_ISSET(listener, &readSet)) { + // Yay, we have a new connection + int sock = accept(listener, NULL, NULL); + + if (clientCount >= CLIENT_LIMIT) { + // We can't accept it. + printf("Too many connections, we can't accept this one. THIS SHOULD NEVER HAPPEN.\n"); + shutdown(sock, SHUT_RDWR); + close(sock); + } else { + // Create a new connection + printf("[%d] New connection, fd=%d\n", clientCount, sock); + + Client *client = new Client; + + clients[clientCount] = client; + ++clientCount; + + client->startService(sock, true); + } + } + } + + // Need to shut down all sockets here + for (int i = 0; i < serverCount; i++) + servers[i]->close(); + + for (int i = 0; i < clientCount; i++) + clients[i]->close(); + + shutdown(listener, SHUT_RDWR); + close(listener); +} + + diff --git a/dns.cpp b/dns.cpp new file mode 100644 index 0000000..056bb09 --- /dev/null +++ b/dns.cpp @@ -0,0 +1,159 @@ +#include "dns.h" +#include +#include +#include +#include + +#define DNS_QUERY_COUNT 20 +#define DNS_QUERY_NAME_SIZE 256 + +enum DNSQueryState { + DQS_FREE = 0, + DQS_WAITING = 1, + DQS_COMPLETE = 2, + DQS_ERROR = 3 +}; + +struct DNSQuery { + char name[DNS_QUERY_NAME_SIZE]; + in_addr result; + DNSQueryState status; + int version; +}; + +static DNSQuery dnsQueue[DNS_QUERY_COUNT]; +static pthread_t dnsThread; +static pthread_mutex_t dnsQueueMutex; +static pthread_cond_t dnsQueueCond; + +static void *dnsThreadProc(void *); + + +void DNS::start() { + pthread_mutex_init(&dnsQueueMutex, NULL); + pthread_cond_init(&dnsQueueCond, NULL); + + pthread_create(&dnsThread, NULL, &dnsThreadProc, NULL); + + for (int i = 0; i < DNS_QUERY_COUNT; i++) { + dnsQueue[i].status = DQS_FREE; + dnsQueue[i].version = 0; + } +} + +int DNS::makeQuery(const char *name) { + int id = -1; + + pthread_mutex_lock(&dnsQueueMutex); + + for (int i = 0; i < DNS_QUERY_COUNT; i++) { + if (dnsQueue[i].status == DQS_FREE) { + id = i; + break; + } + } + + if (id != -1) { + strncpy(dnsQueue[id].name, name, sizeof(dnsQueue[id].name)); + dnsQueue[id].name[sizeof(dnsQueue[id].name) - 1] = 0; + dnsQueue[id].status = DQS_WAITING; + dnsQueue[id].version++; + printf("[DNS::%d] New query: %s\n", id, dnsQueue[id].name); + } + + pthread_mutex_unlock(&dnsQueueMutex); + pthread_cond_signal(&dnsQueueCond); + + return id; +} + +void DNS::closeQuery(int id) { + if (id < 0 || id >= DNS_QUERY_COUNT) + return; + + pthread_mutex_lock(&dnsQueueMutex); + printf("[DNS::%d] Closing query\n", id); + dnsQueue[id].status = DQS_FREE; + pthread_mutex_unlock(&dnsQueueMutex); +} + +bool DNS::checkQuery(int id, in_addr *pResult, bool *pIsError) { + if (id < 0 || id >= DNS_QUERY_COUNT) + return false; + + pthread_mutex_lock(&dnsQueueMutex); + + bool finalResult = false; + if (dnsQueue[id].status == DQS_COMPLETE) { + finalResult = true; + *pIsError = false; + memcpy(pResult, &dnsQueue[id].result, sizeof(dnsQueue[id].result)); + } else if (dnsQueue[id].status == DQS_ERROR) { + finalResult = true; + *pIsError = true; + } + + pthread_mutex_unlock(&dnsQueueMutex); + + return finalResult; +} + + +void *dnsThreadProc(void *) { + pthread_mutex_lock(&dnsQueueMutex); + + for (;;) { + for (int i = 0; i < DNS_QUERY_COUNT; i++) { + if (dnsQueue[i].status == DQS_WAITING) { + char nameCopy[DNS_QUERY_NAME_SIZE]; + memcpy(nameCopy, dnsQueue[i].name, DNS_QUERY_NAME_SIZE); + + int versionCopy = dnsQueue[i].version; + + printf("[DNS::%d] Trying %s...\n", i, nameCopy); + + pthread_mutex_unlock(&dnsQueueMutex); + + addrinfo hints, *res; + + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + hints.ai_flags = AI_ADDRCONFIG | AI_V4MAPPED; + + int s = getaddrinfo(nameCopy, NULL, &hints, &res); + + pthread_mutex_lock(&dnsQueueMutex); + + // Before we write to the request, check that it hasn't been + // closed (and possibly replaced...!) by another thread + + if (dnsQueue[i].status == DQS_WAITING && dnsQueue[i].version == versionCopy) { + if (s == 0) { + // Only try the first one for now... + // Is this safe? Not sure. + dnsQueue[i].status = DQS_COMPLETE; + memcpy(&dnsQueue[i].result, &((sockaddr_in*)res->ai_addr)->sin_addr, sizeof(dnsQueue[i].result)); + + printf("[DNS::%d] Resolved %s to %x\n", i, dnsQueue[i].name, dnsQueue[i].result.s_addr); + } else { + dnsQueue[i].status = DQS_ERROR; + printf("[DNS::%d] Error condition: %d\n", i, s); + } + } else { + printf("[DNS::%d] Request was cancelled before getaddrinfo completed\n", i); + } + + if (s == 0) + freeaddrinfo(res); + } + } + + pthread_cond_wait(&dnsQueueCond, &dnsQueueMutex); + } + + pthread_mutex_unlock(&dnsQueueMutex); + return NULL; +} + diff --git a/dns.h b/dns.h new file mode 100644 index 0000000..78ec75f --- /dev/null +++ b/dns.h @@ -0,0 +1,15 @@ +#ifndef DNS_H +#define DNS_H + +#include +#include +#include + +namespace DNS { + void start(); + int makeQuery(const char *name); + void closeQuery(int id); + bool checkQuery(int id, in_addr *pResult, bool *pIsError); +} + +#endif /* DNS_H */ -- cgit v1.2.3