diff options
author | Treeki <treeki@gmail.com> | 2014-01-16 20:37:55 +0100 |
---|---|---|
committer | Treeki <treeki@gmail.com> | 2014-01-16 20:37:55 +0100 |
commit | b61d25d0d80668faf2df22a6f04bc9509c3fae23 (patch) | |
tree | 6e4d85a137f4da79a06c95650b7f36f0fe9a6102 /core.cpp | |
parent | 61811def989446464a15029a40442b2c789ca371 (diff) | |
download | bounce4-b61d25d0d80668faf2df22a6f04bc9509c3fae23.tar.gz bounce4-b61d25d0d80668faf2df22a6f04bc9509c3fae23.zip |
refactor a bunch of socket stuff to be more DRY, fix client SSL handshakes blocking
Diffstat (limited to 'core.cpp')
-rw-r--r-- | core.cpp | 261 |
1 files changed, 111 insertions, 150 deletions
@@ -19,6 +19,22 @@ #define SESSION_KEEPALIVE 5 + +static bool setSocketNonBlocking(int sock) { + int opts = fcntl(sock, F_GETFL); + if (opts < 0) { + perror("Could not get fcntl options\n"); + return false; + } + opts |= O_NONBLOCK; + if (fcntl(sock, F_SETFL, opts) == -1) { + perror("Could not set fcntl options\n"); + return false; + } + return true; +} + + struct SocketRWCommon { Buffer inputBuf, outputBuf; @@ -48,16 +64,28 @@ struct SocketRWCommon { void tryTLSHandshake(); virtual void close(); + + void readAction(); + void writeAction(); + bool hasTlsPendingData() { + // should un-inline this maybe? + if (gnutlsSessionInited) + return (gnutls_record_check_pending(tls) > 0); + else + return false; + } +private: + virtual void processReadBuffer() = 0; }; struct Client : SocketRWCommon { time_t deadTime; void startService(int _sock, bool withTls); - void processInput(); void close(); private: + void processReadBuffer(); void handleLine(char *line, int size); }; @@ -81,8 +109,8 @@ struct Server : SocketRWCommon { void close(); - void processInput(); private: + void processReadBuffer(); void handleLine(char *line, int size); }; @@ -135,12 +163,79 @@ void SocketRWCommon::close() { } } +void SocketRWCommon::readAction() { + // Ensure we have at least 0x200 bytes space free + // (Up this, maybe?) + int bufSize = inputBuf.size(); + int requiredSize = bufSize + 0x200; + if (requiredSize < inputBuf.capacity()) + inputBuf.setCapacity(requiredSize); + + ssize_t amount; + if (gnutlsSessionInited) { + amount = gnutls_record_recv(tls, + &inputBuf.data()[bufSize], + 0x200); + } else { + amount = recv(sock, + &inputBuf.data()[bufSize], + 0x200, + 0); + } + + + if (amount > 0) { + // Yep, we have data + printf("[fd=%d] Read %d bytes\n", sock, amount); + inputBuf.resize(bufSize + amount); + + processReadBuffer(); + + } else if (amount == 0) { + printf("[fd=%d] Read 0! Socket closing.\n", sock); + close(); + + } else if (amount < 0) + perror("Error while reading!"); + // Close connection in that case, if a fatal error occurs? +} + +void SocketRWCommon::writeAction() { + // What can we get rid of...? + ssize_t amount; + if (gnutlsSessionInited) { + amount = gnutls_record_send(tls, + outputBuf.data(), + outputBuf.size()); + } else { + amount = send(sock, + outputBuf.data(), + outputBuf.size(), + 0); + } + + if (amount > 0) { + printf("[fd=%d] Wrote %d bytes\n", sock, amount); + outputBuf.trimFromStart(amount); + } else if (amount == 0) + printf("Sent 0!\n"); + else if (amount < 0) + perror("Error while sending!"); + // Close connection in that case, if a fatal error occurs? +} + void Client::startService(int _sock, bool withTls) { close(); sock = _sock; + if (!setSocketNonBlocking(sock)) { + perror("[Client::startService] Could not set non-blocking"); + close(); + return; + } + if (withTls) { int initRet = gnutls_init(&tls, GNUTLS_SERVER); if (initRet != GNUTLS_E_SUCCESS) { @@ -175,6 +270,8 @@ void Client::startService(int _sock, bool withTls) { gnutlsSessionInited = true; state = CS_TLS_HANDSHAKE; + + printf("[fd=%d] preparing for TLS handshake\n", sock); } else { state = CS_CONNECTED; } @@ -214,7 +311,7 @@ void Client::handleLine(char *line, int size) { } } -void Client::processInput() { +void Client::processReadBuffer() { // Try to process as many lines as we can // This function will be changed to custom protocol eventually char *buf = inputBuf.data(); @@ -245,7 +342,7 @@ void Server::handleLine(char *line, int size) { clients[i]->outputBuf.append("\n", 1); } } -void Server::processInput() { +void Server::processReadBuffer() { // Try to process as many lines as we can char *buf = inputBuf.data(); int bufSize = inputBuf.size(); @@ -306,15 +403,8 @@ void Server::tryConnectPhase() { return; } - int opts = fcntl(sock, F_GETFL); - if (opts < 0) { - perror("[Server] Could not get fcntl options"); - close(); - return; - } - opts |= O_NONBLOCK; - if (fcntl(sock, F_SETFL, opts) == -1) { - perror("[Server] Could not set fcntl options"); + if (!setSocketNonBlocking(sock)) { + perror("[Server] Could not set non-blocking"); close(); return; } @@ -441,14 +531,8 @@ int main(int argc, char **argv) { return 1; } - int opts = fcntl(listener, F_GETFL); - if (opts < 0) { - perror("Could not get fcntl options\n"); - return 1; - } - opts |= O_NONBLOCK; - if (fcntl(listener, F_SETFL, opts) == -1) { - perror("Could not set fcntl options\n"); + if (!setSocketNonBlocking(listener)) { + perror("[Listener] Could not set non-blocking"); return 1; } @@ -533,82 +617,15 @@ int main(int argc, char **argv) { printf("[%lu select:%d]\n", now, numFDs); - // This is really not very DRY. - // Once I implement SSL properly, I'll look at the common bits between these - // two blocks and make it cleaner... - - for (int i = 0; i < clientCount; i++) { if (clients[i]->sock != -1) { - if (FD_ISSET(clients[i]->sock, &writeSet)) { - // What can we get rid of...? - Client *client = clients[i]; - ssize_t amount; - if (client->gnutlsSessionInited) { - amount = gnutls_record_send(client->tls, - client->outputBuf.data(), - client->outputBuf.size()); - } else { - amount = send(client->sock, - client->outputBuf.data(), - client->outputBuf.size(), - 0); - } - - if (amount > 0) { - printf("[%d] Wrote %d bytes\n", i, amount); - client->outputBuf.trimFromStart(amount); - } else if (amount == 0) - printf("Sent 0!\n"); - else if (amount < 0) - perror("Error while sending!"); - // Close connection in that case, if a fatal error occurs? - } - - - if (FD_ISSET(clients[i]->sock, &readSet) || (clients[i]->gnutlsSessionInited && gnutls_record_check_pending(clients[i]->tls) > 0)) { - Client *client = clients[i]; - - // Ensure we have at least 0x200 bytes space free - // (Up this, maybe?) - int bufSize = client->inputBuf.size(); - int requiredSize = bufSize + 0x200; - if (requiredSize < client->inputBuf.capacity()) - client->inputBuf.setCapacity(requiredSize); - - ssize_t amount; - if (client->gnutlsSessionInited) { - amount = gnutls_record_recv(client->tls, - &client->inputBuf.data()[bufSize], - 0x200); - } else { - amount = recv(client->sock, - &client->inputBuf.data()[bufSize], - 0x200, - 0); - } - - - if (amount > 0) { - // Yep, we have data - printf("[%d] Read %d bytes\n", i, amount); - client->inputBuf.resize(bufSize + amount); - - client->processInput(); - - } else if (amount == 0) { - printf("[%d] Read 0! Client closing.\n", i); - client->close(); - - } else if (amount < 0) - perror("Error while reading!"); - // Close connection in that case, if a fatal error occurs? - } + if (FD_ISSET(clients[i]->sock, &writeSet)) + clients[i]->writeAction(); + if (FD_ISSET(clients[i]->sock, &readSet) || clients[i]->hasTlsPendingData()) + clients[i]->readAction(); } } - - for (int i = 0; i < serverCount; i++) { if (servers[i]->sock != -1) { if (FD_ISSET(servers[i]->sock, &writeSet)) { @@ -639,69 +656,13 @@ int main(int argc, char **argv) { } } else { - // What can we get rid of...? - - ssize_t amount; - if (server->gnutlsSessionInited) { - amount = gnutls_record_send(server->tls, - server->outputBuf.data(), - server->outputBuf.size()); - } else { - amount = send(server->sock, - server->outputBuf.data(), - server->outputBuf.size(), - 0); - } - - if (amount > 0) { - printf("[%d] Wrote %d bytes\n", i, amount); - server->outputBuf.trimFromStart(amount); - } else if (amount == 0) - printf("Sent 0!\n"); - else if (amount < 0) - perror("Error while sending!"); - // Close connection in that case, if a fatal error occurs? + server->writeAction(); } } if (FD_ISSET(servers[i]->sock, &readSet) || (servers[i]->gnutlsSessionInited && gnutls_record_check_pending(servers[i]->tls) > 0)) { - Server *server = servers[i]; - - // Ensure we have at least 0x200 bytes space free - // (Up this, maybe?) - int bufSize = server->inputBuf.size(); - int requiredSize = bufSize + 0x200; - if (requiredSize < server->inputBuf.capacity()) - server->inputBuf.setCapacity(requiredSize); - - ssize_t amount; - if (server->gnutlsSessionInited) { - amount = gnutls_record_recv(server->tls, - &server->inputBuf.data()[bufSize], - 0x200); - } else { - amount = recv(server->sock, - &server->inputBuf.data()[bufSize], - 0x200, - 0); - } - - - if (amount > 0) { - // Yep, we have data - printf("[%d] Read %d bytes\n", i, amount); - server->inputBuf.resize(bufSize + amount); - - server->processInput(); - - } else if (amount == 0) { - printf("[%d] Read 0! Server closing.\n", i); - server->close(); - - } else if (amount < 0) - perror("Error while reading!"); - // Close connection in that case, if a fatal error occurs? + servers[i]->readAction(); } } } |