diff options
Diffstat (limited to '')
-rw-r--r-- | buffer.h | 57 | ||||
-rwxr-xr-x | build.sh | 2 | ||||
-rw-r--r-- | core.cpp | 236 | ||||
-rw-r--r-- | core.h | 24 | ||||
-rw-r--r-- | python_client.py | 91 |
5 files changed, 331 insertions, 79 deletions
@@ -10,6 +10,7 @@ private: bool m_freeBuffer; int m_size; int m_capacity; + int m_readPointer; char m_preAllocBuffer[0x200]; public: @@ -18,6 +19,7 @@ public: m_freeBuffer = false; m_size = 0; m_capacity = sizeof(m_preAllocBuffer); + m_readPointer = 0; } ~Buffer() { @@ -27,6 +29,17 @@ public: } } + void useExistingBuffer(char *data, int size) { + if (m_freeBuffer) + delete[] m_data; + + m_data = data; + m_freeBuffer = false; + m_size = size; + m_capacity = size; + m_readPointer = 0; + } + char *data() const { return m_data; } int size() const { return m_size; } int capacity() const { return m_capacity; } @@ -66,6 +79,9 @@ public: memcpy(&m_data[m_size], data, size); m_size += size; } + void append(const Buffer &buf) { + append(buf.data(), buf.size()); + } void resize(int size) { if (size > m_capacity) setCapacity(size + 0x100); @@ -86,11 +102,11 @@ public: void writeU32(uint32_t v) { append((const char *)&v, 4); } - void writeU16(uint32_t v) { append((const char *)&v, 2); } - void writeU8(uint32_t v) { append((const char *)&v, 1); } - void writeS32(uint32_t v) { append((const char *)&v, 4); } - void writeS16(uint32_t v) { append((const char *)&v, 2); } - void writeS8(uint32_t v) { append((const char *)&v, 1); } + void writeU16(uint16_t v) { append((const char *)&v, 2); } + void writeU8(uint8_t v) { append((const char *)&v, 1); } + void writeS32(int32_t v) { append((const char *)&v, 4); } + void writeS16(int16_t v) { append((const char *)&v, 2); } + void writeS8(int8_t v) { append((const char *)&v, 1); } void writeStr(const char *data, int size = -1) { if (size == -1) @@ -98,6 +114,37 @@ public: writeU32(size); append(data, size); } + + void readSeek(int pos) { + m_readPointer = pos; + } + int readTell() const { + return m_readPointer; + } + bool readRemains(int size) const { + if ((size > 0) && ((m_readPointer + size) <= m_size)) + return true; + return false; + } + void read(char *output, int size) { + if ((m_readPointer + size) > m_size) { + // Not enough space to read the whole thing...! + int copy = m_size - m_readPointer; + if (copy > 0) + memcpy(output, &m_data[m_readPointer], copy); + memset(&output[copy], 0, size - copy); + m_readPointer = m_size; + } else { + memcpy(output, &m_data[m_readPointer], size); + m_readPointer += size; + } + } + uint32_t readU32() { uint32_t v; read((char *)&v, 4); return v; } + uint16_t readU16() { uint16_t v; read((char *)&v, 2); return v; } + uint8_t readU8() { uint8_t v; read((char *)&v, 1); return v; } + int32_t readS32() { int32_t v; read((char *)&v, 4); return v; } + int16_t readS16() { int16_t v; read((char *)&v, 2); return v; } + int8_t readS8() { int8_t v; read((char *)&v, 1); return v; } }; #endif /* BUFFER_H */ @@ -1,4 +1,4 @@ #!/bin/sh mkdir -p binary -g++ -o binary/nb4 core.cpp dns.cpp -lgnutls -pthread +g++ -o binary/nb4 -std=c++11 core.cpp dns.cpp -lgnutls -pthread @@ -42,6 +42,14 @@ static bool setSocketNonBlocking(int sock) { } +static bool isNullSessionKey(uint8_t *key) { + for (int i = 0; i < SESSION_KEY_SIZE; i++) + if (key[i] != 0) + return true; + + return false; +} + static Client *findClientWithKey(uint8_t *key) { for (int i = 0; i < clientCount; i++) if (!memcmp(clients[i]->sessionKey, key, SESSION_KEY_SIZE)) @@ -282,91 +290,143 @@ void Client::generateSessionKey() { } } +void Client::clearCachedPackets(int maxID) { + packetCache.remove_if([maxID](Packet *&pkt) { + return (pkt->id <= maxID); + }); +} + void Client::handleLine(char *line, int size) { // This is a terrible mess that will be replaced shortly - if (strncmp(line, "all ", 4) == 0) { - for (int i = 0; i < clientCount; i++) { - clients[i]->outputBuf.append(&line[4], size - 4); - clients[i]->outputBuf.append("\n", 1); + if (authState == AS_AUTHED) { + if (strncmp(line, "all ", 4) == 0) { + for (int i = 0; i < clientCount; i++) { + clients[i]->outputBuf.append(&line[4], size - 4); + clients[i]->outputBuf.append("\n", 1); + } + } else if (strcmp(line, "quit") == 0) { + quitFlag = true; + } else if (strncmp(line, "resolve ", 8) == 0) { + DNS::makeQuery(&line[8]); + } else if (strncmp(&line[1], "ddsrv ", 6) == 0) { + servers[serverCount] = new Server; + strcpy(servers[serverCount]->ircHostname, &line[7]); + servers[serverCount]->ircPort = 1191; + servers[serverCount]->ircUseTls = (line[0] == 's'); + serverCount++; + outputBuf.append("Your wish is my command!\n", 25); + } else if (strncmp(line, "connsrv", 7) == 0) { + int sid = line[7] - '0'; + servers[sid]->beginConnect(); + } else if (line[0] >= '0' && line[0] <= '9') { + int sid = line[0] - '0'; + servers[sid]->outputBuf.append(&line[1], size - 1); + servers[sid]->outputBuf.append("\r\n", 2); } - } else if (strcmp(line, "quit") == 0) { - quitFlag = true; - } else if (strncmp(line, "resolve ", 8) == 0) { - DNS::makeQuery(&line[8]); - } else if (strncmp(&line[1], "ddsrv ", 6) == 0) { - servers[serverCount] = new Server; - strcpy(servers[serverCount]->ircHostname, &line[7]); - servers[serverCount]->ircPort = 1191; - servers[serverCount]->ircUseTls = (line[0] == 's'); - serverCount++; - outputBuf.append("Your wish is my command!\n", 25); - } else if (strncmp(line, "connsrv", 7) == 0) { - int sid = line[7] - '0'; - servers[sid]->beginConnect(); - } else if (line[0] >= '0' && line[0] <= '9') { - int sid = line[0] - '0'; - servers[sid]->outputBuf.append(&line[1], size - 1); - servers[sid]->outputBuf.append("\r\n", 2); - } else if (strncmp(line, "login", 5) == 0) { - if (line[5] == 0) { - // no session key + } else { + } +} + +void Client::handlePacket(Packet::Type type, char *data, int size) { + Buffer pkt; + pkt.useExistingBuffer(data, size); + + printf("[fd=%d] Packet : type %d, size %d\n", sock, type, size); + + if (authState == AS_LOGIN_WAIT) { + if (type == Packet::C2B_OOB_LOGIN) { + int error = 0; + + uint32_t protocolVersion = pkt.readU32(); + if (protocolVersion != PROTOCOL_VERSION) + error = 1; + + uint32_t lastReceivedByClient = pkt.readU32(); + + if (!pkt.readRemains(SESSION_KEY_SIZE)) + error = 2; + + uint8_t reqKey[SESSION_KEY_SIZE]; + pkt.read((char *)reqKey, SESSION_KEY_SIZE); + + if (!isNullSessionKey(reqKey)) { + Client *other = findClientWithKey(reqKey); + if (other && other->authState == AS_AUTHED) { + // Yep, we can go! + other->resumeSession(this, lastReceivedByClient); + return; + } + } + + // If we got here, it means we couldn't resume the session. + // Start over. generateSessionKey(); authState = AS_AUTHED; - outputBuf.append("OK ", 3); - for (int i = 0; i < SESSION_KEY_SIZE; i++) { - char bits[4]; - sprintf(bits, "%02x", sessionKey[i]); - outputBuf.append(bits, 2); - } - outputBuf.append("\n", 1); - } else { - // This is awful. Don't care about writing clean code - // for something I'm going to throw away shortly... - uint8_t pkey[SESSION_KEY_SIZE]; - for (int i = 0; i < SESSION_KEY_SIZE; i++) { - char highc = line[6 + (i * 2)]; - char lowc = line[7 + (i * 2)]; - int high = ((highc >= '0') && (highc <= '9')) ? (highc - '0') : (highc - 'a' + 10); - int low = ((lowc >= '0') && (lowc <= '9')) ? (lowc - '0') : (lowc - 'a' + 10); - pkey[i] = (high << 4) | low; - } - Client *other = findClientWithKey(pkey); - if (other && other->authState == AS_AUTHED) - other->stealConnection(this); + Buffer pkt; + pkt.append((char *)sessionKey, SESSION_KEY_SIZE); + sendPacket(Packet::B2C_OOB_LOGIN_SUCCESS, pkt); + + } else { + printf("[fd=%d] Unrecognised packet in AS_LOGIN_WAIT authstate: type %d, size %d\n", + sock, type, size); + } + } else if (authState == AS_AUTHED) { + //if (type == Packet::) { + /*} else */{ + printf("[fd=%d] Unrecognised packet in AS_AUTHED authstate: type %d, size %d\n", + sock, type, size); } } } void Client::processReadBuffer() { - // Try to process as many lines as we can - // This function will be changed to custom protocol eventually - char *buf = inputBuf.data(); - int bufSize = inputBuf.size(); - int lineBegin = 0, pos = 0; + // Try to process as many packets as we have in inputBuf - while (pos < bufSize) { - if (buf[pos] == '\r' || buf[pos] == '\n') { - if (pos > lineBegin) { - buf[pos] = 0; - readBufPosition = pos + 1; - handleLine(&buf[lineBegin], pos - lineBegin); - } + // Basic header is 8 bytes + // Extended (non-OOB) header is 16 bytes + inputBuf.readSeek(0); + readBufPosition = 0; - lineBegin = pos + 1; + while (inputBuf.readRemains(8)) { + // We have 8 bytes, so we can try to read a basic header + Packet::Type type = (Packet::Type)inputBuf.readU16(); + int reserved = inputBuf.readU16(); + uint32_t packetSize = inputBuf.readU32(); + + // Do we now have the whole packet in memory...? + int extHeaderSize = (type & Packet::T_OUT_OF_BAND_FLAG) ? 0 : 8; + + if (!inputBuf.readRemains(packetSize + extHeaderSize)) + break; + + + if (!(type & Packet::T_OUT_OF_BAND_FLAG)) { + // Handle packet system things for non-OOB packets + uint32_t packetID = inputBuf.readU32(); + uint32_t lastReceivedByClient = inputBuf.readU32(); + + lastReceivedPacketID = packetID; + clearCachedPackets(lastReceivedByClient); } - pos++; + // Yep, we can process it! + + // Save the position of the next packet + readBufPosition = inputBuf.readTell() + packetSize; + handlePacket(type, &inputBuf.data()[inputBuf.readTell()], packetSize); + + inputBuf.readSeek(readBufPosition); } // If we managed to handle anything, lop it off the buffer - inputBuf.trimFromStart(pos); + inputBuf.trimFromStart(readBufPosition); readBufPosition = 0; } -void Client::stealConnection(Client *other) { +void Client::resumeSession(Client *other, int lastReceivedByClient) { close(); inputBuf.clear(); @@ -388,8 +448,58 @@ void Client::stealConnection(Client *other) { other->tlsActive = false; other->state = CS_DISCONNECTED; other->close(); + + // Now send them everything we've got! + Buffer pkt; + pkt.writeU32(lastReceivedPacketID); + sendPacket(Packet::B2C_OOB_SESSION_RESUMED, pkt); + + clearCachedPackets(lastReceivedByClient); + + std::list<Packet*>::iterator + i = packetCache.begin(), + e = packetCache.end(); + + for (; i != e; ++i) + sendPacketOverWire(*i); } +void Client::sendPacket(Packet::Type type, const Buffer &data, bool allowUnauthed) { + Packet *packet = new Packet; + packet->type = type; + packet->data.append(data); + + if (type & Packet::T_OUT_OF_BAND_FLAG) { + packet->id = 0; + } else { + packet->id = nextPacketID; + nextPacketID++; + } + + if (state == CS_CONNECTED) + if (authState == AS_AUTHED || allowUnauthed) + sendPacketOverWire(packet); + + if (type & Packet::T_OUT_OF_BAND_FLAG) + delete packet; + else + packetCache.push_back(packet); +} + +void Client::sendPacketOverWire(const Packet *packet) { + Buffer header; + header.writeU16(packet->type); + header.writeU16(0); + header.writeU32(packet->data.size()); + + if (!(packet->type & Packet::T_OUT_OF_BAND_FLAG)) { + header.writeU32(packet->id); + header.writeU32(lastReceivedPacketID); + } + + outputBuf.append(header); + outputBuf.append(packet->data); +} @@ -759,7 +869,7 @@ int main(int argc, char **argv) { clients[clientCount] = client; ++clientCount; - client->startService(sock, true); + client->startService(sock, SERVE_VIA_TLS); } } } @@ -10,6 +10,10 @@ #define SESSION_KEY_SIZE 16 +#define PROTOCOL_VERSION 1 + +#define SERVE_VIA_TLS false + struct SocketRWCommon { Buffer inputBuf, outputBuf; @@ -41,7 +45,17 @@ private: struct Packet { - int type; + enum Type { + T_OUT_OF_BAND_FLAG = 0x8000, + + C2B_OOB_LOGIN = 0x8001, + + B2C_OOB_LOGIN_SUCCESS = 0x8001, + B2C_OOB_LOGIN_FAILED = 0x8002, + B2C_OOB_SESSION_RESUMED = 0x8003, + }; + + Type type; int id; Buffer data; }; @@ -65,14 +79,18 @@ struct Client : SocketRWCommon { void startService(int _sock, bool withTls); void close(); + void sendPacket(Packet::Type type, const Buffer &data, bool allowUnauthed = false); + private: int readBufPosition; void processReadBuffer(); + void handlePacket(Packet::Type type, char *data, int size); void handleLine(char *line, int size); void generateSessionKey(); - - void stealConnection(Client *other); + void resumeSession(Client *other, int lastReceivedByClient); + void sendPacketOverWire(const Packet *packet); + void clearCachedPackets(int maxID); }; struct Server : SocketRWCommon { diff --git a/python_client.py b/python_client.py index 561a1c3..21f4c55 100644 --- a/python_client.py +++ b/python_client.py @@ -1,23 +1,100 @@ -import socket, ssl, threading +import socket, ssl, threading, struct basesock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) basesock.connect(('localhost', 5454)) -sock = ssl.wrap_socket(basesock) +#sock = ssl.wrap_socket(basesock) +sock = basesock + +nextID = 1 +lastReceivedPacketID = 0 +packetCache = [] +packetLock = threading.Lock() + +class Packet: + def __init__(self, type, data): + global nextID + + self.type = type + self.data = data + if (type & 0x8000) == 0: + self.id = nextID + nextID = nextID + 1 + + def sendOverWire(self): + header = struct.pack('<HHI', self.type, 0, len(self.data)) + if (self.type & 0x8000) == 0: + extHeader = struct.pack('<II', self.id, lastReceivedPacketID) + else: + extHeader = b'' + + sock.sendall(header) + if extHeader: + sock.sendall(extHeader) + sock.sendall(self.data) + + +def clearCachedPackets(pid): + for packet in packetCache[:]: + if packet.id <= pid: + packetCache.remove(packet) def reader(): + global lastReceivedPacketID + readbuf = b'' + print('(Connected)') while True: - data = sock.read() - if data: - print(data) - else: + data = sock.recv(1024) + if not data: print('(Disconnected)') break + readbuf += data + + pos = 0 + bufsize = len(readbuf) + while True: + if (pos + 8) > bufsize: + break + + type, reserved, size = struct.unpack_from('<HHI', readbuf, pos) + pos += 8 + + extHeaderSize = 8 if ((type & 0x8000) == 0) else 0 + if (pos + extHeaderSize + size) > bufsize: + break + + if ((type & 0x8000) == 0): + pid, lastReceivedByServer = struct.unpack_from('<II', readbuf, pos) + pos += 8 + + with packetLock: + lastReceivedPacketID = pid + clearCachedPackets(lastReceivedByServer) + + packetdata = data[pos:pos+size] + print('0x%x : %d bytes : %s' % (type, size, packetdata)) + pos += size + +def writePacket(type, data): + with packetLock: + packet = Packet(type, data) + if (type & 0x8000) != 0: + packetCache.append(packet) + packet.sendOverWire() + + thd = threading.Thread(None, reader) thd.start() while True: bit = input() - sock.write((bit + '\n').encode('utf-8')) + bits = bit.split(' ', 1) + cmd = bits[0] + + if cmd == 'login': + writePacket(0x8001, struct.pack('<II 16s', 0, 0, b'\0'*16)) + elif cmd == 'cmd': + data = bits[1].encode('utf-8') + writePacket(1, struct.pack('<I', len(data)) + data) |