summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTreeki <treeki@gmail.com>2014-01-16 20:37:55 +0100
committerTreeki <treeki@gmail.com>2014-01-16 20:37:55 +0100
commitb61d25d0d80668faf2df22a6f04bc9509c3fae23 (patch)
tree6e4d85a137f4da79a06c95650b7f36f0fe9a6102
parent61811def989446464a15029a40442b2c789ca371 (diff)
downloadbounce4-b61d25d0d80668faf2df22a6f04bc9509c3fae23.tar.gz
bounce4-b61d25d0d80668faf2df22a6f04bc9509c3fae23.zip
refactor a bunch of socket stuff to be more DRY, fix client SSL handshakes blocking
-rw-r--r--core.cpp261
1 files changed, 111 insertions, 150 deletions
diff --git a/core.cpp b/core.cpp
index 472cdee..8a11433 100644
--- a/core.cpp
+++ b/core.cpp
@@ -19,6 +19,22 @@
#define SESSION_KEEPALIVE 5
+
+static bool setSocketNonBlocking(int sock) {
+ int opts = fcntl(sock, F_GETFL);
+ if (opts < 0) {
+ perror("Could not get fcntl options\n");
+ return false;
+ }
+ opts |= O_NONBLOCK;
+ if (fcntl(sock, F_SETFL, opts) == -1) {
+ perror("Could not set fcntl options\n");
+ return false;
+ }
+ return true;
+}
+
+
struct SocketRWCommon {
Buffer inputBuf, outputBuf;
@@ -48,16 +64,28 @@ struct SocketRWCommon {
void tryTLSHandshake();
virtual void close();
+
+ void readAction();
+ void writeAction();
+ bool hasTlsPendingData() {
+ // should un-inline this maybe?
+ if (gnutlsSessionInited)
+ return (gnutls_record_check_pending(tls) > 0);
+ else
+ return false;
+ }
+private:
+ virtual void processReadBuffer() = 0;
};
struct Client : SocketRWCommon {
time_t deadTime;
void startService(int _sock, bool withTls);
- void processInput();
void close();
private:
+ void processReadBuffer();
void handleLine(char *line, int size);
};
@@ -81,8 +109,8 @@ struct Server : SocketRWCommon {
void close();
- void processInput();
private:
+ void processReadBuffer();
void handleLine(char *line, int size);
};
@@ -135,12 +163,79 @@ void SocketRWCommon::close() {
}
}
+void SocketRWCommon::readAction() {
+ // Ensure we have at least 0x200 bytes space free
+ // (Up this, maybe?)
+ int bufSize = inputBuf.size();
+ int requiredSize = bufSize + 0x200;
+ if (requiredSize < inputBuf.capacity())
+ inputBuf.setCapacity(requiredSize);
+
+ ssize_t amount;
+ if (gnutlsSessionInited) {
+ amount = gnutls_record_recv(tls,
+ &inputBuf.data()[bufSize],
+ 0x200);
+ } else {
+ amount = recv(sock,
+ &inputBuf.data()[bufSize],
+ 0x200,
+ 0);
+ }
+
+
+ if (amount > 0) {
+ // Yep, we have data
+ printf("[fd=%d] Read %d bytes\n", sock, amount);
+ inputBuf.resize(bufSize + amount);
+
+ processReadBuffer();
+
+ } else if (amount == 0) {
+ printf("[fd=%d] Read 0! Socket closing.\n", sock);
+ close();
+
+ } else if (amount < 0)
+ perror("Error while reading!");
+ // Close connection in that case, if a fatal error occurs?
+}
+
+void SocketRWCommon::writeAction() {
+ // What can we get rid of...?
+ ssize_t amount;
+ if (gnutlsSessionInited) {
+ amount = gnutls_record_send(tls,
+ outputBuf.data(),
+ outputBuf.size());
+ } else {
+ amount = send(sock,
+ outputBuf.data(),
+ outputBuf.size(),
+ 0);
+ }
+
+ if (amount > 0) {
+ printf("[fd=%d] Wrote %d bytes\n", sock, amount);
+ outputBuf.trimFromStart(amount);
+ } else if (amount == 0)
+ printf("Sent 0!\n");
+ else if (amount < 0)
+ perror("Error while sending!");
+ // Close connection in that case, if a fatal error occurs?
+}
+
void Client::startService(int _sock, bool withTls) {
close();
sock = _sock;
+ if (!setSocketNonBlocking(sock)) {
+ perror("[Client::startService] Could not set non-blocking");
+ close();
+ return;
+ }
+
if (withTls) {
int initRet = gnutls_init(&tls, GNUTLS_SERVER);
if (initRet != GNUTLS_E_SUCCESS) {
@@ -175,6 +270,8 @@ void Client::startService(int _sock, bool withTls) {
gnutlsSessionInited = true;
state = CS_TLS_HANDSHAKE;
+
+ printf("[fd=%d] preparing for TLS handshake\n", sock);
} else {
state = CS_CONNECTED;
}
@@ -214,7 +311,7 @@ void Client::handleLine(char *line, int size) {
}
}
-void Client::processInput() {
+void Client::processReadBuffer() {
// Try to process as many lines as we can
// This function will be changed to custom protocol eventually
char *buf = inputBuf.data();
@@ -245,7 +342,7 @@ void Server::handleLine(char *line, int size) {
clients[i]->outputBuf.append("\n", 1);
}
}
-void Server::processInput() {
+void Server::processReadBuffer() {
// Try to process as many lines as we can
char *buf = inputBuf.data();
int bufSize = inputBuf.size();
@@ -306,15 +403,8 @@ void Server::tryConnectPhase() {
return;
}
- int opts = fcntl(sock, F_GETFL);
- if (opts < 0) {
- perror("[Server] Could not get fcntl options");
- close();
- return;
- }
- opts |= O_NONBLOCK;
- if (fcntl(sock, F_SETFL, opts) == -1) {
- perror("[Server] Could not set fcntl options");
+ if (!setSocketNonBlocking(sock)) {
+ perror("[Server] Could not set non-blocking");
close();
return;
}
@@ -441,14 +531,8 @@ int main(int argc, char **argv) {
return 1;
}
- int opts = fcntl(listener, F_GETFL);
- if (opts < 0) {
- perror("Could not get fcntl options\n");
- return 1;
- }
- opts |= O_NONBLOCK;
- if (fcntl(listener, F_SETFL, opts) == -1) {
- perror("Could not set fcntl options\n");
+ if (!setSocketNonBlocking(listener)) {
+ perror("[Listener] Could not set non-blocking");
return 1;
}
@@ -533,82 +617,15 @@ int main(int argc, char **argv) {
printf("[%lu select:%d]\n", now, numFDs);
- // This is really not very DRY.
- // Once I implement SSL properly, I'll look at the common bits between these
- // two blocks and make it cleaner...
-
-
for (int i = 0; i < clientCount; i++) {
if (clients[i]->sock != -1) {
- if (FD_ISSET(clients[i]->sock, &writeSet)) {
- // What can we get rid of...?
- Client *client = clients[i];
- ssize_t amount;
- if (client->gnutlsSessionInited) {
- amount = gnutls_record_send(client->tls,
- client->outputBuf.data(),
- client->outputBuf.size());
- } else {
- amount = send(client->sock,
- client->outputBuf.data(),
- client->outputBuf.size(),
- 0);
- }
-
- if (amount > 0) {
- printf("[%d] Wrote %d bytes\n", i, amount);
- client->outputBuf.trimFromStart(amount);
- } else if (amount == 0)
- printf("Sent 0!\n");
- else if (amount < 0)
- perror("Error while sending!");
- // Close connection in that case, if a fatal error occurs?
- }
-
-
- if (FD_ISSET(clients[i]->sock, &readSet) || (clients[i]->gnutlsSessionInited && gnutls_record_check_pending(clients[i]->tls) > 0)) {
- Client *client = clients[i];
-
- // Ensure we have at least 0x200 bytes space free
- // (Up this, maybe?)
- int bufSize = client->inputBuf.size();
- int requiredSize = bufSize + 0x200;
- if (requiredSize < client->inputBuf.capacity())
- client->inputBuf.setCapacity(requiredSize);
-
- ssize_t amount;
- if (client->gnutlsSessionInited) {
- amount = gnutls_record_recv(client->tls,
- &client->inputBuf.data()[bufSize],
- 0x200);
- } else {
- amount = recv(client->sock,
- &client->inputBuf.data()[bufSize],
- 0x200,
- 0);
- }
-
-
- if (amount > 0) {
- // Yep, we have data
- printf("[%d] Read %d bytes\n", i, amount);
- client->inputBuf.resize(bufSize + amount);
-
- client->processInput();
-
- } else if (amount == 0) {
- printf("[%d] Read 0! Client closing.\n", i);
- client->close();
-
- } else if (amount < 0)
- perror("Error while reading!");
- // Close connection in that case, if a fatal error occurs?
- }
+ if (FD_ISSET(clients[i]->sock, &writeSet))
+ clients[i]->writeAction();
+ if (FD_ISSET(clients[i]->sock, &readSet) || clients[i]->hasTlsPendingData())
+ clients[i]->readAction();
}
}
-
-
for (int i = 0; i < serverCount; i++) {
if (servers[i]->sock != -1) {
if (FD_ISSET(servers[i]->sock, &writeSet)) {
@@ -639,69 +656,13 @@ int main(int argc, char **argv) {
}
} else {
- // What can we get rid of...?
-
- ssize_t amount;
- if (server->gnutlsSessionInited) {
- amount = gnutls_record_send(server->tls,
- server->outputBuf.data(),
- server->outputBuf.size());
- } else {
- amount = send(server->sock,
- server->outputBuf.data(),
- server->outputBuf.size(),
- 0);
- }
-
- if (amount > 0) {
- printf("[%d] Wrote %d bytes\n", i, amount);
- server->outputBuf.trimFromStart(amount);
- } else if (amount == 0)
- printf("Sent 0!\n");
- else if (amount < 0)
- perror("Error while sending!");
- // Close connection in that case, if a fatal error occurs?
+ server->writeAction();
}
}
if (FD_ISSET(servers[i]->sock, &readSet) || (servers[i]->gnutlsSessionInited && gnutls_record_check_pending(servers[i]->tls) > 0)) {
- Server *server = servers[i];
-
- // Ensure we have at least 0x200 bytes space free
- // (Up this, maybe?)
- int bufSize = server->inputBuf.size();
- int requiredSize = bufSize + 0x200;
- if (requiredSize < server->inputBuf.capacity())
- server->inputBuf.setCapacity(requiredSize);
-
- ssize_t amount;
- if (server->gnutlsSessionInited) {
- amount = gnutls_record_recv(server->tls,
- &server->inputBuf.data()[bufSize],
- 0x200);
- } else {
- amount = recv(server->sock,
- &server->inputBuf.data()[bufSize],
- 0x200,
- 0);
- }
-
-
- if (amount > 0) {
- // Yep, we have data
- printf("[%d] Read %d bytes\n", i, amount);
- server->inputBuf.resize(bufSize + amount);
-
- server->processInput();
-
- } else if (amount == 0) {
- printf("[%d] Read 0! Server closing.\n", i);
- server->close();
-
- } else if (amount < 0)
- perror("Error while reading!");
- // Close connection in that case, if a fatal error occurs?
+ servers[i]->readAction();
}
}
}