diff --git a/src/net.cpp b/src/net.cpp index dbdbbf9d4e9..25f390961b9 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -913,6 +913,74 @@ size_t V1Transport::GetSendMemoryUsage() const noexcept return m_message_to_send.GetMemoryUsage(); } +namespace { + +/** List of short messages as defined in BIP324, in order. + * + * Only message types that are actually implemented in this codebase need to be listed, as other + * messages get ignored anyway - whether we know how to decode them or not. + */ +const std::array V2_MESSAGE_IDS = { + "", // 12 bytes follow encoding the message type like in V1 + NetMsgType::ADDR, + NetMsgType::BLOCK, + NetMsgType::BLOCKTXN, + NetMsgType::CMPCTBLOCK, + NetMsgType::FEEFILTER, + NetMsgType::FILTERADD, + NetMsgType::FILTERCLEAR, + NetMsgType::FILTERLOAD, + NetMsgType::GETBLOCKS, + NetMsgType::GETBLOCKTXN, + NetMsgType::GETDATA, + NetMsgType::GETHEADERS, + NetMsgType::HEADERS, + NetMsgType::INV, + NetMsgType::MEMPOOL, + NetMsgType::MERKLEBLOCK, + NetMsgType::NOTFOUND, + NetMsgType::PING, + NetMsgType::PONG, + NetMsgType::SENDCMPCT, + NetMsgType::TX, + NetMsgType::GETCFILTERS, + NetMsgType::CFILTER, + NetMsgType::GETCFHEADERS, + NetMsgType::CFHEADERS, + NetMsgType::GETCFCHECKPT, + NetMsgType::CFCHECKPT, + NetMsgType::ADDRV2, + // Unimplemented message types that are assigned in BIP324: + "", + "", + "", + "" +}; + +class V2MessageMap +{ + std::unordered_map m_map; + +public: + V2MessageMap() noexcept + { + for (size_t i = 1; i < std::size(V2_MESSAGE_IDS); ++i) { + m_map.emplace(V2_MESSAGE_IDS[i], i); + } + } + + std::optional operator()(const std::string& message_name) const noexcept + { + auto it = m_map.find(message_name); + if (it == m_map.end()) return std::nullopt; + return it->second; + } +}; + +const V2MessageMap V2_MESSAGE_MAP; + +} // namespace + V2Transport::V2Transport(NodeId nodeid, bool initiating, int type_in, int version_in) noexcept : m_cipher{}, m_initiating{initiating}, m_nodeid{nodeid}, m_v1_fallback{nodeid, type_in, version_in}, m_recv_type{type_in}, m_recv_version{version_in}, @@ -1299,7 +1367,16 @@ std::optional V2Transport::GetMessageType(Span& cont uint8_t first_byte = contents[0]; contents = contents.subspan(1); // Strip first byte. - if (first_byte != 0) return std::nullopt; // TODO: implement short encoding + if (first_byte != 0) { + // Short (1 byte) encoding. + if (first_byte < std::size(V2_MESSAGE_IDS)) { + // Valid short message id. + return V2_MESSAGE_IDS[first_byte]; + } else { + // Unknown short message id. + return std::nullopt; + } + } if (contents.size() < CMessageHeader::COMMAND_SIZE) { return std::nullopt; // Long encoding needs 12 message type bytes. @@ -1364,11 +1441,19 @@ bool V2Transport::SetMessageToSend(CSerializedNetMsg& msg) noexcept // buffer to just one, and leaves the responsibility for queueing them up to the caller. if (!(m_send_state == SendState::READY && m_send_buffer.empty())) return false; // Construct contents (encoding message type + payload). - // Initialize with zeroes, and then write the message type string starting at offset 1. - // This means contents[0] and the unused positions in contents[1..13] remain 0x00. - std::vector contents(1 + CMessageHeader::COMMAND_SIZE + msg.data.size(), 0); - std::copy(msg.m_type.begin(), msg.m_type.end(), contents.data() + 1); - std::copy(msg.data.begin(), msg.data.end(), contents.begin() + 1 + CMessageHeader::COMMAND_SIZE); + std::vector contents; + auto short_message_id = V2_MESSAGE_MAP(msg.m_type); + if (short_message_id) { + contents.resize(1 + msg.data.size()); + contents[0] = *short_message_id; + std::copy(msg.data.begin(), msg.data.end(), contents.begin() + 1); + } else { + // Initialize with zeroes, and then write the message type string starting at offset 1. + // This means contents[0] and the unused positions in contents[1..13] remain 0x00. + contents.resize(1 + CMessageHeader::COMMAND_SIZE + msg.data.size(), 0); + std::copy(msg.m_type.begin(), msg.m_type.end(), contents.data() + 1); + std::copy(msg.data.begin(), msg.data.end(), contents.begin() + 1 + CMessageHeader::COMMAND_SIZE); + } // Construct ciphertext in send buffer. m_send_buffer.resize(contents.size() + BIP324Cipher::EXPANSION); m_cipher.Encrypt(MakeByteSpan(contents), {}, false, MakeWritableByteSpan(m_send_buffer));