AsyncSSLSocketTest.h 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531
  1. /*
  2. * Copyright 2012-present Facebook, Inc.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #pragma once
  17. #include <signal.h>
  18. #include <folly/ExceptionWrapper.h>
  19. #include <folly/SocketAddress.h>
  20. #include <folly/experimental/TestUtil.h>
  21. #include <folly/io/async/AsyncSSLSocket.h>
  22. #include <folly/io/async/AsyncServerSocket.h>
  23. #include <folly/io/async/AsyncSocket.h>
  24. #include <folly/io/async/AsyncTimeout.h>
  25. #include <folly/io/async/AsyncTransport.h>
  26. #include <folly/io/async/EventBase.h>
  27. #include <folly/io/async/ssl/SSLErrors.h>
  28. #include <folly/io/async/test/TestSSLServer.h>
  29. #include <folly/portability/GTest.h>
  30. #include <folly/portability/PThread.h>
  31. #include <folly/portability/Sockets.h>
  32. #include <folly/portability/Unistd.h>
  33. #include <fcntl.h>
  34. #include <sys/types.h>
  35. #include <condition_variable>
  36. #include <iostream>
  37. #include <list>
  38. #include <memory>
  39. namespace folly {
  40. // The destructors of all callback classes assert that the state is
  41. // STATE_SUCCEEDED, for both possitive and negative tests. The tests
  42. // are responsible for setting the succeeded state properly before the
  43. // destructors are called.
  44. class SendMsgParamsCallbackBase
  45. : public folly::AsyncSocket::SendMsgParamsCallback {
  46. public:
  47. SendMsgParamsCallbackBase() {}
  48. void setSocket(const std::shared_ptr<AsyncSSLSocket>& socket) {
  49. socket_ = socket;
  50. oldCallback_ = socket_->getSendMsgParamsCB();
  51. socket_->setSendMsgParamCB(this);
  52. }
  53. int getFlagsImpl(
  54. folly::WriteFlags flags,
  55. int /*defaultFlags*/) noexcept override {
  56. return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
  57. }
  58. void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
  59. oldCallback_->getAncillaryData(flags, data);
  60. }
  61. uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
  62. return oldCallback_->getAncillaryDataSize(flags);
  63. }
  64. std::shared_ptr<AsyncSSLSocket> socket_;
  65. folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
  66. };
  67. class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
  68. public:
  69. SendMsgFlagsCallback() {}
  70. void resetFlags(int flags) {
  71. flags_ = flags;
  72. }
  73. int getFlagsImpl(
  74. folly::WriteFlags flags,
  75. int /*defaultFlags*/) noexcept override {
  76. if (flags_) {
  77. return flags_;
  78. } else {
  79. return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
  80. }
  81. }
  82. int flags_{0};
  83. };
  84. class SendMsgDataCallback : public SendMsgFlagsCallback {
  85. public:
  86. SendMsgDataCallback() {}
  87. void resetData(std::vector<char>&& data) {
  88. ancillaryData_.swap(data);
  89. }
  90. void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
  91. if (ancillaryData_.size()) {
  92. std::cerr << "getAncillaryData: copying data" << std::endl;
  93. memcpy(data, ancillaryData_.data(), ancillaryData_.size());
  94. } else {
  95. oldCallback_->getAncillaryData(flags, data);
  96. }
  97. }
  98. uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
  99. if (ancillaryData_.size()) {
  100. std::cerr << "getAncillaryDataSize: returning size" << std::endl;
  101. return ancillaryData_.size();
  102. } else {
  103. return oldCallback_->getAncillaryDataSize(flags);
  104. }
  105. }
  106. std::vector<char> ancillaryData_;
  107. };
  108. class WriteCallbackBase : public AsyncTransportWrapper::WriteCallback {
  109. public:
  110. explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
  111. : state(STATE_WAITING),
  112. bytesWritten(0),
  113. exception(AsyncSocketException::UNKNOWN, "none"),
  114. mcb_(mcb) {}
  115. ~WriteCallbackBase() override {
  116. EXPECT_EQ(STATE_SUCCEEDED, state);
  117. }
  118. virtual void setSocket(const std::shared_ptr<AsyncSSLSocket>& socket) {
  119. socket_ = socket;
  120. if (mcb_) {
  121. mcb_->setSocket(socket);
  122. }
  123. }
  124. void writeSuccess() noexcept override {
  125. std::cerr << "writeSuccess" << std::endl;
  126. state = STATE_SUCCEEDED;
  127. }
  128. void writeErr(
  129. size_t nBytesWritten,
  130. const AsyncSocketException& ex) noexcept override {
  131. std::cerr << "writeError: bytesWritten " << nBytesWritten << ", exception "
  132. << ex.what() << std::endl;
  133. state = STATE_FAILED;
  134. this->bytesWritten = nBytesWritten;
  135. exception = ex;
  136. socket_->close();
  137. }
  138. std::shared_ptr<AsyncSSLSocket> socket_;
  139. StateEnum state;
  140. size_t bytesWritten;
  141. AsyncSocketException exception;
  142. SendMsgParamsCallbackBase* mcb_;
  143. };
  144. class ExpectWriteErrorCallback : public WriteCallbackBase {
  145. public:
  146. explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
  147. : WriteCallbackBase(mcb) {}
  148. ~ExpectWriteErrorCallback() override {
  149. EXPECT_EQ(STATE_FAILED, state);
  150. EXPECT_EQ(
  151. exception.getType(),
  152. AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
  153. EXPECT_EQ(exception.getErrno(), 22);
  154. // Suppress the assert in ~WriteCallbackBase()
  155. state = STATE_SUCCEEDED;
  156. }
  157. };
  158. #ifdef FOLLY_HAVE_MSG_ERRQUEUE
  159. /* copied from include/uapi/linux/net_tstamp.h */
  160. /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
  161. enum SOF_TIMESTAMPING {
  162. SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
  163. SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
  164. SOF_TIMESTAMPING_OPT_ID = (1 << 7),
  165. SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
  166. SOF_TIMESTAMPING_TX_ACK = (1 << 9),
  167. SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
  168. };
  169. class WriteCheckTimestampCallback : public WriteCallbackBase {
  170. public:
  171. explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
  172. : WriteCallbackBase(mcb) {}
  173. ~WriteCheckTimestampCallback() override {
  174. EXPECT_EQ(STATE_SUCCEEDED, state);
  175. EXPECT_TRUE(gotTimestamp_);
  176. EXPECT_TRUE(gotByteSeq_);
  177. }
  178. void setSocket(const std::shared_ptr<AsyncSSLSocket>& socket) override {
  179. WriteCallbackBase::setSocket(socket);
  180. EXPECT_NE(socket_->getFd(), 0);
  181. int flags = SOF_TIMESTAMPING_OPT_ID | SOF_TIMESTAMPING_OPT_TSONLY |
  182. SOF_TIMESTAMPING_SOFTWARE;
  183. AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
  184. int ret = tstampingOpt.apply(socket_->getFd(), flags);
  185. EXPECT_EQ(ret, 0);
  186. }
  187. void checkForTimestampNotifications() noexcept {
  188. int fd = socket_->getFd();
  189. std::vector<char> ctrl(1024, 0);
  190. unsigned char data;
  191. struct msghdr msg;
  192. iovec entry;
  193. memset(&msg, 0, sizeof(msg));
  194. entry.iov_base = &data;
  195. entry.iov_len = sizeof(data);
  196. msg.msg_iov = &entry;
  197. msg.msg_iovlen = 1;
  198. msg.msg_control = ctrl.data();
  199. msg.msg_controllen = ctrl.size();
  200. int ret;
  201. while (true) {
  202. ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
  203. if (ret < 0) {
  204. if (errno != EAGAIN) {
  205. auto errnoCopy = errno;
  206. std::cerr << "::recvmsg exited with code " << ret
  207. << ", errno: " << errnoCopy << std::endl;
  208. AsyncSocketException ex(
  209. AsyncSocketException::INTERNAL_ERROR,
  210. "recvmsg() failed",
  211. errnoCopy);
  212. exception = ex;
  213. }
  214. return;
  215. }
  216. for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
  217. cmsg != nullptr && cmsg->cmsg_len != 0;
  218. cmsg = CMSG_NXTHDR(&msg, cmsg)) {
  219. if (cmsg->cmsg_level == SOL_SOCKET &&
  220. cmsg->cmsg_type == SCM_TIMESTAMPING) {
  221. gotTimestamp_ = true;
  222. continue;
  223. }
  224. if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
  225. (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
  226. gotByteSeq_ = true;
  227. continue;
  228. }
  229. }
  230. }
  231. }
  232. bool gotTimestamp_{false};
  233. bool gotByteSeq_{false};
  234. };
  235. #endif // FOLLY_HAVE_MSG_ERRQUEUE
  236. class ReadCallbackBase : public AsyncTransportWrapper::ReadCallback {
  237. public:
  238. explicit ReadCallbackBase(WriteCallbackBase* wcb)
  239. : wcb_(wcb), state(STATE_WAITING) {}
  240. ~ReadCallbackBase() override {
  241. EXPECT_EQ(STATE_SUCCEEDED, state);
  242. }
  243. void setSocket(const std::shared_ptr<AsyncSSLSocket>& socket) {
  244. socket_ = socket;
  245. }
  246. void setState(StateEnum s) {
  247. state = s;
  248. if (wcb_) {
  249. wcb_->state = s;
  250. }
  251. }
  252. void readErr(const AsyncSocketException& ex) noexcept override {
  253. std::cerr << "readError " << ex.what() << std::endl;
  254. state = STATE_FAILED;
  255. socket_->close();
  256. }
  257. void readEOF() noexcept override {
  258. std::cerr << "readEOF" << std::endl;
  259. socket_->close();
  260. }
  261. std::shared_ptr<AsyncSSLSocket> socket_;
  262. WriteCallbackBase* wcb_;
  263. StateEnum state;
  264. };
  265. class ReadCallback : public ReadCallbackBase {
  266. public:
  267. explicit ReadCallback(WriteCallbackBase* wcb)
  268. : ReadCallbackBase(wcb), buffers() {}
  269. ~ReadCallback() override {
  270. for (std::vector<Buffer>::iterator it = buffers.begin();
  271. it != buffers.end();
  272. ++it) {
  273. it->free();
  274. }
  275. currentBuffer.free();
  276. }
  277. void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
  278. if (!currentBuffer.buffer) {
  279. currentBuffer.allocate(4096);
  280. }
  281. *bufReturn = currentBuffer.buffer;
  282. *lenReturn = currentBuffer.length;
  283. }
  284. void readDataAvailable(size_t len) noexcept override {
  285. std::cerr << "readDataAvailable, len " << len << std::endl;
  286. currentBuffer.length = len;
  287. wcb_->setSocket(socket_);
  288. // Write back the same data.
  289. socket_->write(wcb_, currentBuffer.buffer, len);
  290. buffers.push_back(currentBuffer);
  291. currentBuffer.reset();
  292. state = STATE_SUCCEEDED;
  293. }
  294. class Buffer {
  295. public:
  296. Buffer() : buffer(nullptr), length(0) {}
  297. Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
  298. void reset() {
  299. buffer = nullptr;
  300. length = 0;
  301. }
  302. void allocate(size_t len) {
  303. assert(buffer == nullptr);
  304. this->buffer = static_cast<char*>(malloc(len));
  305. this->length = len;
  306. }
  307. void free() {
  308. ::free(buffer);
  309. reset();
  310. }
  311. char* buffer;
  312. size_t length;
  313. };
  314. std::vector<Buffer> buffers;
  315. Buffer currentBuffer;
  316. };
  317. class ReadErrorCallback : public ReadCallbackBase {
  318. public:
  319. explicit ReadErrorCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
  320. // Return nullptr buffer to trigger readError()
  321. void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
  322. *bufReturn = nullptr;
  323. *lenReturn = 0;
  324. }
  325. void readDataAvailable(size_t /* len */) noexcept override {
  326. // This should never to called.
  327. FAIL();
  328. }
  329. void readErr(const AsyncSocketException& ex) noexcept override {
  330. ReadCallbackBase::readErr(ex);
  331. std::cerr << "ReadErrorCallback::readError" << std::endl;
  332. setState(STATE_SUCCEEDED);
  333. }
  334. };
  335. class ReadEOFCallback : public ReadCallbackBase {
  336. public:
  337. explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
  338. // Return nullptr buffer to trigger readError()
  339. void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
  340. *bufReturn = nullptr;
  341. *lenReturn = 0;
  342. }
  343. void readDataAvailable(size_t /* len */) noexcept override {
  344. // This should never to called.
  345. FAIL();
  346. }
  347. void readEOF() noexcept override {
  348. ReadCallbackBase::readEOF();
  349. setState(STATE_SUCCEEDED);
  350. }
  351. };
  352. class WriteErrorCallback : public ReadCallback {
  353. public:
  354. explicit WriteErrorCallback(WriteCallbackBase* wcb) : ReadCallback(wcb) {}
  355. void readDataAvailable(size_t len) noexcept override {
  356. std::cerr << "readDataAvailable, len " << len << std::endl;
  357. currentBuffer.length = len;
  358. // close the socket before writing to trigger writeError().
  359. ::close(socket_->getFd());
  360. wcb_->setSocket(socket_);
  361. // Write back the same data.
  362. folly::test::msvcSuppressAbortOnInvalidParams(
  363. [&] { socket_->write(wcb_, currentBuffer.buffer, len); });
  364. if (wcb_->state == STATE_FAILED) {
  365. setState(STATE_SUCCEEDED);
  366. } else {
  367. state = STATE_FAILED;
  368. }
  369. buffers.push_back(currentBuffer);
  370. currentBuffer.reset();
  371. }
  372. void readErr(const AsyncSocketException& ex) noexcept override {
  373. std::cerr << "readError " << ex.what() << std::endl;
  374. // do nothing since this is expected
  375. }
  376. };
  377. class EmptyReadCallback : public ReadCallback {
  378. public:
  379. explicit EmptyReadCallback() : ReadCallback(nullptr) {}
  380. void readErr(const AsyncSocketException& ex) noexcept override {
  381. std::cerr << "readError " << ex.what() << std::endl;
  382. state = STATE_FAILED;
  383. if (tcpSocket_) {
  384. tcpSocket_->close();
  385. }
  386. }
  387. void readEOF() noexcept override {
  388. std::cerr << "readEOF" << std::endl;
  389. if (tcpSocket_) {
  390. tcpSocket_->close();
  391. }
  392. state = STATE_SUCCEEDED;
  393. }
  394. std::shared_ptr<AsyncSocket> tcpSocket_;
  395. };
  396. class HandshakeCallback : public AsyncSSLSocket::HandshakeCB {
  397. public:
  398. enum ExpectType { EXPECT_SUCCESS, EXPECT_ERROR };
  399. explicit HandshakeCallback(
  400. ReadCallbackBase* rcb,
  401. ExpectType expect = EXPECT_SUCCESS)
  402. : state(STATE_WAITING), rcb_(rcb), expect_(expect) {}
  403. void setSocket(const std::shared_ptr<AsyncSSLSocket>& socket) {
  404. socket_ = socket;
  405. }
  406. void setState(StateEnum s) {
  407. state = s;
  408. rcb_->setState(s);
  409. }
  410. // Functions inherited from AsyncSSLSocketHandshakeCallback
  411. void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
  412. std::lock_guard<std::mutex> g(mutex_);
  413. cv_.notify_all();
  414. EXPECT_EQ(sock, socket_.get());
  415. std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
  416. rcb_->setSocket(socket_);
  417. sock->setReadCB(rcb_);
  418. state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
  419. }
  420. void handshakeErr(
  421. AsyncSSLSocket* /* sock */,
  422. const AsyncSocketException& ex) noexcept override {
  423. std::lock_guard<std::mutex> g(mutex_);
  424. cv_.notify_all();
  425. std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
  426. state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
  427. if (expect_ == EXPECT_ERROR) {
  428. // rcb will never be invoked
  429. rcb_->setState(STATE_SUCCEEDED);
  430. }
  431. errorString_ = ex.what();
  432. }
  433. void waitForHandshake() {
  434. std::unique_lock<std::mutex> lock(mutex_);
  435. cv_.wait(lock, [this] { return state != STATE_WAITING; });
  436. }
  437. ~HandshakeCallback() override {
  438. EXPECT_EQ(STATE_SUCCEEDED, state);
  439. }
  440. void closeSocket() {
  441. socket_->close();
  442. state = STATE_SUCCEEDED;
  443. }
  444. std::shared_ptr<AsyncSSLSocket> getSocket() {
  445. return socket_;
  446. }
  447. StateEnum state;
  448. std::shared_ptr<AsyncSSLSocket> socket_;
  449. ReadCallbackBase* rcb_;
  450. ExpectType expect_;
  451. std::mutex mutex_;
  452. std::condition_variable cv_;
  453. std::string errorString_;
  454. };
  455. class SSLServerAcceptCallback : public SSLServerAcceptCallbackBase {
  456. public:
  457. uint32_t timeout_;
  458. explicit SSLServerAcceptCallback(HandshakeCallback* hcb, uint32_t timeout = 0)
  459. : SSLServerAcceptCallbackBase(hcb), timeout_(timeout) {}
  460. ~SSLServerAcceptCallback() override {
  461. if (timeout_ > 0) {
  462. // if we set a timeout, we expect failure
  463. EXPECT_EQ(hcb_->state, STATE_FAILED);
  464. hcb_->setState(STATE_SUCCEEDED);
  465. }
  466. }
  467. void connAccepted(
  468. const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
  469. auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
  470. std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
  471. hcb_->setSocket(sock);
  472. sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
  473. EXPECT_EQ(sock->getSSLState(), AsyncSSLSocket::STATE_ACCEPTING);
  474. state = STATE_SUCCEEDED;
  475. }
  476. };
  477. class SSLServerAcceptCallbackDelay : public SSLServerAcceptCallback {
  478. public:
  479. explicit SSLServerAcceptCallbackDelay(HandshakeCallback* hcb)
  480. : SSLServerAcceptCallback(hcb) {}
  481. void connAccepted(
  482. const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
  483. auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
  484. std::cerr << "SSLServerAcceptCallbackDelay::connAccepted" << std::endl;
  485. int fd = sock->getFd();
  486. #ifndef TCP_NOPUSH
  487. {
  488. // The accepted connection should already have TCP_NODELAY set
  489. int value;
  490. socklen_t valueLength = sizeof(value);
  491. int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
  492. EXPECT_EQ(rc, 0);
  493. EXPECT_EQ(value, 1);
  494. }
  495. #endif
  496. // Unset the TCP_NODELAY option.
  497. int value = 0;
  498. socklen_t valueLength = sizeof(value);
  499. int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
  500. EXPECT_EQ(rc, 0);
  501. rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
  502. EXPECT_EQ(rc, 0);
  503. EXPECT_EQ(value, 0);
  504. SSLServerAcceptCallback::connAccepted(sock);
  505. }
  506. };
  507. class SSLServerAsyncCacheAcceptCallback : public SSLServerAcceptCallback {
  508. public:
  509. explicit SSLServerAsyncCacheAcceptCallback(
  510. HandshakeCallback* hcb,
  511. uint32_t timeout = 0)
  512. : SSLServerAcceptCallback(hcb, timeout) {}
  513. void connAccepted(
  514. const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
  515. auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
  516. std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
  517. hcb_->setSocket(sock);
  518. sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
  519. ASSERT_TRUE(
  520. (sock->getSSLState() == AsyncSSLSocket::STATE_ACCEPTING) ||
  521. (sock->getSSLState() == AsyncSSLSocket::STATE_CACHE_LOOKUP));
  522. state = STATE_SUCCEEDED;
  523. }
  524. };
  525. class HandshakeErrorCallback : public SSLServerAcceptCallbackBase {
  526. public:
  527. explicit HandshakeErrorCallback(HandshakeCallback* hcb)
  528. : SSLServerAcceptCallbackBase(hcb) {}
  529. void connAccepted(
  530. const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
  531. auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
  532. std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
  533. // The first call to sslAccept() should succeed.
  534. hcb_->setSocket(sock);
  535. sock->sslAccept(hcb_);
  536. EXPECT_EQ(sock->getSSLState(), AsyncSSLSocket::STATE_ACCEPTING);
  537. // The second call to sslAccept() should fail.
  538. HandshakeCallback callback2(hcb_->rcb_);
  539. callback2.setSocket(sock);
  540. sock->sslAccept(&callback2);
  541. EXPECT_EQ(sock->getSSLState(), AsyncSSLSocket::STATE_ERROR);
  542. // Both callbacks should be in the error state.
  543. EXPECT_EQ(hcb_->state, STATE_FAILED);
  544. EXPECT_EQ(callback2.state, STATE_FAILED);
  545. state = STATE_SUCCEEDED;
  546. hcb_->setState(STATE_SUCCEEDED);
  547. callback2.setState(STATE_SUCCEEDED);
  548. }
  549. };
  550. class HandshakeTimeoutCallback : public SSLServerAcceptCallbackBase {
  551. public:
  552. explicit HandshakeTimeoutCallback(HandshakeCallback* hcb)
  553. : SSLServerAcceptCallbackBase(hcb) {}
  554. void connAccepted(
  555. const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
  556. std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
  557. auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
  558. hcb_->setSocket(sock);
  559. sock->getEventBase()->tryRunAfterDelay(
  560. [=] {
  561. std::cerr << "Delayed SSL accept, client will have close by now"
  562. << std::endl;
  563. // SSL accept will fail
  564. EXPECT_EQ(sock->getSSLState(), AsyncSSLSocket::STATE_UNINIT);
  565. hcb_->socket_->sslAccept(hcb_);
  566. // This registers for an event
  567. EXPECT_EQ(sock->getSSLState(), AsyncSSLSocket::STATE_ACCEPTING);
  568. state = STATE_SUCCEEDED;
  569. },
  570. 100);
  571. }
  572. };
  573. class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
  574. public:
  575. ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
  576. // We don't care if we get invoked or not.
  577. // The client may time out and give up before connAccepted() is even
  578. // called.
  579. state = STATE_SUCCEEDED;
  580. }
  581. void connAccepted(
  582. const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
  583. std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
  584. // Just wait a while before closing the socket, so the client
  585. // will time out waiting for the handshake to complete.
  586. s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
  587. }
  588. };
  589. class TestSSLAsyncCacheServer : public TestSSLServer {
  590. public:
  591. explicit TestSSLAsyncCacheServer(
  592. SSLServerAcceptCallbackBase* acb,
  593. int lookupDelay = 100)
  594. : TestSSLServer(acb) {
  595. SSL_CTX* sslCtx = ctx_->getSSLCtx();
  596. #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
  597. SSL_CTX_sess_set_get_cb(
  598. sslCtx, TestSSLAsyncCacheServer::getSessionCallback);
  599. #endif
  600. SSL_CTX_set_session_cache_mode(
  601. sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
  602. asyncCallbacks_ = 0;
  603. asyncLookups_ = 0;
  604. lookupDelay_ = lookupDelay;
  605. }
  606. uint32_t getAsyncCallbacks() const {
  607. return asyncCallbacks_;
  608. }
  609. uint32_t getAsyncLookups() const {
  610. return asyncLookups_;
  611. }
  612. private:
  613. static uint32_t asyncCallbacks_;
  614. static uint32_t asyncLookups_;
  615. static uint32_t lookupDelay_;
  616. static SSL_SESSION* getSessionCallback(
  617. SSL* ssl,
  618. unsigned char* /* sess_id */,
  619. int /* id_len */,
  620. int* copyflag) {
  621. *copyflag = 0;
  622. asyncCallbacks_++;
  623. (void)ssl;
  624. #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
  625. if (!SSL_want_sess_cache_lookup(ssl)) {
  626. // libssl.so mismatch
  627. std::cerr << "no async support" << std::endl;
  628. return nullptr;
  629. }
  630. AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
  631. assert(sslSocket != nullptr);
  632. // Going to simulate an async cache by just running delaying the miss 100ms
  633. if (asyncCallbacks_ % 2 == 0) {
  634. // This socket is already blocked on lookup, return miss
  635. std::cerr << "returning miss" << std::endl;
  636. } else {
  637. // fresh meat - block it
  638. std::cerr << "async lookup" << std::endl;
  639. sslSocket->getEventBase()->tryRunAfterDelay(
  640. std::bind(&AsyncSSLSocket::restartSSLAccept, sslSocket),
  641. lookupDelay_);
  642. *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
  643. asyncLookups_++;
  644. }
  645. #endif
  646. return nullptr;
  647. }
  648. };
  649. void getfds(int fds[2]);
  650. void getctx(
  651. std::shared_ptr<folly::SSLContext> clientCtx,
  652. std::shared_ptr<folly::SSLContext> serverCtx);
  653. void sslsocketpair(
  654. EventBase* eventBase,
  655. AsyncSSLSocket::UniquePtr* clientSock,
  656. AsyncSSLSocket::UniquePtr* serverSock);
  657. class BlockingWriteClient : private AsyncSSLSocket::HandshakeCB,
  658. private AsyncTransportWrapper::WriteCallback {
  659. public:
  660. explicit BlockingWriteClient(AsyncSSLSocket::UniquePtr socket)
  661. : socket_(std::move(socket)), bufLen_(2500), iovCount_(2000) {
  662. // Fill buf_
  663. buf_ = std::make_unique<uint8_t[]>(bufLen_);
  664. for (uint32_t n = 0; n < sizeof(buf_); ++n) {
  665. buf_[n] = n % 0xff;
  666. }
  667. // Initialize iov_
  668. iov_ = std::make_unique<struct iovec[]>(iovCount_);
  669. for (uint32_t n = 0; n < iovCount_; ++n) {
  670. iov_[n].iov_base = buf_.get() + n;
  671. if (n & 0x1) {
  672. iov_[n].iov_len = n % bufLen_;
  673. } else {
  674. iov_[n].iov_len = bufLen_ - (n % bufLen_);
  675. }
  676. }
  677. socket_->sslConn(this, std::chrono::milliseconds(100));
  678. }
  679. struct iovec* getIovec() const {
  680. return iov_.get();
  681. }
  682. uint32_t getIovecCount() const {
  683. return iovCount_;
  684. }
  685. private:
  686. void handshakeSuc(AsyncSSLSocket*) noexcept override {
  687. socket_->writev(this, iov_.get(), iovCount_);
  688. }
  689. void handshakeErr(
  690. AsyncSSLSocket*,
  691. const AsyncSocketException& ex) noexcept override {
  692. ADD_FAILURE() << "client handshake error: " << ex.what();
  693. }
  694. void writeSuccess() noexcept override {
  695. socket_->close();
  696. }
  697. void writeErr(
  698. size_t bytesWritten,
  699. const AsyncSocketException& ex) noexcept override {
  700. ADD_FAILURE() << "client write error after " << bytesWritten
  701. << " bytes: " << ex.what();
  702. }
  703. AsyncSSLSocket::UniquePtr socket_;
  704. uint32_t bufLen_;
  705. uint32_t iovCount_;
  706. std::unique_ptr<uint8_t[]> buf_;
  707. std::unique_ptr<struct iovec[]> iov_;
  708. };
  709. class BlockingWriteServer : private AsyncSSLSocket::HandshakeCB,
  710. private AsyncTransportWrapper::ReadCallback {
  711. public:
  712. explicit BlockingWriteServer(AsyncSSLSocket::UniquePtr socket)
  713. : socket_(std::move(socket)), bufSize_(2500 * 2000), bytesRead_(0) {
  714. buf_ = std::make_unique<uint8_t[]>(bufSize_);
  715. socket_->sslAccept(this, std::chrono::milliseconds(100));
  716. }
  717. void checkBuffer(struct iovec* iov, uint32_t count) const {
  718. uint32_t idx = 0;
  719. for (uint32_t n = 0; n < count; ++n) {
  720. size_t bytesLeft = bytesRead_ - idx;
  721. int rc = memcmp(
  722. buf_.get() + idx,
  723. iov[n].iov_base,
  724. std::min(iov[n].iov_len, bytesLeft));
  725. if (rc != 0) {
  726. FAIL() << "buffer mismatch at iovec " << n << "/" << count
  727. << ": rc=" << rc;
  728. }
  729. if (iov[n].iov_len > bytesLeft) {
  730. FAIL() << "server did not read enough data: "
  731. << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
  732. << " in iovec " << n << "/" << count;
  733. }
  734. idx += iov[n].iov_len;
  735. }
  736. if (idx != bytesRead_) {
  737. ADD_FAILURE() << "server read extra data: " << bytesRead_
  738. << " bytes read; expected " << idx;
  739. }
  740. }
  741. private:
  742. void handshakeSuc(AsyncSSLSocket*) noexcept override {
  743. // Wait 10ms before reading, so the client's writes will initially block.
  744. socket_->getEventBase()->tryRunAfterDelay(
  745. [this] { socket_->setReadCB(this); }, 10);
  746. }
  747. void handshakeErr(
  748. AsyncSSLSocket*,
  749. const AsyncSocketException& ex) noexcept override {
  750. ADD_FAILURE() << "server handshake error: " << ex.what();
  751. }
  752. void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
  753. *bufReturn = buf_.get() + bytesRead_;
  754. *lenReturn = bufSize_ - bytesRead_;
  755. }
  756. void readDataAvailable(size_t len) noexcept override {
  757. bytesRead_ += len;
  758. socket_->setReadCB(nullptr);
  759. socket_->getEventBase()->tryRunAfterDelay(
  760. [this] { socket_->setReadCB(this); }, 2);
  761. }
  762. void readEOF() noexcept override {
  763. socket_->close();
  764. }
  765. void readErr(const AsyncSocketException& ex) noexcept override {
  766. ADD_FAILURE() << "server read error: " << ex.what();
  767. }
  768. AsyncSSLSocket::UniquePtr socket_;
  769. uint32_t bufSize_;
  770. uint32_t bytesRead_;
  771. std::unique_ptr<uint8_t[]> buf_;
  772. };
  773. class AlpnClient : private AsyncSSLSocket::HandshakeCB,
  774. private AsyncTransportWrapper::WriteCallback {
  775. public:
  776. explicit AlpnClient(AsyncSSLSocket::UniquePtr socket)
  777. : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
  778. socket_->sslConn(this);
  779. }
  780. const unsigned char* nextProto;
  781. unsigned nextProtoLength;
  782. folly::Optional<AsyncSocketException> except;
  783. private:
  784. void handshakeSuc(AsyncSSLSocket*) noexcept override {
  785. socket_->getSelectedNextProtocol(&nextProto, &nextProtoLength);
  786. }
  787. void handshakeErr(
  788. AsyncSSLSocket*,
  789. const AsyncSocketException& ex) noexcept override {
  790. except = ex;
  791. }
  792. void writeSuccess() noexcept override {
  793. socket_->close();
  794. }
  795. void writeErr(
  796. size_t bytesWritten,
  797. const AsyncSocketException& ex) noexcept override {
  798. ADD_FAILURE() << "client write error after " << bytesWritten
  799. << " bytes: " << ex.what();
  800. }
  801. AsyncSSLSocket::UniquePtr socket_;
  802. };
  803. class AlpnServer : private AsyncSSLSocket::HandshakeCB,
  804. private AsyncTransportWrapper::ReadCallback {
  805. public:
  806. explicit AlpnServer(AsyncSSLSocket::UniquePtr socket)
  807. : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
  808. socket_->sslAccept(this);
  809. }
  810. const unsigned char* nextProto;
  811. unsigned nextProtoLength;
  812. folly::Optional<AsyncSocketException> except;
  813. private:
  814. void handshakeSuc(AsyncSSLSocket*) noexcept override {
  815. socket_->getSelectedNextProtocol(&nextProto, &nextProtoLength);
  816. }
  817. void handshakeErr(
  818. AsyncSSLSocket*,
  819. const AsyncSocketException& ex) noexcept override {
  820. except = ex;
  821. }
  822. void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
  823. *lenReturn = 0;
  824. }
  825. void readDataAvailable(size_t /* len */) noexcept override {}
  826. void readEOF() noexcept override {
  827. socket_->close();
  828. }
  829. void readErr(const AsyncSocketException& ex) noexcept override {
  830. ADD_FAILURE() << "server read error: " << ex.what();
  831. }
  832. AsyncSSLSocket::UniquePtr socket_;
  833. };
  834. class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
  835. public AsyncTransportWrapper::ReadCallback {
  836. public:
  837. explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
  838. : socket_(std::move(socket)) {
  839. socket_->sslAccept(this);
  840. }
  841. ~RenegotiatingServer() override {
  842. socket_->setReadCB(nullptr);
  843. }
  844. void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
  845. LOG(INFO) << "Renegotiating server handshake success";
  846. socket_->setReadCB(this);
  847. }
  848. void handshakeErr(
  849. AsyncSSLSocket*,
  850. const AsyncSocketException& ex) noexcept override {
  851. ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
  852. }
  853. void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
  854. *lenReturn = sizeof(buf);
  855. *bufReturn = buf;
  856. }
  857. void readDataAvailable(size_t /* len */) noexcept override {}
  858. void readEOF() noexcept override {}
  859. void readErr(const AsyncSocketException& ex) noexcept override {
  860. LOG(INFO) << "server got read error " << ex.what();
  861. auto exPtr = dynamic_cast<const SSLException*>(&ex);
  862. ASSERT_NE(nullptr, exPtr);
  863. std::string exStr(ex.what());
  864. SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
  865. ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
  866. renegotiationError_ = true;
  867. }
  868. AsyncSSLSocket::UniquePtr socket_;
  869. unsigned char buf[128];
  870. bool renegotiationError_{false};
  871. };
  872. #ifndef OPENSSL_NO_TLSEXT
  873. class SNIClient : private AsyncSSLSocket::HandshakeCB,
  874. private AsyncTransportWrapper::WriteCallback {
  875. public:
  876. explicit SNIClient(AsyncSSLSocket::UniquePtr socket)
  877. : serverNameMatch(false), socket_(std::move(socket)) {
  878. socket_->sslConn(this);
  879. }
  880. bool serverNameMatch;
  881. private:
  882. void handshakeSuc(AsyncSSLSocket*) noexcept override {
  883. serverNameMatch = socket_->isServerNameMatch();
  884. }
  885. void handshakeErr(
  886. AsyncSSLSocket*,
  887. const AsyncSocketException& ex) noexcept override {
  888. ADD_FAILURE() << "client handshake error: " << ex.what();
  889. }
  890. void writeSuccess() noexcept override {
  891. socket_->close();
  892. }
  893. void writeErr(
  894. size_t bytesWritten,
  895. const AsyncSocketException& ex) noexcept override {
  896. ADD_FAILURE() << "client write error after " << bytesWritten
  897. << " bytes: " << ex.what();
  898. }
  899. AsyncSSLSocket::UniquePtr socket_;
  900. };
  901. class SNIServer : private AsyncSSLSocket::HandshakeCB,
  902. private AsyncTransportWrapper::ReadCallback {
  903. public:
  904. explicit SNIServer(
  905. AsyncSSLSocket::UniquePtr socket,
  906. const std::shared_ptr<folly::SSLContext>& ctx,
  907. const std::shared_ptr<folly::SSLContext>& sniCtx,
  908. const std::string& expectedServerName)
  909. : serverNameMatch(false),
  910. socket_(std::move(socket)),
  911. sniCtx_(sniCtx),
  912. expectedServerName_(expectedServerName) {
  913. ctx->setServerNameCallback(
  914. std::bind(&SNIServer::serverNameCallback, this, std::placeholders::_1));
  915. socket_->sslAccept(this);
  916. }
  917. bool serverNameMatch;
  918. private:
  919. void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
  920. void handshakeErr(
  921. AsyncSSLSocket*,
  922. const AsyncSocketException& ex) noexcept override {
  923. ADD_FAILURE() << "server handshake error: " << ex.what();
  924. }
  925. void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
  926. *lenReturn = 0;
  927. }
  928. void readDataAvailable(size_t /* len */) noexcept override {}
  929. void readEOF() noexcept override {
  930. socket_->close();
  931. }
  932. void readErr(const AsyncSocketException& ex) noexcept override {
  933. ADD_FAILURE() << "server read error: " << ex.what();
  934. }
  935. folly::SSLContext::ServerNameCallbackResult serverNameCallback(SSL* ssl) {
  936. const char* sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
  937. if (sniCtx_ && sn && !strcasecmp(expectedServerName_.c_str(), sn)) {
  938. AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
  939. sslSocket->switchServerSSLContext(sniCtx_);
  940. serverNameMatch = true;
  941. return folly::SSLContext::SERVER_NAME_FOUND;
  942. } else {
  943. serverNameMatch = false;
  944. return folly::SSLContext::SERVER_NAME_NOT_FOUND;
  945. }
  946. }
  947. AsyncSSLSocket::UniquePtr socket_;
  948. std::shared_ptr<folly::SSLContext> sniCtx_;
  949. std::string expectedServerName_;
  950. };
  951. #endif
  952. class SSLClient : public AsyncSocket::ConnectCallback,
  953. public AsyncTransportWrapper::WriteCallback,
  954. public AsyncTransportWrapper::ReadCallback {
  955. private:
  956. EventBase* eventBase_;
  957. std::shared_ptr<AsyncSSLSocket> sslSocket_;
  958. SSL_SESSION* session_;
  959. std::shared_ptr<folly::SSLContext> ctx_;
  960. uint32_t requests_;
  961. folly::SocketAddress address_;
  962. uint32_t timeout_;
  963. char buf_[128];
  964. char readbuf_[128];
  965. uint32_t bytesRead_;
  966. uint32_t hit_;
  967. uint32_t miss_;
  968. uint32_t errors_;
  969. uint32_t writeAfterConnectErrors_;
  970. // These settings test that we eventually drain the
  971. // socket, even if the maxReadsPerEvent_ is hit during
  972. // a event loop iteration.
  973. static constexpr size_t kMaxReadsPerEvent = 2;
  974. // 2 event loop iterations
  975. static constexpr size_t kMaxReadBufferSz =
  976. sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
  977. public:
  978. SSLClient(
  979. EventBase* eventBase,
  980. const folly::SocketAddress& address,
  981. uint32_t requests,
  982. uint32_t timeout = 0)
  983. : eventBase_(eventBase),
  984. session_(nullptr),
  985. requests_(requests),
  986. address_(address),
  987. timeout_(timeout),
  988. bytesRead_(0),
  989. hit_(0),
  990. miss_(0),
  991. errors_(0),
  992. writeAfterConnectErrors_(0) {
  993. ctx_.reset(new folly::SSLContext());
  994. ctx_->setOptions(SSL_OP_NO_TICKET);
  995. ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  996. memset(buf_, 'a', sizeof(buf_));
  997. }
  998. ~SSLClient() override {
  999. if (session_) {
  1000. SSL_SESSION_free(session_);
  1001. }
  1002. if (errors_ == 0) {
  1003. EXPECT_EQ(bytesRead_, sizeof(buf_));
  1004. }
  1005. }
  1006. uint32_t getHit() const {
  1007. return hit_;
  1008. }
  1009. uint32_t getMiss() const {
  1010. return miss_;
  1011. }
  1012. uint32_t getErrors() const {
  1013. return errors_;
  1014. }
  1015. uint32_t getWriteAfterConnectErrors() const {
  1016. return writeAfterConnectErrors_;
  1017. }
  1018. void connect(bool writeNow = false) {
  1019. sslSocket_ = AsyncSSLSocket::newSocket(ctx_, eventBase_);
  1020. if (session_ != nullptr) {
  1021. sslSocket_->setSSLSession(session_);
  1022. }
  1023. requests_--;
  1024. sslSocket_->connect(this, address_, timeout_);
  1025. if (sslSocket_ && writeNow) {
  1026. // write some junk, used in an error test
  1027. sslSocket_->write(this, buf_, sizeof(buf_));
  1028. }
  1029. }
  1030. void connectSuccess() noexcept override {
  1031. std::cerr << "client SSL socket connected" << std::endl;
  1032. if (sslSocket_->getSSLSessionReused()) {
  1033. hit_++;
  1034. } else {
  1035. miss_++;
  1036. if (session_ != nullptr) {
  1037. SSL_SESSION_free(session_);
  1038. }
  1039. session_ = sslSocket_->getSSLSession();
  1040. }
  1041. // write()
  1042. sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
  1043. sslSocket_->write(this, buf_, sizeof(buf_));
  1044. sslSocket_->setReadCB(this);
  1045. memset(readbuf_, 'b', sizeof(readbuf_));
  1046. bytesRead_ = 0;
  1047. }
  1048. void connectErr(const AsyncSocketException& ex) noexcept override {
  1049. std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
  1050. errors_++;
  1051. sslSocket_.reset();
  1052. }
  1053. void writeSuccess() noexcept override {
  1054. std::cerr << "client write success" << std::endl;
  1055. }
  1056. void writeErr(
  1057. size_t /* bytesWritten */,
  1058. const AsyncSocketException& ex) noexcept override {
  1059. std::cerr << "client writeError: " << ex.what() << std::endl;
  1060. if (!sslSocket_) {
  1061. writeAfterConnectErrors_++;
  1062. }
  1063. }
  1064. void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
  1065. *bufReturn = readbuf_ + bytesRead_;
  1066. *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
  1067. }
  1068. void readEOF() noexcept override {
  1069. std::cerr << "client readEOF" << std::endl;
  1070. }
  1071. void readErr(const AsyncSocketException& ex) noexcept override {
  1072. std::cerr << "client readError: " << ex.what() << std::endl;
  1073. }
  1074. void readDataAvailable(size_t len) noexcept override {
  1075. std::cerr << "client read data: " << len << std::endl;
  1076. bytesRead_ += len;
  1077. if (bytesRead_ == sizeof(buf_)) {
  1078. EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
  1079. sslSocket_->closeNow();
  1080. sslSocket_.reset();
  1081. if (requests_ != 0) {
  1082. connect();
  1083. }
  1084. }
  1085. }
  1086. };
  1087. class SSLHandshakeBase : public AsyncSSLSocket::HandshakeCB,
  1088. private AsyncTransportWrapper::WriteCallback {
  1089. public:
  1090. explicit SSLHandshakeBase(
  1091. AsyncSSLSocket::UniquePtr socket,
  1092. bool preverifyResult,
  1093. bool verifyResult)
  1094. : handshakeVerify_(false),
  1095. handshakeSuccess_(false),
  1096. handshakeError_(false),
  1097. socket_(std::move(socket)),
  1098. preverifyResult_(preverifyResult),
  1099. verifyResult_(verifyResult) {}
  1100. AsyncSSLSocket::UniquePtr moveSocket() && {
  1101. return std::move(socket_);
  1102. }
  1103. bool handshakeVerify_;
  1104. bool handshakeSuccess_;
  1105. bool handshakeError_;
  1106. std::chrono::nanoseconds handshakeTime;
  1107. protected:
  1108. AsyncSSLSocket::UniquePtr socket_;
  1109. bool preverifyResult_;
  1110. bool verifyResult_;
  1111. // HandshakeCallback
  1112. bool handshakeVer(
  1113. AsyncSSLSocket* /* sock */,
  1114. bool preverifyOk,
  1115. X509_STORE_CTX* /* ctx */) noexcept override {
  1116. handshakeVerify_ = true;
  1117. EXPECT_EQ(preverifyResult_, preverifyOk);
  1118. return verifyResult_;
  1119. }
  1120. void handshakeSuc(AsyncSSLSocket*) noexcept override {
  1121. LOG(INFO) << "Handshake success";
  1122. handshakeSuccess_ = true;
  1123. if (socket_) {
  1124. handshakeTime = socket_->getHandshakeTime();
  1125. }
  1126. }
  1127. void handshakeErr(
  1128. AsyncSSLSocket*,
  1129. const AsyncSocketException& ex) noexcept override {
  1130. LOG(INFO) << "Handshake error " << ex.what();
  1131. handshakeError_ = true;
  1132. if (socket_) {
  1133. handshakeTime = socket_->getHandshakeTime();
  1134. }
  1135. }
  1136. // WriteCallback
  1137. void writeSuccess() noexcept override {
  1138. if (socket_) {
  1139. socket_->close();
  1140. }
  1141. }
  1142. void writeErr(
  1143. size_t bytesWritten,
  1144. const AsyncSocketException& ex) noexcept override {
  1145. ADD_FAILURE() << "client write error after " << bytesWritten
  1146. << " bytes: " << ex.what();
  1147. }
  1148. };
  1149. class SSLHandshakeClient : public SSLHandshakeBase {
  1150. public:
  1151. SSLHandshakeClient(
  1152. AsyncSSLSocket::UniquePtr socket,
  1153. bool preverifyResult,
  1154. bool verifyResult)
  1155. : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
  1156. socket_->sslConn(this, std::chrono::milliseconds::zero());
  1157. }
  1158. };
  1159. class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
  1160. public:
  1161. SSLHandshakeClientNoVerify(
  1162. AsyncSSLSocket::UniquePtr socket,
  1163. bool preverifyResult,
  1164. bool verifyResult)
  1165. : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
  1166. socket_->sslConn(
  1167. this,
  1168. std::chrono::milliseconds::zero(),
  1169. folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
  1170. }
  1171. };
  1172. class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
  1173. public:
  1174. SSLHandshakeClientDoVerify(
  1175. AsyncSSLSocket::UniquePtr socket,
  1176. bool preverifyResult,
  1177. bool verifyResult)
  1178. : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
  1179. socket_->sslConn(
  1180. this,
  1181. std::chrono::milliseconds::zero(),
  1182. folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
  1183. }
  1184. };
  1185. class SSLHandshakeServer : public SSLHandshakeBase {
  1186. public:
  1187. SSLHandshakeServer(
  1188. AsyncSSLSocket::UniquePtr socket,
  1189. bool preverifyResult,
  1190. bool verifyResult)
  1191. : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
  1192. socket_->sslAccept(this, std::chrono::milliseconds::zero());
  1193. }
  1194. };
  1195. class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
  1196. public:
  1197. SSLHandshakeServerParseClientHello(
  1198. AsyncSSLSocket::UniquePtr socket,
  1199. bool preverifyResult,
  1200. bool verifyResult)
  1201. : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
  1202. socket_->enableClientHelloParsing();
  1203. socket_->sslAccept(this, std::chrono::milliseconds::zero());
  1204. }
  1205. std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
  1206. protected:
  1207. void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
  1208. handshakeSuccess_ = true;
  1209. sock->getSSLSharedCiphers(sharedCiphers_);
  1210. sock->getSSLServerCiphers(serverCiphers_);
  1211. sock->getSSLClientCiphers(clientCiphers_);
  1212. chosenCipher_ = sock->getNegotiatedCipherName();
  1213. }
  1214. };
  1215. class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
  1216. public:
  1217. SSLHandshakeServerNoVerify(
  1218. AsyncSSLSocket::UniquePtr socket,
  1219. bool preverifyResult,
  1220. bool verifyResult)
  1221. : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
  1222. socket_->sslAccept(
  1223. this,
  1224. std::chrono::milliseconds::zero(),
  1225. folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
  1226. }
  1227. };
  1228. class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
  1229. public:
  1230. SSLHandshakeServerDoVerify(
  1231. AsyncSSLSocket::UniquePtr socket,
  1232. bool preverifyResult,
  1233. bool verifyResult)
  1234. : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
  1235. socket_->sslAccept(
  1236. this,
  1237. std::chrono::milliseconds::zero(),
  1238. folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
  1239. }
  1240. };
  1241. class EventBaseAborter : public AsyncTimeout {
  1242. public:
  1243. EventBaseAborter(EventBase* eventBase, uint32_t timeoutMS)
  1244. : AsyncTimeout(eventBase, AsyncTimeout::InternalEnum::INTERNAL),
  1245. eventBase_(eventBase) {
  1246. scheduleTimeout(timeoutMS);
  1247. }
  1248. void timeoutExpired() noexcept override {
  1249. FAIL() << "test timed out";
  1250. eventBase_->terminateLoopSoon();
  1251. }
  1252. private:
  1253. EventBase* eventBase_;
  1254. };
  1255. class SSLAcceptEvbRunner : public SSLAcceptRunner {
  1256. public:
  1257. explicit SSLAcceptEvbRunner(EventBase* evb) : evb_(evb) {}
  1258. ~SSLAcceptEvbRunner() override = default;
  1259. void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
  1260. const override {
  1261. evb_->runInLoop([acceptFunc = std::move(acceptFunc),
  1262. finallyFunc = std::move(finallyFunc)]() mutable {
  1263. finallyFunc(acceptFunc());
  1264. });
  1265. }
  1266. protected:
  1267. EventBase* evb_;
  1268. };
  1269. class SSLAcceptErrorRunner : public SSLAcceptEvbRunner {
  1270. public:
  1271. explicit SSLAcceptErrorRunner(EventBase* evb) : SSLAcceptEvbRunner(evb) {}
  1272. ~SSLAcceptErrorRunner() override = default;
  1273. void run(Function<int()> /*acceptFunc*/, Function<void(int)> finallyFunc)
  1274. const override {
  1275. evb_->runInLoop(
  1276. [finallyFunc = std::move(finallyFunc)]() mutable { finallyFunc(-1); });
  1277. }
  1278. };
  1279. class SSLAcceptCloseRunner : public SSLAcceptEvbRunner {
  1280. public:
  1281. explicit SSLAcceptCloseRunner(EventBase* evb, folly::AsyncSSLSocket* sock)
  1282. : SSLAcceptEvbRunner(evb), socket_(sock) {}
  1283. ~SSLAcceptCloseRunner() override = default;
  1284. void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
  1285. const override {
  1286. evb_->runInLoop([acceptFunc = std::move(acceptFunc),
  1287. finallyFunc = std::move(finallyFunc),
  1288. sock = socket_]() mutable {
  1289. auto ret = acceptFunc();
  1290. sock->closeNow();
  1291. finallyFunc(ret);
  1292. });
  1293. }
  1294. private:
  1295. folly::AsyncSSLSocket* socket_;
  1296. };
  1297. class SSLAcceptDestroyRunner : public SSLAcceptEvbRunner {
  1298. public:
  1299. explicit SSLAcceptDestroyRunner(EventBase* evb, SSLHandshakeBase* base)
  1300. : SSLAcceptEvbRunner(evb), sslBase_(base) {}
  1301. ~SSLAcceptDestroyRunner() override = default;
  1302. void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
  1303. const override {
  1304. evb_->runInLoop([acceptFunc = std::move(acceptFunc),
  1305. finallyFunc = std::move(finallyFunc),
  1306. sslBase = sslBase_]() mutable {
  1307. auto ret = acceptFunc();
  1308. std::move(*sslBase).moveSocket();
  1309. finallyFunc(ret);
  1310. });
  1311. }
  1312. private:
  1313. SSLHandshakeBase* sslBase_;
  1314. };
  1315. } // namespace folly