123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422 |
- /*
- * Copyright 2013-present Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include <folly/io/Cursor.h>
- #include <folly/io/async/AsyncSSLSocket.h>
- #include <folly/io/async/AsyncSocket.h>
- #include <folly/io/async/EventBase.h>
- #include <folly/portability/GMock.h>
- #include <folly/portability/GTest.h>
- #include <string>
- #include <vector>
- using std::string;
- using namespace testing;
- namespace folly {
- class MockAsyncSSLSocket : public AsyncSSLSocket {
- public:
- static std::shared_ptr<MockAsyncSSLSocket> newSocket(
- const std::shared_ptr<SSLContext>& ctx,
- EventBase* evb) {
- auto sock = std::shared_ptr<MockAsyncSSLSocket>(
- new MockAsyncSSLSocket(ctx, evb), Destructor());
- sock->ssl_.reset(SSL_new(ctx->getSSLCtx()));
- SSL_set_fd(sock->ssl_.get(), -1);
- return sock;
- }
- // Fake constructor sets the state to established without call to connect
- // or accept
- MockAsyncSSLSocket(const std::shared_ptr<SSLContext>& ctx, EventBase* evb)
- : AsyncSocket(evb), AsyncSSLSocket(ctx, evb) {
- state_ = AsyncSocket::StateEnum::ESTABLISHED;
- sslState_ = AsyncSSLSocket::SSLStateEnum::STATE_ESTABLISHED;
- }
- // mock the calls to SSL_write to see the buffer length and contents
- MOCK_METHOD3(sslWriteImpl, int(SSL* ssl, const void* buf, int n));
- // mock the calls to getRawBytesWritten()
- MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
- // public wrapper for protected interface
- WriteResult testPerformWrite(
- const iovec* vec,
- uint32_t count,
- WriteFlags flags,
- uint32_t* countWritten,
- uint32_t* partialWritten) {
- return performWrite(vec, count, flags, countWritten, partialWritten);
- }
- void checkEor(size_t appEor, size_t rawEor) {
- EXPECT_EQ(appEor, appEorByteNo_);
- EXPECT_EQ(rawEor, minEorRawByteNo_);
- }
- void setAppBytesWritten(size_t n) {
- appBytesWritten_ = n;
- }
- };
- class AsyncSSLSocketWriteTest : public testing::Test {
- public:
- AsyncSSLSocketWriteTest()
- : sslContext_(new SSLContext()),
- sock_(MockAsyncSSLSocket::newSocket(sslContext_, &eventBase_)) {
- for (int i = 0; i < 500; i++) {
- memcpy(source_ + i * 26, "abcdefghijklmnopqrstuvwxyz", 26);
- }
- }
- // Make an iovec containing chunks of the reference text with requested sizes
- // for each chunk
- std::unique_ptr<iovec[]> makeVec(std::vector<uint32_t> sizes) {
- std::unique_ptr<iovec[]> vec(new iovec[sizes.size()]);
- int i = 0;
- int pos = 0;
- for (auto size : sizes) {
- vec[i].iov_base = (void*)(source_ + pos);
- vec[i++].iov_len = size;
- pos += size;
- }
- return vec;
- }
- // Verify that the given buf/pos matches the reference text
- void verifyVec(const void* buf, int n, int pos) {
- ASSERT_EQ(memcmp(source_ + pos, buf, n), 0);
- }
- // Update a vec on partial write
- void consumeVec(iovec* vec, uint32_t countWritten, uint32_t partialWritten) {
- vec[countWritten].iov_base =
- ((char*)vec[countWritten].iov_base) + partialWritten;
- vec[countWritten].iov_len -= partialWritten;
- }
- EventBase eventBase_;
- std::shared_ptr<SSLContext> sslContext_;
- std::shared_ptr<MockAsyncSSLSocket> sock_;
- char source_[26 * 500];
- };
- // The entire vec fits in one packet
- TEST_F(AsyncSSLSocketWriteTest, write_coalescing1) {
- int n = 3;
- auto vec = makeVec({3, 3, 3});
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 9))
- .WillOnce(Invoke([this](SSL*, const void* buf, int m) {
- verifyVec(buf, m, 0);
- return 9;
- }));
- uint32_t countWritten = 0;
- uint32_t partialWritten = 0;
- sock_->testPerformWrite(
- vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
- EXPECT_EQ(countWritten, n);
- EXPECT_EQ(partialWritten, 0);
- }
- // First packet is full, second two go in one packet
- TEST_F(AsyncSSLSocketWriteTest, write_coalescing2) {
- int n = 3;
- auto vec = makeVec({1500, 3, 3});
- int pos = 0;
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
- .WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
- verifyVec(buf, m, pos);
- pos += m;
- return m;
- }));
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
- .WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
- verifyVec(buf, m, pos);
- pos += m;
- return m;
- }));
- uint32_t countWritten = 0;
- uint32_t partialWritten = 0;
- sock_->testPerformWrite(
- vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
- EXPECT_EQ(countWritten, n);
- EXPECT_EQ(partialWritten, 0);
- }
- // Two exactly full packets (coalesce ends midway through second chunk)
- TEST_F(AsyncSSLSocketWriteTest, write_coalescing3) {
- int n = 3;
- auto vec = makeVec({1000, 1000, 1000});
- int pos = 0;
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
- .Times(2)
- .WillRepeatedly(Invoke([this, &pos](SSL*, const void* buf, int m) {
- verifyVec(buf, m, pos);
- pos += m;
- return m;
- }));
- uint32_t countWritten = 0;
- uint32_t partialWritten = 0;
- sock_->testPerformWrite(
- vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
- EXPECT_EQ(countWritten, n);
- EXPECT_EQ(partialWritten, 0);
- }
- // Partial write success midway through a coalesced vec
- TEST_F(AsyncSSLSocketWriteTest, write_coalescing4) {
- int n = 5;
- auto vec = makeVec({300, 300, 300, 300, 300});
- int pos = 0;
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
- .WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
- verifyVec(buf, m, pos);
- pos += 1000;
- return 1000; /* 500 bytes "pending" */
- }));
- uint32_t countWritten = 0;
- uint32_t partialWritten = 0;
- sock_->testPerformWrite(
- vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
- EXPECT_EQ(countWritten, 3);
- EXPECT_EQ(partialWritten, 100);
- consumeVec(vec.get(), countWritten, partialWritten);
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
- .WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
- verifyVec(buf, m, pos);
- pos += m;
- return 500;
- }));
- sock_->testPerformWrite(
- vec.get() + countWritten,
- n - countWritten,
- WriteFlags::NONE,
- &countWritten,
- &partialWritten);
- EXPECT_EQ(countWritten, 2);
- EXPECT_EQ(partialWritten, 0);
- }
- // coalesce ends exactly on a buffer boundary
- TEST_F(AsyncSSLSocketWriteTest, write_coalescing5) {
- int n = 3;
- auto vec = makeVec({1000, 500, 500});
- int pos = 0;
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
- .WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
- verifyVec(buf, m, pos);
- pos += m;
- return m;
- }));
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
- .WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
- verifyVec(buf, m, pos);
- pos += m;
- return m;
- }));
- uint32_t countWritten = 0;
- uint32_t partialWritten = 0;
- sock_->testPerformWrite(
- vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
- EXPECT_EQ(countWritten, 3);
- EXPECT_EQ(partialWritten, 0);
- }
- // partial write midway through first chunk
- TEST_F(AsyncSSLSocketWriteTest, write_coalescing6) {
- int n = 2;
- auto vec = makeVec({1000, 500});
- int pos = 0;
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
- .WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
- verifyVec(buf, m, pos);
- pos += 700;
- return 700;
- }));
- uint32_t countWritten = 0;
- uint32_t partialWritten = 0;
- sock_->testPerformWrite(
- vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
- EXPECT_EQ(countWritten, 0);
- EXPECT_EQ(partialWritten, 700);
- consumeVec(vec.get(), countWritten, partialWritten);
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 800))
- .WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
- verifyVec(buf, m, pos);
- pos += m;
- return m;
- }));
- sock_->testPerformWrite(
- vec.get() + countWritten,
- n - countWritten,
- WriteFlags::NONE,
- &countWritten,
- &partialWritten);
- EXPECT_EQ(countWritten, 2);
- EXPECT_EQ(partialWritten, 0);
- }
- // Repeat coalescing2 with WriteFlags::EOR
- TEST_F(AsyncSSLSocketWriteTest, write_with_eor1) {
- int n = 3;
- auto vec = makeVec({1500, 3, 3});
- int pos = 0;
- const size_t initAppBytesWritten = 500;
- const size_t appEor = initAppBytesWritten + 1506;
- sock_->setAppBytesWritten(initAppBytesWritten);
- EXPECT_FALSE(sock_->isEorTrackingEnabled());
- sock_->setEorTracking(true);
- EXPECT_TRUE(sock_->isEorTrackingEnabled());
- EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
- // rawBytesWritten after writting initAppBytesWritten + 1500
- // + some random SSL overhead
- .WillOnce(Return(3600u))
- // rawBytesWritten after writting last 6 bytes
- // + some random SSL overhead
- .WillOnce(Return(3728u));
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
- .WillOnce(Invoke([=, &pos](SSL*, const void* buf, int m) {
- // the first 1500 does not have the EOR byte
- sock_->checkEor(0, 0);
- verifyVec(buf, m, pos);
- pos += m;
- return m;
- }));
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
- .WillOnce(Invoke([=, &pos](SSL*, const void* buf, int m) {
- sock_->checkEor(appEor, 3600 + m);
- verifyVec(buf, m, pos);
- pos += m;
- return m;
- }));
- uint32_t countWritten = 0;
- uint32_t partialWritten = 0;
- sock_->testPerformWrite(
- vec.get(), n, WriteFlags::EOR, &countWritten, &partialWritten);
- EXPECT_EQ(countWritten, n);
- EXPECT_EQ(partialWritten, 0);
- sock_->checkEor(0, 0);
- }
- // coalescing with left over at the last chunk
- // WriteFlags::EOR turned on
- TEST_F(AsyncSSLSocketWriteTest, write_with_eor2) {
- int n = 3;
- auto vec = makeVec({600, 600, 600});
- int pos = 0;
- const size_t initAppBytesWritten = 500;
- const size_t appEor = initAppBytesWritten + 1800;
- sock_->setAppBytesWritten(initAppBytesWritten);
- sock_->setEorTracking(true);
- EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
- // rawBytesWritten after writting initAppBytesWritten + 1500 bytes
- // + some random SSL overhead
- .WillOnce(Return(3600))
- // rawBytesWritten after writting last 300 bytes
- // + some random SSL overhead
- .WillOnce(Return(4100));
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
- .WillOnce(Invoke([=, &pos](SSL*, const void* buf, int m) {
- // the first 1500 does not have the EOR byte
- sock_->checkEor(0, 0);
- verifyVec(buf, m, pos);
- pos += m;
- return m;
- }));
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 300))
- .WillOnce(Invoke([=, &pos](SSL*, const void* buf, int m) {
- sock_->checkEor(appEor, 3600 + m);
- verifyVec(buf, m, pos);
- pos += m;
- return m;
- }));
- uint32_t countWritten = 0;
- uint32_t partialWritten = 0;
- sock_->testPerformWrite(
- vec.get(), n, WriteFlags::EOR, &countWritten, &partialWritten);
- EXPECT_EQ(countWritten, n);
- EXPECT_EQ(partialWritten, 0);
- sock_->checkEor(0, 0);
- }
- // WriteFlags::EOR set
- // One buf in iovec
- // Partial write at 1000-th byte
- TEST_F(AsyncSSLSocketWriteTest, write_with_eor3) {
- int n = 1;
- auto vec = makeVec({1600});
- int pos = 0;
- static constexpr size_t initAppBytesWritten = 500;
- static constexpr size_t appEor = initAppBytesWritten + 1600;
- sock_->setAppBytesWritten(initAppBytesWritten);
- sock_->setEorTracking(true);
- EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
- // rawBytesWritten after the initAppBytesWritten
- // + some random SSL overhead
- .WillOnce(Return(2000))
- // rawBytesWritten after the initAppBytesWritten + 1000 (with 100
- // overhead)
- // + some random SSL overhead
- .WillOnce(Return(3100));
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1600))
- .WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
- sock_->checkEor(appEor, 2000 + m);
- verifyVec(buf, m, pos);
- pos += 1000;
- return 1000;
- }));
- uint32_t countWritten = 0;
- uint32_t partialWritten = 0;
- sock_->testPerformWrite(
- vec.get(), n, WriteFlags::EOR, &countWritten, &partialWritten);
- EXPECT_EQ(countWritten, 0);
- EXPECT_EQ(partialWritten, 1000);
- sock_->checkEor(appEor, 2000 + 1600);
- consumeVec(vec.get(), countWritten, partialWritten);
- EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
- .WillOnce(Return(3100))
- .WillOnce(Return(3800));
- EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 600))
- .WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
- sock_->checkEor(appEor, 3100 + m);
- verifyVec(buf, m, pos);
- pos += m;
- return m;
- }));
- sock_->testPerformWrite(
- vec.get() + countWritten,
- n - countWritten,
- WriteFlags::EOR,
- &countWritten,
- &partialWritten);
- EXPECT_EQ(countWritten, n);
- EXPECT_EQ(partialWritten, 0);
- sock_->checkEor(0, 0);
- }
- } // namespace folly
|