/* * Copyright 2013-present Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include using folly::ShutdownSocketSet; namespace fsp = folly::portability::sockets; namespace folly { namespace test { ShutdownSocketSet shutdownSocketSet; class Server { public: Server(); void stop(bool abortive); void join(); int port() const { return port_; } int closeClients(bool abortive); private: int acceptSocket_; int port_; enum StopMode { NO_STOP, ORDERLY, ABORTIVE }; std::atomic stop_; std::thread serverThread_; std::vector fds_; }; Server::Server() : acceptSocket_(-1), port_(0), stop_(NO_STOP) { acceptSocket_ = fsp::socket(PF_INET, SOCK_STREAM, 0); CHECK_ERR(acceptSocket_); shutdownSocketSet.add(acceptSocket_); sockaddr_in addr; addr.sin_family = AF_INET; addr.sin_port = 0; addr.sin_addr.s_addr = INADDR_ANY; CHECK_ERR(bind( acceptSocket_, reinterpret_cast(&addr), sizeof(addr))); CHECK_ERR(listen(acceptSocket_, 10)); socklen_t addrLen = sizeof(addr); CHECK_ERR( getsockname(acceptSocket_, reinterpret_cast(&addr), &addrLen)); port_ = ntohs(addr.sin_port); serverThread_ = std::thread([this] { while (stop_ == NO_STOP) { sockaddr_in peer; socklen_t peerLen = sizeof(peer); int fd = accept(acceptSocket_, reinterpret_cast(&peer), &peerLen); if (fd == -1) { if (errno == EINTR) { continue; } if (errno == EINVAL || errno == ENOTSOCK) { // socket broken break; } } CHECK_ERR(fd); shutdownSocketSet.add(fd); fds_.push_back(fd); } if (stop_ != NO_STOP) { closeClients(stop_ == ABORTIVE); } shutdownSocketSet.close(acceptSocket_); acceptSocket_ = -1; port_ = 0; }); } int Server::closeClients(bool abortive) { for (int fd : fds_) { if (abortive) { struct linger l = {1, 0}; CHECK_ERR(setsockopt(fd, SOL_SOCKET, SO_LINGER, &l, sizeof(l))); } shutdownSocketSet.close(fd); } int n = fds_.size(); fds_.clear(); return n; } void Server::stop(bool abortive) { stop_ = abortive ? ABORTIVE : ORDERLY; shutdown(acceptSocket_, SHUT_RDWR); } void Server::join() { serverThread_.join(); } int createConnectedSocket(int port) { int sock = fsp::socket(PF_INET, SOCK_STREAM, 0); CHECK_ERR(sock); sockaddr_in addr; addr.sin_family = AF_INET; addr.sin_port = htons(port); addr.sin_addr.s_addr = htonl((127 << 24) | 1); // XXX CHECK_ERR( connect(sock, reinterpret_cast(&addr), sizeof(addr))); return sock; } void runCloseTest(bool abortive) { Server server; int sock = createConnectedSocket(server.port()); std::thread stopper([&server, abortive] { std::this_thread::sleep_for(std::chrono::milliseconds(200)); server.stop(abortive); server.join(); }); char c; int r = read(sock, &c, 1); if (abortive) { int e = errno; EXPECT_EQ(-1, r); EXPECT_EQ(ECONNRESET, e); } else { EXPECT_EQ(0, r); } close(sock); stopper.join(); EXPECT_EQ(0, server.closeClients(false)); // closed by server when it exited } TEST(ShutdownSocketSetTest, OrderlyClose) { runCloseTest(false); } TEST(ShutdownSocketSetTest, AbortiveClose) { runCloseTest(true); } void runKillTest(bool abortive) { Server server; int sock = createConnectedSocket(server.port()); std::thread killer([&server, abortive] { std::this_thread::sleep_for(std::chrono::milliseconds(200)); shutdownSocketSet.shutdownAll(abortive); server.join(); }); char c; int r = read(sock, &c, 1); // "abortive" is just a hint for ShutdownSocketSet, so accept both // behaviors if (abortive) { if (r == -1) { EXPECT_EQ(ECONNRESET, errno); } else { EXPECT_EQ(r, 0); } } else { EXPECT_EQ(0, r); } close(sock); killer.join(); // NOT closed by server when it exited EXPECT_EQ(1, server.closeClients(false)); } TEST(ShutdownSocketSetTest, OrderlyKill) { runKillTest(false); } TEST(ShutdownSocketSetTest, AbortiveKill) { runKillTest(true); } } // namespace test } // namespace folly