/* * Copyright 2016-present Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include using namespace std; using namespace testing; using folly::ssl::SSLSession; namespace folly { void getfds(int fds[2]) { if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) { LOG(ERROR) << "failed to create socketpair: " << errnoStr(errno); } for (int idx = 0; idx < 2; ++idx) { int flags = fcntl(fds[idx], F_GETFL, 0); if (flags == -1) { LOG(ERROR) << "failed to get flags for socket " << idx << ": " << errnoStr(errno); } if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) { LOG(ERROR) << "failed to put socket " << idx << " in non-blocking mode: " << errnoStr(errno); } } } void getctx( std::shared_ptr clientCtx, std::shared_ptr serverCtx) { clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->loadCertificate(kTestCert); serverCtx->loadPrivateKey(kTestKey); } class SSLSessionTest : public testing::Test { public: void SetUp() override { clientCtx.reset(new folly::SSLContext()); dfServerCtx.reset(new folly::SSLContext()); hskServerCtx.reset(new folly::SSLContext()); serverName = "xyz.newdev.facebook.com"; getctx(clientCtx, dfServerCtx); } void TearDown() override {} folly::EventBase eventBase; std::shared_ptr clientCtx; std::shared_ptr dfServerCtx; // Use the same SSLContext to continue the handshake after // tlsext_hostname match. std::shared_ptr hskServerCtx; std::string serverName; }; /** * 1. Client sends TLSEXT_HOSTNAME in client hello. * 2. Server found a match SSL_CTX and use this SSL_CTX to * continue the SSL handshake. * 3. Server sends back TLSEXT_HOSTNAME in server hello. */ TEST_F(SSLSessionTest, BasicTest) { std::unique_ptr sess; { int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName)); auto clientPtr = clientSock.get(); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), false, false); SSLHandshakeServerParseClientHello server( std::move(serverSock), false, false); eventBase.loop(); ASSERT_TRUE(client.handshakeSuccess_); sess = std::make_unique(clientPtr->getSSLSession()); ASSERT_NE(sess.get(), nullptr); } { int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName)); auto clientPtr = clientSock.get(); clientSock->setSSLSession(sess->getRawSSLSessionDangerous(), true); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), false, false); SSLHandshakeServerParseClientHello server( std::move(serverSock), false, false); eventBase.loop(); ASSERT_TRUE(client.handshakeSuccess_); ASSERT_TRUE(clientPtr->getSSLSessionReused()); } } TEST_F(SSLSessionTest, SerializeDeserializeTest) { std::string sessiondata; { int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName)); auto clientPtr = clientSock.get(); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), false, false); SSLHandshakeServerParseClientHello server( std::move(serverSock), false, false); eventBase.loop(); ASSERT_TRUE(client.handshakeSuccess_); std::unique_ptr sess = std::make_unique(clientPtr->getSSLSession()); sessiondata = sess->serialize(); ASSERT_TRUE(!sessiondata.empty()); } { int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName)); auto clientPtr = clientSock.get(); std::unique_ptr sess = std::make_unique(sessiondata); ASSERT_NE(sess.get(), nullptr); clientSock->setSSLSession(sess->getRawSSLSessionDangerous(), true); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), false, false); SSLHandshakeServerParseClientHello server( std::move(serverSock), false, false); eventBase.loop(); ASSERT_TRUE(client.handshakeSuccess_); ASSERT_TRUE(clientPtr->getSSLSessionReused()); } } TEST_F(SSLSessionTest, GetSessionID) { int fds[2]; getfds(fds); AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName)); auto clientPtr = clientSock.get(); AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), false, false); SSLHandshakeServerParseClientHello server( std::move(serverSock), false, false); eventBase.loop(); ASSERT_TRUE(client.handshakeSuccess_); std::unique_ptr sess = std::make_unique(clientPtr->getSSLSession()); ASSERT_NE(sess, nullptr); auto sessID = sess->getSessionID(); ASSERT_GE(sessID.length(), 0); } } // namespace folly