prudp: Code cleanup

This commit is contained in:
Exzap 2024-04-18 19:22:28 +02:00
parent ee36992bd6
commit e2f9725719
3 changed files with 410 additions and 422 deletions

View file

@ -106,7 +106,7 @@ nexService::nexService()
nexService::nexService(prudpClient* con) : nexService()
{
if (con->isConnected() == false)
if (con->IsConnected() == false)
cemu_assert_suspicious();
this->conNexService = con;
bufferReceive = std::vector<uint8>(1024 * 4);
@ -191,7 +191,7 @@ void nexService::processQueuedRequest(queuedRequest_t* queuedRequest)
uint32 callId = _currentCallId;
_currentCallId++;
// check state of connection
if (conNexService->getConnectionState() != prudpClient::STATE_CONNECTED)
if (conNexService->GetConnectionState() != prudpClient::ConnectionState::Connected)
{
nexServiceResponse_t response = { 0 };
response.isSuccessful = false;
@ -214,7 +214,7 @@ void nexService::processQueuedRequest(queuedRequest_t* queuedRequest)
assert_dbg();
memcpy((packetBuffer + 0x0D), &queuedRequest->parameterData.front(), queuedRequest->parameterData.size());
sint32 length = 0xD + (sint32)queuedRequest->parameterData.size();
conNexService->sendDatagram(packetBuffer, length, true);
conNexService->SendDatagram(packetBuffer, length, true);
// remember request
nexActiveRequestInfo_t requestInfo = { 0 };
requestInfo.callId = callId;
@ -299,13 +299,13 @@ void nexService::registerForAsyncProcessing()
void nexService::updateTemporaryConnections()
{
// check for connection
conNexService->update();
if (conNexService->isConnected())
conNexService->Update();
if (conNexService->IsConnected())
{
if (connectionState == STATE_CONNECTING)
connectionState = STATE_CONNECTED;
}
if (conNexService->getConnectionState() == prudpClient::STATE_DISCONNECTED)
if (conNexService->GetConnectionState() == prudpClient::ConnectionState::Disconnected)
connectionState = STATE_DISCONNECTED;
}
@ -356,18 +356,18 @@ void nexService::sendRequestResponse(nexServiceRequest_t* request, uint32 errorC
// update length field
*(uint32*)response.getDataPtr() = response.getWriteIndex()-4;
if(request->nex->conNexService)
request->nex->conNexService->sendDatagram(response.getDataPtr(), response.getWriteIndex(), true);
request->nex->conNexService->SendDatagram(response.getDataPtr(), response.getWriteIndex(), true);
}
void nexService::updateNexServiceConnection()
{
if (conNexService->getConnectionState() == prudpClient::STATE_DISCONNECTED)
if (conNexService->GetConnectionState() == prudpClient::ConnectionState::Disconnected)
{
this->connectionState = STATE_DISCONNECTED;
return;
}
conNexService->update();
sint32 datagramLen = conNexService->receiveDatagram(bufferReceive);
conNexService->Update();
sint32 datagramLen = conNexService->ReceiveDatagram(bufferReceive);
if (datagramLen > 0)
{
if (nexIsRequest(&bufferReceive[0], datagramLen))
@ -454,12 +454,12 @@ bool _extractStationUrlParamValue(const char* urlStr, const char* paramName, cha
return false;
}
void nexServiceAuthentication_parseStationURL(char* urlStr, stationUrl_t* stationUrl)
void nexServiceAuthentication_parseStationURL(char* urlStr, prudpStationUrl* stationUrl)
{
// example:
// prudps:/address=34.210.xxx.xxx;port=60181;CID=1;PID=2;sid=1;stream=10;type=2
memset(stationUrl, 0, sizeof(stationUrl_t));
memset(stationUrl, 0, sizeof(prudpStationUrl));
char optionValue[128];
if (_extractStationUrlParamValue(urlStr, "address", optionValue, sizeof(optionValue)))
@ -499,7 +499,7 @@ typedef struct
sint32 kerberosTicketSize;
uint8 kerberosTicket2[4096];
sint32 kerberosTicket2Size;
stationUrl_t server;
prudpStationUrl server;
// progress info
bool hasError;
bool done;
@ -611,18 +611,18 @@ void nexServiceSecure_handleResponse_RegisterEx(nexService* nex, nexServiceRespo
return;
}
nexService* nex_secureLogin(authServerInfo_t* authServerInfo, const char* accessKey, const char* nexToken)
nexService* nex_secureLogin(prudpAuthServerInfo* authServerInfo, const char* accessKey, const char* nexToken)
{
prudpClient* prudpSecureSock = new prudpClient(authServerInfo->server.ip, authServerInfo->server.port, accessKey, authServerInfo);
// wait until connected
while (true)
{
prudpSecureSock->update();
if (prudpSecureSock->isConnected())
prudpSecureSock->Update();
if (prudpSecureSock->IsConnected())
{
break;
}
if (prudpSecureSock->getConnectionState() == prudpClient::STATE_DISCONNECTED)
if (prudpSecureSock->GetConnectionState() == prudpClient::ConnectionState::Disconnected)
{
// timeout or disconnected
cemuLog_log(LogType::Force, "NEX: Secure login connection time-out");
@ -638,7 +638,7 @@ nexService* nex_secureLogin(authServerInfo_t* authServerInfo, const char* access
nexPacketBuffer packetBuffer(tempNexBufferArray, sizeof(tempNexBufferArray), true);
char clientStationUrl[256];
sprintf(clientStationUrl, "prudp:/port=%u;natf=0;natm=0;pmp=0;sid=15;type=2;upnp=0", (uint32)nex->getPRUDPConnection()->getSourcePort());
sprintf(clientStationUrl, "prudp:/port=%u;natf=0;natm=0;pmp=0;sid=15;type=2;upnp=0", (uint32)nex->getPRUDPConnection()->GetSourcePort());
// station url list
packetBuffer.writeU32(1);
packetBuffer.writeString(clientStationUrl);
@ -737,9 +737,9 @@ nexService* nex_establishSecureConnection(uint32 authServerIp, uint16 authServer
return nullptr;
}
// auth info
auto authServerInfo = std::make_unique<authServerInfo_t>();
auto authServerInfo = std::make_unique<prudpAuthServerInfo>();
// decrypt ticket
RC4Ctx_t rc4Ticket;
RC4Ctx rc4Ticket;
RC4_initCtx(&rc4Ticket, kerberosKey, 16);
RC4_transform(&rc4Ticket, nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, nexAuthService.kerberosTicket2);
nexPacketBuffer packetKerberosTicket(nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, false);
@ -756,7 +756,7 @@ nexService* nex_establishSecureConnection(uint32 authServerIp, uint16 authServer
memcpy(authServerInfo->kerberosKey, kerberosKey, 16);
memcpy(authServerInfo->secureKey, secureKey, 16);
memcpy(&authServerInfo->server, &nexAuthService.server, sizeof(stationUrl_t));
memcpy(&authServerInfo->server, &nexAuthService.server, sizeof(prudpStationUrl));
authServerInfo->userPid = pid;
return nex_secureLogin(authServerInfo.get(), accessKey, nexToken);

File diff suppressed because it is too large Load diff

View file

@ -4,26 +4,26 @@
#define RC4_N 256
typedef struct
struct RC4Ctx
{
unsigned char S[RC4_N];
int i;
int j;
}RC4Ctx_t;
};
void RC4_initCtx(RC4Ctx_t* rc4Ctx, char *key);
void RC4_initCtx(RC4Ctx_t* rc4Ctx, unsigned char* key, int keyLen);
void RC4_transform(RC4Ctx_t* rc4Ctx, unsigned char* input, int len, unsigned char* output);
void RC4_initCtx(RC4Ctx* rc4Ctx, const char* key);
void RC4_initCtx(RC4Ctx* rc4Ctx, unsigned char* key, int keyLen);
void RC4_transform(RC4Ctx* rc4Ctx, unsigned char* input, int len, unsigned char* output);
typedef struct
struct prudpStreamSettings
{
uint8 checksumBase; // calculated from key
uint8 accessKeyDigest[16]; // MD5 hash of key
RC4Ctx_t rc4Client;
RC4Ctx_t rc4Server;
}prudpStreamSettings_t;
RC4Ctx rc4Client;
RC4Ctx rc4Server;
};
typedef struct
struct prudpStationUrl
{
uint32 ip;
uint16 port;
@ -32,19 +32,17 @@ typedef struct
sint32 sid;
sint32 stream;
sint32 type;
}stationUrl_t;
};
typedef struct
struct prudpAuthServerInfo
{
uint32 userPid;
uint8 secureKey[16];
uint8 kerberosKey[16];
uint8 secureTicket[1024];
sint32 secureTicketLength;
stationUrl_t server;
}authServerInfo_t;
uint8 prudp_calculateChecksum(uint8 checksumBase, uint8* data, sint32 length);
prudpStationUrl server;
};
class prudpPacket
{
@ -66,7 +64,7 @@ public:
static sint32 calculateSizeFromPacketData(uint8* data, sint32 length);
prudpPacket(prudpStreamSettings_t* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature);
prudpPacket(prudpStreamSettings* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature);
bool requiresAck();
void setData(uint8* data, sint32 length);
void setFragmentIndex(uint8 fragmentIndex);
@ -87,7 +85,7 @@ private:
uint16 flags;
uint8 sessionId;
uint32 specifiedPacketSignature;
prudpStreamSettings_t* streamSettings;
prudpStreamSettings* streamSettings;
std::vector<uint8> packetData;
bool isEncrypted;
uint16 m_sequenceId{0};
@ -97,7 +95,7 @@ private:
class prudpIncomingPacket
{
public:
prudpIncomingPacket(prudpStreamSettings_t* streamSettings, uint8* data, sint32 length);
prudpIncomingPacket(prudpStreamSettings* streamSettings, uint8* data, sint32 length);
bool hasError();
@ -122,83 +120,91 @@ public:
private:
bool isInvalid = false;
prudpStreamSettings_t* streamSettings = nullptr;
prudpStreamSettings* streamSettings = nullptr;
};
typedef struct
{
prudpPacket* packet;
uint32 initialSendTimestamp;
uint32 lastRetryTimestamp;
sint32 retryCount;
}prudpAckRequired_t;
class prudpClient
{
struct PacketWithAckRequired
{
PacketWithAckRequired(prudpPacket* packet, uint32 initialSendTimestamp) :
packet(packet), initialSendTimestamp(initialSendTimestamp), lastRetryTimestamp(initialSendTimestamp) { }
prudpPacket* packet;
uint32 initialSendTimestamp;
uint32 lastRetryTimestamp;
sint32 retryCount{0};
};
public:
static const int STATE_CONNECTING = 0;
static const int STATE_CONNECTED = 1;
static const int STATE_DISCONNECTED = 2;
enum class ConnectionState : uint8
{
Connecting,
Connected,
Disconnected
};
public:
prudpClient(uint32 dstIp, uint16 dstPort, const char* key);
prudpClient(uint32 dstIp, uint16 dstPort, const char* key, authServerInfo_t* authInfo);
prudpClient(uint32 dstIp, uint16 dstPort, const char* key, prudpAuthServerInfo* authInfo);
~prudpClient();
bool isConnected();
bool IsConnected() const { return m_currentConnectionState == ConnectionState::Connected; }
ConnectionState GetConnectionState() const { return m_currentConnectionState; }
uint16 GetSourcePort() const { return m_srcPort; }
uint8 getConnectionState();
void acknowledgePacket(uint16 sequenceId);
void sortIncomingDataPacket(prudpIncomingPacket* incomingPacket);
void handleIncomingPacket(prudpIncomingPacket* incomingPacket);
bool update(); // check for new incoming packets, returns true if receiveDatagram() should be called
bool Update(); // update connection state and check for incoming packets. Returns true if ReceiveDatagram() should be called
sint32 receiveDatagram(std::vector<uint8>& outputBuffer);
void sendDatagram(uint8* input, sint32 length, bool reliable = true);
uint16 getSourcePort();
SOCKET getSocket();
sint32 ReceiveDatagram(std::vector<uint8>& outputBuffer);
void SendDatagram(uint8* input, sint32 length, bool reliable = true);
private:
prudpClient();
void directSendPacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort);
sint32 kerberosEncryptData(uint8* input, sint32 length, uint8* output);
void queuePacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort);
void HandleIncomingPacket(std::unique_ptr<prudpIncomingPacket> incomingPacket);
void DirectSendPacket(prudpPacket* packet);
sint32 KerberosEncryptData(uint8* input, sint32 length, uint8* output);
void QueuePacket(prudpPacket* packet);
void AcknowledgePacket(uint16 sequenceId);
void SortIncomingDataPacket(std::unique_ptr<prudpIncomingPacket> incomingPacket);
void SendCurrentHandshakePacket();
private:
uint16 srcPort;
uint32 dstIp;
uint16 dstPort;
uint8 vport_src;
uint8 vport_dst;
prudpStreamSettings_t streamSettings;
std::vector<prudpAckRequired_t> list_packetsWithAckReq;
std::vector<prudpIncomingPacket*> queue_incomingPackets;
// connection
uint8 currentConnectionState;
uint32 serverConnectionSignature;
uint32 clientConnectionSignature;
bool hasSentCon;
uint32 lastPingTimestamp;
uint16 m_srcPort;
uint32 m_dstIp;
uint16 m_dstPort;
uint8 m_srcVPort;
uint8 m_dstVPort;
prudpStreamSettings m_streamSettings;
std::vector<PacketWithAckRequired> m_dataPacketsWithAckReq;
std::vector<std::unique_ptr<prudpIncomingPacket>> m_incomingPacketQueue;
uint16 outgoingSequenceId;
uint16 incomingSequenceId;
// connection handshake state
bool m_hasSynAck{false};
bool m_hasConAck{false};
uint32 m_lastHandshakeTimestamp{0};
uint8 m_handshakeRetryCount{0};
// connection
ConnectionState m_currentConnectionState;
uint32 m_serverConnectionSignature;
uint32 m_clientConnectionSignature;
uint32 m_lastPingTimestamp;
uint16 m_outgoingReliableSequenceId{2}; // 1 is reserved for CON
uint16 m_incomingSequenceId;
uint16 m_outgoingSequenceId_ping{0};
uint8 m_unacknowledgedPingCount{0};
uint8 clientSessionId;
uint8 serverSessionId;
uint8 m_clientSessionId;
uint8 m_serverSessionId;
// secure
bool isSecureConnection;
authServerInfo_t authInfo;
bool m_isSecureConnection{false};
prudpAuthServerInfo m_authInfo;
// socket
SOCKET socketUdp;
SOCKET m_socketUdp;
};
uint32 prudpGetMSTimestamp();