/* * Copyright 2011-present Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #if FOLLY_OPENSSL_IS_110 #include #endif #ifdef FOLLY_HAVE_MSG_ERRQUEUE #include #endif using std::cerr; using std::endl; using std::list; using std::min; using std::string; using std::vector; using namespace testing; #if defined __linux__ namespace { // to store libc's original setsockopt() typedef int (*setsockopt_ptr)(int, int, int, const void*, socklen_t); setsockopt_ptr real_setsockopt_ = nullptr; // global struct to initialize before main runs. we can init within a test, // or in main, but this method seems to be least intrsive and universal struct GlobalStatic { GlobalStatic() { real_setsockopt_ = (setsockopt_ptr)dlsym(RTLD_NEXT, "setsockopt"); } void reset() noexcept { ttlsDisabledSet.clear(); } // for each fd, tracks whether TTLS is disabled or not std::set ttlsDisabledSet; }; // the constructor will be called before main() which is all we care about GlobalStatic globalStatic; } // namespace // we intercept setsoctopt to test setting NO_TRANSPARENT_TLS opt // this name has to be global int setsockopt( int sockfd, int level, int optname, const void* optval, socklen_t optlen) { if (optname == SO_NO_TRANSPARENT_TLS) { globalStatic.ttlsDisabledSet.insert(sockfd); return 0; } return real_setsockopt_(sockfd, level, optname, optval, optlen); } #endif namespace folly { uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0; uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0; uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0; constexpr size_t SSLClient::kMaxReadBufferSz; constexpr size_t SSLClient::kMaxReadsPerEvent; void getfds(int fds[2]) { if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) { FAIL() << "failed to create socketpair: " << errnoStr(errno); } for (int idx = 0; idx < 2; ++idx) { int flags = fcntl(fds[idx], F_GETFL, 0); if (flags == -1) { FAIL() << "failed to get flags for socket " << idx << ": " << errnoStr(errno); } if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) { FAIL() << "failed to put socket " << idx << " in non-blocking mode: " << errnoStr(errno); } } } void getctx( std::shared_ptr clientCtx, std::shared_ptr serverCtx) { clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->loadCertificate(kTestCert); serverCtx->loadPrivateKey(kTestKey); } void sslsocketpair( EventBase* eventBase, AsyncSSLSocket::UniquePtr* clientSock, AsyncSSLSocket::UniquePtr* serverSock) { auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); int fds[2]; getfds(fds); getctx(clientCtx, serverCtx); clientSock->reset(new AsyncSSLSocket(clientCtx, eventBase, fds[0], false)); serverSock->reset(new AsyncSSLSocket(serverCtx, eventBase, fds[1], true)); // (*clientSock)->setSendTimeout(100); // (*serverSock)->setSendTimeout(100); } // client protocol filters bool clientProtoFilterPickPony( unsigned char** client, unsigned int* client_len, const unsigned char*, unsigned int) { // the protocol string in length prefixed byte string. the // length byte is not included in the length static unsigned char p[7] = {6, 'p', 'o', 'n', 'i', 'e', 's'}; *client = p; *client_len = 7; return true; } bool clientProtoFilterPickNone( unsigned char**, unsigned int*, const unsigned char*, unsigned int) { return false; } std::string getFileAsBuf(const char* fileName) { std::string buffer; folly::readFile(fileName, buffer); return buffer; } /** * Test connecting to, writing to, reading from, and closing the * connection to the SSL server. */ TEST(AsyncSSLSocketTest, ConnectWriteReadClose) { // Start listening on a local port WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); // Set up SSL context. std::shared_ptr sslContext(new SSLContext()); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem"); // sslContext->authenticate(true, false); // connect auto socket = std::make_shared(server.getAddress(), sslContext); socket->open(std::chrono::milliseconds(10000)); // write() uint8_t buf[128]; memset(buf, 'a', sizeof(buf)); socket->write(buf, sizeof(buf)); // read() uint8_t readbuf[128]; uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf)); EXPECT_EQ(bytesRead, 128); EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0); // close() socket->close(); cerr << "ConnectWriteReadClose test completed" << endl; EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000); } /** * Test reading after server close. */ TEST(AsyncSSLSocketTest, ReadAfterClose) { // Start listening on a local port WriteCallbackBase writeCallback; ReadEOFCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback); auto server = std::make_unique(&acceptCallback); // Set up SSL context. auto sslContext = std::make_shared(); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); auto socket = std::make_shared(server->getAddress(), sslContext); socket->open(); // This should trigger an EOF on the client. auto evb = handshakeCallback.getSocket()->getEventBase(); evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); }); std::array readbuf; auto bytesRead = socket->read(readbuf.data(), readbuf.size()); EXPECT_EQ(0, bytesRead); } /** * Test bad renegotiation */ #if !defined(OPENSSL_IS_BORINGSSL) TEST(AsyncSSLSocketTest, Renegotiate) { EventBase eventBase; auto clientCtx = std::make_shared(); auto dfServerCtx = std::make_shared(); std::array fds; getfds(fds.data()); getctx(clientCtx, dfServerCtx); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), true, true); RenegotiatingServer server(std::move(serverSock)); while (!client.handshakeSuccess_ && !client.handshakeError_) { eventBase.loopOnce(); } ASSERT_TRUE(client.handshakeSuccess_); auto sslSock = std::move(client).moveSocket(); sslSock->detachEventBase(); // This is nasty, however we don't want to add support for // renegotiation in AsyncSSLSocket. SSL_renegotiate(const_cast(sslSock->getSSL())); auto socket = std::make_shared(std::move(sslSock)); std::thread t([&]() { eventBase.loopForever(); }); // Trigger the renegotiation. std::array buf; memset(buf.data(), 'a', buf.size()); try { socket->write(buf.data(), buf.size()); } catch (AsyncSocketException& e) { LOG(INFO) << "client got error " << e.what(); } eventBase.terminateLoopSoon(); t.join(); eventBase.loop(); ASSERT_TRUE(server.renegotiationError_); } #endif /** * Negative test for handshakeError(). */ TEST(AsyncSSLSocketTest, HandshakeError) { // Start listening on a local port WriteCallbackBase writeCallback; WriteErrorCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); HandshakeErrorCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); // Set up SSL context. std::shared_ptr sslContext(new SSLContext()); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); // connect auto socket = std::make_shared(server.getAddress(), sslContext); // read() bool ex = false; try { socket->open(); uint8_t readbuf[128]; uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf)); LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing"; } catch (AsyncSocketException&) { ex = true; } EXPECT_TRUE(ex); // close() socket->close(); cerr << "HandshakeError test completed" << endl; } /** * Negative test for readError(). */ TEST(AsyncSSLSocketTest, ReadError) { // Start listening on a local port WriteCallbackBase writeCallback; ReadErrorCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); // Set up SSL context. std::shared_ptr sslContext(new SSLContext()); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); // connect auto socket = std::make_shared(server.getAddress(), sslContext); socket->open(); // write something to trigger ssl handshake uint8_t buf[128]; memset(buf, 'a', sizeof(buf)); socket->write(buf, sizeof(buf)); socket->close(); cerr << "ReadError test completed" << endl; } /** * Negative test for writeError(). */ TEST(AsyncSSLSocketTest, WriteError) { // Start listening on a local port WriteCallbackBase writeCallback; WriteErrorCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); // Set up SSL context. std::shared_ptr sslContext(new SSLContext()); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); // connect auto socket = std::make_shared(server.getAddress(), sslContext); socket->open(); // write something to trigger ssl handshake uint8_t buf[128]; memset(buf, 'a', sizeof(buf)); socket->write(buf, sizeof(buf)); socket->close(); cerr << "WriteError test completed" << endl; } /** * Test a socket with TCP_NODELAY unset. */ TEST(AsyncSSLSocketTest, SocketWithDelay) { // Start listening on a local port WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); // Set up SSL context. std::shared_ptr sslContext(new SSLContext()); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); // connect auto socket = std::make_shared(server.getAddress(), sslContext); socket->open(); // write() uint8_t buf[128]; memset(buf, 'a', sizeof(buf)); socket->write(buf, sizeof(buf)); // read() uint8_t readbuf[128]; uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf)); EXPECT_EQ(bytesRead, 128); EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0); // close() socket->close(); cerr << "SocketWithDelay test completed" << endl; } #if FOLLY_OPENSSL_HAS_ALPN class NextProtocolTest : public Test { // For matching protos public: void SetUp() override { getctx(clientCtx, serverCtx); } void connect(bool unset = false) { getfds(fds); if (unset) { // unsetting NPN for any of [client, server] is enough to make NPN not // work clientCtx->unsetNextProtocols(); } AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); client = std::make_unique(std::move(clientSock)); server = std::make_unique(std::move(serverSock)); eventBase.loop(); } void expectProtocol(const std::string& proto) { expectHandshakeSuccess(); EXPECT_NE(client->nextProtoLength, 0); EXPECT_EQ(client->nextProtoLength, server->nextProtoLength); EXPECT_EQ( memcmp(client->nextProto, server->nextProto, server->nextProtoLength), 0); string selected((const char*)client->nextProto, client->nextProtoLength); EXPECT_EQ(proto, selected); } void expectNoProtocol() { expectHandshakeSuccess(); EXPECT_EQ(client->nextProtoLength, 0); EXPECT_EQ(server->nextProtoLength, 0); EXPECT_EQ(client->nextProto, nullptr); EXPECT_EQ(server->nextProto, nullptr); } void expectHandshakeSuccess() { EXPECT_FALSE(client->except.hasValue()) << "client handshake error: " << client->except->what(); EXPECT_FALSE(server->except.hasValue()) << "server handshake error: " << server->except->what(); } void expectHandshakeError() { EXPECT_TRUE(client->except.hasValue()) << "Expected client handshake error!"; EXPECT_TRUE(server->except.hasValue()) << "Expected server handshake error!"; } EventBase eventBase; std::shared_ptr clientCtx{std::make_shared()}; std::shared_ptr serverCtx{std::make_shared()}; int fds[2]; std::unique_ptr client; std::unique_ptr server; }; TEST_F(NextProtocolTest, AlpnTestOverlap) { clientCtx->setAdvertisedNextProtocols({"blub", "baz"}); serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}); connect(); expectProtocol("baz"); } TEST_F(NextProtocolTest, AlpnTestUnset) { // Identical to above test, except that we want unset NPN before // looping. clientCtx->setAdvertisedNextProtocols({"blub", "baz"}); serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}); connect(true /* unset */); expectNoProtocol(); } TEST_F(NextProtocolTest, AlpnTestNoOverlap) { clientCtx->setAdvertisedNextProtocols({"blub"}); serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}); connect(); expectNoProtocol(); } TEST_F(NextProtocolTest, RandomizedAlpnTest) { // Probability that this test will fail is 2^-64, which could be considered // as negligible. const int kTries = 64; clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}); serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}}); std::set selectedProtocols; for (int i = 0; i < kTries; ++i) { connect(); EXPECT_NE(client->nextProtoLength, 0); EXPECT_EQ(client->nextProtoLength, server->nextProtoLength); EXPECT_EQ( memcmp(client->nextProto, server->nextProto, server->nextProtoLength), 0); string selected((const char*)client->nextProto, client->nextProtoLength); selectedProtocols.insert(selected); expectHandshakeSuccess(); } EXPECT_EQ(selectedProtocols.size(), 2); } #endif #ifndef OPENSSL_NO_TLSEXT /** * 1. Client sends TLSEXT_HOSTNAME in client hello. * 2. Server found a match SSL_CTX and use this SSL_CTX to * continue the SSL handshake. * 3. Server sends back TLSEXT_HOSTNAME in server hello. */ TEST(AsyncSSLSocketTest, SNITestMatch) { EventBase eventBase; std::shared_ptr clientCtx(new SSLContext); std::shared_ptr dfServerCtx(new SSLContext); // Use the same SSLContext to continue the handshake after // tlsext_hostname match. std::shared_ptr hskServerCtx(dfServerCtx); const std::string serverName("xyz.newdev.facebook.com"); int fds[2]; getfds(fds); getctx(clientCtx, dfServerCtx); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SNIClient client(std::move(clientSock)); SNIServer server( std::move(serverSock), dfServerCtx, hskServerCtx, serverName); eventBase.loop(); EXPECT_TRUE(client.serverNameMatch); EXPECT_TRUE(server.serverNameMatch); } /** * 1. Client sends TLSEXT_HOSTNAME in client hello. * 2. Server cannot find a matching SSL_CTX and continue to use * the current SSL_CTX to do the handshake. * 3. Server does not send back TLSEXT_HOSTNAME in server hello. */ TEST(AsyncSSLSocketTest, SNITestNotMatch) { EventBase eventBase; std::shared_ptr clientCtx(new SSLContext); std::shared_ptr dfServerCtx(new SSLContext); // Use the same SSLContext to continue the handshake after // tlsext_hostname match. std::shared_ptr hskServerCtx(dfServerCtx); const std::string clientRequestingServerName("foo.com"); const std::string serverExpectedServerName("xyz.newdev.facebook.com"); int fds[2]; getfds(fds); getctx(clientCtx, dfServerCtx); AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket( clientCtx, &eventBase, fds[0], clientRequestingServerName)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SNIClient client(std::move(clientSock)); SNIServer server( std::move(serverSock), dfServerCtx, hskServerCtx, serverExpectedServerName); eventBase.loop(); EXPECT_TRUE(!client.serverNameMatch); EXPECT_TRUE(!server.serverNameMatch); } /** * 1. Client sends TLSEXT_HOSTNAME in client hello. * 2. We then change the serverName. * 3. We expect that we get 'false' as the result for serNameMatch. */ TEST(AsyncSSLSocketTest, SNITestChangeServerName) { EventBase eventBase; std::shared_ptr clientCtx(new SSLContext); std::shared_ptr dfServerCtx(new SSLContext); // Use the same SSLContext to continue the handshake after // tlsext_hostname match. std::shared_ptr hskServerCtx(dfServerCtx); const std::string serverName("xyz.newdev.facebook.com"); int fds[2]; getfds(fds); getctx(clientCtx, dfServerCtx); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName)); // Change the server name std::string newName("new.com"); clientSock->setServerName(newName); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SNIClient client(std::move(clientSock)); SNIServer server( std::move(serverSock), dfServerCtx, hskServerCtx, serverName); eventBase.loop(); EXPECT_TRUE(!client.serverNameMatch); } /** * 1. Client does not send TLSEXT_HOSTNAME in client hello. * 2. Server does not send back TLSEXT_HOSTNAME in server hello. */ TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) { EventBase eventBase; std::shared_ptr clientCtx(new SSLContext); std::shared_ptr dfServerCtx(new SSLContext); // Use the same SSLContext to continue the handshake after // tlsext_hostname match. std::shared_ptr hskServerCtx(dfServerCtx); const std::string serverExpectedServerName("xyz.newdev.facebook.com"); int fds[2]; getfds(fds); getctx(clientCtx, dfServerCtx); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SNIClient client(std::move(clientSock)); SNIServer server( std::move(serverSock), dfServerCtx, hskServerCtx, serverExpectedServerName); eventBase.loop(); EXPECT_TRUE(!client.serverNameMatch); EXPECT_TRUE(!server.serverNameMatch); } #endif /** * Test SSL client socket */ TEST(AsyncSSLSocketTest, SSLClientTest) { // Start listening on a local port WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); // Set up SSL client EventBase eventBase; auto client = std::make_shared(&eventBase, server.getAddress(), 1); client->connect(); EventBaseAborter eba(&eventBase, 3000); eventBase.loop(); EXPECT_EQ(client->getMiss(), 1); EXPECT_EQ(client->getHit(), 0); cerr << "SSLClientTest test completed" << endl; } /** * Test SSL client socket session re-use */ TEST(AsyncSSLSocketTest, SSLClientTestReuse) { // Start listening on a local port WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); // Set up SSL client EventBase eventBase; auto client = std::make_shared(&eventBase, server.getAddress(), 10); client->connect(); EventBaseAborter eba(&eventBase, 3000); eventBase.loop(); EXPECT_EQ(client->getMiss(), 1); EXPECT_EQ(client->getHit(), 9); cerr << "SSLClientTestReuse test completed" << endl; } /** * Test SSL client socket timeout */ TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) { // Start listening on a local port EmptyReadCallback readCallback; HandshakeCallback handshakeCallback( &readCallback, HandshakeCallback::EXPECT_ERROR); HandshakeTimeoutCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); // Set up SSL client EventBase eventBase; auto client = std::make_shared(&eventBase, server.getAddress(), 1, 10); client->connect(true /* write before connect completes */); EventBaseAborter eba(&eventBase, 3000); eventBase.loop(); usleep(100000); // This is checking that the connectError callback precedes any queued // writeError callbacks. This matches AsyncSocket's behavior EXPECT_EQ(client->getWriteAfterConnectErrors(), 1); EXPECT_EQ(client->getErrors(), 1); EXPECT_EQ(client->getMiss(), 0); EXPECT_EQ(client->getHit(), 0); cerr << "SSLClientTimeoutTest test completed" << endl; } // The next 3 tests need an FB-only extension, and will fail without it #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP /** * Test SSL server async cache */ TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) { // Start listening on a local port WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback); TestSSLAsyncCacheServer server(&acceptCallback); // Set up SSL client EventBase eventBase; auto client = std::make_shared(&eventBase, server.getAddress(), 10, 500); client->connect(); EventBaseAborter eba(&eventBase, 3000); eventBase.loop(); EXPECT_EQ(server.getAsyncCallbacks(), 18); EXPECT_EQ(server.getAsyncLookups(), 9); EXPECT_EQ(client->getMiss(), 10); EXPECT_EQ(client->getHit(), 0); cerr << "SSLServerAsyncCacheTest test completed" << endl; } /** * Test SSL server accept timeout with cache path */ TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) { // Start listening on a local port WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50); TestSSLAsyncCacheServer server(&acceptCallback); // Set up SSL client EventBase eventBase; // only do a TCP connect std::shared_ptr sock = AsyncSocket::newSocket(&eventBase); sock->connect(nullptr, server.getAddress()); EmptyReadCallback clientReadCallback; clientReadCallback.tcpSocket_ = sock; sock->setReadCB(&clientReadCallback); EventBaseAborter eba(&eventBase, 3000); eventBase.loop(); EXPECT_EQ(readCallback.state, STATE_WAITING); cerr << "SSLServerTimeoutTest test completed" << endl; } /** * Test SSL server accept timeout with cache path */ TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) { // Start listening on a local port WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50); TestSSLAsyncCacheServer server(&acceptCallback); // Set up SSL client EventBase eventBase; auto client = std::make_shared(&eventBase, server.getAddress(), 2); client->connect(); EventBaseAborter eba(&eventBase, 3000); eventBase.loop(); EXPECT_EQ(server.getAsyncCallbacks(), 1); EXPECT_EQ(server.getAsyncLookups(), 1); EXPECT_EQ(client->getErrors(), 1); EXPECT_EQ(client->getMiss(), 1); EXPECT_EQ(client->getHit(), 0); cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl; } /** * Test SSL server accept timeout with cache path */ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) { // Start listening on a local port WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback( &readCallback, HandshakeCallback::EXPECT_ERROR); SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback); TestSSLAsyncCacheServer server(&acceptCallback, 500); // Set up SSL client EventBase eventBase; auto client = std::make_shared(&eventBase, server.getAddress(), 2, 100); client->connect(); EventBaseAborter eba(&eventBase, 3000); eventBase.loop(); server.getEventBase().runInEventBaseThread( [&handshakeCallback] { handshakeCallback.closeSocket(); }); // give time for the cache lookup to come back and find it closed handshakeCallback.waitForHandshake(); EXPECT_EQ(server.getAsyncCallbacks(), 1); EXPECT_EQ(server.getAsyncLookups(), 1); EXPECT_EQ(client->getErrors(), 1); EXPECT_EQ(client->getMiss(), 1); EXPECT_EQ(client->getHit(), 0); cerr << "SSLServerCacheCloseTest test completed" << endl; } #endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP /** * Verify Client Ciphers obtained using SSL MSG Callback. */ TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) { EventBase eventBase; auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA"); serverCtx->loadPrivateKey(kTestKey); serverCtx->loadCertificate(kTestCert); serverCtx->loadTrustedCertificates(kTestCA); serverCtx->loadClientCAList(kTestCA); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); clientCtx->ciphers("AES256-SHA:AES128-SHA"); clientCtx->loadPrivateKey(kTestKey); clientCtx->loadCertificate(kTestCert); clientCtx->loadTrustedCertificates(kTestCA); int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), true, true); SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true); eventBase.loop(); #if defined(OPENSSL_IS_BORINGSSL) EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA"); #else EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff"); #endif EXPECT_EQ(server.chosenCipher_, "AES256-SHA"); EXPECT_TRUE(client.handshakeVerify_); EXPECT_TRUE(client.handshakeSuccess_); EXPECT_TRUE(!client.handshakeError_); EXPECT_TRUE(server.handshakeVerify_); EXPECT_TRUE(server.handshakeSuccess_); EXPECT_TRUE(!server.handshakeError_); } /** * Verify that server is able to get client cert by getPeerCert() API. */ TEST(AsyncSSLSocketTest, GetClientCertificate) { EventBase eventBase; auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA"); serverCtx->loadPrivateKey(kTestKey); serverCtx->loadCertificate(kTestCert); serverCtx->loadTrustedCertificates(kClientTestCA); serverCtx->loadClientCAList(kClientTestCA); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); clientCtx->ciphers("AES256-SHA:AES128-SHA"); clientCtx->loadPrivateKey(kClientTestKey); clientCtx->loadCertificate(kClientTestCert); clientCtx->loadTrustedCertificates(kTestCA); std::array fds; getfds(fds.data()); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), true, true); SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true); eventBase.loop(); // Handshake should succeed. EXPECT_TRUE(client.handshakeSuccess_); EXPECT_TRUE(server.handshakeSuccess_); // Reclaim the sockets from SSLHandshakeBase. auto cliSocket = std::move(client).moveSocket(); auto srvSocket = std::move(server).moveSocket(); // Client cert retrieved from server side. folly::ssl::X509UniquePtr serverPeerCert = srvSocket->getPeerCert(); CHECK(serverPeerCert); // Client cert retrieved from client side. const X509* clientSelfCert = cliSocket->getSelfCert(); CHECK(clientSelfCert); // The two certs should be the same. EXPECT_EQ(0, X509_cmp(clientSelfCert, serverPeerCert.get())); } TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) { EventBase eventBase; auto ctx = std::make_shared(); int fds[2]; getfds(fds); int bufLen = 42; uint8_t majorVersion = 18; uint8_t minorVersion = 25; // Create callback buf auto buf = IOBuf::create(bufLen); buf->append(bufLen); folly::io::RWPrivateCursor cursor(buf.get()); cursor.write(SSL3_MT_CLIENT_HELLO); cursor.write(0); cursor.write(38); cursor.write(majorVersion); cursor.write(minorVersion); cursor.skip(32); cursor.write(0); SSL* ssl = ctx->createSSL(); SCOPE_EXIT { SSL_free(ssl); }; AsyncSSLSocket::UniquePtr sock( new AsyncSSLSocket(ctx, &eventBase, fds[0], true)); sock->enableClientHelloParsing(); // Test client hello parsing in one packet AsyncSSLSocket::clientHelloParsingCallback( 0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get()); buf.reset(); auto parsedClientHello = sock->getClientHelloInfo(); EXPECT_TRUE(parsedClientHello != nullptr); EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion); EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion); } TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) { EventBase eventBase; auto ctx = std::make_shared(); int fds[2]; getfds(fds); int bufLen = 42; uint8_t majorVersion = 18; uint8_t minorVersion = 25; // Create callback buf auto buf = IOBuf::create(bufLen); buf->append(bufLen); folly::io::RWPrivateCursor cursor(buf.get()); cursor.write(SSL3_MT_CLIENT_HELLO); cursor.write(0); cursor.write(38); cursor.write(majorVersion); cursor.write(minorVersion); cursor.skip(32); cursor.write(0); SSL* ssl = ctx->createSSL(); SCOPE_EXIT { SSL_free(ssl); }; AsyncSSLSocket::UniquePtr sock( new AsyncSSLSocket(ctx, &eventBase, fds[0], true)); sock->enableClientHelloParsing(); // Test parsing with two packets with first packet size < 3 auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2); AsyncSSLSocket::clientHelloParsingCallback( 0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(), ssl, sock.get()); bufCopy.reset(); bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2); AsyncSSLSocket::clientHelloParsingCallback( 0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(), ssl, sock.get()); bufCopy.reset(); auto parsedClientHello = sock->getClientHelloInfo(); EXPECT_TRUE(parsedClientHello != nullptr); EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion); EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion); } TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) { EventBase eventBase; auto ctx = std::make_shared(); int fds[2]; getfds(fds); int bufLen = 42; uint8_t majorVersion = 18; uint8_t minorVersion = 25; // Create callback buf auto buf = IOBuf::create(bufLen); buf->append(bufLen); folly::io::RWPrivateCursor cursor(buf.get()); cursor.write(SSL3_MT_CLIENT_HELLO); cursor.write(0); cursor.write(38); cursor.write(majorVersion); cursor.write(minorVersion); cursor.skip(32); cursor.write(0); SSL* ssl = ctx->createSSL(); SCOPE_EXIT { SSL_free(ssl); }; AsyncSSLSocket::UniquePtr sock( new AsyncSSLSocket(ctx, &eventBase, fds[0], true)); sock->enableClientHelloParsing(); // Test parsing with multiple small packets for (std::size_t i = 0; i < buf->length(); i += 3) { auto bufCopy = folly::IOBuf::copyBuffer( buf->data() + i, std::min((std::size_t)3, buf->length() - i)); AsyncSSLSocket::clientHelloParsingCallback( 0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(), ssl, sock.get()); bufCopy.reset(); } auto parsedClientHello = sock->getClientHelloInfo(); EXPECT_TRUE(parsedClientHello != nullptr); EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion); EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion); } /** * Verify sucessful behavior of SSL certificate validation. */ TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) { EventBase eventBase; auto clientCtx = std::make_shared(); auto dfServerCtx = std::make_shared(); int fds[2]; getfds(fds); getctx(clientCtx, dfServerCtx); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), true, true); clientCtx->loadTrustedCertificates(kTestCA); SSLHandshakeServer server(std::move(serverSock), true, true); eventBase.loop(); EXPECT_TRUE(client.handshakeVerify_); EXPECT_TRUE(client.handshakeSuccess_); EXPECT_TRUE(!client.handshakeError_); EXPECT_LE(0, client.handshakeTime.count()); EXPECT_TRUE(!server.handshakeVerify_); EXPECT_TRUE(server.handshakeSuccess_); EXPECT_TRUE(!server.handshakeError_); EXPECT_LE(0, server.handshakeTime.count()); } /** * Verify that the client's verification callback is able to fail SSL * connection establishment. */ TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) { EventBase eventBase; auto clientCtx = std::make_shared(); auto dfServerCtx = std::make_shared(); int fds[2]; getfds(fds); getctx(clientCtx, dfServerCtx); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), true, false); clientCtx->loadTrustedCertificates(kTestCA); SSLHandshakeServer server(std::move(serverSock), true, true); eventBase.loop(); EXPECT_TRUE(client.handshakeVerify_); EXPECT_TRUE(!client.handshakeSuccess_); EXPECT_TRUE(client.handshakeError_); EXPECT_LE(0, client.handshakeTime.count()); EXPECT_TRUE(!server.handshakeVerify_); EXPECT_TRUE(!server.handshakeSuccess_); EXPECT_TRUE(server.handshakeError_); EXPECT_LE(0, server.handshakeTime.count()); } /** * Verify that the options in SSLContext can be overridden in * sslConnect/Accept.i.e specifying that no validation should be performed * allows an otherwise-invalid certificate to be accepted and doesn't fire * the validation callback. */ TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) { EventBase eventBase; auto clientCtx = std::make_shared(); auto dfServerCtx = std::make_shared(); int fds[2]; getfds(fds); getctx(clientCtx, dfServerCtx); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SSLHandshakeClientNoVerify client(std::move(clientSock), false, false); clientCtx->loadTrustedCertificates(kTestCA); SSLHandshakeServerNoVerify server(std::move(serverSock), false, false); eventBase.loop(); EXPECT_TRUE(!client.handshakeVerify_); EXPECT_TRUE(client.handshakeSuccess_); EXPECT_TRUE(!client.handshakeError_); EXPECT_LE(0, client.handshakeTime.count()); EXPECT_TRUE(!server.handshakeVerify_); EXPECT_TRUE(server.handshakeSuccess_); EXPECT_TRUE(!server.handshakeError_); EXPECT_LE(0, server.handshakeTime.count()); } /** * Verify that the options in SSLContext can be overridden in * sslConnect/Accept. Enable verification even if context says otherwise. * Test requireClientCert with client cert */ TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) { EventBase eventBase; auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY); serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->loadPrivateKey(kTestKey); serverCtx->loadCertificate(kTestCert); serverCtx->loadTrustedCertificates(kTestCA); serverCtx->loadClientCAList(kTestCA); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY); clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); clientCtx->loadPrivateKey(kTestKey); clientCtx->loadCertificate(kTestCert); clientCtx->loadTrustedCertificates(kTestCA); int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); SSLHandshakeClientDoVerify client(std::move(clientSock), true, true); SSLHandshakeServerDoVerify server(std::move(serverSock), true, true); eventBase.loop(); EXPECT_TRUE(client.handshakeVerify_); EXPECT_TRUE(client.handshakeSuccess_); EXPECT_FALSE(client.handshakeError_); EXPECT_LE(0, client.handshakeTime.count()); EXPECT_TRUE(server.handshakeVerify_); EXPECT_TRUE(server.handshakeSuccess_); EXPECT_FALSE(server.handshakeError_); EXPECT_LE(0, server.handshakeTime.count()); } /** * Verify that the client's verification callback is able to override * the preverification failure and allow a successful connection. */ TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) { EventBase eventBase; auto clientCtx = std::make_shared(); auto dfServerCtx = std::make_shared(); int fds[2]; getfds(fds); getctx(clientCtx, dfServerCtx); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), false, true); SSLHandshakeServer server(std::move(serverSock), true, true); eventBase.loop(); EXPECT_TRUE(client.handshakeVerify_); EXPECT_TRUE(client.handshakeSuccess_); EXPECT_TRUE(!client.handshakeError_); EXPECT_LE(0, client.handshakeTime.count()); EXPECT_TRUE(!server.handshakeVerify_); EXPECT_TRUE(server.handshakeSuccess_); EXPECT_TRUE(!server.handshakeError_); EXPECT_LE(0, server.handshakeTime.count()); } /** * Verify that specifying that no validation should be performed allows an * otherwise-invalid certificate to be accepted and doesn't fire the validation * callback. */ TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) { EventBase eventBase; auto clientCtx = std::make_shared(); auto dfServerCtx = std::make_shared(); int fds[2]; getfds(fds); getctx(clientCtx, dfServerCtx); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY); dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), false, false); SSLHandshakeServer server(std::move(serverSock), false, false); eventBase.loop(); EXPECT_TRUE(!client.handshakeVerify_); EXPECT_TRUE(client.handshakeSuccess_); EXPECT_TRUE(!client.handshakeError_); EXPECT_LE(0, client.handshakeTime.count()); EXPECT_TRUE(!server.handshakeVerify_); EXPECT_TRUE(server.handshakeSuccess_); EXPECT_TRUE(!server.handshakeError_); EXPECT_LE(0, server.handshakeTime.count()); } /** * Test requireClientCert with client cert */ TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) { EventBase eventBase; auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); serverCtx->setVerificationOption( SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT); serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->loadPrivateKey(kTestKey); serverCtx->loadCertificate(kTestCert); serverCtx->loadTrustedCertificates(kTestCA); serverCtx->loadClientCAList(kTestCA); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); clientCtx->loadPrivateKey(kTestKey); clientCtx->loadCertificate(kTestCert); clientCtx->loadTrustedCertificates(kTestCA); int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), true, true); SSLHandshakeServer server(std::move(serverSock), true, true); eventBase.loop(); EXPECT_TRUE(client.handshakeVerify_); EXPECT_TRUE(client.handshakeSuccess_); EXPECT_FALSE(client.handshakeError_); EXPECT_LE(0, client.handshakeTime.count()); EXPECT_TRUE(server.handshakeVerify_); EXPECT_TRUE(server.handshakeSuccess_); EXPECT_FALSE(server.handshakeError_); EXPECT_LE(0, server.handshakeTime.count()); // check certificates auto clientSsl = std::move(client).moveSocket(); auto serverSsl = std::move(server).moveSocket(); auto clientPeer = clientSsl->getPeerCertificate(); auto clientSelf = clientSsl->getSelfCertificate(); auto serverPeer = serverSsl->getPeerCertificate(); auto serverSelf = serverSsl->getSelfCertificate(); EXPECT_NE(clientPeer, nullptr); EXPECT_NE(clientSelf, nullptr); EXPECT_NE(serverPeer, nullptr); EXPECT_NE(serverSelf, nullptr); EXPECT_EQ(clientPeer->getIdentity(), serverSelf->getIdentity()); EXPECT_EQ(clientSelf->getIdentity(), serverPeer->getIdentity()); } /** * Test requireClientCert with no client cert */ TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) { EventBase eventBase; auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); serverCtx->setVerificationOption( SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT); serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->loadPrivateKey(kTestKey); serverCtx->loadCertificate(kTestCert); serverCtx->loadTrustedCertificates(kTestCA); serverCtx->loadClientCAList(kTestCA); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY); clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), false, false); SSLHandshakeServer server(std::move(serverSock), false, false); eventBase.loop(); EXPECT_FALSE(server.handshakeVerify_); EXPECT_FALSE(server.handshakeSuccess_); EXPECT_TRUE(server.handshakeError_); EXPECT_LE(0, client.handshakeTime.count()); EXPECT_LE(0, server.handshakeTime.count()); } /** * Test OpenSSL 1.1.0's async functionality */ #if FOLLY_OPENSSL_IS_110 static void makeNonBlockingPipe(int pipefds[2]) { if (pipe(pipefds) != 0) { throw std::runtime_error("Cannot create pipe"); } if (::fcntl(pipefds[0], F_SETFL, O_NONBLOCK) != 0) { throw std::runtime_error("Cannot set pipe to nonblocking"); } if (::fcntl(pipefds[1], F_SETFL, O_NONBLOCK) != 0) { throw std::runtime_error("Cannot set pipe to nonblocking"); } } // Custom RSA private key encryption method static int kRSAExIndex = -1; static int kRSAEvbExIndex = -1; static int kRSASocketExIndex = -1; static constexpr StringPiece kEngineId = "AsyncSSLSocketTest"; static int customRsaPrivEnc( int flen, const unsigned char* from, unsigned char* to, RSA* rsa, int padding) { LOG(INFO) << "rsa_priv_enc"; EventBase* asyncJobEvb = reinterpret_cast(RSA_get_ex_data(rsa, kRSAEvbExIndex)); CHECK(asyncJobEvb); RSA* actualRSA = reinterpret_cast(RSA_get_ex_data(rsa, kRSAExIndex)); CHECK(actualRSA); AsyncSSLSocket* socket = reinterpret_cast( RSA_get_ex_data(rsa, kRSASocketExIndex)); ASYNC_JOB* job = ASYNC_get_current_job(); if (job == nullptr) { throw std::runtime_error("Expected call in job context"); } ASYNC_WAIT_CTX* waitctx = ASYNC_get_wait_ctx(job); OSSL_ASYNC_FD pipefds[2] = {0, 0}; makeNonBlockingPipe(pipefds); if (!ASYNC_WAIT_CTX_set_wait_fd( waitctx, kEngineId.data(), pipefds[0], nullptr, nullptr)) { throw std::runtime_error("Cannot set wait fd"); } int ret = 0; int* retptr = &ret; auto asyncPipeWriter = folly::AsyncPipeWriter::newWriter(asyncJobEvb, pipefds[1]); asyncJobEvb->runInEventBaseThread([retptr = retptr, flen = flen, from = from, to = to, padding = padding, actualRSA = actualRSA, writer = std::move(asyncPipeWriter), socket = socket]() { LOG(INFO) << "Running job"; if (socket) { LOG(INFO) << "Got a socket passed in, closing it..."; socket->closeNow(); } *retptr = RSA_meth_get_priv_enc(RSA_PKCS1_OpenSSL())( flen, from, to, actualRSA, padding); LOG(INFO) << "Finished job, writing to pipe"; uint8_t byte = *retptr > 0 ? 1 : 0; writer->write(nullptr, &byte, 1); }); LOG(INFO) << "About to pause job"; ASYNC_pause_job(); LOG(INFO) << "Resumed job with ret: " << ret; return ret; } void rsaFree(void*, void* ptr, CRYPTO_EX_DATA*, int, long, void*) { LOG(INFO) << "RSA_free is called with ptr " << std::hex << ptr; if (ptr == nullptr) { LOG(INFO) << "Returning early from rsaFree because ptr is null"; return; } RSA* rsa = (RSA*)ptr; auto meth = RSA_get_method(rsa); if (meth != RSA_get_default_method()) { auto nonconst = const_cast(meth); RSA_meth_free(nonconst); RSA_set_method(rsa, RSA_get_default_method()); } RSA_free(rsa); } struct RSAPointers { RSA* actualrsa{nullptr}; RSA* dummyrsa{nullptr}; RSA_METHOD* meth{nullptr}; }; inline void RSAPointersFree(RSAPointers* p) { if (p->meth && p->dummyrsa && RSA_get_method(p->dummyrsa) == p->meth) { RSA_set_method(p->dummyrsa, RSA_get_default_method()); } if (p->meth) { LOG(INFO) << "Freeing meth"; RSA_meth_free(p->meth); } if (p->actualrsa) { LOG(INFO) << "Freeing actualrsa"; RSA_free(p->actualrsa); } if (p->dummyrsa) { LOG(INFO) << "Freeing dummyrsa"; RSA_free(p->dummyrsa); } delete p; } using RSAPointersDeleter = folly::static_function_deleter; std::unique_ptr setupCustomRSA(const char* certPath, const char* keyPath, EventBase* jobEvb) { auto certPEM = getFileAsBuf(certPath); auto keyPEM = getFileAsBuf(keyPath); ssl::BioUniquePtr certBio( BIO_new_mem_buf((void*)certPEM.data(), certPEM.size())); ssl::BioUniquePtr keyBio( BIO_new_mem_buf((void*)keyPEM.data(), keyPEM.size())); ssl::X509UniquePtr cert( PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr)); ssl::EvpPkeyUniquePtr evpPkey( PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr)); ssl::EvpPkeyUniquePtr publicEvpPkey(X509_get_pubkey(cert.get())); std::unique_ptr ret(new RSAPointers()); RSA* actualrsa = EVP_PKEY_get1_RSA(evpPkey.get()); LOG(INFO) << "actualrsa ptr " << std::hex << (void*)actualrsa; RSA* dummyrsa = EVP_PKEY_get1_RSA(publicEvpPkey.get()); if (dummyrsa == nullptr) { throw std::runtime_error("Couldn't get RSA cert public factors"); } RSA_METHOD* meth = RSA_meth_dup(RSA_get_default_method()); if (meth == nullptr || RSA_meth_set1_name(meth, "Async RSA method") == 0 || RSA_meth_set_priv_enc(meth, customRsaPrivEnc) == 0 || RSA_meth_set_flags(meth, RSA_METHOD_FLAG_NO_CHECK) == 0) { throw std::runtime_error("Cannot create async RSA_METHOD"); } RSA_set_method(dummyrsa, meth); RSA_set_flags(dummyrsa, RSA_FLAG_EXT_PKEY); kRSAExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); kRSAEvbExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); kRSASocketExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); CHECK_NE(kRSAExIndex, -1); CHECK_NE(kRSAEvbExIndex, -1); CHECK_NE(kRSASocketExIndex, -1); RSA_set_ex_data(dummyrsa, kRSAExIndex, actualrsa); RSA_set_ex_data(dummyrsa, kRSAEvbExIndex, jobEvb); ret->actualrsa = actualrsa; ret->dummyrsa = dummyrsa; ret->meth = meth; return ret; } // TODO: disabled with ASAN doesn't play nice with ASYNC for some reason #ifndef FOLLY_SANITIZE_ADDRESS TEST(AsyncSSLSocketTest, OpenSSL110AsyncTest) { ASYNC_init_thread(1, 1); EventBase eventBase; ScopedEventBaseThread jobEvbThread; auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->loadCertificate(kTestCert); serverCtx->loadTrustedCertificates(kTestCA); serverCtx->loadClientCAList(kTestCA); auto rsaPointers = setupCustomRSA(kTestCert, kTestKey, jobEvbThread.getEventBase()); CHECK(rsaPointers->dummyrsa); // up-refs dummyrsa SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa); SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY); clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), false, false); SSLHandshakeServer server(std::move(serverSock), false, false); eventBase.loop(); EXPECT_TRUE(server.handshakeSuccess_); EXPECT_TRUE(client.handshakeSuccess_); ASYNC_cleanup_thread(); } TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestFailure) { ASYNC_init_thread(1, 1); EventBase eventBase; ScopedEventBaseThread jobEvbThread; auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->loadCertificate(kTestCert); serverCtx->loadTrustedCertificates(kTestCA); serverCtx->loadClientCAList(kTestCA); // Set the wrong key for the cert auto rsaPointers = setupCustomRSA(kTestCert, kClientTestKey, jobEvbThread.getEventBase()); CHECK(rsaPointers->dummyrsa); SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa); SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY); clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), false, false); SSLHandshakeServer server(std::move(serverSock), false, false); eventBase.loop(); EXPECT_TRUE(server.handshakeError_); EXPECT_TRUE(client.handshakeError_); ASYNC_cleanup_thread(); } TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestClosedWithCallbackPending) { ASYNC_init_thread(1, 1); EventBase eventBase; ScopedEventBaseThread jobEvbThread; auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->loadCertificate(kTestCert); serverCtx->loadTrustedCertificates(kTestCA); serverCtx->loadClientCAList(kTestCA); auto rsaPointers = setupCustomRSA(kTestCert, kTestKey, jobEvbThread.getEventBase()); CHECK(rsaPointers->dummyrsa); // up-refs dummyrsa SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa); SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY); clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); RSA_set_ex_data(rsaPointers->dummyrsa, kRSASocketExIndex, serverSock.get()); SSLHandshakeClient client(std::move(clientSock), false, false); SSLHandshakeServer server(std::move(serverSock), false, false); eventBase.loop(); EXPECT_TRUE(server.handshakeError_); EXPECT_TRUE(client.handshakeError_); ASYNC_cleanup_thread(); } #endif // FOLLY_SANITIZE_ADDRESS #endif // FOLLY_OPENSSL_IS_110 TEST(AsyncSSLSocketTest, LoadCertFromMemory) { using folly::ssl::OpenSSLUtils; auto cert = getFileAsBuf(kTestCert); auto key = getFileAsBuf(kTestKey); ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem())); BIO_write(certBio.get(), cert.data(), cert.size()); ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem())); BIO_write(keyBio.get(), key.data(), key.size()); // Create SSL structs from buffers to get properties ssl::X509UniquePtr certStruct( PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr)); ssl::EvpPkeyUniquePtr keyStruct( PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr)); certBio = nullptr; keyBio = nullptr; auto origCommonName = OpenSSLUtils::getCommonName(certStruct.get()); auto origKeySize = EVP_PKEY_bits(keyStruct.get()); certStruct = nullptr; keyStruct = nullptr; auto ctx = std::make_shared(); ctx->loadPrivateKeyFromBufferPEM(key); ctx->loadCertificateFromBufferPEM(cert); ctx->loadTrustedCertificates(kTestCA); ssl::SSLUniquePtr ssl(ctx->createSSL()); auto newCert = SSL_get_certificate(ssl.get()); auto newKey = SSL_get_privatekey(ssl.get()); // Get properties from SSL struct auto newCommonName = OpenSSLUtils::getCommonName(newCert); auto newKeySize = EVP_PKEY_bits(newKey); // Check that the key and cert have the expected properties EXPECT_EQ(origCommonName, newCommonName); EXPECT_EQ(origKeySize, newKeySize); } TEST(AsyncSSLSocketTest, MinWriteSizeTest) { EventBase eb; // Set up SSL context. auto sslContext = std::make_shared(); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); // create SSL socket AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb)); EXPECT_EQ(1500, socket->getMinWriteSize()); socket->setMinWriteSize(0); EXPECT_EQ(0, socket->getMinWriteSize()); socket->setMinWriteSize(50000); EXPECT_EQ(50000, socket->getMinWriteSize()); } class ReadCallbackTerminator : public ReadCallback { public: ReadCallbackTerminator(EventBase* base, WriteCallbackBase* wcb) : ReadCallback(wcb), base_(base) {} // Do not write data back, terminate the loop. void readDataAvailable(size_t len) noexcept override { std::cerr << "readDataAvailable, len " << len << std::endl; currentBuffer.length = len; buffers.push_back(currentBuffer); currentBuffer.reset(); state = STATE_SUCCEEDED; socket_->setReadCB(nullptr); base_->terminateLoopSoon(); } private: EventBase* base_; }; /** * Test a full unencrypted codepath */ TEST(AsyncSSLSocketTest, UnencryptedTest) { EventBase base; auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); int fds[2]; getfds(fds); getctx(clientCtx, serverCtx); auto client = AsyncSSLSocket::newSocket(clientCtx, &base, fds[0], false, true); auto server = AsyncSSLSocket::newSocket(serverCtx, &base, fds[1], true, true); ReadCallbackTerminator readCallback(&base, nullptr); server->setReadCB(&readCallback); readCallback.setSocket(server); uint8_t buf[128]; memset(buf, 'a', sizeof(buf)); client->write(nullptr, buf, sizeof(buf)); // Check that bytes are unencrypted char c; EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK)); EXPECT_EQ('a', c); EventBaseAborter eba(&base, 3000); base.loop(); EXPECT_EQ(1, readCallback.buffers.size()); EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState()); server->setReadCB(&readCallback); // Unencrypted server->sslAccept(nullptr); client->sslConn(nullptr); // Do NOT wait for handshake, writing should be queued and happen after client->write(nullptr, buf, sizeof(buf)); // Check that bytes are *not* unencrypted char c2; EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK)); EXPECT_NE('a', c2); base.loop(); EXPECT_EQ(2, readCallback.buffers.size()); EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState()); } TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) { auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); getctx(clientCtx, serverCtx); WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); EventBase evb; std::shared_ptr socket = AsyncSSLSocket::newSocket(clientCtx, &evb, true); socket->connect(nullptr, server.getAddress(), 0); evb.loop(); EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, socket->getSSLState()); socket->sslConn(nullptr); evb.loop(); EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, socket->getSSLState()); // write() std::array buf; memset(buf.data(), 'a', buf.size()); socket->write(nullptr, buf.data(), buf.size()); socket->close(); } /** * Test acceptrunner in various situations */ TEST(AsyncSSLSocketTest, SSLAcceptRunnerBasic) { EventBase eventBase; auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->loadPrivateKey(kTestKey); serverCtx->loadCertificate(kTestCert); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); clientCtx->loadTrustedCertificates(kTestCA); int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); serverCtx->sslAcceptRunner(std::make_unique(&eventBase)); SSLHandshakeClient client(std::move(clientSock), true, true); SSLHandshakeServer server(std::move(serverSock), true, true); eventBase.loop(); EXPECT_TRUE(client.handshakeSuccess_); EXPECT_FALSE(client.handshakeError_); EXPECT_LE(0, client.handshakeTime.count()); EXPECT_TRUE(server.handshakeSuccess_); EXPECT_FALSE(server.handshakeError_); EXPECT_LE(0, server.handshakeTime.count()); } TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptError) { EventBase eventBase; auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->loadPrivateKey(kTestKey); serverCtx->loadCertificate(kTestCert); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); clientCtx->loadTrustedCertificates(kTestCA); int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); serverCtx->sslAcceptRunner( std::make_unique(&eventBase)); SSLHandshakeClient client(std::move(clientSock), true, true); SSLHandshakeServer server(std::move(serverSock), true, true); eventBase.loop(); EXPECT_FALSE(client.handshakeSuccess_); EXPECT_TRUE(client.handshakeError_); EXPECT_FALSE(server.handshakeSuccess_); EXPECT_TRUE(server.handshakeError_); } TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptClose) { EventBase eventBase; auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->loadPrivateKey(kTestKey); serverCtx->loadCertificate(kTestCert); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); clientCtx->loadTrustedCertificates(kTestCA); int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); serverCtx->sslAcceptRunner( std::make_unique(&eventBase, serverSock.get())); SSLHandshakeClient client(std::move(clientSock), true, true); SSLHandshakeServer server(std::move(serverSock), true, true); eventBase.loop(); EXPECT_FALSE(client.handshakeSuccess_); EXPECT_TRUE(client.handshakeError_); EXPECT_FALSE(server.handshakeSuccess_); EXPECT_TRUE(server.handshakeError_); } TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptDestroy) { EventBase eventBase; auto clientCtx = std::make_shared(); auto serverCtx = std::make_shared(); serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->loadPrivateKey(kTestKey); serverCtx->loadCertificate(kTestCert); clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY); clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); clientCtx->loadTrustedCertificates(kTestCA); int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), true, true); SSLHandshakeServer server(std::move(serverSock), true, true); serverCtx->sslAcceptRunner( std::make_unique(&eventBase, &server)); eventBase.loop(); EXPECT_FALSE(client.handshakeSuccess_); EXPECT_TRUE(client.handshakeError_); EXPECT_FALSE(server.handshakeSuccess_); EXPECT_TRUE(server.handshakeError_); } TEST(AsyncSSLSocketTest, ConnResetErrorString) { // Start listening on a local port WriteCallbackBase writeCallback; WriteErrorCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback( &readCallback, HandshakeCallback::EXPECT_ERROR); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); auto socket = std::make_shared(server.getAddress(), nullptr); socket->open(); uint8_t buf[3] = {0x16, 0x03, 0x01}; socket->write(buf, sizeof(buf)); socket->closeWithReset(); handshakeCallback.waitForHandshake(); EXPECT_NE( handshakeCallback.errorString_.find("Network error"), std::string::npos); EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos); } TEST(AsyncSSLSocketTest, ConnEOFErrorString) { // Start listening on a local port WriteCallbackBase writeCallback; WriteErrorCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback( &readCallback, HandshakeCallback::EXPECT_ERROR); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); auto socket = std::make_shared(server.getAddress(), nullptr); socket->open(); uint8_t buf[3] = {0x16, 0x03, 0x01}; socket->write(buf, sizeof(buf)); socket->close(); handshakeCallback.waitForHandshake(); #if FOLLY_OPENSSL_IS_110 EXPECT_NE( handshakeCallback.errorString_.find("Network error"), std::string::npos); #else EXPECT_NE( handshakeCallback.errorString_.find("Connection EOF"), std::string::npos); #endif } TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) { // Start listening on a local port WriteCallbackBase writeCallback; WriteErrorCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback( &readCallback, HandshakeCallback::EXPECT_ERROR); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); auto socket = std::make_shared(server.getAddress(), nullptr); socket->open(); uint8_t buf[256] = {0x16, 0x03}; memset(buf + 2, 'a', sizeof(buf) - 2); socket->write(buf, sizeof(buf)); socket->close(); handshakeCallback.waitForHandshake(); EXPECT_NE( handshakeCallback.errorString_.find("SSL routines"), std::string::npos); #if defined(OPENSSL_IS_BORINGSSL) EXPECT_NE( handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"), std::string::npos); #elif FOLLY_OPENSSL_IS_110 EXPECT_NE( handshakeCallback.errorString_.find("packet length too long"), std::string::npos); #else EXPECT_NE( handshakeCallback.errorString_.find("unknown protocol"), std::string::npos); #endif } TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) { using folly::ssl::OpenSSLUtils; EXPECT_EQ( OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384"); // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), ""); // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), ""); } #if defined __linux__ /** * Ensure TransparentTLS flag is disabled with AsyncSSLSocket */ TEST(AsyncSSLSocketTest, TTLSDisabled) { // clear all setsockopt tracking history globalStatic.reset(); // Start listening on a local port WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback, false); // Set up SSL context. auto sslContext = std::make_shared(); // connect auto socket = std::make_shared(server.getAddress(), sslContext); socket->open(); EXPECT_EQ(1, globalStatic.ttlsDisabledSet.count(socket->getSocketFD())); // write() std::array buf; memset(buf.data(), 'a', buf.size()); socket->write(buf.data(), buf.size()); // close() socket->close(); } #endif #if FOLLY_ALLOW_TFO class MockAsyncTFOSSLSocket : public AsyncSSLSocket { public: using UniquePtr = std::unique_ptr; explicit MockAsyncTFOSSLSocket( std::shared_ptr sslCtx, EventBase* evb) : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {} MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags)); }; #if defined __linux__ /** * Ensure TransparentTLS flag is disabled with AsyncSSLSocket + TFO */ TEST(AsyncSSLSocketTest, TTLSDisabledWithTFO) { // clear all setsockopt tracking history globalStatic.reset(); // Start listening on a local port WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback, true); // Set up SSL context. auto sslContext = std::make_shared(); // connect auto socket = std::make_shared(server.getAddress(), sslContext); socket->enableTFO(); socket->open(); EXPECT_EQ(1, globalStatic.ttlsDisabledSet.count(socket->getSocketFD())); // write() std::array buf; memset(buf.data(), 'a', buf.size()); socket->write(buf.data(), buf.size()); // close() socket->close(); } #endif /** * Test connecting to, writing to, reading from, and closing the * connection to the SSL server with TFO. */ TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) { // Start listening on a local port WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback, true); // Set up SSL context. auto sslContext = std::make_shared(); // connect auto socket = std::make_shared(server.getAddress(), sslContext); socket->enableTFO(); socket->open(); // write() std::array buf; memset(buf.data(), 'a', buf.size()); socket->write(buf.data(), buf.size()); // read() std::array readbuf; uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size()); EXPECT_EQ(bytesRead, 128); EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0); // close() socket->close(); } /** * Test connecting to, writing to, reading from, and closing the * connection to the SSL server with TFO. */ TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) { // Start listening on a local port WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback, false); // Set up SSL context. auto sslContext = std::make_shared(); // connect auto socket = std::make_shared(server.getAddress(), sslContext); socket->enableTFO(); socket->open(); // write() std::array buf; memset(buf.data(), 'a', buf.size()); socket->write(buf.data(), buf.size()); // read() std::array readbuf; uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size()); EXPECT_EQ(bytesRead, 128); EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0); // close() socket->close(); } class ConnCallback : public AsyncSocket::ConnectCallback { public: void connectSuccess() noexcept override { state = State::SUCCESS; } void connectErr(const AsyncSocketException& ex) noexcept override { state = State::ERROR; error = ex.what(); } enum class State { WAITING, SUCCESS, ERROR }; State state{State::WAITING}; std::string error; }; template MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback( EventBase* evb, const SocketAddress& address, Cardinality cardinality) { // Set up SSL context. auto sslContext = std::make_shared(); // connect auto socket = MockAsyncTFOSSLSocket::UniquePtr( new MockAsyncTFOSSLSocket(sslContext, evb)); socket->enableTFO(); EXPECT_CALL(*socket, tfoSendMsg(_, _, _)) .Times(cardinality) .WillOnce(Invoke([&](int fd, struct msghdr*, int) { sockaddr_storage addr; auto len = address.getAddress(&addr); return connect(fd, (const struct sockaddr*)&addr, len); })); return socket; } TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) { // Start listening on a local port WriteCallbackBase writeCallback; ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback, true); EventBase evb; auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1); ConnCallback ccb; socket->connect(&ccb, server.getAddress(), 30); evb.loop(); EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state); evb.runInEventBaseThread([&] { socket->detachEventBase(); }); evb.loop(); BlockingSocket sock(std::move(socket)); // write() std::array buf; memset(buf.data(), 'a', buf.size()); sock.write(buf.data(), buf.size()); // read() std::array readbuf; uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size()); EXPECT_EQ(bytesRead, 128); EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0); // close() sock.close(); } #if !defined(OPENSSL_IS_BORINGSSL) TEST(AsyncSSLSocketTest, ConnectTFOTimeout) { // Start listening on a local port ConnectTimeoutCallback acceptCallback; TestSSLServer server(&acceptCallback, true); // Set up SSL context. auto sslContext = std::make_shared(); // connect auto socket = std::make_shared(server.getAddress(), sslContext); socket->enableTFO(); EXPECT_THROW( socket->open(std::chrono::milliseconds(20)), AsyncSocketException); } #endif #if !defined(OPENSSL_IS_BORINGSSL) TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) { // Start listening on a local port ConnectTimeoutCallback acceptCallback; TestSSLServer server(&acceptCallback, true); EventBase evb; auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1)); ConnCallback ccb; // Set a short timeout socket->connect(&ccb, server.getAddress(), 1); evb.loop(); EXPECT_EQ(ConnCallback::State::ERROR, ccb.state); } #endif TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) { // Start listening on a local port EmptyReadCallback readCallback; HandshakeCallback handshakeCallback( &readCallback, HandshakeCallback::EXPECT_ERROR); HandshakeTimeoutCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback, true); EventBase evb; auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1)); ConnCallback ccb; socket->connect(&ccb, server.getAddress(), 100); evb.loop(); EXPECT_EQ(ConnCallback::State::ERROR, ccb.state); EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out")); } TEST(AsyncSSLSocketTest, HandshakeTFORefused) { // Start listening on a local port EventBase evb; // Hopefully nothing is listening on this address SocketAddress addr("127.0.0.1", 65535); auto socket = setupSocketWithFallback(&evb, addr, AtMost(1)); ConnCallback ccb; socket->connect(&ccb, addr, 100); evb.loop(); EXPECT_EQ(ConnCallback::State::ERROR, ccb.state); EXPECT_THAT(ccb.error, testing::HasSubstr("refused")); } TEST(AsyncSSLSocketTest, TestPreReceivedData) { EventBase eventBase; auto clientCtx = std::make_shared(); auto dfServerCtx = std::make_shared(); std::array fds; getfds(fds.data()); getctx(clientCtx, dfServerCtx); AsyncSSLSocket::UniquePtr clientSockPtr( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSSLSocket::UniquePtr serverSockPtr( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); auto clientSock = clientSockPtr.get(); auto serverSock = serverSockPtr.get(); SSLHandshakeClient client(std::move(clientSockPtr), true, true); // Steal some data from the server. std::array buf; auto bytesReceived = recv(fds[1], buf.data(), buf.size(), 0); checkUnixError(bytesReceived, "recv failed"); serverSock->setPreReceivedData( IOBuf::wrapBuffer(ByteRange(buf.data(), bytesReceived))); SSLHandshakeServer server(std::move(serverSockPtr), true, true); while (!client.handshakeSuccess_ && !client.handshakeError_) { eventBase.loopOnce(); } EXPECT_TRUE(client.handshakeSuccess_); EXPECT_TRUE(server.handshakeSuccess_); EXPECT_EQ( serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten()); } TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) { EventBase eventBase; auto clientCtx = std::make_shared(); auto dfServerCtx = std::make_shared(); std::array fds; getfds(fds.data()); getctx(clientCtx, dfServerCtx); AsyncSSLSocket::UniquePtr clientSockPtr( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); AsyncSocket::UniquePtr serverSockPtr(new AsyncSocket(&eventBase, fds[1])); auto clientSock = clientSockPtr.get(); auto serverSock = serverSockPtr.get(); SSLHandshakeClient client(std::move(clientSockPtr), true, true); // Steal some data from the server. std::array buf; auto bytesReceived = recv(fds[1], buf.data(), buf.size(), 0); checkUnixError(bytesReceived, "recv failed"); serverSock->setPreReceivedData( IOBuf::wrapBuffer(ByteRange(buf.data(), bytesReceived))); AsyncSSLSocket::UniquePtr serverSSLSockPtr( new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true)); auto serverSSLSock = serverSSLSockPtr.get(); SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true); while (!client.handshakeSuccess_ && !client.handshakeError_) { eventBase.loopOnce(); } EXPECT_TRUE(client.handshakeSuccess_); EXPECT_TRUE(server.handshakeSuccess_); EXPECT_EQ( serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten()); } /** * Test overriding the flags passed to "sendmsg()" system call, * and verifying that write requests fail properly. */ TEST(AsyncSSLSocketTest, SendMsgParamsCallback) { // Start listening on a local port SendMsgFlagsCallback msgCallback; ExpectWriteErrorCallback writeCallback(&msgCallback); ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); // Set up SSL context. auto sslContext = std::make_shared(); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); // connect auto socket = std::make_shared(server.getAddress(), sslContext); socket->open(); // Setting flags to "-1" to trigger "Invalid argument" error // on attempt to use this flags in sendmsg() system call. msgCallback.resetFlags(-1); // write() std::vector buf(128, 'a'); ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size()); // close() socket->close(); cerr << "SendMsgParamsCallback test completed" << endl; } #ifdef FOLLY_HAVE_MSG_ERRQUEUE /** * Test connecting to, writing to, reading from, and closing the * connection to the SSL server. */ TEST(AsyncSSLSocketTest, SendMsgDataCallback) { // This test requires Linux kernel v4.6 or later struct utsname s_uname; memset(&s_uname, 0, sizeof(s_uname)); ASSERT_EQ(uname(&s_uname), 0); int major, minor; folly::StringPiece extra; if (folly::split( '.', std::string(s_uname.release) + ".", major, minor, extra)) { if (major < 4 || (major == 4 && minor < 6)) { LOG(INFO) << "Kernel version: 4.6 and newer required for this test (" << "kernel ver. " << s_uname.release << " detected)."; return; } } // Start listening on a local port SendMsgDataCallback msgCallback; WriteCheckTimestampCallback writeCallback(&msgCallback); ReadCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); // Set up SSL context. auto sslContext = std::make_shared(); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); // connect auto socket = std::make_shared(server.getAddress(), sslContext); socket->open(); // Adding MSG_EOR flag to the message flags - it'll trigger // timestamp generation for the last byte of the message. msgCallback.resetFlags(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_EOR); // Init ancillary data buffer to trigger timestamp notification union { uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))]; struct cmsghdr cmsg; } u; u.cmsg.cmsg_level = SOL_SOCKET; u.cmsg.cmsg_type = SO_TIMESTAMPING; u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t)); uint32_t flags = SOF_TIMESTAMPING_TX_SCHED | SOF_TIMESTAMPING_TX_SOFTWARE | SOF_TIMESTAMPING_TX_ACK; memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t)); std::vector ctrl(CMSG_LEN(sizeof(uint32_t))); memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t))); msgCallback.resetData(std::move(ctrl)); // write() std::vector buf(128, 'a'); socket->write(buf.data(), buf.size()); // read() std::vector readbuf(buf.size()); uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size()); EXPECT_EQ(bytesRead, buf.size()); EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin())); writeCallback.checkForTimestampNotifications(); // close() socket->close(); cerr << "SendMsgDataCallback test completed" << endl; } #endif // FOLLY_HAVE_MSG_ERRQUEUE #endif } // namespace folly #ifdef SIGPIPE /////////////////////////////////////////////////////////////////////////// // init_unit_test_suite /////////////////////////////////////////////////////////////////////////// namespace { struct Initializer { Initializer() { signal(SIGPIPE, SIG_IGN); } }; Initializer initializer; } // namespace #endif