diff options
Diffstat (limited to '')
-rw-r--r-- | core.cpp | 236 |
1 files changed, 173 insertions, 63 deletions
@@ -42,6 +42,14 @@ static bool setSocketNonBlocking(int sock) { } +static bool isNullSessionKey(uint8_t *key) { + for (int i = 0; i < SESSION_KEY_SIZE; i++) + if (key[i] != 0) + return true; + + return false; +} + static Client *findClientWithKey(uint8_t *key) { for (int i = 0; i < clientCount; i++) if (!memcmp(clients[i]->sessionKey, key, SESSION_KEY_SIZE)) @@ -282,91 +290,143 @@ void Client::generateSessionKey() { } } +void Client::clearCachedPackets(int maxID) { + packetCache.remove_if([maxID](Packet *&pkt) { + return (pkt->id <= maxID); + }); +} + void Client::handleLine(char *line, int size) { // This is a terrible mess that will be replaced shortly - if (strncmp(line, "all ", 4) == 0) { - for (int i = 0; i < clientCount; i++) { - clients[i]->outputBuf.append(&line[4], size - 4); - clients[i]->outputBuf.append("\n", 1); + if (authState == AS_AUTHED) { + if (strncmp(line, "all ", 4) == 0) { + for (int i = 0; i < clientCount; i++) { + clients[i]->outputBuf.append(&line[4], size - 4); + clients[i]->outputBuf.append("\n", 1); + } + } else if (strcmp(line, "quit") == 0) { + quitFlag = true; + } else if (strncmp(line, "resolve ", 8) == 0) { + DNS::makeQuery(&line[8]); + } else if (strncmp(&line[1], "ddsrv ", 6) == 0) { + servers[serverCount] = new Server; + strcpy(servers[serverCount]->ircHostname, &line[7]); + servers[serverCount]->ircPort = 1191; + servers[serverCount]->ircUseTls = (line[0] == 's'); + serverCount++; + outputBuf.append("Your wish is my command!\n", 25); + } else if (strncmp(line, "connsrv", 7) == 0) { + int sid = line[7] - '0'; + servers[sid]->beginConnect(); + } else if (line[0] >= '0' && line[0] <= '9') { + int sid = line[0] - '0'; + servers[sid]->outputBuf.append(&line[1], size - 1); + servers[sid]->outputBuf.append("\r\n", 2); } - } else if (strcmp(line, "quit") == 0) { - quitFlag = true; - } else if (strncmp(line, "resolve ", 8) == 0) { - DNS::makeQuery(&line[8]); - } else if (strncmp(&line[1], "ddsrv ", 6) == 0) { - servers[serverCount] = new Server; - strcpy(servers[serverCount]->ircHostname, &line[7]); - servers[serverCount]->ircPort = 1191; - servers[serverCount]->ircUseTls = (line[0] == 's'); - serverCount++; - outputBuf.append("Your wish is my command!\n", 25); - } else if (strncmp(line, "connsrv", 7) == 0) { - int sid = line[7] - '0'; - servers[sid]->beginConnect(); - } else if (line[0] >= '0' && line[0] <= '9') { - int sid = line[0] - '0'; - servers[sid]->outputBuf.append(&line[1], size - 1); - servers[sid]->outputBuf.append("\r\n", 2); - } else if (strncmp(line, "login", 5) == 0) { - if (line[5] == 0) { - // no session key + } else { + } +} + +void Client::handlePacket(Packet::Type type, char *data, int size) { + Buffer pkt; + pkt.useExistingBuffer(data, size); + + printf("[fd=%d] Packet : type %d, size %d\n", sock, type, size); + + if (authState == AS_LOGIN_WAIT) { + if (type == Packet::C2B_OOB_LOGIN) { + int error = 0; + + uint32_t protocolVersion = pkt.readU32(); + if (protocolVersion != PROTOCOL_VERSION) + error = 1; + + uint32_t lastReceivedByClient = pkt.readU32(); + + if (!pkt.readRemains(SESSION_KEY_SIZE)) + error = 2; + + uint8_t reqKey[SESSION_KEY_SIZE]; + pkt.read((char *)reqKey, SESSION_KEY_SIZE); + + if (!isNullSessionKey(reqKey)) { + Client *other = findClientWithKey(reqKey); + if (other && other->authState == AS_AUTHED) { + // Yep, we can go! + other->resumeSession(this, lastReceivedByClient); + return; + } + } + + // If we got here, it means we couldn't resume the session. + // Start over. generateSessionKey(); authState = AS_AUTHED; - outputBuf.append("OK ", 3); - for (int i = 0; i < SESSION_KEY_SIZE; i++) { - char bits[4]; - sprintf(bits, "%02x", sessionKey[i]); - outputBuf.append(bits, 2); - } - outputBuf.append("\n", 1); - } else { - // This is awful. Don't care about writing clean code - // for something I'm going to throw away shortly... - uint8_t pkey[SESSION_KEY_SIZE]; - for (int i = 0; i < SESSION_KEY_SIZE; i++) { - char highc = line[6 + (i * 2)]; - char lowc = line[7 + (i * 2)]; - int high = ((highc >= '0') && (highc <= '9')) ? (highc - '0') : (highc - 'a' + 10); - int low = ((lowc >= '0') && (lowc <= '9')) ? (lowc - '0') : (lowc - 'a' + 10); - pkey[i] = (high << 4) | low; - } - Client *other = findClientWithKey(pkey); - if (other && other->authState == AS_AUTHED) - other->stealConnection(this); + Buffer pkt; + pkt.append((char *)sessionKey, SESSION_KEY_SIZE); + sendPacket(Packet::B2C_OOB_LOGIN_SUCCESS, pkt); + + } else { + printf("[fd=%d] Unrecognised packet in AS_LOGIN_WAIT authstate: type %d, size %d\n", + sock, type, size); + } + } else if (authState == AS_AUTHED) { + //if (type == Packet::) { + /*} else */{ + printf("[fd=%d] Unrecognised packet in AS_AUTHED authstate: type %d, size %d\n", + sock, type, size); } } } void Client::processReadBuffer() { - // Try to process as many lines as we can - // This function will be changed to custom protocol eventually - char *buf = inputBuf.data(); - int bufSize = inputBuf.size(); - int lineBegin = 0, pos = 0; + // Try to process as many packets as we have in inputBuf - while (pos < bufSize) { - if (buf[pos] == '\r' || buf[pos] == '\n') { - if (pos > lineBegin) { - buf[pos] = 0; - readBufPosition = pos + 1; - handleLine(&buf[lineBegin], pos - lineBegin); - } + // Basic header is 8 bytes + // Extended (non-OOB) header is 16 bytes + inputBuf.readSeek(0); + readBufPosition = 0; - lineBegin = pos + 1; + while (inputBuf.readRemains(8)) { + // We have 8 bytes, so we can try to read a basic header + Packet::Type type = (Packet::Type)inputBuf.readU16(); + int reserved = inputBuf.readU16(); + uint32_t packetSize = inputBuf.readU32(); + + // Do we now have the whole packet in memory...? + int extHeaderSize = (type & Packet::T_OUT_OF_BAND_FLAG) ? 0 : 8; + + if (!inputBuf.readRemains(packetSize + extHeaderSize)) + break; + + + if (!(type & Packet::T_OUT_OF_BAND_FLAG)) { + // Handle packet system things for non-OOB packets + uint32_t packetID = inputBuf.readU32(); + uint32_t lastReceivedByClient = inputBuf.readU32(); + + lastReceivedPacketID = packetID; + clearCachedPackets(lastReceivedByClient); } - pos++; + // Yep, we can process it! + + // Save the position of the next packet + readBufPosition = inputBuf.readTell() + packetSize; + handlePacket(type, &inputBuf.data()[inputBuf.readTell()], packetSize); + + inputBuf.readSeek(readBufPosition); } // If we managed to handle anything, lop it off the buffer - inputBuf.trimFromStart(pos); + inputBuf.trimFromStart(readBufPosition); readBufPosition = 0; } -void Client::stealConnection(Client *other) { +void Client::resumeSession(Client *other, int lastReceivedByClient) { close(); inputBuf.clear(); @@ -388,8 +448,58 @@ void Client::stealConnection(Client *other) { other->tlsActive = false; other->state = CS_DISCONNECTED; other->close(); + + // Now send them everything we've got! + Buffer pkt; + pkt.writeU32(lastReceivedPacketID); + sendPacket(Packet::B2C_OOB_SESSION_RESUMED, pkt); + + clearCachedPackets(lastReceivedByClient); + + std::list<Packet*>::iterator + i = packetCache.begin(), + e = packetCache.end(); + + for (; i != e; ++i) + sendPacketOverWire(*i); } +void Client::sendPacket(Packet::Type type, const Buffer &data, bool allowUnauthed) { + Packet *packet = new Packet; + packet->type = type; + packet->data.append(data); + + if (type & Packet::T_OUT_OF_BAND_FLAG) { + packet->id = 0; + } else { + packet->id = nextPacketID; + nextPacketID++; + } + + if (state == CS_CONNECTED) + if (authState == AS_AUTHED || allowUnauthed) + sendPacketOverWire(packet); + + if (type & Packet::T_OUT_OF_BAND_FLAG) + delete packet; + else + packetCache.push_back(packet); +} + +void Client::sendPacketOverWire(const Packet *packet) { + Buffer header; + header.writeU16(packet->type); + header.writeU16(0); + header.writeU32(packet->data.size()); + + if (!(packet->type & Packet::T_OUT_OF_BAND_FLAG)) { + header.writeU32(packet->id); + header.writeU32(lastReceivedPacketID); + } + + outputBuf.append(header); + outputBuf.append(packet->data); +} @@ -759,7 +869,7 @@ int main(int argc, char **argv) { clients[clientCount] = client; ++clientCount; - client->startService(sock, true); + client->startService(sock, SERVE_VIA_TLS); } } } |