summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--buffer.h57
-rwxr-xr-xbuild.sh2
-rw-r--r--core.cpp236
-rw-r--r--core.h24
-rw-r--r--python_client.py91
5 files changed, 331 insertions, 79 deletions
diff --git a/buffer.h b/buffer.h
index ef6d027..df063e9 100644
--- a/buffer.h
+++ b/buffer.h
@@ -10,6 +10,7 @@ private:
bool m_freeBuffer;
int m_size;
int m_capacity;
+ int m_readPointer;
char m_preAllocBuffer[0x200];
public:
@@ -18,6 +19,7 @@ public:
m_freeBuffer = false;
m_size = 0;
m_capacity = sizeof(m_preAllocBuffer);
+ m_readPointer = 0;
}
~Buffer() {
@@ -27,6 +29,17 @@ public:
}
}
+ void useExistingBuffer(char *data, int size) {
+ if (m_freeBuffer)
+ delete[] m_data;
+
+ m_data = data;
+ m_freeBuffer = false;
+ m_size = size;
+ m_capacity = size;
+ m_readPointer = 0;
+ }
+
char *data() const { return m_data; }
int size() const { return m_size; }
int capacity() const { return m_capacity; }
@@ -66,6 +79,9 @@ public:
memcpy(&m_data[m_size], data, size);
m_size += size;
}
+ void append(const Buffer &buf) {
+ append(buf.data(), buf.size());
+ }
void resize(int size) {
if (size > m_capacity)
setCapacity(size + 0x100);
@@ -86,11 +102,11 @@ public:
void writeU32(uint32_t v) { append((const char *)&v, 4); }
- void writeU16(uint32_t v) { append((const char *)&v, 2); }
- void writeU8(uint32_t v) { append((const char *)&v, 1); }
- void writeS32(uint32_t v) { append((const char *)&v, 4); }
- void writeS16(uint32_t v) { append((const char *)&v, 2); }
- void writeS8(uint32_t v) { append((const char *)&v, 1); }
+ void writeU16(uint16_t v) { append((const char *)&v, 2); }
+ void writeU8(uint8_t v) { append((const char *)&v, 1); }
+ void writeS32(int32_t v) { append((const char *)&v, 4); }
+ void writeS16(int16_t v) { append((const char *)&v, 2); }
+ void writeS8(int8_t v) { append((const char *)&v, 1); }
void writeStr(const char *data, int size = -1) {
if (size == -1)
@@ -98,6 +114,37 @@ public:
writeU32(size);
append(data, size);
}
+
+ void readSeek(int pos) {
+ m_readPointer = pos;
+ }
+ int readTell() const {
+ return m_readPointer;
+ }
+ bool readRemains(int size) const {
+ if ((size > 0) && ((m_readPointer + size) <= m_size))
+ return true;
+ return false;
+ }
+ void read(char *output, int size) {
+ if ((m_readPointer + size) > m_size) {
+ // Not enough space to read the whole thing...!
+ int copy = m_size - m_readPointer;
+ if (copy > 0)
+ memcpy(output, &m_data[m_readPointer], copy);
+ memset(&output[copy], 0, size - copy);
+ m_readPointer = m_size;
+ } else {
+ memcpy(output, &m_data[m_readPointer], size);
+ m_readPointer += size;
+ }
+ }
+ uint32_t readU32() { uint32_t v; read((char *)&v, 4); return v; }
+ uint16_t readU16() { uint16_t v; read((char *)&v, 2); return v; }
+ uint8_t readU8() { uint8_t v; read((char *)&v, 1); return v; }
+ int32_t readS32() { int32_t v; read((char *)&v, 4); return v; }
+ int16_t readS16() { int16_t v; read((char *)&v, 2); return v; }
+ int8_t readS8() { int8_t v; read((char *)&v, 1); return v; }
};
#endif /* BUFFER_H */
diff --git a/build.sh b/build.sh
index c45ee0b..8070ce6 100755
--- a/build.sh
+++ b/build.sh
@@ -1,4 +1,4 @@
#!/bin/sh
mkdir -p binary
-g++ -o binary/nb4 core.cpp dns.cpp -lgnutls -pthread
+g++ -o binary/nb4 -std=c++11 core.cpp dns.cpp -lgnutls -pthread
diff --git a/core.cpp b/core.cpp
index f969033..17092c1 100644
--- a/core.cpp
+++ b/core.cpp
@@ -42,6 +42,14 @@ static bool setSocketNonBlocking(int sock) {
}
+static bool isNullSessionKey(uint8_t *key) {
+ for (int i = 0; i < SESSION_KEY_SIZE; i++)
+ if (key[i] != 0)
+ return true;
+
+ return false;
+}
+
static Client *findClientWithKey(uint8_t *key) {
for (int i = 0; i < clientCount; i++)
if (!memcmp(clients[i]->sessionKey, key, SESSION_KEY_SIZE))
@@ -282,91 +290,143 @@ void Client::generateSessionKey() {
}
}
+void Client::clearCachedPackets(int maxID) {
+ packetCache.remove_if([maxID](Packet *&pkt) {
+ return (pkt->id <= maxID);
+ });
+}
+
void Client::handleLine(char *line, int size) {
// This is a terrible mess that will be replaced shortly
- if (strncmp(line, "all ", 4) == 0) {
- for (int i = 0; i < clientCount; i++) {
- clients[i]->outputBuf.append(&line[4], size - 4);
- clients[i]->outputBuf.append("\n", 1);
+ if (authState == AS_AUTHED) {
+ if (strncmp(line, "all ", 4) == 0) {
+ for (int i = 0; i < clientCount; i++) {
+ clients[i]->outputBuf.append(&line[4], size - 4);
+ clients[i]->outputBuf.append("\n", 1);
+ }
+ } else if (strcmp(line, "quit") == 0) {
+ quitFlag = true;
+ } else if (strncmp(line, "resolve ", 8) == 0) {
+ DNS::makeQuery(&line[8]);
+ } else if (strncmp(&line[1], "ddsrv ", 6) == 0) {
+ servers[serverCount] = new Server;
+ strcpy(servers[serverCount]->ircHostname, &line[7]);
+ servers[serverCount]->ircPort = 1191;
+ servers[serverCount]->ircUseTls = (line[0] == 's');
+ serverCount++;
+ outputBuf.append("Your wish is my command!\n", 25);
+ } else if (strncmp(line, "connsrv", 7) == 0) {
+ int sid = line[7] - '0';
+ servers[sid]->beginConnect();
+ } else if (line[0] >= '0' && line[0] <= '9') {
+ int sid = line[0] - '0';
+ servers[sid]->outputBuf.append(&line[1], size - 1);
+ servers[sid]->outputBuf.append("\r\n", 2);
}
- } else if (strcmp(line, "quit") == 0) {
- quitFlag = true;
- } else if (strncmp(line, "resolve ", 8) == 0) {
- DNS::makeQuery(&line[8]);
- } else if (strncmp(&line[1], "ddsrv ", 6) == 0) {
- servers[serverCount] = new Server;
- strcpy(servers[serverCount]->ircHostname, &line[7]);
- servers[serverCount]->ircPort = 1191;
- servers[serverCount]->ircUseTls = (line[0] == 's');
- serverCount++;
- outputBuf.append("Your wish is my command!\n", 25);
- } else if (strncmp(line, "connsrv", 7) == 0) {
- int sid = line[7] - '0';
- servers[sid]->beginConnect();
- } else if (line[0] >= '0' && line[0] <= '9') {
- int sid = line[0] - '0';
- servers[sid]->outputBuf.append(&line[1], size - 1);
- servers[sid]->outputBuf.append("\r\n", 2);
- } else if (strncmp(line, "login", 5) == 0) {
- if (line[5] == 0) {
- // no session key
+ } else {
+ }
+}
+
+void Client::handlePacket(Packet::Type type, char *data, int size) {
+ Buffer pkt;
+ pkt.useExistingBuffer(data, size);
+
+ printf("[fd=%d] Packet : type %d, size %d\n", sock, type, size);
+
+ if (authState == AS_LOGIN_WAIT) {
+ if (type == Packet::C2B_OOB_LOGIN) {
+ int error = 0;
+
+ uint32_t protocolVersion = pkt.readU32();
+ if (protocolVersion != PROTOCOL_VERSION)
+ error = 1;
+
+ uint32_t lastReceivedByClient = pkt.readU32();
+
+ if (!pkt.readRemains(SESSION_KEY_SIZE))
+ error = 2;
+
+ uint8_t reqKey[SESSION_KEY_SIZE];
+ pkt.read((char *)reqKey, SESSION_KEY_SIZE);
+
+ if (!isNullSessionKey(reqKey)) {
+ Client *other = findClientWithKey(reqKey);
+ if (other && other->authState == AS_AUTHED) {
+ // Yep, we can go!
+ other->resumeSession(this, lastReceivedByClient);
+ return;
+ }
+ }
+
+ // If we got here, it means we couldn't resume the session.
+ // Start over.
generateSessionKey();
authState = AS_AUTHED;
- outputBuf.append("OK ", 3);
- for (int i = 0; i < SESSION_KEY_SIZE; i++) {
- char bits[4];
- sprintf(bits, "%02x", sessionKey[i]);
- outputBuf.append(bits, 2);
- }
- outputBuf.append("\n", 1);
- } else {
- // This is awful. Don't care about writing clean code
- // for something I'm going to throw away shortly...
- uint8_t pkey[SESSION_KEY_SIZE];
- for (int i = 0; i < SESSION_KEY_SIZE; i++) {
- char highc = line[6 + (i * 2)];
- char lowc = line[7 + (i * 2)];
- int high = ((highc >= '0') && (highc <= '9')) ? (highc - '0') : (highc - 'a' + 10);
- int low = ((lowc >= '0') && (lowc <= '9')) ? (lowc - '0') : (lowc - 'a' + 10);
- pkey[i] = (high << 4) | low;
- }
- Client *other = findClientWithKey(pkey);
- if (other && other->authState == AS_AUTHED)
- other->stealConnection(this);
+ Buffer pkt;
+ pkt.append((char *)sessionKey, SESSION_KEY_SIZE);
+ sendPacket(Packet::B2C_OOB_LOGIN_SUCCESS, pkt);
+
+ } else {
+ printf("[fd=%d] Unrecognised packet in AS_LOGIN_WAIT authstate: type %d, size %d\n",
+ sock, type, size);
+ }
+ } else if (authState == AS_AUTHED) {
+ //if (type == Packet::) {
+ /*} else */{
+ printf("[fd=%d] Unrecognised packet in AS_AUTHED authstate: type %d, size %d\n",
+ sock, type, size);
}
}
}
void Client::processReadBuffer() {
- // Try to process as many lines as we can
- // This function will be changed to custom protocol eventually
- char *buf = inputBuf.data();
- int bufSize = inputBuf.size();
- int lineBegin = 0, pos = 0;
+ // Try to process as many packets as we have in inputBuf
- while (pos < bufSize) {
- if (buf[pos] == '\r' || buf[pos] == '\n') {
- if (pos > lineBegin) {
- buf[pos] = 0;
- readBufPosition = pos + 1;
- handleLine(&buf[lineBegin], pos - lineBegin);
- }
+ // Basic header is 8 bytes
+ // Extended (non-OOB) header is 16 bytes
+ inputBuf.readSeek(0);
+ readBufPosition = 0;
- lineBegin = pos + 1;
+ while (inputBuf.readRemains(8)) {
+ // We have 8 bytes, so we can try to read a basic header
+ Packet::Type type = (Packet::Type)inputBuf.readU16();
+ int reserved = inputBuf.readU16();
+ uint32_t packetSize = inputBuf.readU32();
+
+ // Do we now have the whole packet in memory...?
+ int extHeaderSize = (type & Packet::T_OUT_OF_BAND_FLAG) ? 0 : 8;
+
+ if (!inputBuf.readRemains(packetSize + extHeaderSize))
+ break;
+
+
+ if (!(type & Packet::T_OUT_OF_BAND_FLAG)) {
+ // Handle packet system things for non-OOB packets
+ uint32_t packetID = inputBuf.readU32();
+ uint32_t lastReceivedByClient = inputBuf.readU32();
+
+ lastReceivedPacketID = packetID;
+ clearCachedPackets(lastReceivedByClient);
}
- pos++;
+ // Yep, we can process it!
+
+ // Save the position of the next packet
+ readBufPosition = inputBuf.readTell() + packetSize;
+ handlePacket(type, &inputBuf.data()[inputBuf.readTell()], packetSize);
+
+ inputBuf.readSeek(readBufPosition);
}
// If we managed to handle anything, lop it off the buffer
- inputBuf.trimFromStart(pos);
+ inputBuf.trimFromStart(readBufPosition);
readBufPosition = 0;
}
-void Client::stealConnection(Client *other) {
+void Client::resumeSession(Client *other, int lastReceivedByClient) {
close();
inputBuf.clear();
@@ -388,8 +448,58 @@ void Client::stealConnection(Client *other) {
other->tlsActive = false;
other->state = CS_DISCONNECTED;
other->close();
+
+ // Now send them everything we've got!
+ Buffer pkt;
+ pkt.writeU32(lastReceivedPacketID);
+ sendPacket(Packet::B2C_OOB_SESSION_RESUMED, pkt);
+
+ clearCachedPackets(lastReceivedByClient);
+
+ std::list<Packet*>::iterator
+ i = packetCache.begin(),
+ e = packetCache.end();
+
+ for (; i != e; ++i)
+ sendPacketOverWire(*i);
}
+void Client::sendPacket(Packet::Type type, const Buffer &data, bool allowUnauthed) {
+ Packet *packet = new Packet;
+ packet->type = type;
+ packet->data.append(data);
+
+ if (type & Packet::T_OUT_OF_BAND_FLAG) {
+ packet->id = 0;
+ } else {
+ packet->id = nextPacketID;
+ nextPacketID++;
+ }
+
+ if (state == CS_CONNECTED)
+ if (authState == AS_AUTHED || allowUnauthed)
+ sendPacketOverWire(packet);
+
+ if (type & Packet::T_OUT_OF_BAND_FLAG)
+ delete packet;
+ else
+ packetCache.push_back(packet);
+}
+
+void Client::sendPacketOverWire(const Packet *packet) {
+ Buffer header;
+ header.writeU16(packet->type);
+ header.writeU16(0);
+ header.writeU32(packet->data.size());
+
+ if (!(packet->type & Packet::T_OUT_OF_BAND_FLAG)) {
+ header.writeU32(packet->id);
+ header.writeU32(lastReceivedPacketID);
+ }
+
+ outputBuf.append(header);
+ outputBuf.append(packet->data);
+}
@@ -759,7 +869,7 @@ int main(int argc, char **argv) {
clients[clientCount] = client;
++clientCount;
- client->startService(sock, true);
+ client->startService(sock, SERVE_VIA_TLS);
}
}
}
diff --git a/core.h b/core.h
index e38d0f6..a01904d 100644
--- a/core.h
+++ b/core.h
@@ -10,6 +10,10 @@
#define SESSION_KEY_SIZE 16
+#define PROTOCOL_VERSION 1
+
+#define SERVE_VIA_TLS false
+
struct SocketRWCommon {
Buffer inputBuf, outputBuf;
@@ -41,7 +45,17 @@ private:
struct Packet {
- int type;
+ enum Type {
+ T_OUT_OF_BAND_FLAG = 0x8000,
+
+ C2B_OOB_LOGIN = 0x8001,
+
+ B2C_OOB_LOGIN_SUCCESS = 0x8001,
+ B2C_OOB_LOGIN_FAILED = 0x8002,
+ B2C_OOB_SESSION_RESUMED = 0x8003,
+ };
+
+ Type type;
int id;
Buffer data;
};
@@ -65,14 +79,18 @@ struct Client : SocketRWCommon {
void startService(int _sock, bool withTls);
void close();
+ void sendPacket(Packet::Type type, const Buffer &data, bool allowUnauthed = false);
+
private:
int readBufPosition;
void processReadBuffer();
+ void handlePacket(Packet::Type type, char *data, int size);
void handleLine(char *line, int size);
void generateSessionKey();
-
- void stealConnection(Client *other);
+ void resumeSession(Client *other, int lastReceivedByClient);
+ void sendPacketOverWire(const Packet *packet);
+ void clearCachedPackets(int maxID);
};
struct Server : SocketRWCommon {
diff --git a/python_client.py b/python_client.py
index 561a1c3..21f4c55 100644
--- a/python_client.py
+++ b/python_client.py
@@ -1,23 +1,100 @@
-import socket, ssl, threading
+import socket, ssl, threading, struct
basesock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
basesock.connect(('localhost', 5454))
-sock = ssl.wrap_socket(basesock)
+#sock = ssl.wrap_socket(basesock)
+sock = basesock
+
+nextID = 1
+lastReceivedPacketID = 0
+packetCache = []
+packetLock = threading.Lock()
+
+class Packet:
+ def __init__(self, type, data):
+ global nextID
+
+ self.type = type
+ self.data = data
+ if (type & 0x8000) == 0:
+ self.id = nextID
+ nextID = nextID + 1
+
+ def sendOverWire(self):
+ header = struct.pack('<HHI', self.type, 0, len(self.data))
+ if (self.type & 0x8000) == 0:
+ extHeader = struct.pack('<II', self.id, lastReceivedPacketID)
+ else:
+ extHeader = b''
+
+ sock.sendall(header)
+ if extHeader:
+ sock.sendall(extHeader)
+ sock.sendall(self.data)
+
+
+def clearCachedPackets(pid):
+ for packet in packetCache[:]:
+ if packet.id <= pid:
+ packetCache.remove(packet)
def reader():
+ global lastReceivedPacketID
+ readbuf = b''
+
print('(Connected)')
while True:
- data = sock.read()
- if data:
- print(data)
- else:
+ data = sock.recv(1024)
+ if not data:
print('(Disconnected)')
break
+ readbuf += data
+
+ pos = 0
+ bufsize = len(readbuf)
+ while True:
+ if (pos + 8) > bufsize:
+ break
+
+ type, reserved, size = struct.unpack_from('<HHI', readbuf, pos)
+ pos += 8
+
+ extHeaderSize = 8 if ((type & 0x8000) == 0) else 0
+ if (pos + extHeaderSize + size) > bufsize:
+ break
+
+ if ((type & 0x8000) == 0):
+ pid, lastReceivedByServer = struct.unpack_from('<II', readbuf, pos)
+ pos += 8
+
+ with packetLock:
+ lastReceivedPacketID = pid
+ clearCachedPackets(lastReceivedByServer)
+
+ packetdata = data[pos:pos+size]
+ print('0x%x : %d bytes : %s' % (type, size, packetdata))
+ pos += size
+
+def writePacket(type, data):
+ with packetLock:
+ packet = Packet(type, data)
+ if (type & 0x8000) != 0:
+ packetCache.append(packet)
+ packet.sendOverWire()
+
+
thd = threading.Thread(None, reader)
thd.start()
while True:
bit = input()
- sock.write((bit + '\n').encode('utf-8'))
+ bits = bit.split(' ', 1)
+ cmd = bits[0]
+
+ if cmd == 'login':
+ writePacket(0x8001, struct.pack('<II 16s', 0, 0, b'\0'*16))
+ elif cmd == 'cmd':
+ data = bits[1].encode('utf-8')
+ writePacket(1, struct.pack('<I', len(data)) + data)