diff --git a/src/httpserver.cpp b/src/httpserver.cpp index b3bd56d7a5c..3c374603b34 100644 --- a/src/httpserver.cpp +++ b/src/httpserver.cpp @@ -999,6 +999,8 @@ void HTTPRequest::WriteReply(HTTPStatusCode status, std::span r res.m_keep_alive = false; } + m_client->m_keep_alive = res.m_keep_alive; + // Serialize the response headers const std::string headers{res.StringifyHeaders()}; const auto headers_bytes{std::as_bytes(std::span(headers.begin(), headers.end()))}; @@ -1062,7 +1064,9 @@ bool HTTPClient::SendBytesFromBuffer() m_origin, m_node_id, err); - // TODO: disconnect + m_send_ready = false; + m_prevent_disconnect = false; + m_disconnect = true; return false; } @@ -1075,11 +1079,47 @@ bool HTTPClient::SendBytesFromBuffer() bytes_sent, m_origin, m_node_id); + + // This check is inside the if(!empty) block meaning "there was data but now its gone". + // We shouldn't even be calling SendBytesFromBuffer() when the send buffer is empty, + // but for belt-and-suspenders, we don't want to modify the disconnect flags if SendBytesFromBuffer() was a no-op. + if (m_send_buffer.empty()) { + m_send_ready = false; + m_prevent_disconnect = false; + + // Our work is done here + if (!m_keep_alive) { + m_disconnect = true; + return false; + } + } } return true; } +void HTTPServer::CloseConnectionInternal(std::shared_ptr& client) +{ + if (CloseConnection(client->m_node_id)) { + LogDebug(BCLog::HTTP, "Disconnected HTTP client %s (id=%d)\n", client->m_origin, client->m_node_id); + } else { + LogDebug(BCLog::HTTP, "Failed to disconnect non-existent HTTP client %s (id=%d)\n", client->m_origin, client->m_node_id); + } +} + +void HTTPServer::DisconnectClients() +{ + for (auto it = m_connected_clients.begin(); it != m_connected_clients.end();) { + if ((it->second->m_disconnect || m_disconnect_all_clients) && !it->second->m_prevent_disconnect) { + CloseConnectionInternal(it->second); + it = m_connected_clients.erase(it); + } else { + ++it; + } + } + m_no_clients = m_connected_clients.size() == 0; +} + bool HTTPServer::EventNewConnectionAccepted(NodeId node_id, const CService& me, const CService& them) @@ -1118,6 +1158,9 @@ void HTTPServer::EventGotData(NodeId node_id, std::span data) return; } + // Prevent disconnect until all requests are completely handled. + client->m_prevent_disconnect = true; + // Copy data from socket buffer to client receive buffer client->m_recv_buffer.insert( client->m_recv_buffer.end(), @@ -1141,8 +1184,8 @@ void HTTPServer::EventGotData(NodeId node_id, std::span data) e.what()); // We failed to read a complete request from the buffer - // TODO: respond with HTTP_BAD_REQUEST and disconnect - + req->WriteReply(HTTP_BAD_REQUEST); + client->m_disconnect = true; break; } @@ -1160,6 +1203,33 @@ void HTTPServer::EventGotData(NodeId node_id, std::span data) } } +void HTTPServer::EventGotEOF(NodeId node_id) +{ + // Get the HTTPClient + auto client{GetClientById(node_id)}; + if (client == nullptr) { + return; + } + + client->m_disconnect = true; +} + +void HTTPServer::EventGotPermanentReadError(NodeId node_id, const std::string& errmsg) +{ + // Get the HTTPClient + auto client{GetClientById(node_id)}; + if (client == nullptr) { + return; + } + + client->m_disconnect = true; +} + +void HTTPServer::EventIOLoopCompletedForAll() +{ + DisconnectClients(); +} + bool HTTPServer::ShouldTryToSend(NodeId node_id) const { // Get the HTTPClient diff --git a/src/httpserver.h b/src/httpserver.h index 7c1c21156aa..9ec99a23b79 100644 --- a/src/httpserver.h +++ b/src/httpserver.h @@ -305,6 +305,17 @@ public: // Checked in the Sockman I/O loop to avoid locking m_send_mutex if there's nothing to send. std::atomic_bool m_send_ready{false}; + // Set to true when we receive request data and set to false once m_send_buffer is cleared. + // Checked during DisconnectClients(). All of these operations take place in the Sockman I/O loop. + bool m_prevent_disconnect{false}; + + // Client request to keep connection open after all requests have been responded to. + // Set by (potentially multiple) worker threads and checked in the Sockman I/O loop. + std::atomic_bool m_keep_alive{false}; + + // Flag this client for disconnection on next loop + bool m_disconnect{false}; + explicit HTTPClient(NodeId node_id, CService addr) : m_node_id(node_id), m_addr(addr) { m_origin = addr.ToStringAddrPort(); @@ -325,6 +336,9 @@ public: class HTTPServer : public SockMan { +private: + void CloseConnectionInternal(std::shared_ptr& client); + public: explicit HTTPServer(std::function)> func) : m_request_dispatcher(func) {}; @@ -339,6 +353,13 @@ public: std::shared_ptr GetClientById(NodeId node_id) const; + // Close underlying connections where flagged + void DisconnectClients(); + + // Flag used during shutdown to bypass keep-alive flag. + // Set by main thread and read by Sockman I/O thread + std::atomic_bool m_disconnect_all_clients{false}; + /** * Be notified when a new connection has been accepted. * @param[in] node_id Id of the newly accepted connection. @@ -373,14 +394,22 @@ public: * makes sense at the application level. * @param[in] node_id Node whose socket got EOF. */ - virtual void EventGotEOF(NodeId node_id) override {}; + virtual void EventGotEOF(NodeId node_id) override; /** * Called when we get an irrecoverable error trying to read from a socket. * @param[in] node_id Node whose socket got an error. * @param[in] errmsg Message describing the error. */ - virtual void EventGotPermanentReadError(NodeId node_id, const std::string& errmsg) override {}; + virtual void EventGotPermanentReadError(NodeId node_id, const std::string& errmsg) override; + + /** + * SockMan has completed send+recv for all nodes. + * Can be used to execute periodic tasks for all nodes, like disconnecting + * nodes due to higher level logic. + * The implementation in SockMan does nothing. + */ + virtual void EventIOLoopCompletedForAll() override; /** * Can be used to temporarily pause sends on a connection. diff --git a/src/test/httpserver_tests.cpp b/src/test/httpserver_tests.cpp index de8739b9a82..14804e12dd4 100644 --- a/src/test/httpserver_tests.cpp +++ b/src/test/httpserver_tests.cpp @@ -385,6 +385,17 @@ BOOST_AUTO_TEST_CASE(http_client_server_tests) } BOOST_CHECK_EQUAL(actual, expected); + // Wait up to one minute for connection to be closed + attempts = 6000; + while (attempts > 0) + { + if (server.m_no_clients) break; + + std::this_thread::sleep_for(10ms); + --attempts; + } + BOOST_REQUIRE(server.m_no_clients); + // Close server server.interruptNet(); // Wait for I/O loop to finish, after all sockets are closed