mirror of
https://github.com/bitcoin/bitcoin.git
synced 2025-04-29 23:09:44 -04:00
refactor: merge transport serializer and deserializer into Transport class
This allows state that is shared between both directions to be encapsulated into a single object. Specifically the v2 transport protocol introduced by BIP324 has sending state (the encryption keys) that depends on received messages (the DH key exchange). Having a single object for both means it can hide logic from callers related to that key exchange and other interactions.
This commit is contained in:
parent
23f3f402fc
commit
93594e42c3
4 changed files with 37 additions and 42 deletions
21
src/net.cpp
21
src/net.cpp
|
@ -681,16 +681,16 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete)
|
||||||
nRecvBytes += msg_bytes.size();
|
nRecvBytes += msg_bytes.size();
|
||||||
while (msg_bytes.size() > 0) {
|
while (msg_bytes.size() > 0) {
|
||||||
// absorb network data
|
// absorb network data
|
||||||
int handled = m_deserializer->Read(msg_bytes);
|
int handled = m_transport->Read(msg_bytes);
|
||||||
if (handled < 0) {
|
if (handled < 0) {
|
||||||
// Serious header problem, disconnect from the peer.
|
// Serious header problem, disconnect from the peer.
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (m_deserializer->Complete()) {
|
if (m_transport->Complete()) {
|
||||||
// decompose a transport agnostic CNetMessage from the deserializer
|
// decompose a transport agnostic CNetMessage from the deserializer
|
||||||
bool reject_message{false};
|
bool reject_message{false};
|
||||||
CNetMessage msg = m_deserializer->GetMessage(time, reject_message);
|
CNetMessage msg = m_transport->GetMessage(time, reject_message);
|
||||||
if (reject_message) {
|
if (reject_message) {
|
||||||
// Message deserialization failed. Drop the message but don't disconnect the peer.
|
// Message deserialization failed. Drop the message but don't disconnect the peer.
|
||||||
// store the size of the corrupt message
|
// store the size of the corrupt message
|
||||||
|
@ -717,7 +717,7 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete)
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int V1TransportDeserializer::readHeader(Span<const uint8_t> msg_bytes)
|
int V1Transport::readHeader(Span<const uint8_t> msg_bytes)
|
||||||
{
|
{
|
||||||
// copy data to temporary parsing buffer
|
// copy data to temporary parsing buffer
|
||||||
unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos;
|
unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos;
|
||||||
|
@ -757,7 +757,7 @@ int V1TransportDeserializer::readHeader(Span<const uint8_t> msg_bytes)
|
||||||
return nCopy;
|
return nCopy;
|
||||||
}
|
}
|
||||||
|
|
||||||
int V1TransportDeserializer::readData(Span<const uint8_t> msg_bytes)
|
int V1Transport::readData(Span<const uint8_t> msg_bytes)
|
||||||
{
|
{
|
||||||
unsigned int nRemaining = hdr.nMessageSize - nDataPos;
|
unsigned int nRemaining = hdr.nMessageSize - nDataPos;
|
||||||
unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
|
unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
|
||||||
|
@ -774,7 +774,7 @@ int V1TransportDeserializer::readData(Span<const uint8_t> msg_bytes)
|
||||||
return nCopy;
|
return nCopy;
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint256& V1TransportDeserializer::GetMessageHash() const
|
const uint256& V1Transport::GetMessageHash() const
|
||||||
{
|
{
|
||||||
assert(Complete());
|
assert(Complete());
|
||||||
if (data_hash.IsNull())
|
if (data_hash.IsNull())
|
||||||
|
@ -782,7 +782,7 @@ const uint256& V1TransportDeserializer::GetMessageHash() const
|
||||||
return data_hash;
|
return data_hash;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, bool& reject_message)
|
CNetMessage V1Transport::GetMessage(const std::chrono::microseconds time, bool& reject_message)
|
||||||
{
|
{
|
||||||
// Initialize out parameter
|
// Initialize out parameter
|
||||||
reject_message = false;
|
reject_message = false;
|
||||||
|
@ -819,7 +819,7 @@ CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds
|
||||||
return msg;
|
return msg;
|
||||||
}
|
}
|
||||||
|
|
||||||
void V1TransportSerializer::prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const
|
void V1Transport::prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const
|
||||||
{
|
{
|
||||||
// create dbl-sha256 checksum
|
// create dbl-sha256 checksum
|
||||||
uint256 hash = Hash(msg.data);
|
uint256 hash = Hash(msg.data);
|
||||||
|
@ -2822,8 +2822,7 @@ CNode::CNode(NodeId idIn,
|
||||||
ConnectionType conn_type_in,
|
ConnectionType conn_type_in,
|
||||||
bool inbound_onion,
|
bool inbound_onion,
|
||||||
CNodeOptions&& node_opts)
|
CNodeOptions&& node_opts)
|
||||||
: m_deserializer{std::make_unique<V1TransportDeserializer>(V1TransportDeserializer(Params(), idIn, SER_NETWORK, INIT_PROTO_VERSION))},
|
: m_transport{std::make_unique<V1Transport>(Params(), idIn, SER_NETWORK, INIT_PROTO_VERSION)},
|
||||||
m_serializer{std::make_unique<V1TransportSerializer>(V1TransportSerializer())},
|
|
||||||
m_permission_flags{node_opts.permission_flags},
|
m_permission_flags{node_opts.permission_flags},
|
||||||
m_sock{sock},
|
m_sock{sock},
|
||||||
m_connected{GetTime<std::chrono::seconds>()},
|
m_connected{GetTime<std::chrono::seconds>()},
|
||||||
|
@ -2908,7 +2907,7 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg)
|
||||||
|
|
||||||
// make sure we use the appropriate network transport format
|
// make sure we use the appropriate network transport format
|
||||||
std::vector<unsigned char> serializedHeader;
|
std::vector<unsigned char> serializedHeader;
|
||||||
pnode->m_serializer->prepareForTransport(msg, serializedHeader);
|
pnode->m_transport->prepareForTransport(msg, serializedHeader);
|
||||||
size_t nTotalSize = nMessageSize + serializedHeader.size();
|
size_t nTotalSize = nMessageSize + serializedHeader.size();
|
||||||
|
|
||||||
size_t nBytesSent = 0;
|
size_t nBytesSent = 0;
|
||||||
|
|
41
src/net.h
41
src/net.h
|
@ -253,24 +253,31 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/** The TransportDeserializer takes care of holding and deserializing the
|
/** The Transport converts one connection's sent messages to wire bytes, and received bytes back. */
|
||||||
* network receive buffer. It can deserialize the network buffer into a
|
class Transport {
|
||||||
* transport protocol agnostic CNetMessage (message type & payload)
|
|
||||||
*/
|
|
||||||
class TransportDeserializer {
|
|
||||||
public:
|
public:
|
||||||
|
virtual ~Transport() {}
|
||||||
|
|
||||||
|
// 1. Receiver side functions, for decoding bytes received on the wire into transport protocol
|
||||||
|
// agnostic CNetMessage (message type & payload) objects. Callers must guarantee that none of
|
||||||
|
// these functions are called concurrently w.r.t. one another.
|
||||||
|
|
||||||
// returns true if the current deserialization is complete
|
// returns true if the current deserialization is complete
|
||||||
virtual bool Complete() const = 0;
|
virtual bool Complete() const = 0;
|
||||||
// set the serialization context version
|
// set the deserialization context version
|
||||||
virtual void SetVersion(int version) = 0;
|
virtual void SetVersion(int version) = 0;
|
||||||
/** read and deserialize data, advances msg_bytes data pointer */
|
/** read and deserialize data, advances msg_bytes data pointer */
|
||||||
virtual int Read(Span<const uint8_t>& msg_bytes) = 0;
|
virtual int Read(Span<const uint8_t>& msg_bytes) = 0;
|
||||||
// decomposes a message from the context
|
// decomposes a message from the context
|
||||||
virtual CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) = 0;
|
virtual CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) = 0;
|
||||||
virtual ~TransportDeserializer() {}
|
|
||||||
|
// 2. Sending side functions:
|
||||||
|
|
||||||
|
// prepare message for transport (header construction, error-correction computation, payload encryption, etc.)
|
||||||
|
virtual void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class V1TransportDeserializer final : public TransportDeserializer
|
class V1Transport final : public Transport
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
const CChainParams& m_chain_params;
|
const CChainParams& m_chain_params;
|
||||||
|
@ -300,7 +307,7 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
V1TransportDeserializer(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn)
|
V1Transport(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn)
|
||||||
: m_chain_params(chain_params),
|
: m_chain_params(chain_params),
|
||||||
m_node_id(node_id),
|
m_node_id(node_id),
|
||||||
hdrbuf(nTypeIn, nVersionIn),
|
hdrbuf(nTypeIn, nVersionIn),
|
||||||
|
@ -331,19 +338,7 @@ public:
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override;
|
CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override;
|
||||||
};
|
|
||||||
|
|
||||||
/** The TransportSerializer prepares messages for the network transport
|
|
||||||
*/
|
|
||||||
class TransportSerializer {
|
|
||||||
public:
|
|
||||||
// prepare message for transport (header construction, error-correction computation, payload encryption, etc.)
|
|
||||||
virtual void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const = 0;
|
|
||||||
virtual ~TransportSerializer() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
class V1TransportSerializer : public TransportSerializer {
|
|
||||||
public:
|
|
||||||
void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const override;
|
void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -359,8 +354,8 @@ struct CNodeOptions
|
||||||
class CNode
|
class CNode
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
const std::unique_ptr<TransportDeserializer> m_deserializer; // Used only by SocketHandler thread
|
/** Transport serializer/deserializer. The receive side functions are only called under cs_vRecv. */
|
||||||
const std::unique_ptr<const TransportSerializer> m_serializer;
|
const std::unique_ptr<Transport> m_transport;
|
||||||
|
|
||||||
const NetPermissionFlags m_permission_flags;
|
const NetPermissionFlags m_permission_flags;
|
||||||
|
|
||||||
|
|
|
@ -24,9 +24,10 @@ void initialize_p2p_transport_serialization()
|
||||||
|
|
||||||
FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serialization)
|
FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serialization)
|
||||||
{
|
{
|
||||||
// Construct deserializer, with a dummy NodeId
|
// Construct transports for both sides, with dummy NodeIds.
|
||||||
V1TransportDeserializer deserializer{Params(), NodeId{0}, SER_NETWORK, INIT_PROTO_VERSION};
|
V1Transport recv_transport{Params(), NodeId{0}, SER_NETWORK, INIT_PROTO_VERSION};
|
||||||
V1TransportSerializer serializer{};
|
V1Transport send_transport{Params(), NodeId{1}, SER_NETWORK, INIT_PROTO_VERSION};
|
||||||
|
|
||||||
FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()};
|
FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()};
|
||||||
|
|
||||||
auto checksum_assist = fuzzed_data_provider.ConsumeBool();
|
auto checksum_assist = fuzzed_data_provider.ConsumeBool();
|
||||||
|
@ -63,14 +64,14 @@ FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serial
|
||||||
mutable_msg_bytes.insert(mutable_msg_bytes.end(), payload_bytes.begin(), payload_bytes.end());
|
mutable_msg_bytes.insert(mutable_msg_bytes.end(), payload_bytes.begin(), payload_bytes.end());
|
||||||
Span<const uint8_t> msg_bytes{mutable_msg_bytes};
|
Span<const uint8_t> msg_bytes{mutable_msg_bytes};
|
||||||
while (msg_bytes.size() > 0) {
|
while (msg_bytes.size() > 0) {
|
||||||
const int handled = deserializer.Read(msg_bytes);
|
const int handled = recv_transport.Read(msg_bytes);
|
||||||
if (handled < 0) {
|
if (handled < 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (deserializer.Complete()) {
|
if (recv_transport.Complete()) {
|
||||||
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
|
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
|
||||||
bool reject_message{false};
|
bool reject_message{false};
|
||||||
CNetMessage msg = deserializer.GetMessage(m_time, reject_message);
|
CNetMessage msg = recv_transport.GetMessage(m_time, reject_message);
|
||||||
assert(msg.m_type.size() <= CMessageHeader::COMMAND_SIZE);
|
assert(msg.m_type.size() <= CMessageHeader::COMMAND_SIZE);
|
||||||
assert(msg.m_raw_message_size <= mutable_msg_bytes.size());
|
assert(msg.m_raw_message_size <= mutable_msg_bytes.size());
|
||||||
assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size);
|
assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size);
|
||||||
|
@ -78,7 +79,7 @@ FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serial
|
||||||
|
|
||||||
std::vector<unsigned char> header;
|
std::vector<unsigned char> header;
|
||||||
auto msg2 = CNetMsgMaker{msg.m_recv.GetVersion()}.Make(msg.m_type, Span{msg.m_recv});
|
auto msg2 = CNetMsgMaker{msg.m_recv.GetVersion()}.Make(msg.m_type, Span{msg.m_recv});
|
||||||
serializer.prepareForTransport(msg2, header);
|
send_transport.prepareForTransport(msg2, header);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,7 +73,7 @@ void ConnmanTestMsg::NodeReceiveMsgBytes(CNode& node, Span<const uint8_t> msg_by
|
||||||
bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const
|
bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const
|
||||||
{
|
{
|
||||||
std::vector<uint8_t> ser_msg_header;
|
std::vector<uint8_t> ser_msg_header;
|
||||||
node.m_serializer->prepareForTransport(ser_msg, ser_msg_header);
|
node.m_transport->prepareForTransport(ser_msg, ser_msg_header);
|
||||||
|
|
||||||
bool complete;
|
bool complete;
|
||||||
NodeReceiveMsgBytes(node, ser_msg_header, complete);
|
NodeReceiveMsgBytes(node, ser_msg_header, complete);
|
||||||
|
|
Loading…
Add table
Reference in a new issue