AsyncSSLSocketTest2.cpp 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. /*
  2. * Copyright 2012-present Facebook, Inc.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <folly/io/async/test/AsyncSSLSocketTest.h>
  17. #include <folly/futures/Promise.h>
  18. #include <folly/init/Init.h>
  19. #include <folly/io/async/AsyncSSLSocket.h>
  20. #include <folly/io/async/EventBase.h>
  21. #include <folly/io/async/SSLContext.h>
  22. #include <folly/io/async/ScopedEventBaseThread.h>
  23. #include <folly/portability/GTest.h>
  24. #include <folly/portability/PThread.h>
  25. #include <folly/ssl/Init.h>
  26. using std::cerr;
  27. using std::endl;
  28. using std::list;
  29. using std::min;
  30. using std::string;
  31. using std::vector;
  32. namespace folly {
  33. struct EvbAndContext {
  34. EvbAndContext() {
  35. ctx_.reset(new SSLContext());
  36. ctx_->setOptions(SSL_OP_NO_TICKET);
  37. ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  38. }
  39. std::shared_ptr<AsyncSSLSocket> createSocket() {
  40. return AsyncSSLSocket::newSocket(ctx_, getEventBase());
  41. }
  42. EventBase* getEventBase() {
  43. return evb_.getEventBase();
  44. }
  45. void attach(AsyncSSLSocket& socket) {
  46. socket.attachEventBase(getEventBase());
  47. socket.attachSSLContext(ctx_);
  48. }
  49. folly::ScopedEventBaseThread evb_;
  50. std::shared_ptr<SSLContext> ctx_;
  51. };
  52. class AttachDetachClient : public AsyncSocket::ConnectCallback,
  53. public AsyncTransportWrapper::WriteCallback,
  54. public AsyncTransportWrapper::ReadCallback {
  55. private:
  56. // two threads here - we'll create the socket in one, connect
  57. // in the other, and then read/write in the initial one
  58. EvbAndContext t1_;
  59. EvbAndContext t2_;
  60. std::shared_ptr<AsyncSSLSocket> sslSocket_;
  61. folly::SocketAddress address_;
  62. char buf_[128];
  63. char readbuf_[128];
  64. uint32_t bytesRead_;
  65. // promise to fulfill when done
  66. folly::Promise<bool> promise_;
  67. void detach() {
  68. sslSocket_->detachEventBase();
  69. sslSocket_->detachSSLContext();
  70. }
  71. public:
  72. explicit AttachDetachClient(const folly::SocketAddress& address)
  73. : address_(address), bytesRead_(0) {}
  74. Future<bool> getFuture() {
  75. return promise_.getFuture();
  76. }
  77. void connect() {
  78. // create in one and then move to another
  79. auto t1Evb = t1_.getEventBase();
  80. t1Evb->runInEventBaseThread([this] {
  81. sslSocket_ = t1_.createSocket();
  82. // ensure we can detach and reattach the context before connecting
  83. for (int i = 0; i < 1000; ++i) {
  84. sslSocket_->detachSSLContext();
  85. sslSocket_->attachSSLContext(t1_.ctx_);
  86. }
  87. // detach from t1 and connect in t2
  88. detach();
  89. auto t2Evb = t2_.getEventBase();
  90. t2Evb->runInEventBaseThread([this] {
  91. t2_.attach(*sslSocket_);
  92. sslSocket_->connect(this, address_);
  93. });
  94. });
  95. }
  96. void connectSuccess() noexcept override {
  97. auto t2Evb = t2_.getEventBase();
  98. EXPECT_TRUE(t2Evb->isInEventBaseThread());
  99. cerr << "client SSL socket connected" << endl;
  100. for (int i = 0; i < 1000; ++i) {
  101. sslSocket_->detachSSLContext();
  102. sslSocket_->attachSSLContext(t2_.ctx_);
  103. }
  104. // detach from t2 and then read/write in t1
  105. t2Evb->runInEventBaseThread([this] {
  106. detach();
  107. auto t1Evb = t1_.getEventBase();
  108. t1Evb->runInEventBaseThread([this] {
  109. t1_.attach(*sslSocket_);
  110. sslSocket_->write(this, buf_, sizeof(buf_));
  111. sslSocket_->setReadCB(this);
  112. memset(readbuf_, 'b', sizeof(readbuf_));
  113. bytesRead_ = 0;
  114. });
  115. });
  116. }
  117. void connectErr(const AsyncSocketException& ex) noexcept override {
  118. cerr << "AttachDetachClient::connectError: " << ex.what() << endl;
  119. sslSocket_.reset();
  120. }
  121. void writeSuccess() noexcept override {
  122. cerr << "client write success" << endl;
  123. }
  124. void writeErr(
  125. size_t /* bytesWritten */,
  126. const AsyncSocketException& ex) noexcept override {
  127. cerr << "client writeError: " << ex.what() << endl;
  128. }
  129. void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
  130. *bufReturn = readbuf_ + bytesRead_;
  131. *lenReturn = sizeof(readbuf_) - bytesRead_;
  132. }
  133. void readEOF() noexcept override {
  134. cerr << "client readEOF" << endl;
  135. }
  136. void readErr(const AsyncSocketException& ex) noexcept override {
  137. cerr << "client readError: " << ex.what() << endl;
  138. promise_.setException(ex);
  139. }
  140. void readDataAvailable(size_t len) noexcept override {
  141. EXPECT_TRUE(t1_.getEventBase()->isInEventBaseThread());
  142. EXPECT_EQ(sslSocket_->getEventBase(), t1_.getEventBase());
  143. cerr << "client read data: " << len << endl;
  144. bytesRead_ += len;
  145. if (len == sizeof(buf_)) {
  146. EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
  147. sslSocket_->closeNow();
  148. sslSocket_.reset();
  149. promise_.setValue(true);
  150. }
  151. }
  152. };
  153. /**
  154. * Test passing contexts between threads
  155. */
  156. TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
  157. // Start listening on a local port
  158. WriteCallbackBase writeCallback;
  159. ReadCallback readCallback(&writeCallback);
  160. HandshakeCallback handshakeCallback(&readCallback);
  161. SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
  162. TestSSLServer server(&acceptCallback);
  163. std::shared_ptr<AttachDetachClient> client(
  164. new AttachDetachClient(server.getAddress()));
  165. auto f = client->getFuture();
  166. client->connect();
  167. EXPECT_TRUE(std::move(f).within(std::chrono::seconds(3)).get());
  168. }
  169. class ConnectClient : public AsyncSocket::ConnectCallback {
  170. public:
  171. ConnectClient() = default;
  172. Future<bool> getFuture() {
  173. return promise_.getFuture();
  174. }
  175. void connect(const folly::SocketAddress& addr) {
  176. t1_.getEventBase()->runInEventBaseThread([&] {
  177. socket_ = t1_.createSocket();
  178. socket_->connect(this, addr);
  179. });
  180. }
  181. void connectSuccess() noexcept override {
  182. socket_.reset();
  183. promise_.setValue(true);
  184. }
  185. void connectErr(const AsyncSocketException& /* ex */) noexcept override {
  186. socket_.reset();
  187. promise_.setValue(false);
  188. }
  189. void setCtx(std::shared_ptr<SSLContext> ctx) {
  190. t1_.ctx_ = ctx;
  191. }
  192. private:
  193. EvbAndContext t1_;
  194. // promise to fulfill when done with a value of true if connect succeeded
  195. folly::Promise<bool> promise_;
  196. std::shared_ptr<AsyncSSLSocket> socket_;
  197. };
  198. class NoopReadCallback : public ReadCallbackBase {
  199. public:
  200. NoopReadCallback() : ReadCallbackBase(nullptr) {
  201. state = STATE_SUCCEEDED;
  202. }
  203. void getReadBuffer(void** buf, size_t* lenReturn) override {
  204. *buf = &buffer_;
  205. *lenReturn = 1;
  206. }
  207. void readDataAvailable(size_t) noexcept override {}
  208. uint8_t buffer_{0};
  209. };
  210. TEST(AsyncSSLSocketTest2, TestTLS12DefaultClient) {
  211. // Start listening on a local port
  212. NoopReadCallback readCallback;
  213. HandshakeCallback handshakeCallback(&readCallback);
  214. SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
  215. auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
  216. TestSSLServer server(&acceptCallback, ctx);
  217. server.loadTestCerts();
  218. // create a default client
  219. auto c1 = std::make_unique<ConnectClient>();
  220. auto f1 = c1->getFuture();
  221. c1->connect(server.getAddress());
  222. EXPECT_TRUE(std::move(f1).within(std::chrono::seconds(3)).get());
  223. }
  224. TEST(AsyncSSLSocketTest2, TestTLS12BadClient) {
  225. // Start listening on a local port
  226. NoopReadCallback readCallback;
  227. HandshakeCallback handshakeCallback(
  228. &readCallback, HandshakeCallback::EXPECT_ERROR);
  229. SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
  230. auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
  231. TestSSLServer server(&acceptCallback, ctx);
  232. server.loadTestCerts();
  233. // create a client that doesn't speak TLS 1.2
  234. auto c2 = std::make_unique<ConnectClient>();
  235. auto clientCtx = std::make_shared<SSLContext>();
  236. clientCtx->setOptions(SSL_OP_NO_TLSv1_2);
  237. c2->setCtx(clientCtx);
  238. auto f2 = c2->getFuture();
  239. c2->connect(server.getAddress());
  240. EXPECT_FALSE(std::move(f2).within(std::chrono::seconds(3)).get());
  241. }
  242. } // namespace folly
  243. int main(int argc, char* argv[]) {
  244. folly::ssl::init();
  245. #ifdef SIGPIPE
  246. signal(SIGPIPE, SIG_IGN);
  247. #endif
  248. testing::InitGoogleTest(&argc, argv);
  249. folly::init(&argc, &argv);
  250. return RUN_ALL_TESTS();
  251. OPENSSL_cleanup();
  252. }