AsyncSocketTest.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. /*
  2. * Copyright 2015-present Facebook, Inc.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #pragma once
  17. #include <folly/io/async/AsyncSocket.h>
  18. #include <folly/io/async/test/BlockingSocket.h>
  19. #include <folly/portability/Sockets.h>
  20. #include <memory>
  21. enum StateEnum { STATE_WAITING, STATE_SUCCEEDED, STATE_FAILED };
  22. typedef std::function<void()> VoidCallback;
  23. class ConnCallback : public folly::AsyncSocket::ConnectCallback {
  24. public:
  25. ConnCallback()
  26. : state(STATE_WAITING),
  27. exception(folly::AsyncSocketException::UNKNOWN, "none") {}
  28. void connectSuccess() noexcept override {
  29. state = STATE_SUCCEEDED;
  30. if (successCallback) {
  31. successCallback();
  32. }
  33. }
  34. void connectErr(const folly::AsyncSocketException& ex) noexcept override {
  35. state = STATE_FAILED;
  36. exception = ex;
  37. if (errorCallback) {
  38. errorCallback();
  39. }
  40. }
  41. StateEnum state;
  42. folly::AsyncSocketException exception;
  43. VoidCallback successCallback;
  44. VoidCallback errorCallback;
  45. };
  46. class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
  47. public:
  48. WriteCallback()
  49. : state(STATE_WAITING),
  50. bytesWritten(0),
  51. exception(folly::AsyncSocketException::UNKNOWN, "none") {}
  52. void writeSuccess() noexcept override {
  53. state = STATE_SUCCEEDED;
  54. if (successCallback) {
  55. successCallback();
  56. }
  57. }
  58. void writeErr(
  59. size_t nBytesWritten,
  60. const folly::AsyncSocketException& ex) noexcept override {
  61. LOG(ERROR) << ex.what();
  62. state = STATE_FAILED;
  63. this->bytesWritten = nBytesWritten;
  64. exception = ex;
  65. if (errorCallback) {
  66. errorCallback();
  67. }
  68. }
  69. StateEnum state;
  70. std::atomic<size_t> bytesWritten;
  71. folly::AsyncSocketException exception;
  72. VoidCallback successCallback;
  73. VoidCallback errorCallback;
  74. };
  75. class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
  76. public:
  77. explicit ReadCallback(size_t _maxBufferSz = 4096)
  78. : state(STATE_WAITING),
  79. exception(folly::AsyncSocketException::UNKNOWN, "none"),
  80. buffers(),
  81. maxBufferSz(_maxBufferSz) {}
  82. ~ReadCallback() override {
  83. for (std::vector<Buffer>::iterator it = buffers.begin();
  84. it != buffers.end();
  85. ++it) {
  86. it->free();
  87. }
  88. currentBuffer.free();
  89. }
  90. void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
  91. if (!currentBuffer.buffer) {
  92. currentBuffer.allocate(maxBufferSz);
  93. }
  94. *bufReturn = currentBuffer.buffer;
  95. *lenReturn = currentBuffer.length;
  96. }
  97. void readDataAvailable(size_t len) noexcept override {
  98. currentBuffer.length = len;
  99. buffers.push_back(currentBuffer);
  100. currentBuffer.reset();
  101. if (dataAvailableCallback) {
  102. dataAvailableCallback();
  103. }
  104. }
  105. void readEOF() noexcept override {
  106. state = STATE_SUCCEEDED;
  107. }
  108. void readErr(const folly::AsyncSocketException& ex) noexcept override {
  109. state = STATE_FAILED;
  110. exception = ex;
  111. }
  112. void verifyData(const char* expected, size_t expectedLen) const {
  113. size_t offset = 0;
  114. for (size_t idx = 0; idx < buffers.size(); ++idx) {
  115. const auto& buf = buffers[idx];
  116. size_t cmpLen = std::min(buf.length, expectedLen - offset);
  117. CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
  118. CHECK_EQ(cmpLen, buf.length);
  119. offset += cmpLen;
  120. }
  121. CHECK_EQ(offset, expectedLen);
  122. }
  123. size_t dataRead() const {
  124. size_t ret = 0;
  125. for (const auto& buf : buffers) {
  126. ret += buf.length;
  127. }
  128. return ret;
  129. }
  130. class Buffer {
  131. public:
  132. Buffer() : buffer(nullptr), length(0) {}
  133. Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
  134. void reset() {
  135. buffer = nullptr;
  136. length = 0;
  137. }
  138. void allocate(size_t len) {
  139. assert(buffer == nullptr);
  140. this->buffer = static_cast<char*>(malloc(len));
  141. this->length = len;
  142. }
  143. void free() {
  144. ::free(buffer);
  145. reset();
  146. }
  147. char* buffer;
  148. size_t length;
  149. };
  150. StateEnum state;
  151. folly::AsyncSocketException exception;
  152. std::vector<Buffer> buffers;
  153. Buffer currentBuffer;
  154. VoidCallback dataAvailableCallback;
  155. const size_t maxBufferSz;
  156. };
  157. class BufferCallback : public folly::AsyncTransport::BufferCallback {
  158. public:
  159. BufferCallback() : buffered_(false), bufferCleared_(false) {}
  160. void onEgressBuffered() override {
  161. buffered_ = true;
  162. }
  163. void onEgressBufferCleared() override {
  164. bufferCleared_ = true;
  165. }
  166. bool hasBuffered() const {
  167. return buffered_;
  168. }
  169. bool hasBufferCleared() const {
  170. return bufferCleared_;
  171. }
  172. private:
  173. bool buffered_{false};
  174. bool bufferCleared_{false};
  175. };
  176. class ReadVerifier {};
  177. class TestSendMsgParamsCallback
  178. : public folly::AsyncSocket::SendMsgParamsCallback {
  179. public:
  180. TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
  181. : flags_(flags),
  182. writeFlags_(folly::WriteFlags::NONE),
  183. dataSize_(dataSize),
  184. data_(data),
  185. queriedFlags_(false),
  186. queriedData_(false) {}
  187. void reset(int flags) {
  188. flags_ = flags;
  189. writeFlags_ = folly::WriteFlags::NONE;
  190. queriedFlags_ = false;
  191. queriedData_ = false;
  192. }
  193. int getFlagsImpl(
  194. folly::WriteFlags flags,
  195. int /*defaultFlags*/) noexcept override {
  196. queriedFlags_ = true;
  197. if (writeFlags_ == folly::WriteFlags::NONE) {
  198. writeFlags_ = flags;
  199. } else {
  200. assert(flags == writeFlags_);
  201. }
  202. return flags_;
  203. }
  204. void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
  205. queriedData_ = true;
  206. if (writeFlags_ == folly::WriteFlags::NONE) {
  207. writeFlags_ = flags;
  208. } else {
  209. assert(flags == writeFlags_);
  210. }
  211. assert(data != nullptr);
  212. memcpy(data, data_, dataSize_);
  213. }
  214. uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
  215. if (writeFlags_ == folly::WriteFlags::NONE) {
  216. writeFlags_ = flags;
  217. } else {
  218. assert(flags == writeFlags_);
  219. }
  220. return dataSize_;
  221. }
  222. int flags_;
  223. folly::WriteFlags writeFlags_;
  224. uint32_t dataSize_;
  225. void* data_;
  226. bool queriedFlags_;
  227. bool queriedData_;
  228. };
  229. class TestServer {
  230. public:
  231. // Create a TestServer.
  232. // This immediately starts listening on an ephemeral port.
  233. explicit TestServer(bool enableTFO = false, int bufSize = -1) : fd_(-1) {
  234. namespace fsp = folly::portability::sockets;
  235. fd_ = fsp::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
  236. if (fd_ < 0) {
  237. throw folly::AsyncSocketException(
  238. folly::AsyncSocketException::INTERNAL_ERROR,
  239. "failed to create test server socket",
  240. errno);
  241. }
  242. if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
  243. throw folly::AsyncSocketException(
  244. folly::AsyncSocketException::INTERNAL_ERROR,
  245. "failed to put test server socket in "
  246. "non-blocking mode",
  247. errno);
  248. }
  249. if (enableTFO) {
  250. #if FOLLY_ALLOW_TFO
  251. folly::detail::tfo_enable(fd_, 100);
  252. #endif
  253. }
  254. struct addrinfo hints, *res;
  255. memset(&hints, 0, sizeof(hints));
  256. hints.ai_family = AF_INET;
  257. hints.ai_socktype = SOCK_STREAM;
  258. hints.ai_flags = AI_PASSIVE;
  259. if (getaddrinfo(nullptr, "0", &hints, &res)) {
  260. throw folly::AsyncSocketException(
  261. folly::AsyncSocketException::INTERNAL_ERROR,
  262. "Attempted to bind address to socket with "
  263. "bad getaddrinfo",
  264. errno);
  265. }
  266. SCOPE_EXIT {
  267. freeaddrinfo(res);
  268. };
  269. if (bufSize > 0) {
  270. setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize));
  271. setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize));
  272. }
  273. if (bind(fd_, res->ai_addr, res->ai_addrlen)) {
  274. throw folly::AsyncSocketException(
  275. folly::AsyncSocketException::INTERNAL_ERROR,
  276. "failed to bind to async server socket for port 10",
  277. errno);
  278. }
  279. if (listen(fd_, 10) != 0) {
  280. throw folly::AsyncSocketException(
  281. folly::AsyncSocketException::INTERNAL_ERROR,
  282. "failed to listen on test server socket",
  283. errno);
  284. }
  285. address_.setFromLocalAddress(fd_);
  286. // The local address will contain 0.0.0.0.
  287. // Change it to 127.0.0.1, so it can be used to connect to the server
  288. address_.setFromIpPort("127.0.0.1", address_.getPort());
  289. }
  290. ~TestServer() {
  291. if (fd_ != -1) {
  292. close(fd_);
  293. }
  294. }
  295. // Get the address for connecting to the server
  296. const folly::SocketAddress& getAddress() const {
  297. return address_;
  298. }
  299. int acceptFD(int timeout = 50) {
  300. namespace fsp = folly::portability::sockets;
  301. struct pollfd pfd;
  302. pfd.fd = fd_;
  303. pfd.events = POLLIN;
  304. int ret = poll(&pfd, 1, timeout);
  305. if (ret == 0) {
  306. throw folly::AsyncSocketException(
  307. folly::AsyncSocketException::INTERNAL_ERROR,
  308. "test server accept() timed out");
  309. } else if (ret < 0) {
  310. throw folly::AsyncSocketException(
  311. folly::AsyncSocketException::INTERNAL_ERROR,
  312. "test server accept() poll failed",
  313. errno);
  314. }
  315. int acceptedFd = fsp::accept(fd_, nullptr, nullptr);
  316. if (acceptedFd < 0) {
  317. throw folly::AsyncSocketException(
  318. folly::AsyncSocketException::INTERNAL_ERROR,
  319. "test server accept() failed",
  320. errno);
  321. }
  322. return acceptedFd;
  323. }
  324. std::shared_ptr<BlockingSocket> accept(int timeout = 50) {
  325. int fd = acceptFD(timeout);
  326. return std::make_shared<BlockingSocket>(fd);
  327. }
  328. std::shared_ptr<folly::AsyncSocket> acceptAsync(
  329. folly::EventBase* evb,
  330. int timeout = 50) {
  331. int fd = acceptFD(timeout);
  332. return folly::AsyncSocket::newSocket(evb, fd);
  333. }
  334. /**
  335. * Accept a connection, read data from it, and verify that it matches the
  336. * data in the specified buffer.
  337. */
  338. void verifyConnection(const char* buf, size_t len) {
  339. // accept a connection
  340. std::shared_ptr<BlockingSocket> acceptedSocket = accept();
  341. // read the data and compare it to the specified buffer
  342. std::unique_ptr<uint8_t[]> readbuf(new uint8_t[len]);
  343. acceptedSocket->readAll(readbuf.get(), len);
  344. CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
  345. // make sure we get EOF next
  346. uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
  347. CHECK_EQ(bytesRead, 0);
  348. }
  349. private:
  350. int fd_;
  351. folly::SocketAddress address_;
  352. };