AsyncSSLSocket.cpp 63 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043
  1. /*
  2. * Copyright 2014-present Facebook, Inc.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <folly/io/async/AsyncSSLSocket.h>
  17. #include <folly/io/async/EventBase.h>
  18. #include <folly/portability/Sockets.h>
  19. #include <errno.h>
  20. #include <fcntl.h>
  21. #include <sys/types.h>
  22. #include <chrono>
  23. #include <memory>
  24. #include <folly/Format.h>
  25. #include <folly/SocketAddress.h>
  26. #include <folly/SpinLock.h>
  27. #include <folly/io/Cursor.h>
  28. #include <folly/io/IOBuf.h>
  29. #include <folly/lang/Bits.h>
  30. #include <folly/portability/OpenSSL.h>
  31. using folly::SocketAddress;
  32. using folly::SSLContext;
  33. using std::shared_ptr;
  34. using std::string;
  35. using folly::Endian;
  36. using folly::IOBuf;
  37. using folly::SpinLock;
  38. using folly::SpinLockGuard;
  39. using folly::io::Cursor;
  40. using std::bind;
  41. using std::unique_ptr;
  42. namespace {
  43. using folly::AsyncSocket;
  44. using folly::AsyncSocketException;
  45. using folly::AsyncSSLSocket;
  46. using folly::Optional;
  47. using folly::SSLContext;
  48. // For OpenSSL portability API
  49. using namespace folly::ssl;
  50. using folly::ssl::OpenSSLUtils;
  51. // We have one single dummy SSL context so that we can implement attach
  52. // and detach methods in a thread safe fashion without modifying opnessl.
  53. static SSLContext* dummyCtx = nullptr;
  54. static SpinLock dummyCtxLock;
  55. // If given min write size is less than this, buffer will be allocated on
  56. // stack, otherwise it is allocated on heap
  57. const size_t MAX_STACK_BUF_SIZE = 2048;
  58. // This converts "illegal" shutdowns into ZERO_RETURN
  59. inline bool zero_return(int error, int rc) {
  60. return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
  61. }
  62. class AsyncSSLCertificate : public folly::AsyncTransportCertificate {
  63. public:
  64. // assumed to be non null
  65. explicit AsyncSSLCertificate(folly::ssl::X509UniquePtr x509)
  66. : x509_(std::move(x509)) {}
  67. folly::ssl::X509UniquePtr getX509() const override {
  68. X509_up_ref(x509_.get());
  69. return folly::ssl::X509UniquePtr(x509_.get());
  70. }
  71. std::string getIdentity() const override {
  72. return OpenSSLUtils::getCommonName(x509_.get());
  73. }
  74. private:
  75. folly::ssl::X509UniquePtr x509_;
  76. };
  77. class AsyncSSLSocketConnector : public AsyncSocket::ConnectCallback,
  78. public AsyncSSLSocket::HandshakeCB {
  79. private:
  80. AsyncSSLSocket* sslSocket_;
  81. AsyncSSLSocket::ConnectCallback* callback_;
  82. std::chrono::milliseconds timeout_;
  83. std::chrono::steady_clock::time_point startTime_;
  84. protected:
  85. ~AsyncSSLSocketConnector() override {}
  86. public:
  87. AsyncSSLSocketConnector(
  88. AsyncSSLSocket* sslSocket,
  89. AsyncSocket::ConnectCallback* callback,
  90. std::chrono::milliseconds timeout)
  91. : sslSocket_(sslSocket),
  92. callback_(callback),
  93. timeout_(timeout),
  94. startTime_(std::chrono::steady_clock::now()) {}
  95. void connectSuccess() noexcept override {
  96. VLOG(7) << "client socket connected";
  97. std::chrono::milliseconds timeoutLeft{0};
  98. if (timeout_ > std::chrono::milliseconds::zero()) {
  99. auto curTime = std::chrono::steady_clock::now();
  100. timeoutLeft = std::chrono::duration_cast<std::chrono::milliseconds>(
  101. timeout_ - (curTime - startTime_));
  102. if (timeoutLeft <= std::chrono::milliseconds::zero()) {
  103. AsyncSocketException ex(
  104. AsyncSocketException::TIMED_OUT,
  105. folly::sformat(
  106. "SSL connect timed out after {}ms", timeout_.count()));
  107. fail(ex);
  108. delete this;
  109. return;
  110. }
  111. }
  112. sslSocket_->sslConn(this, timeoutLeft);
  113. }
  114. void connectErr(const AsyncSocketException& ex) noexcept override {
  115. VLOG(1) << "TCP connect failed: " << ex.what();
  116. fail(ex);
  117. delete this;
  118. }
  119. void handshakeSuc(AsyncSSLSocket* /* sock */) noexcept override {
  120. VLOG(7) << "client handshake success";
  121. if (callback_) {
  122. callback_->connectSuccess();
  123. }
  124. delete this;
  125. }
  126. void handshakeErr(
  127. AsyncSSLSocket* /* socket */,
  128. const AsyncSocketException& ex) noexcept override {
  129. VLOG(1) << "client handshakeErr: " << ex.what();
  130. fail(ex);
  131. delete this;
  132. }
  133. void fail(const AsyncSocketException& ex) {
  134. // fail is a noop if called twice
  135. if (callback_) {
  136. AsyncSSLSocket::ConnectCallback* cb = callback_;
  137. callback_ = nullptr;
  138. cb->connectErr(ex);
  139. sslSocket_->closeNow();
  140. // closeNow can call handshakeErr if it hasn't been called already.
  141. // So this may have been deleted, no member variable access beyond this
  142. // point
  143. // Note that closeNow may invoke writeError callbacks if the socket had
  144. // write data pending connection completion.
  145. }
  146. }
  147. };
  148. void setup_SSL_CTX(SSL_CTX* ctx) {
  149. #ifdef SSL_MODE_RELEASE_BUFFERS
  150. SSL_CTX_set_mode(
  151. ctx,
  152. SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE |
  153. SSL_MODE_RELEASE_BUFFERS);
  154. #else
  155. SSL_CTX_set_mode(
  156. ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
  157. #endif
  158. // SSL_CTX_set_mode is a Macro
  159. #ifdef SSL_MODE_WRITE_IOVEC
  160. SSL_CTX_set_mode(ctx, SSL_CTX_get_mode(ctx) | SSL_MODE_WRITE_IOVEC);
  161. #endif
  162. }
  163. // Note: This is a Leaky Meyer's Singleton. The reason we can't use a non-leaky
  164. // thing is because we will be setting this BIO_METHOD* inside BIOs owned by
  165. // various SSL objects which may get callbacks even during teardown. We may
  166. // eventually try to fix this
  167. static BIO_METHOD* getSSLBioMethod() {
  168. static auto const instance = OpenSSLUtils::newSocketBioMethod().release();
  169. return instance;
  170. }
  171. void* initsslBioMethod() {
  172. auto sslBioMethod = getSSLBioMethod();
  173. // override the bwrite method for MSG_EOR support
  174. OpenSSLUtils::setCustomBioWriteMethod(sslBioMethod, AsyncSSLSocket::bioWrite);
  175. OpenSSLUtils::setCustomBioReadMethod(sslBioMethod, AsyncSSLSocket::bioRead);
  176. // Note that the sslBioMethod.type and sslBioMethod.name are not
  177. // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
  178. // then have specific handlings. The sslWriteBioWrite should be compatible
  179. // with the one in openssl.
  180. // Return something here to enable AsyncSSLSocket to call this method using
  181. // a function-scoped static.
  182. return nullptr;
  183. }
  184. } // namespace
  185. namespace folly {
  186. /**
  187. * Create a client AsyncSSLSocket
  188. */
  189. AsyncSSLSocket::AsyncSSLSocket(
  190. const shared_ptr<SSLContext>& ctx,
  191. EventBase* evb,
  192. bool deferSecurityNegotiation)
  193. : AsyncSocket(evb),
  194. ctx_(ctx),
  195. handshakeTimeout_(this, evb),
  196. connectionTimeout_(this, evb) {
  197. init();
  198. if (deferSecurityNegotiation) {
  199. sslState_ = STATE_UNENCRYPTED;
  200. }
  201. }
  202. /**
  203. * Create a server/client AsyncSSLSocket
  204. */
  205. AsyncSSLSocket::AsyncSSLSocket(
  206. const shared_ptr<SSLContext>& ctx,
  207. EventBase* evb,
  208. int fd,
  209. bool server,
  210. bool deferSecurityNegotiation)
  211. : AsyncSocket(evb, fd),
  212. server_(server),
  213. ctx_(ctx),
  214. handshakeTimeout_(this, evb),
  215. connectionTimeout_(this, evb) {
  216. noTransparentTls_ = true;
  217. init();
  218. if (server) {
  219. SSL_CTX_set_info_callback(
  220. ctx_->getSSLCtx(), AsyncSSLSocket::sslInfoCallback);
  221. }
  222. if (deferSecurityNegotiation) {
  223. sslState_ = STATE_UNENCRYPTED;
  224. }
  225. }
  226. AsyncSSLSocket::AsyncSSLSocket(
  227. const shared_ptr<SSLContext>& ctx,
  228. AsyncSocket::UniquePtr oldAsyncSocket,
  229. bool server,
  230. bool deferSecurityNegotiation)
  231. : AsyncSocket(std::move(oldAsyncSocket)),
  232. server_(server),
  233. ctx_(ctx),
  234. handshakeTimeout_(this, AsyncSocket::getEventBase()),
  235. connectionTimeout_(this, AsyncSocket::getEventBase()) {
  236. noTransparentTls_ = true;
  237. init();
  238. if (server) {
  239. SSL_CTX_set_info_callback(
  240. ctx_->getSSLCtx(), AsyncSSLSocket::sslInfoCallback);
  241. }
  242. if (deferSecurityNegotiation) {
  243. sslState_ = STATE_UNENCRYPTED;
  244. }
  245. }
  246. #if FOLLY_OPENSSL_HAS_SNI
  247. /**
  248. * Create a client AsyncSSLSocket and allow tlsext_hostname
  249. * to be sent in Client Hello.
  250. */
  251. AsyncSSLSocket::AsyncSSLSocket(
  252. const shared_ptr<SSLContext>& ctx,
  253. EventBase* evb,
  254. const std::string& serverName,
  255. bool deferSecurityNegotiation)
  256. : AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
  257. tlsextHostname_ = serverName;
  258. }
  259. /**
  260. * Create a client AsyncSSLSocket from an already connected fd
  261. * and allow tlsext_hostname to be sent in Client Hello.
  262. */
  263. AsyncSSLSocket::AsyncSSLSocket(
  264. const shared_ptr<SSLContext>& ctx,
  265. EventBase* evb,
  266. int fd,
  267. const std::string& serverName,
  268. bool deferSecurityNegotiation)
  269. : AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
  270. tlsextHostname_ = serverName;
  271. }
  272. #endif // FOLLY_OPENSSL_HAS_SNI
  273. AsyncSSLSocket::~AsyncSSLSocket() {
  274. VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
  275. << ", evb=" << eventBase_ << ", fd=" << fd_
  276. << ", state=" << int(state_) << ", sslState=" << sslState_
  277. << ", events=" << eventFlags_ << ")";
  278. }
  279. void AsyncSSLSocket::init() {
  280. // Do this here to ensure we initialize this once before any use of
  281. // AsyncSSLSocket instances and not as part of library load.
  282. static const auto sslBioMethodInitializer = initsslBioMethod();
  283. (void)sslBioMethodInitializer;
  284. setup_SSL_CTX(ctx_->getSSLCtx());
  285. }
  286. void AsyncSSLSocket::closeNow() {
  287. // Close the SSL connection.
  288. if (ssl_ != nullptr && fd_ != -1) {
  289. int rc = SSL_shutdown(ssl_.get());
  290. if (rc == 0) {
  291. rc = SSL_shutdown(ssl_.get());
  292. }
  293. if (rc < 0) {
  294. ERR_clear_error();
  295. }
  296. }
  297. if (sslSession_ != nullptr) {
  298. SSL_SESSION_free(sslSession_);
  299. sslSession_ = nullptr;
  300. }
  301. sslState_ = STATE_CLOSED;
  302. if (handshakeTimeout_.isScheduled()) {
  303. handshakeTimeout_.cancelTimeout();
  304. }
  305. DestructorGuard dg(this);
  306. invokeHandshakeErr(AsyncSocketException(
  307. AsyncSocketException::END_OF_FILE, "SSL connection closed locally"));
  308. // Close the socket.
  309. AsyncSocket::closeNow();
  310. }
  311. void AsyncSSLSocket::shutdownWrite() {
  312. // SSL sockets do not support half-shutdown, so just perform a full shutdown.
  313. //
  314. // (Performing a full shutdown here is more desirable than doing nothing at
  315. // all. The purpose of shutdownWrite() is normally to notify the other end
  316. // of the connection that no more data will be sent. If we do nothing, the
  317. // other end will never know that no more data is coming, and this may result
  318. // in protocol deadlock.)
  319. close();
  320. }
  321. void AsyncSSLSocket::shutdownWriteNow() {
  322. closeNow();
  323. }
  324. bool AsyncSSLSocket::good() const {
  325. return (
  326. AsyncSocket::good() &&
  327. (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
  328. sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED ||
  329. sslState_ == STATE_UNINIT));
  330. }
  331. // The TAsyncTransport definition of 'good' states that the transport is
  332. // ready to perform reads and writes, so sslState_ == UNINIT must report !good.
  333. // connecting can be true when the sslState_ == UNINIT because the AsyncSocket
  334. // is connected but we haven't initiated the call to SSL_connect.
  335. bool AsyncSSLSocket::connecting() const {
  336. return (
  337. !server_ &&
  338. (AsyncSocket::connecting() ||
  339. (AsyncSocket::good() &&
  340. (sslState_ == STATE_UNINIT || sslState_ == STATE_CONNECTING))));
  341. }
  342. std::string AsyncSSLSocket::getApplicationProtocol() const noexcept {
  343. const unsigned char* protoName = nullptr;
  344. unsigned protoLength;
  345. if (getSelectedNextProtocolNoThrow(&protoName, &protoLength)) {
  346. return std::string(reinterpret_cast<const char*>(protoName), protoLength);
  347. }
  348. return "";
  349. }
  350. void AsyncSSLSocket::setEorTracking(bool track) {
  351. if (isEorTrackingEnabled() != track) {
  352. AsyncSocket::setEorTracking(track);
  353. appEorByteNo_ = 0;
  354. minEorRawByteNo_ = 0;
  355. }
  356. }
  357. size_t AsyncSSLSocket::getRawBytesWritten() const {
  358. // The bio(s) in the write path are in a chain
  359. // each bio flushes to the next and finally written into the socket
  360. // to get the rawBytesWritten on the socket,
  361. // get the write bytes of the last bio
  362. BIO* b;
  363. if (!ssl_ || !(b = SSL_get_wbio(ssl_.get()))) {
  364. return 0;
  365. }
  366. BIO* next = BIO_next(b);
  367. while (next != nullptr) {
  368. b = next;
  369. next = BIO_next(b);
  370. }
  371. return BIO_number_written(b);
  372. }
  373. size_t AsyncSSLSocket::getRawBytesReceived() const {
  374. BIO* b;
  375. if (!ssl_ || !(b = SSL_get_rbio(ssl_.get()))) {
  376. return 0;
  377. }
  378. return BIO_number_read(b);
  379. }
  380. void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
  381. LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_
  382. << ", state=" << int(state_) << ", sslState=" << sslState_ << ", "
  383. << "events=" << eventFlags_ << ", server=" << short(server_)
  384. << "): "
  385. << "sslAccept/Connect() called in invalid "
  386. << "state, handshake callback " << handshakeCallback_
  387. << ", new callback " << callback;
  388. assert(!handshakeTimeout_.isScheduled());
  389. sslState_ = STATE_ERROR;
  390. AsyncSocketException ex(
  391. AsyncSocketException::INVALID_STATE,
  392. "sslAccept() called with socket in invalid state");
  393. handshakeEndTime_ = std::chrono::steady_clock::now();
  394. if (callback) {
  395. callback->handshakeErr(this, ex);
  396. }
  397. failHandshake(__func__, ex);
  398. }
  399. void AsyncSSLSocket::sslAccept(
  400. HandshakeCB* callback,
  401. std::chrono::milliseconds timeout,
  402. const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
  403. DestructorGuard dg(this);
  404. eventBase_->dcheckIsInEventBaseThread();
  405. verifyPeer_ = verifyPeer;
  406. // Make sure we're in the uninitialized state
  407. if (!server_ ||
  408. (sslState_ != STATE_UNINIT && sslState_ != STATE_UNENCRYPTED) ||
  409. handshakeCallback_ != nullptr) {
  410. return invalidState(callback);
  411. }
  412. // Cache local and remote socket addresses to keep them available
  413. // after socket file descriptor is closed.
  414. if (cacheAddrOnFailure_) {
  415. cacheAddresses();
  416. }
  417. handshakeStartTime_ = std::chrono::steady_clock::now();
  418. // Make end time at least >= start time.
  419. handshakeEndTime_ = handshakeStartTime_;
  420. sslState_ = STATE_ACCEPTING;
  421. handshakeCallback_ = callback;
  422. if (timeout > std::chrono::milliseconds::zero()) {
  423. handshakeTimeout_.scheduleTimeout(timeout);
  424. }
  425. /* register for a read operation (waiting for CLIENT HELLO) */
  426. updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
  427. checkForImmediateRead();
  428. }
  429. void AsyncSSLSocket::attachSSLContext(const std::shared_ptr<SSLContext>& ctx) {
  430. // Check to ensure we are in client mode. Changing a server's ssl
  431. // context doesn't make sense since clients of that server would likely
  432. // become confused when the server's context changes.
  433. DCHECK(!server_);
  434. DCHECK(!ctx_);
  435. DCHECK(ctx);
  436. DCHECK(ctx->getSSLCtx());
  437. ctx_ = ctx;
  438. // It's possible this could be attached before ssl_ is set up
  439. if (!ssl_) {
  440. return;
  441. }
  442. // In order to call attachSSLContext, detachSSLContext must have been
  443. // previously called.
  444. // We need to update the initial_ctx if necessary
  445. // The 'initial_ctx' inside an SSL* points to the context that it was created
  446. // with, which is also where session callbacks and servername callbacks
  447. // happen.
  448. // When we switch to a different SSL_CTX, we want to update the initial_ctx as
  449. // well so that any callbacks don't go to a different object
  450. // NOTE: this will only work if we have access to ssl_ internals, so it may
  451. // not work on
  452. // OpenSSL version >= 1.1.0
  453. auto sslCtx = ctx->getSSLCtx();
  454. OpenSSLUtils::setSSLInitialCtx(ssl_.get(), sslCtx);
  455. // Detach sets the socket's context to the dummy context. Thus we must acquire
  456. // this lock.
  457. SpinLockGuard guard(dummyCtxLock);
  458. SSL_set_SSL_CTX(ssl_.get(), sslCtx);
  459. }
  460. void AsyncSSLSocket::detachSSLContext() {
  461. DCHECK(ctx_);
  462. ctx_.reset();
  463. // It's possible for this to be called before ssl_ has been
  464. // set up
  465. if (!ssl_) {
  466. return;
  467. }
  468. // The 'initial_ctx' inside an SSL* points to the context that it was created
  469. // with, which is also where session callbacks and servername callbacks
  470. // happen.
  471. // Detach the initial_ctx as well. It will be reattached in attachSSLContext
  472. // it is used for session info.
  473. // NOTE: this will only work if we have access to ssl_ internals, so it may
  474. // not work on
  475. // OpenSSL version >= 1.1.0
  476. SSL_CTX* initialCtx = OpenSSLUtils::getSSLInitialCtx(ssl_.get());
  477. if (initialCtx) {
  478. SSL_CTX_free(initialCtx);
  479. OpenSSLUtils::setSSLInitialCtx(ssl_.get(), nullptr);
  480. }
  481. SpinLockGuard guard(dummyCtxLock);
  482. if (nullptr == dummyCtx) {
  483. // We need to lazily initialize the dummy context so we don't
  484. // accidentally override any programmatic settings to openssl
  485. dummyCtx = new SSLContext;
  486. }
  487. // We must remove this socket's references to its context right now
  488. // since this socket could get passed to any thread. If the context has
  489. // had its locking disabled, just doing a set in attachSSLContext()
  490. // would not be thread safe.
  491. SSL_set_SSL_CTX(ssl_.get(), dummyCtx->getSSLCtx());
  492. }
  493. #if FOLLY_OPENSSL_HAS_SNI
  494. void AsyncSSLSocket::switchServerSSLContext(
  495. const std::shared_ptr<SSLContext>& handshakeCtx) {
  496. CHECK(server_);
  497. if (sslState_ != STATE_ACCEPTING) {
  498. // We log it here and allow the switch.
  499. // It should not affect our re-negotiation support (which
  500. // is not supported now).
  501. VLOG(6) << "fd=" << getFd()
  502. << " renegotation detected when switching SSL_CTX";
  503. }
  504. setup_SSL_CTX(handshakeCtx->getSSLCtx());
  505. SSL_CTX_set_info_callback(
  506. handshakeCtx->getSSLCtx(), AsyncSSLSocket::sslInfoCallback);
  507. handshakeCtx_ = handshakeCtx;
  508. SSL_set_SSL_CTX(ssl_.get(), handshakeCtx->getSSLCtx());
  509. }
  510. bool AsyncSSLSocket::isServerNameMatch() const {
  511. CHECK(!server_);
  512. if (!ssl_) {
  513. return false;
  514. }
  515. SSL_SESSION* ss = SSL_get_session(ssl_.get());
  516. if (!ss) {
  517. return false;
  518. }
  519. auto tlsextHostname = SSL_SESSION_get0_hostname(ss);
  520. return (tlsextHostname && !tlsextHostname_.compare(tlsextHostname));
  521. }
  522. void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
  523. tlsextHostname_ = std::move(serverName);
  524. }
  525. #endif // FOLLY_OPENSSL_HAS_SNI
  526. void AsyncSSLSocket::timeoutExpired(
  527. std::chrono::milliseconds timeout) noexcept {
  528. if (state_ == StateEnum::ESTABLISHED &&
  529. (sslState_ == STATE_CACHE_LOOKUP || sslState_ == STATE_ASYNC_PENDING)) {
  530. sslState_ = STATE_ERROR;
  531. // We are expecting a callback in restartSSLAccept. The cache lookup
  532. // and rsa-call necessarily have pointers to this ssl socket, so delay
  533. // the cleanup until he calls us back.
  534. } else if (state_ == StateEnum::CONNECTING) {
  535. assert(sslState_ == STATE_CONNECTING);
  536. DestructorGuard dg(this);
  537. AsyncSocketException ex(
  538. AsyncSocketException::TIMED_OUT,
  539. "Fallback connect timed out during TFO");
  540. failHandshake(__func__, ex);
  541. } else {
  542. assert(
  543. state_ == StateEnum::ESTABLISHED &&
  544. (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
  545. DestructorGuard dg(this);
  546. AsyncSocketException ex(
  547. AsyncSocketException::TIMED_OUT,
  548. folly::sformat(
  549. "SSL {} timed out after {}ms",
  550. (sslState_ == STATE_CONNECTING) ? "connect" : "accept",
  551. timeout.count()));
  552. failHandshake(__func__, ex);
  553. }
  554. }
  555. int AsyncSSLSocket::getSSLExDataIndex() {
  556. static auto index = SSL_get_ex_new_index(
  557. 0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
  558. return index;
  559. }
  560. AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL* ssl) {
  561. return static_cast<AsyncSSLSocket*>(
  562. SSL_get_ex_data(ssl, getSSLExDataIndex()));
  563. }
  564. void AsyncSSLSocket::failHandshake(
  565. const char* /* fn */,
  566. const AsyncSocketException& ex) {
  567. startFail();
  568. if (handshakeTimeout_.isScheduled()) {
  569. handshakeTimeout_.cancelTimeout();
  570. }
  571. invokeHandshakeErr(ex);
  572. finishFail();
  573. }
  574. void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
  575. handshakeEndTime_ = std::chrono::steady_clock::now();
  576. if (handshakeCallback_ != nullptr) {
  577. HandshakeCB* callback = handshakeCallback_;
  578. handshakeCallback_ = nullptr;
  579. callback->handshakeErr(this, ex);
  580. }
  581. }
  582. void AsyncSSLSocket::invokeHandshakeCB() {
  583. handshakeEndTime_ = std::chrono::steady_clock::now();
  584. if (handshakeTimeout_.isScheduled()) {
  585. handshakeTimeout_.cancelTimeout();
  586. }
  587. if (handshakeCallback_) {
  588. HandshakeCB* callback = handshakeCallback_;
  589. handshakeCallback_ = nullptr;
  590. callback->handshakeSuc(this);
  591. }
  592. }
  593. void AsyncSSLSocket::connect(
  594. ConnectCallback* callback,
  595. const folly::SocketAddress& address,
  596. int timeout,
  597. const OptionMap& options,
  598. const folly::SocketAddress& bindAddr) noexcept {
  599. auto timeoutChrono = std::chrono::milliseconds(timeout);
  600. connect(callback, address, timeoutChrono, timeoutChrono, options, bindAddr);
  601. }
  602. void AsyncSSLSocket::connect(
  603. ConnectCallback* callback,
  604. const folly::SocketAddress& address,
  605. std::chrono::milliseconds connectTimeout,
  606. std::chrono::milliseconds totalConnectTimeout,
  607. const OptionMap& options,
  608. const folly::SocketAddress& bindAddr) noexcept {
  609. assert(!server_);
  610. assert(state_ == StateEnum::UNINIT);
  611. assert(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED);
  612. noTransparentTls_ = true;
  613. totalConnectTimeout_ = totalConnectTimeout;
  614. if (sslState_ != STATE_UNENCRYPTED) {
  615. callback = new AsyncSSLSocketConnector(this, callback, totalConnectTimeout);
  616. }
  617. AsyncSocket::connect(
  618. callback, address, int(connectTimeout.count()), options, bindAddr);
  619. }
  620. bool AsyncSSLSocket::needsPeerVerification() const {
  621. if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
  622. return ctx_->needsPeerVerification();
  623. }
  624. return (
  625. verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
  626. verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
  627. }
  628. void AsyncSSLSocket::applyVerificationOptions(const ssl::SSLUniquePtr& ssl) {
  629. // apply the settings specified in verifyPeer_
  630. if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
  631. if (ctx_->needsPeerVerification()) {
  632. SSL_set_verify(
  633. ssl.get(),
  634. ctx_->getVerificationMode(),
  635. AsyncSSLSocket::sslVerifyCallback);
  636. }
  637. } else {
  638. if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
  639. verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT) {
  640. SSL_set_verify(
  641. ssl.get(),
  642. SSLContext::getVerificationMode(verifyPeer_),
  643. AsyncSSLSocket::sslVerifyCallback);
  644. }
  645. }
  646. }
  647. bool AsyncSSLSocket::setupSSLBio() {
  648. auto sslBio = BIO_new(getSSLBioMethod());
  649. if (!sslBio) {
  650. return false;
  651. }
  652. OpenSSLUtils::setBioAppData(sslBio, this);
  653. OpenSSLUtils::setBioFd(sslBio, fd_, BIO_NOCLOSE);
  654. SSL_set_bio(ssl_.get(), sslBio, sslBio);
  655. return true;
  656. }
  657. void AsyncSSLSocket::sslConn(
  658. HandshakeCB* callback,
  659. std::chrono::milliseconds timeout,
  660. const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
  661. DestructorGuard dg(this);
  662. eventBase_->dcheckIsInEventBaseThread();
  663. // Cache local and remote socket addresses to keep them available
  664. // after socket file descriptor is closed.
  665. if (cacheAddrOnFailure_) {
  666. cacheAddresses();
  667. }
  668. verifyPeer_ = verifyPeer;
  669. // Make sure we're in the uninitialized state
  670. if (server_ ||
  671. (sslState_ != STATE_UNINIT && sslState_ != STATE_UNENCRYPTED) ||
  672. handshakeCallback_ != nullptr) {
  673. return invalidState(callback);
  674. }
  675. sslState_ = STATE_CONNECTING;
  676. handshakeCallback_ = callback;
  677. try {
  678. ssl_.reset(ctx_->createSSL());
  679. } catch (std::exception& e) {
  680. sslState_ = STATE_ERROR;
  681. AsyncSocketException ex(
  682. AsyncSocketException::INTERNAL_ERROR,
  683. "error calling SSLContext::createSSL()");
  684. LOG(ERROR) << "AsyncSSLSocket::sslConn(this=" << this << ", fd=" << fd_
  685. << "): " << e.what();
  686. return failHandshake(__func__, ex);
  687. }
  688. if (!setupSSLBio()) {
  689. sslState_ = STATE_ERROR;
  690. AsyncSocketException ex(
  691. AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio");
  692. return failHandshake(__func__, ex);
  693. }
  694. applyVerificationOptions(ssl_);
  695. if (sslSession_ != nullptr) {
  696. sessionResumptionAttempted_ = true;
  697. SSL_set_session(ssl_.get(), sslSession_);
  698. SSL_SESSION_free(sslSession_);
  699. sslSession_ = nullptr;
  700. }
  701. #if FOLLY_OPENSSL_HAS_SNI
  702. if (tlsextHostname_.size()) {
  703. SSL_set_tlsext_host_name(ssl_.get(), tlsextHostname_.c_str());
  704. }
  705. #endif
  706. SSL_set_ex_data(ssl_.get(), getSSLExDataIndex(), this);
  707. handshakeConnectTimeout_ = timeout;
  708. startSSLConnect();
  709. }
  710. // This could be called multiple times, during normal ssl connections
  711. // and after TFO fallback.
  712. void AsyncSSLSocket::startSSLConnect() {
  713. handshakeStartTime_ = std::chrono::steady_clock::now();
  714. // Make end time at least >= start time.
  715. handshakeEndTime_ = handshakeStartTime_;
  716. if (handshakeConnectTimeout_ > std::chrono::milliseconds::zero()) {
  717. handshakeTimeout_.scheduleTimeout(handshakeConnectTimeout_);
  718. }
  719. handleConnect();
  720. }
  721. SSL_SESSION* AsyncSSLSocket::getSSLSession() {
  722. if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
  723. return SSL_get1_session(ssl_.get());
  724. }
  725. return sslSession_;
  726. }
  727. const SSL* AsyncSSLSocket::getSSL() const {
  728. return ssl_.get();
  729. }
  730. void AsyncSSLSocket::setSSLSession(SSL_SESSION* session, bool takeOwnership) {
  731. if (sslSession_) {
  732. SSL_SESSION_free(sslSession_);
  733. }
  734. sslSession_ = session;
  735. if (!takeOwnership && session != nullptr) {
  736. // Increment the reference count
  737. // This API exists in BoringSSL and OpenSSL 1.1.0
  738. SSL_SESSION_up_ref(session);
  739. }
  740. }
  741. void AsyncSSLSocket::getSelectedNextProtocol(
  742. const unsigned char** protoName,
  743. unsigned* protoLen) const {
  744. if (!getSelectedNextProtocolNoThrow(protoName, protoLen)) {
  745. throw AsyncSocketException(
  746. AsyncSocketException::NOT_SUPPORTED, "ALPN not supported");
  747. }
  748. }
  749. bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
  750. const unsigned char** protoName,
  751. unsigned* protoLen) const {
  752. *protoName = nullptr;
  753. *protoLen = 0;
  754. #if FOLLY_OPENSSL_HAS_ALPN
  755. SSL_get0_alpn_selected(ssl_.get(), protoName, protoLen);
  756. return true;
  757. #else
  758. return false;
  759. #endif
  760. }
  761. bool AsyncSSLSocket::getSSLSessionReused() const {
  762. if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
  763. return SSL_session_reused(ssl_.get());
  764. }
  765. return false;
  766. }
  767. const char* AsyncSSLSocket::getNegotiatedCipherName() const {
  768. return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_.get()) : nullptr;
  769. }
  770. /* static */
  771. const char* AsyncSSLSocket::getSSLServerNameFromSSL(SSL* ssl) {
  772. if (ssl == nullptr) {
  773. return nullptr;
  774. }
  775. #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
  776. return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
  777. #else
  778. return nullptr;
  779. #endif
  780. }
  781. const char* AsyncSSLSocket::getSSLServerName() const {
  782. #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
  783. return getSSLServerNameFromSSL(ssl_.get());
  784. #else
  785. throw AsyncSocketException(
  786. AsyncSocketException::NOT_SUPPORTED, "SNI not supported");
  787. #endif
  788. }
  789. const char* AsyncSSLSocket::getSSLServerNameNoThrow() const {
  790. return getSSLServerNameFromSSL(ssl_.get());
  791. }
  792. int AsyncSSLSocket::getSSLVersion() const {
  793. return (ssl_ != nullptr) ? SSL_version(ssl_.get()) : 0;
  794. }
  795. const char* AsyncSSLSocket::getSSLCertSigAlgName() const {
  796. X509* cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_.get()) : nullptr;
  797. if (cert) {
  798. int nid = X509_get_signature_nid(cert);
  799. return OBJ_nid2ln(nid);
  800. }
  801. return nullptr;
  802. }
  803. int AsyncSSLSocket::getSSLCertSize() const {
  804. int certSize = 0;
  805. X509* cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_.get()) : nullptr;
  806. if (cert) {
  807. EVP_PKEY* key = X509_get_pubkey(cert);
  808. certSize = EVP_PKEY_bits(key);
  809. EVP_PKEY_free(key);
  810. }
  811. return certSize;
  812. }
  813. const AsyncTransportCertificate* AsyncSSLSocket::getPeerCertificate() const {
  814. if (peerCertData_) {
  815. return peerCertData_.get();
  816. }
  817. if (ssl_ != nullptr) {
  818. auto peerX509 = SSL_get_peer_certificate(ssl_.get());
  819. if (peerX509) {
  820. // already up ref'd
  821. folly::ssl::X509UniquePtr peer(peerX509);
  822. peerCertData_ = std::make_unique<AsyncSSLCertificate>(std::move(peer));
  823. }
  824. }
  825. return peerCertData_.get();
  826. }
  827. const AsyncTransportCertificate* AsyncSSLSocket::getSelfCertificate() const {
  828. if (selfCertData_) {
  829. return selfCertData_.get();
  830. }
  831. if (ssl_ != nullptr) {
  832. auto selfX509 = SSL_get_certificate(ssl_.get());
  833. if (selfX509) {
  834. // need to upref
  835. X509_up_ref(selfX509);
  836. folly::ssl::X509UniquePtr peer(selfX509);
  837. selfCertData_ = std::make_unique<AsyncSSLCertificate>(std::move(peer));
  838. }
  839. }
  840. return selfCertData_.get();
  841. }
  842. // TODO: deprecate/remove in favor of getSelfCertificate.
  843. const X509* AsyncSSLSocket::getSelfCert() const {
  844. return (ssl_ != nullptr) ? SSL_get_certificate(ssl_.get()) : nullptr;
  845. }
  846. bool AsyncSSLSocket::willBlock(
  847. int ret,
  848. int* sslErrorOut,
  849. unsigned long* errErrorOut) noexcept {
  850. *errErrorOut = 0;
  851. int error = *sslErrorOut = SSL_get_error(ssl_.get(), ret);
  852. if (error == SSL_ERROR_WANT_READ) {
  853. // Register for read event if not already.
  854. updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
  855. return true;
  856. } else if (error == SSL_ERROR_WANT_WRITE) {
  857. VLOG(3) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
  858. << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
  859. << "SSL_ERROR_WANT_WRITE";
  860. // Register for write event if not already.
  861. updateEventRegistration(EventHandler::WRITE, EventHandler::READ);
  862. return true;
  863. #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
  864. } else if (error == SSL_ERROR_WANT_SESS_CACHE_LOOKUP) {
  865. // We will block but we can't register our own socket. The callback that
  866. // triggered this code will re-call handleAccept at the appropriate time.
  867. // We can only get here if the linked libssl.so has support for this feature
  868. // as well, otherwise SSL_get_error cannot return our error code.
  869. sslState_ = STATE_CACHE_LOOKUP;
  870. // Unregister for all events while blocked here
  871. updateEventRegistration(
  872. EventHandler::NONE, EventHandler::READ | EventHandler::WRITE);
  873. // The timeout (if set) keeps running here
  874. return true;
  875. #endif
  876. } else if ((false
  877. #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
  878. || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
  879. #endif
  880. #ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
  881. || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
  882. #endif
  883. #ifdef SSL_ERROR_WANT_ASYNC // OpenSSL 1.1.0 Async API
  884. || error == SSL_ERROR_WANT_ASYNC
  885. #endif
  886. )) {
  887. // Our custom openssl function has kicked off an async request to do
  888. // rsa/ecdsa private key operation. When that call returns, a callback will
  889. // be invoked that will re-call handleAccept.
  890. sslState_ = STATE_ASYNC_PENDING;
  891. // Unregister for all events while blocked here
  892. updateEventRegistration(
  893. EventHandler::NONE, EventHandler::READ | EventHandler::WRITE);
  894. #ifdef SSL_ERROR_WANT_ASYNC
  895. if (error == SSL_ERROR_WANT_ASYNC) {
  896. size_t numfds;
  897. if (SSL_get_all_async_fds(ssl_.get(), NULL, &numfds) <= 0) {
  898. VLOG(4) << "SSL_ERROR_WANT_ASYNC but no async FDs set!";
  899. return false;
  900. }
  901. if (numfds != 1) {
  902. VLOG(4) << "SSL_ERROR_WANT_ASYNC expected exactly 1 async fd, got "
  903. << numfds;
  904. return false;
  905. }
  906. OSSL_ASYNC_FD ofd; // This should just be an int in POSIX
  907. if (SSL_get_all_async_fds(ssl_.get(), &ofd, &numfds) <= 0) {
  908. VLOG(4) << "SSL_ERROR_WANT_ASYNC cant get async fd";
  909. return false;
  910. }
  911. auto asyncPipeReader = AsyncPipeReader::newReader(eventBase_, ofd);
  912. auto asyncPipeReaderPtr = asyncPipeReader.get();
  913. if (!asyncOperationFinishCallback_) {
  914. asyncOperationFinishCallback_.reset(
  915. new DefaultOpenSSLAsyncFinishCallback(
  916. std::move(asyncPipeReader), this, DestructorGuard(this)));
  917. }
  918. asyncPipeReaderPtr->setReadCB(asyncOperationFinishCallback_.get());
  919. }
  920. #endif
  921. // The timeout (if set) keeps running here
  922. return true;
  923. } else {
  924. unsigned long lastError = *errErrorOut = ERR_get_error();
  925. VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
  926. << "state=" << state_ << ", "
  927. << "sslState=" << sslState_ << ", "
  928. << "events=" << std::hex << eventFlags_ << "): "
  929. << "SSL error: " << error << ", "
  930. << "errno: " << errno << ", "
  931. << "ret: " << ret << ", "
  932. << "read: " << BIO_number_read(SSL_get_rbio(ssl_.get())) << ", "
  933. << "written: " << BIO_number_written(SSL_get_wbio(ssl_.get()))
  934. << ", "
  935. << "func: " << ERR_func_error_string(lastError) << ", "
  936. << "reason: " << ERR_reason_error_string(lastError);
  937. return false;
  938. }
  939. }
  940. void AsyncSSLSocket::checkForImmediateRead() noexcept {
  941. // openssl may have buffered data that it read from the socket already.
  942. // In this case we have to process it immediately, rather than waiting for
  943. // the socket to become readable again.
  944. if (ssl_ != nullptr && SSL_pending(ssl_.get()) > 0) {
  945. AsyncSocket::handleRead();
  946. } else {
  947. AsyncSocket::checkForImmediateRead();
  948. }
  949. }
  950. void AsyncSSLSocket::restartSSLAccept() {
  951. VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this
  952. << ", fd=" << fd_ << ", state=" << int(state_) << ", "
  953. << "sslState=" << sslState_ << ", events=" << eventFlags_;
  954. DestructorGuard dg(this);
  955. assert(
  956. sslState_ == STATE_CACHE_LOOKUP || sslState_ == STATE_ASYNC_PENDING ||
  957. sslState_ == STATE_ERROR || sslState_ == STATE_CLOSED);
  958. if (sslState_ == STATE_CLOSED) {
  959. // I sure hope whoever closed this socket didn't delete it already,
  960. // but this is not strictly speaking an error
  961. return;
  962. }
  963. if (sslState_ == STATE_ERROR) {
  964. // go straight to fail if timeout expired during lookup
  965. AsyncSocketException ex(
  966. AsyncSocketException::TIMED_OUT, "SSL accept timed out");
  967. failHandshake(__func__, ex);
  968. return;
  969. }
  970. sslState_ = STATE_ACCEPTING;
  971. this->handleAccept();
  972. }
  973. void AsyncSSLSocket::handleAccept() noexcept {
  974. VLOG(3) << "AsyncSSLSocket::handleAccept() this=" << this << ", fd=" << fd_
  975. << ", state=" << int(state_) << ", "
  976. << "sslState=" << sslState_ << ", events=" << eventFlags_;
  977. assert(server_);
  978. assert(state_ == StateEnum::ESTABLISHED && sslState_ == STATE_ACCEPTING);
  979. if (!ssl_) {
  980. /* lazily create the SSL structure */
  981. try {
  982. ssl_.reset(ctx_->createSSL());
  983. } catch (std::exception& e) {
  984. sslState_ = STATE_ERROR;
  985. AsyncSocketException ex(
  986. AsyncSocketException::INTERNAL_ERROR,
  987. "error calling SSLContext::createSSL()");
  988. LOG(ERROR) << "AsyncSSLSocket::handleAccept(this=" << this
  989. << ", fd=" << fd_ << "): " << e.what();
  990. return failHandshake(__func__, ex);
  991. }
  992. if (!setupSSLBio()) {
  993. sslState_ = STATE_ERROR;
  994. AsyncSocketException ex(
  995. AsyncSocketException::INTERNAL_ERROR, "error creating write bio");
  996. return failHandshake(__func__, ex);
  997. }
  998. SSL_set_ex_data(ssl_.get(), getSSLExDataIndex(), this);
  999. applyVerificationOptions(ssl_);
  1000. }
  1001. if (server_ && parseClientHello_) {
  1002. SSL_set_msg_callback(
  1003. ssl_.get(), &AsyncSSLSocket::clientHelloParsingCallback);
  1004. SSL_set_msg_callback_arg(ssl_.get(), this);
  1005. }
  1006. DCHECK(ctx_->sslAcceptRunner());
  1007. updateEventRegistration(
  1008. EventHandler::NONE, EventHandler::READ | EventHandler::WRITE);
  1009. DelayedDestruction::DestructorGuard dg(this);
  1010. ctx_->sslAcceptRunner()->run(
  1011. [this, dg]() { return SSL_accept(ssl_.get()); },
  1012. [this, dg](int ret) { handleReturnFromSSLAccept(ret); });
  1013. }
  1014. void AsyncSSLSocket::handleReturnFromSSLAccept(int ret) {
  1015. if (sslState_ != STATE_ACCEPTING) {
  1016. return;
  1017. }
  1018. if (ret <= 0) {
  1019. VLOG(3) << "SSL_accept returned: " << ret;
  1020. int sslError;
  1021. unsigned long errError;
  1022. int errnoCopy = errno;
  1023. if (willBlock(ret, &sslError, &errError)) {
  1024. return;
  1025. } else {
  1026. sslState_ = STATE_ERROR;
  1027. SSLException ex(sslError, errError, ret, errnoCopy);
  1028. return failHandshake(__func__, ex);
  1029. }
  1030. }
  1031. handshakeComplete_ = true;
  1032. updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
  1033. // Move into STATE_ESTABLISHED in the normal case that we are in
  1034. // STATE_ACCEPTING.
  1035. sslState_ = STATE_ESTABLISHED;
  1036. VLOG(3) << "AsyncSSLSocket " << this << ": fd " << fd_
  1037. << " successfully accepted; state=" << int(state_)
  1038. << ", sslState=" << sslState_ << ", events=" << eventFlags_;
  1039. // Remember the EventBase we are attached to, before we start invoking any
  1040. // callbacks (since the callbacks may call detachEventBase()).
  1041. EventBase* originalEventBase = eventBase_;
  1042. // Call the accept callback.
  1043. invokeHandshakeCB();
  1044. // Note that the accept callback may have changed our state.
  1045. // (set or unset the read callback, called write(), closed the socket, etc.)
  1046. // The following code needs to handle these situations correctly.
  1047. //
  1048. // If the socket has been closed, readCallback_ and writeReqHead_ will
  1049. // always be nullptr, so that will prevent us from trying to read or write.
  1050. //
  1051. // The main thing to check for is if eventBase_ is still originalEventBase.
  1052. // If not, we have been detached from this event base, so we shouldn't
  1053. // perform any more operations.
  1054. if (eventBase_ != originalEventBase) {
  1055. return;
  1056. }
  1057. AsyncSocket::handleInitialReadWrite();
  1058. }
  1059. void AsyncSSLSocket::handleConnect() noexcept {
  1060. VLOG(3) << "AsyncSSLSocket::handleConnect() this=" << this << ", fd=" << fd_
  1061. << ", state=" << int(state_) << ", "
  1062. << "sslState=" << sslState_ << ", events=" << eventFlags_;
  1063. assert(!server_);
  1064. if (state_ < StateEnum::ESTABLISHED) {
  1065. return AsyncSocket::handleConnect();
  1066. }
  1067. assert(
  1068. (state_ == StateEnum::FAST_OPEN || state_ == StateEnum::ESTABLISHED) &&
  1069. sslState_ == STATE_CONNECTING);
  1070. assert(ssl_);
  1071. auto originalState = state_;
  1072. int ret = SSL_connect(ssl_.get());
  1073. if (ret <= 0) {
  1074. int sslError;
  1075. unsigned long errError;
  1076. int errnoCopy = errno;
  1077. if (willBlock(ret, &sslError, &errError)) {
  1078. // We fell back to connecting state due to TFO
  1079. if (state_ == StateEnum::CONNECTING) {
  1080. DCHECK_EQ(StateEnum::FAST_OPEN, originalState);
  1081. if (handshakeTimeout_.isScheduled()) {
  1082. handshakeTimeout_.cancelTimeout();
  1083. }
  1084. }
  1085. return;
  1086. } else {
  1087. sslState_ = STATE_ERROR;
  1088. SSLException ex(sslError, errError, ret, errnoCopy);
  1089. return failHandshake(__func__, ex);
  1090. }
  1091. }
  1092. handshakeComplete_ = true;
  1093. updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
  1094. // Move into STATE_ESTABLISHED in the normal case that we are in
  1095. // STATE_CONNECTING.
  1096. sslState_ = STATE_ESTABLISHED;
  1097. VLOG(3) << "AsyncSSLSocket " << this << ": "
  1098. << "fd " << fd_ << " successfully connected; "
  1099. << "state=" << int(state_) << ", sslState=" << sslState_
  1100. << ", events=" << eventFlags_;
  1101. // Remember the EventBase we are attached to, before we start invoking any
  1102. // callbacks (since the callbacks may call detachEventBase()).
  1103. EventBase* originalEventBase = eventBase_;
  1104. // Call the handshake callback.
  1105. invokeHandshakeCB();
  1106. // Note that the connect callback may have changed our state.
  1107. // (set or unset the read callback, called write(), closed the socket, etc.)
  1108. // The following code needs to handle these situations correctly.
  1109. //
  1110. // If the socket has been closed, readCallback_ and writeReqHead_ will
  1111. // always be nullptr, so that will prevent us from trying to read or write.
  1112. //
  1113. // The main thing to check for is if eventBase_ is still originalEventBase.
  1114. // If not, we have been detached from this event base, so we shouldn't
  1115. // perform any more operations.
  1116. if (eventBase_ != originalEventBase) {
  1117. return;
  1118. }
  1119. AsyncSocket::handleInitialReadWrite();
  1120. }
  1121. void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) {
  1122. connectionTimeout_.cancelTimeout();
  1123. AsyncSocket::invokeConnectErr(ex);
  1124. if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
  1125. if (handshakeTimeout_.isScheduled()) {
  1126. handshakeTimeout_.cancelTimeout();
  1127. }
  1128. // If we fell back to connecting state during TFO and the connection
  1129. // failed, it would be an SSL failure as well.
  1130. invokeHandshakeErr(ex);
  1131. }
  1132. }
  1133. void AsyncSSLSocket::invokeConnectSuccess() {
  1134. connectionTimeout_.cancelTimeout();
  1135. if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
  1136. assert(tfoAttempted_);
  1137. // If we failed TFO, we'd fall back to trying to connect the socket,
  1138. // to setup things like timeouts.
  1139. startSSLConnect();
  1140. }
  1141. // still invoke the base class since it re-sets the connect time.
  1142. AsyncSocket::invokeConnectSuccess();
  1143. }
  1144. void AsyncSSLSocket::scheduleConnectTimeout() {
  1145. if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
  1146. // We fell back from TFO, and need to set the timeouts.
  1147. // We will not have a connect callback in this case, thus if the timer
  1148. // expires we would have no-one to notify.
  1149. // Thus we should reset even the connect timers to point to the handshake
  1150. // timeouts.
  1151. assert(connectCallback_ == nullptr);
  1152. // We use a different connect timeout here than the handshake timeout, so
  1153. // that we can disambiguate the 2 timers.
  1154. if (connectTimeout_.count() > 0) {
  1155. if (!connectionTimeout_.scheduleTimeout(connectTimeout_)) {
  1156. throw AsyncSocketException(
  1157. AsyncSocketException::INTERNAL_ERROR,
  1158. withAddr("failed to schedule AsyncSSLSocket connect timeout"));
  1159. }
  1160. }
  1161. return;
  1162. }
  1163. AsyncSocket::scheduleConnectTimeout();
  1164. }
  1165. void AsyncSSLSocket::setReadCB(ReadCallback* callback) {
  1166. #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
  1167. // turn on the buffer movable in openssl
  1168. if (bufferMovableEnabled_ && ssl_ != nullptr && !isBufferMovable_ &&
  1169. callback != nullptr && callback->isBufferMovable()) {
  1170. SSL_set_mode(
  1171. ssl_.get(), SSL_get_mode(ssl_.get()) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
  1172. isBufferMovable_ = true;
  1173. }
  1174. #endif
  1175. AsyncSocket::setReadCB(callback);
  1176. }
  1177. void AsyncSSLSocket::setBufferMovableEnabled(bool enabled) {
  1178. bufferMovableEnabled_ = enabled;
  1179. }
  1180. void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) {
  1181. CHECK(readCallback_);
  1182. if (isBufferMovable_) {
  1183. *buf = nullptr;
  1184. *buflen = 0;
  1185. } else {
  1186. // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP
  1187. readCallback_->getReadBuffer(buf, buflen);
  1188. }
  1189. }
  1190. void AsyncSSLSocket::handleRead() noexcept {
  1191. VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
  1192. << ", state=" << int(state_) << ", "
  1193. << "sslState=" << sslState_ << ", events=" << eventFlags_;
  1194. if (state_ < StateEnum::ESTABLISHED) {
  1195. return AsyncSocket::handleRead();
  1196. }
  1197. if (sslState_ == STATE_ACCEPTING) {
  1198. assert(server_);
  1199. handleAccept();
  1200. return;
  1201. } else if (sslState_ == STATE_CONNECTING) {
  1202. assert(!server_);
  1203. handleConnect();
  1204. return;
  1205. }
  1206. // Normal read
  1207. AsyncSocket::handleRead();
  1208. }
  1209. AsyncSocket::ReadResult
  1210. AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
  1211. VLOG(4) << "AsyncSSLSocket::performRead() this=" << this << ", buf=" << *buf
  1212. << ", buflen=" << *buflen;
  1213. if (sslState_ == STATE_UNENCRYPTED) {
  1214. return AsyncSocket::performRead(buf, buflen, offset);
  1215. }
  1216. int bytes = 0;
  1217. if (!isBufferMovable_) {
  1218. bytes = SSL_read(ssl_.get(), *buf, int(*buflen));
  1219. }
  1220. #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
  1221. else {
  1222. bytes = SSL_read_buf(ssl_.get(), buf, (int*)offset, (int*)buflen);
  1223. }
  1224. #endif
  1225. if (server_ && renegotiateAttempted_) {
  1226. LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
  1227. << ", sslstate=" << sslState_ << ", events=" << eventFlags_
  1228. << "): client intitiated SSL renegotiation not permitted";
  1229. return ReadResult(
  1230. READ_ERROR,
  1231. std::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
  1232. }
  1233. if (bytes <= 0) {
  1234. int error = SSL_get_error(ssl_.get(), bytes);
  1235. if (error == SSL_ERROR_WANT_READ) {
  1236. // The caller will register for read event if not already.
  1237. if (errno == EWOULDBLOCK || errno == EAGAIN) {
  1238. return ReadResult(READ_BLOCKING);
  1239. } else {
  1240. return ReadResult(READ_ERROR);
  1241. }
  1242. } else if (error == SSL_ERROR_WANT_WRITE) {
  1243. // TODO: Even though we are attempting to read data, SSL_read() may
  1244. // need to write data if renegotiation is being performed. We currently
  1245. // don't support this and just fail the read.
  1246. LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
  1247. << ", sslState=" << sslState_ << ", events=" << eventFlags_
  1248. << "): unsupported SSL renegotiation during read";
  1249. return ReadResult(
  1250. READ_ERROR,
  1251. std::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
  1252. } else {
  1253. if (zero_return(error, bytes)) {
  1254. return ReadResult(bytes);
  1255. }
  1256. auto errError = ERR_get_error();
  1257. VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
  1258. << "state=" << state_ << ", "
  1259. << "sslState=" << sslState_ << ", "
  1260. << "events=" << std::hex << eventFlags_ << "): "
  1261. << "bytes: " << bytes << ", "
  1262. << "error: " << error << ", "
  1263. << "errno: " << errno << ", "
  1264. << "func: " << ERR_func_error_string(errError) << ", "
  1265. << "reason: " << ERR_reason_error_string(errError);
  1266. return ReadResult(
  1267. READ_ERROR,
  1268. std::make_unique<SSLException>(error, errError, bytes, errno));
  1269. }
  1270. } else {
  1271. appBytesReceived_ += bytes;
  1272. return ReadResult(bytes);
  1273. }
  1274. }
  1275. void AsyncSSLSocket::handleWrite() noexcept {
  1276. VLOG(5) << "AsyncSSLSocket::handleWrite() this=" << this << ", fd=" << fd_
  1277. << ", state=" << int(state_) << ", "
  1278. << "sslState=" << sslState_ << ", events=" << eventFlags_;
  1279. if (state_ < StateEnum::ESTABLISHED) {
  1280. return AsyncSocket::handleWrite();
  1281. }
  1282. if (sslState_ == STATE_ACCEPTING) {
  1283. assert(server_);
  1284. handleAccept();
  1285. return;
  1286. }
  1287. if (sslState_ == STATE_CONNECTING) {
  1288. assert(!server_);
  1289. handleConnect();
  1290. return;
  1291. }
  1292. // Normal write
  1293. AsyncSocket::handleWrite();
  1294. }
  1295. AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
  1296. if (error == SSL_ERROR_WANT_READ) {
  1297. // Even though we are attempting to write data, SSL_write() may
  1298. // need to read data if renegotiation is being performed. We currently
  1299. // don't support this and just fail the write.
  1300. LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
  1301. << ", sslState=" << sslState_ << ", events=" << eventFlags_
  1302. << "): "
  1303. << "unsupported SSL renegotiation during write";
  1304. return WriteResult(
  1305. WRITE_ERROR,
  1306. std::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
  1307. } else {
  1308. if (zero_return(error, rc)) {
  1309. return WriteResult(0);
  1310. }
  1311. auto errError = ERR_get_error();
  1312. VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
  1313. << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
  1314. << "SSL error: " << error << ", errno: " << errno
  1315. << ", func: " << ERR_func_error_string(errError)
  1316. << ", reason: " << ERR_reason_error_string(errError);
  1317. return WriteResult(
  1318. WRITE_ERROR,
  1319. std::make_unique<SSLException>(error, errError, rc, errno));
  1320. }
  1321. }
  1322. AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
  1323. const iovec* vec,
  1324. uint32_t count,
  1325. WriteFlags flags,
  1326. uint32_t* countWritten,
  1327. uint32_t* partialWritten) {
  1328. if (sslState_ == STATE_UNENCRYPTED) {
  1329. return AsyncSocket::performWrite(
  1330. vec, count, flags, countWritten, partialWritten);
  1331. }
  1332. if (sslState_ != STATE_ESTABLISHED) {
  1333. LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
  1334. << ", sslState=" << sslState_ << ", events=" << eventFlags_
  1335. << "): "
  1336. << "TODO: AsyncSSLSocket currently does not support calling "
  1337. << "write() before the handshake has fully completed";
  1338. return WriteResult(
  1339. WRITE_ERROR, std::make_unique<SSLException>(SSLError::EARLY_WRITE));
  1340. }
  1341. // Declare a buffer used to hold small write requests. It could point to a
  1342. // memory block either on stack or on heap. If it is on heap, we release it
  1343. // manually when scope exits
  1344. char* combinedBuf{nullptr};
  1345. SCOPE_EXIT {
  1346. // Note, always keep this check consistent with what we do below
  1347. if (combinedBuf != nullptr && minWriteSize_ > MAX_STACK_BUF_SIZE) {
  1348. delete[] combinedBuf;
  1349. }
  1350. };
  1351. *countWritten = 0;
  1352. *partialWritten = 0;
  1353. ssize_t totalWritten = 0;
  1354. size_t bytesStolenFromNextBuffer = 0;
  1355. for (uint32_t i = 0; i < count; i++) {
  1356. const iovec* v = vec + i;
  1357. size_t offset = bytesStolenFromNextBuffer;
  1358. bytesStolenFromNextBuffer = 0;
  1359. size_t len = v->iov_len - offset;
  1360. const void* buf;
  1361. if (len == 0) {
  1362. (*countWritten)++;
  1363. continue;
  1364. }
  1365. buf = ((const char*)v->iov_base) + offset;
  1366. ssize_t bytes;
  1367. uint32_t buffersStolen = 0;
  1368. auto sslWriteBuf = buf;
  1369. if ((len < minWriteSize_) && ((i + 1) < count)) {
  1370. // Combine this buffer with part or all of the next buffers in
  1371. // order to avoid really small-grained calls to SSL_write().
  1372. // Each call to SSL_write() produces a separate record in
  1373. // the egress SSL stream, and we've found that some low-end
  1374. // mobile clients can't handle receiving an HTTP response
  1375. // header and the first part of the response body in two
  1376. // separate SSL records (even if those two records are in
  1377. // the same TCP packet).
  1378. if (combinedBuf == nullptr) {
  1379. if (minWriteSize_ > MAX_STACK_BUF_SIZE) {
  1380. // Allocate the buffer on heap
  1381. combinedBuf = new char[minWriteSize_];
  1382. } else {
  1383. // Allocate the buffer on stack
  1384. combinedBuf = (char*)alloca(minWriteSize_);
  1385. }
  1386. }
  1387. assert(combinedBuf != nullptr);
  1388. sslWriteBuf = combinedBuf;
  1389. memcpy(combinedBuf, buf, len);
  1390. do {
  1391. // INVARIANT: i + buffersStolen == complete chunks serialized
  1392. uint32_t nextIndex = i + buffersStolen + 1;
  1393. bytesStolenFromNextBuffer =
  1394. std::min(vec[nextIndex].iov_len, minWriteSize_ - len);
  1395. if (bytesStolenFromNextBuffer > 0) {
  1396. assert(vec[nextIndex].iov_base != nullptr);
  1397. ::memcpy(
  1398. combinedBuf + len,
  1399. vec[nextIndex].iov_base,
  1400. bytesStolenFromNextBuffer);
  1401. }
  1402. len += bytesStolenFromNextBuffer;
  1403. if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
  1404. // couldn't steal the whole buffer
  1405. break;
  1406. } else {
  1407. bytesStolenFromNextBuffer = 0;
  1408. buffersStolen++;
  1409. }
  1410. } while ((i + buffersStolen + 1) < count && (len < minWriteSize_));
  1411. }
  1412. // Advance any empty buffers immediately after.
  1413. if (bytesStolenFromNextBuffer == 0) {
  1414. while ((i + buffersStolen + 1) < count &&
  1415. vec[i + buffersStolen + 1].iov_len == 0) {
  1416. buffersStolen++;
  1417. }
  1418. }
  1419. corkCurrentWrite_ =
  1420. isSet(flags, WriteFlags::CORK) || (i + buffersStolen + 1 < count);
  1421. bytes = eorAwareSSLWrite(
  1422. ssl_,
  1423. sslWriteBuf,
  1424. int(len),
  1425. (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
  1426. if (bytes <= 0) {
  1427. int error = SSL_get_error(ssl_.get(), int(bytes));
  1428. if (error == SSL_ERROR_WANT_WRITE) {
  1429. // The caller will register for write event if not already.
  1430. *partialWritten = uint32_t(offset);
  1431. return WriteResult(totalWritten);
  1432. }
  1433. auto writeResult = interpretSSLError(int(bytes), error);
  1434. if (writeResult.writeReturn < 0) {
  1435. return writeResult;
  1436. } // else fall through to below to correctly record totalWritten
  1437. }
  1438. totalWritten += bytes;
  1439. if (bytes == (ssize_t)len) {
  1440. // The full iovec is written.
  1441. (*countWritten) += 1 + buffersStolen;
  1442. i += buffersStolen;
  1443. // continue
  1444. } else {
  1445. bytes += offset; // adjust bytes to account for all of v
  1446. while (bytes >= (ssize_t)v->iov_len) {
  1447. // We combined this buf with part or all of the next one, and
  1448. // we managed to write all of this buf but not all of the bytes
  1449. // from the next one that we'd hoped to write.
  1450. bytes -= v->iov_len;
  1451. (*countWritten)++;
  1452. v = &(vec[++i]);
  1453. }
  1454. *partialWritten = uint32_t(bytes);
  1455. return WriteResult(totalWritten);
  1456. }
  1457. }
  1458. return WriteResult(totalWritten);
  1459. }
  1460. int AsyncSSLSocket::eorAwareSSLWrite(
  1461. const ssl::SSLUniquePtr& ssl,
  1462. const void* buf,
  1463. int n,
  1464. bool eor) {
  1465. if (eor && isEorTrackingEnabled()) {
  1466. if (appEorByteNo_) {
  1467. // cannot track for more than one app byte EOR
  1468. CHECK(appEorByteNo_ == appBytesWritten_ + n);
  1469. } else {
  1470. appEorByteNo_ = appBytesWritten_ + n;
  1471. }
  1472. // 1. It is fine to keep updating minEorRawByteNo_.
  1473. // 2. It is _min_ in the sense that SSL record will add some overhead.
  1474. minEorRawByteNo_ = getRawBytesWritten() + n;
  1475. }
  1476. n = sslWriteImpl(ssl.get(), buf, n);
  1477. if (n > 0) {
  1478. appBytesWritten_ += n;
  1479. if (appEorByteNo_) {
  1480. if (getRawBytesWritten() >= minEorRawByteNo_) {
  1481. minEorRawByteNo_ = 0;
  1482. }
  1483. if (appBytesWritten_ == appEorByteNo_) {
  1484. appEorByteNo_ = 0;
  1485. } else {
  1486. CHECK(appBytesWritten_ < appEorByteNo_);
  1487. }
  1488. }
  1489. }
  1490. return n;
  1491. }
  1492. void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
  1493. AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
  1494. if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
  1495. sslSocket->renegotiateAttempted_ = true;
  1496. }
  1497. if (where & SSL_CB_READ_ALERT) {
  1498. const char* type = SSL_alert_type_string(ret);
  1499. if (type) {
  1500. const char* desc = SSL_alert_desc_string(ret);
  1501. sslSocket->alertsReceived_.emplace_back(
  1502. *type, StringPiece(desc, std::strlen(desc)));
  1503. }
  1504. }
  1505. }
  1506. int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
  1507. struct msghdr msg;
  1508. struct iovec iov;
  1509. AsyncSSLSocket* tsslSock;
  1510. iov.iov_base = const_cast<char*>(in);
  1511. iov.iov_len = size_t(inl);
  1512. memset(&msg, 0, sizeof(msg));
  1513. msg.msg_iov = &iov;
  1514. msg.msg_iovlen = 1;
  1515. auto appData = OpenSSLUtils::getBioAppData(b);
  1516. CHECK(appData);
  1517. tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
  1518. CHECK(tsslSock);
  1519. WriteFlags flags = WriteFlags::NONE;
  1520. if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ &&
  1521. tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
  1522. flags |= WriteFlags::EOR;
  1523. }
  1524. if (tsslSock->corkCurrentWrite_) {
  1525. flags |= WriteFlags::CORK;
  1526. }
  1527. int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(
  1528. flags, false /*zeroCopyEnabled*/);
  1529. msg.msg_controllen =
  1530. tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags);
  1531. CHECK_GE(
  1532. AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
  1533. msg.msg_controllen);
  1534. if (msg.msg_controllen != 0) {
  1535. msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
  1536. tsslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control);
  1537. }
  1538. auto result = tsslSock->sendSocketMessage(
  1539. OpenSSLUtils::getBioFd(b, nullptr), &msg, msg_flags);
  1540. BIO_clear_retry_flags(b);
  1541. if (!result.exception && result.writeReturn <= 0) {
  1542. if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) {
  1543. BIO_set_retry_write(b);
  1544. }
  1545. }
  1546. return int(result.writeReturn);
  1547. }
  1548. int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) {
  1549. if (!out) {
  1550. return 0;
  1551. }
  1552. BIO_clear_retry_flags(b);
  1553. auto appData = OpenSSLUtils::getBioAppData(b);
  1554. CHECK(appData);
  1555. auto sslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
  1556. if (sslSock->preReceivedData_ && !sslSock->preReceivedData_->empty()) {
  1557. VLOG(5) << "AsyncSSLSocket::bioRead() this=" << sslSock
  1558. << ", reading pre-received data";
  1559. Cursor cursor(sslSock->preReceivedData_.get());
  1560. auto len = cursor.pullAtMost(out, outl);
  1561. IOBufQueue queue;
  1562. queue.append(std::move(sslSock->preReceivedData_));
  1563. queue.trimStart(len);
  1564. sslSock->preReceivedData_ = queue.move();
  1565. return static_cast<int>(len);
  1566. } else {
  1567. auto result = int(recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0));
  1568. if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
  1569. BIO_set_retry_read(b);
  1570. }
  1571. return result;
  1572. }
  1573. }
  1574. int AsyncSSLSocket::sslVerifyCallback(
  1575. int preverifyOk,
  1576. X509_STORE_CTX* x509Ctx) {
  1577. SSL* ssl = (SSL*)X509_STORE_CTX_get_ex_data(
  1578. x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
  1579. AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
  1580. VLOG(3) << "AsyncSSLSocket::sslVerifyCallback() this=" << self << ", "
  1581. << "fd=" << self->fd_ << ", preverifyOk=" << preverifyOk;
  1582. return (self->handshakeCallback_)
  1583. ? self->handshakeCallback_->handshakeVer(self, preverifyOk, x509Ctx)
  1584. : preverifyOk;
  1585. }
  1586. void AsyncSSLSocket::enableClientHelloParsing() {
  1587. parseClientHello_ = true;
  1588. clientHelloInfo_ = std::make_unique<ssl::ClientHelloInfo>();
  1589. }
  1590. void AsyncSSLSocket::resetClientHelloParsing(SSL* ssl) {
  1591. SSL_set_msg_callback(ssl, nullptr);
  1592. SSL_set_msg_callback_arg(ssl, nullptr);
  1593. clientHelloInfo_->clientHelloBuf_.clear();
  1594. }
  1595. void AsyncSSLSocket::clientHelloParsingCallback(
  1596. int written,
  1597. int /* version */,
  1598. int contentType,
  1599. const void* buf,
  1600. size_t len,
  1601. SSL* ssl,
  1602. void* arg) {
  1603. AsyncSSLSocket* sock = static_cast<AsyncSSLSocket*>(arg);
  1604. if (written != 0) {
  1605. sock->resetClientHelloParsing(ssl);
  1606. return;
  1607. }
  1608. if (contentType != SSL3_RT_HANDSHAKE) {
  1609. return;
  1610. }
  1611. if (len == 0) {
  1612. return;
  1613. }
  1614. auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_;
  1615. clientHelloBuf.append(IOBuf::wrapBuffer(buf, len));
  1616. try {
  1617. Cursor cursor(clientHelloBuf.front());
  1618. if (cursor.read<uint8_t>() != SSL3_MT_CLIENT_HELLO) {
  1619. sock->resetClientHelloParsing(ssl);
  1620. return;
  1621. }
  1622. if (cursor.totalLength() < 3) {
  1623. clientHelloBuf.trimEnd(len);
  1624. clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
  1625. return;
  1626. }
  1627. uint32_t messageLength = cursor.read<uint8_t>();
  1628. messageLength <<= 8;
  1629. messageLength |= cursor.read<uint8_t>();
  1630. messageLength <<= 8;
  1631. messageLength |= cursor.read<uint8_t>();
  1632. if (cursor.totalLength() < messageLength) {
  1633. clientHelloBuf.trimEnd(len);
  1634. clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
  1635. return;
  1636. }
  1637. sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read<uint8_t>();
  1638. sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read<uint8_t>();
  1639. cursor.skip(4); // gmt_unix_time
  1640. cursor.skip(28); // random_bytes
  1641. cursor.skip(cursor.read<uint8_t>()); // session_id
  1642. uint16_t cipherSuitesLength = cursor.readBE<uint16_t>();
  1643. for (int i = 0; i < cipherSuitesLength; i += 2) {
  1644. sock->clientHelloInfo_->clientHelloCipherSuites_.push_back(
  1645. cursor.readBE<uint16_t>());
  1646. }
  1647. uint8_t compressionMethodsLength = cursor.read<uint8_t>();
  1648. for (int i = 0; i < compressionMethodsLength; ++i) {
  1649. sock->clientHelloInfo_->clientHelloCompressionMethods_.push_back(
  1650. cursor.readBE<uint8_t>());
  1651. }
  1652. if (cursor.totalLength() > 0) {
  1653. uint16_t extensionsLength = cursor.readBE<uint16_t>();
  1654. while (extensionsLength) {
  1655. ssl::TLSExtension extensionType =
  1656. static_cast<ssl::TLSExtension>(cursor.readBE<uint16_t>());
  1657. sock->clientHelloInfo_->clientHelloExtensions_.push_back(extensionType);
  1658. extensionsLength -= 2;
  1659. uint16_t extensionDataLength = cursor.readBE<uint16_t>();
  1660. extensionsLength -= 2;
  1661. extensionsLength -= extensionDataLength;
  1662. if (extensionType == ssl::TLSExtension::SIGNATURE_ALGORITHMS) {
  1663. cursor.skip(2);
  1664. extensionDataLength -= 2;
  1665. while (extensionDataLength) {
  1666. ssl::HashAlgorithm hashAlg =
  1667. static_cast<ssl::HashAlgorithm>(cursor.readBE<uint8_t>());
  1668. ssl::SignatureAlgorithm sigAlg =
  1669. static_cast<ssl::SignatureAlgorithm>(cursor.readBE<uint8_t>());
  1670. extensionDataLength -= 2;
  1671. sock->clientHelloInfo_->clientHelloSigAlgs_.emplace_back(
  1672. hashAlg, sigAlg);
  1673. }
  1674. } else if (extensionType == ssl::TLSExtension::SUPPORTED_VERSIONS) {
  1675. cursor.skip(1);
  1676. extensionDataLength -= 1;
  1677. while (extensionDataLength) {
  1678. sock->clientHelloInfo_->clientHelloSupportedVersions_.push_back(
  1679. cursor.readBE<uint16_t>());
  1680. extensionDataLength -= 2;
  1681. }
  1682. } else {
  1683. cursor.skip(extensionDataLength);
  1684. }
  1685. }
  1686. }
  1687. } catch (std::out_of_range&) {
  1688. // we'll use what we found and cleanup below.
  1689. VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
  1690. << "buffer finished unexpectedly."
  1691. << " AsyncSSLSocket socket=" << sock;
  1692. }
  1693. sock->resetClientHelloParsing(ssl);
  1694. }
  1695. void AsyncSSLSocket::getSSLClientCiphers(
  1696. std::string& clientCiphers,
  1697. bool convertToString) const {
  1698. std::string ciphers;
  1699. if (parseClientHello_ == false ||
  1700. clientHelloInfo_->clientHelloCipherSuites_.empty()) {
  1701. clientCiphers = "";
  1702. return;
  1703. }
  1704. bool first = true;
  1705. for (auto originalCipherCode : clientHelloInfo_->clientHelloCipherSuites_) {
  1706. if (first) {
  1707. first = false;
  1708. } else {
  1709. ciphers += ":";
  1710. }
  1711. bool nameFound = convertToString;
  1712. if (convertToString) {
  1713. const auto& name = OpenSSLUtils::getCipherName(originalCipherCode);
  1714. if (name.empty()) {
  1715. nameFound = false;
  1716. } else {
  1717. ciphers += name;
  1718. }
  1719. }
  1720. if (!nameFound) {
  1721. folly::hexlify(
  1722. std::array<uint8_t, 2>{
  1723. {static_cast<uint8_t>((originalCipherCode >> 8) & 0xffL),
  1724. static_cast<uint8_t>(originalCipherCode & 0x00ffL)}},
  1725. ciphers,
  1726. /* append to ciphers = */ true);
  1727. }
  1728. }
  1729. clientCiphers = std::move(ciphers);
  1730. }
  1731. std::string AsyncSSLSocket::getSSLClientComprMethods() const {
  1732. if (!parseClientHello_) {
  1733. return "";
  1734. }
  1735. return folly::join(":", clientHelloInfo_->clientHelloCompressionMethods_);
  1736. }
  1737. std::string AsyncSSLSocket::getSSLClientExts() const {
  1738. if (!parseClientHello_) {
  1739. return "";
  1740. }
  1741. return folly::join(":", clientHelloInfo_->clientHelloExtensions_);
  1742. }
  1743. std::string AsyncSSLSocket::getSSLClientSigAlgs() const {
  1744. if (!parseClientHello_) {
  1745. return "";
  1746. }
  1747. std::string sigAlgs;
  1748. sigAlgs.reserve(clientHelloInfo_->clientHelloSigAlgs_.size() * 4);
  1749. for (size_t i = 0; i < clientHelloInfo_->clientHelloSigAlgs_.size(); i++) {
  1750. if (i) {
  1751. sigAlgs.push_back(':');
  1752. }
  1753. sigAlgs.append(
  1754. folly::to<std::string>(clientHelloInfo_->clientHelloSigAlgs_[i].first));
  1755. sigAlgs.push_back(',');
  1756. sigAlgs.append(folly::to<std::string>(
  1757. clientHelloInfo_->clientHelloSigAlgs_[i].second));
  1758. }
  1759. return sigAlgs;
  1760. }
  1761. std::string AsyncSSLSocket::getSSLClientSupportedVersions() const {
  1762. if (!parseClientHello_) {
  1763. return "";
  1764. }
  1765. return folly::join(":", clientHelloInfo_->clientHelloSupportedVersions_);
  1766. }
  1767. std::string AsyncSSLSocket::getSSLAlertsReceived() const {
  1768. std::string ret;
  1769. for (const auto& alert : alertsReceived_) {
  1770. if (!ret.empty()) {
  1771. ret.append(",");
  1772. }
  1773. ret.append(folly::to<std::string>(alert.first, ": ", alert.second));
  1774. }
  1775. return ret;
  1776. }
  1777. void AsyncSSLSocket::setSSLCertVerificationAlert(std::string alert) {
  1778. sslVerificationAlert_ = std::move(alert);
  1779. }
  1780. std::string AsyncSSLSocket::getSSLCertVerificationAlert() const {
  1781. return sslVerificationAlert_;
  1782. }
  1783. void AsyncSSLSocket::getSSLSharedCiphers(std::string& sharedCiphers) const {
  1784. char ciphersBuffer[1024];
  1785. ciphersBuffer[0] = '\0';
  1786. SSL_get_shared_ciphers(ssl_.get(), ciphersBuffer, sizeof(ciphersBuffer) - 1);
  1787. sharedCiphers = ciphersBuffer;
  1788. }
  1789. void AsyncSSLSocket::getSSLServerCiphers(std::string& serverCiphers) const {
  1790. serverCiphers = SSL_get_cipher_list(ssl_.get(), 0);
  1791. int i = 1;
  1792. const char* cipher;
  1793. while ((cipher = SSL_get_cipher_list(ssl_.get(), i)) != nullptr) {
  1794. serverCiphers.append(":");
  1795. serverCiphers.append(cipher);
  1796. i++;
  1797. }
  1798. }
  1799. } // namespace folly