diff --git a/src/common/netif.cpp b/src/common/netif.cpp index 7424f977c7e..d604d1e853b 100644 --- a/src/common/netif.cpp +++ b/src/common/netif.cpp @@ -36,6 +36,9 @@ namespace { // will fail, so we skip that. #if defined(__linux__) || (defined(__FreeBSD__) && __FreeBSD_version >= 1400000) +// Good for responses containing ~ 10,000-15,000 routes. +static constexpr ssize_t NETLINK_MAX_RESPONSE_SIZE{1'048'576}; + std::optional QueryDefaultGatewayImpl(sa_family_t family) { // Create a netlink socket. @@ -84,41 +87,78 @@ std::optional QueryDefaultGatewayImpl(sa_family_t family) // Receive response. char response[4096]; - int64_t recv_result; - do { - recv_result = sock->Recv(response, sizeof(response), 0); - } while (recv_result < 0 && (errno == EINTR || errno == EAGAIN)); - if (recv_result < 0) { - LogPrintLevel(BCLog::NET, BCLog::Level::Error, "recv() from netlink socket: %s\n", NetworkErrorString(errno)); - return std::nullopt; - } + ssize_t total_bytes_read{0}; + bool done{false}; + bool multi_part{false}; + while (!done) { + int64_t recv_result; + do { + recv_result = sock->Recv(response, sizeof(response), 0); + } while (recv_result < 0 && (errno == EINTR || errno == EAGAIN)); + if (recv_result < 0) { + LogPrintLevel(BCLog::NET, BCLog::Level::Error, "recv() from netlink socket: %s\n", NetworkErrorString(errno)); + return std::nullopt; + } - for (nlmsghdr* hdr = (nlmsghdr*)response; NLMSG_OK(hdr, recv_result); hdr = NLMSG_NEXT(hdr, recv_result)) { - rtmsg* r = (rtmsg*)NLMSG_DATA(hdr); - int remaining_len = RTM_PAYLOAD(hdr); + total_bytes_read += recv_result; + if (total_bytes_read > NETLINK_MAX_RESPONSE_SIZE) { + LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "Netlink response exceeded size limit (%zu bytes, family=%d)\n", NETLINK_MAX_RESPONSE_SIZE, family); + return std::nullopt; + } - // Iterate over the attributes. - rtattr *rta_gateway = nullptr; - int scope_id = 0; - for (rtattr* attr = RTM_RTA(r); RTA_OK(attr, remaining_len); attr = RTA_NEXT(attr, remaining_len)) { - if (attr->rta_type == RTA_GATEWAY) { - rta_gateway = attr; - } else if (attr->rta_type == RTA_OIF && sizeof(int) == RTA_PAYLOAD(attr)) { - std::memcpy(&scope_id, RTA_DATA(attr), sizeof(scope_id)); + bool processed_one{false}; + for (nlmsghdr* hdr = (nlmsghdr*)response; NLMSG_OK(hdr, recv_result); hdr = NLMSG_NEXT(hdr, recv_result)) { + rtmsg* r = (rtmsg*)NLMSG_DATA(hdr); + int remaining_len = RTM_PAYLOAD(hdr); + + processed_one = true; + if (hdr->nlmsg_flags & NLM_F_MULTI) { + multi_part = true; + } + + if (hdr->nlmsg_type == NLMSG_DONE) { + done = true; + break; + } + + if (hdr->nlmsg_type != RTM_NEWROUTE) { + continue; // Skip non-route messages + } + + // Only consider default routes (destination prefix length of 0). + if (r->rtm_dst_len != 0) { + continue; + } + + // Iterate over the attributes. + rtattr* rta_gateway = nullptr; + int scope_id = 0; + for (rtattr* attr = RTM_RTA(r); RTA_OK(attr, remaining_len); attr = RTA_NEXT(attr, remaining_len)) { + if (attr->rta_type == RTA_GATEWAY) { + rta_gateway = attr; + } else if (attr->rta_type == RTA_OIF && sizeof(int) == RTA_PAYLOAD(attr)) { + std::memcpy(&scope_id, RTA_DATA(attr), sizeof(scope_id)); + } + } + + // Found gateway? + if (rta_gateway != nullptr) { + if (family == AF_INET && sizeof(in_addr) == RTA_PAYLOAD(rta_gateway)) { + in_addr gw; + std::memcpy(&gw, RTA_DATA(rta_gateway), sizeof(gw)); + return CNetAddr(gw); + } else if (family == AF_INET6 && sizeof(in6_addr) == RTA_PAYLOAD(rta_gateway)) { + in6_addr gw; + std::memcpy(&gw, RTA_DATA(rta_gateway), sizeof(gw)); + return CNetAddr(gw, scope_id); + } } } - // Found gateway? - if (rta_gateway != nullptr) { - if (family == AF_INET && sizeof(in_addr) == RTA_PAYLOAD(rta_gateway)) { - in_addr gw; - std::memcpy(&gw, RTA_DATA(rta_gateway), sizeof(gw)); - return CNetAddr(gw); - } else if (family == AF_INET6 && sizeof(in6_addr) == RTA_PAYLOAD(rta_gateway)) { - in6_addr gw; - std::memcpy(&gw, RTA_DATA(rta_gateway), sizeof(gw)); - return CNetAddr(gw, scope_id); - } + // If we processed at least one message and multi flag not set, or if + // we received no valid messages, then we're done. + if ((processed_one && !multi_part) || !processed_one) { + done = true; } }