/* * Copyright 2015-present Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include class BlockingSocket : public folly::AsyncSocket::ConnectCallback, public folly::AsyncTransportWrapper::ReadCallback, public folly::AsyncTransportWrapper::WriteCallback { public: explicit BlockingSocket(int fd) : sock_(new folly::AsyncSocket(&eventBase_, fd)) {} BlockingSocket( folly::SocketAddress address, std::shared_ptr sslContext) : sock_( sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_) : new folly::AsyncSocket(&eventBase_)), address_(address) {} explicit BlockingSocket(folly::AsyncSocket::UniquePtr socket) : sock_(std::move(socket)) { sock_->attachEventBase(&eventBase_); } void enableTFO() { sock_->enableTFO(); } void setAddress(folly::SocketAddress address) { address_ = address; } void open( std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) { sock_->connect(this, address_, timeout.count()); eventBase_.loop(); if (err_.hasValue()) { throw err_.value(); } } void close() { sock_->close(); } void closeWithReset() { sock_->closeWithReset(); } int32_t write(uint8_t const* buf, size_t len) { sock_->write(this, buf, len); eventBase_.loop(); if (err_.hasValue()) { throw err_.value(); } return len; } void flush() {} int32_t readAll(uint8_t* buf, size_t len) { return readHelper(buf, len, true); } int32_t read(uint8_t* buf, size_t len) { return readHelper(buf, len, false); } int getSocketFD() const { return sock_->getFd(); } folly::AsyncSocket* getSocket() { return sock_.get(); } folly::AsyncSSLSocket* getSSLSocket() { return dynamic_cast(sock_.get()); } private: folly::EventBase eventBase_; folly::AsyncSocket::UniquePtr sock_; folly::Optional err_; uint8_t* readBuf_{nullptr}; size_t readLen_{0}; folly::SocketAddress address_; void connectSuccess() noexcept override {} void connectErr(const folly::AsyncSocketException& ex) noexcept override { err_ = ex; } void getReadBuffer(void** bufReturn, size_t* lenReturn) override { *bufReturn = readBuf_; *lenReturn = readLen_; } void readDataAvailable(size_t len) noexcept override { readBuf_ += len; readLen_ -= len; if (readLen_ == 0) { sock_->setReadCB(nullptr); } } void readEOF() noexcept override {} void readErr(const folly::AsyncSocketException& ex) noexcept override { err_ = ex; } void writeSuccess() noexcept override {} void writeErr( size_t /* bytesWritten */, const folly::AsyncSocketException& ex) noexcept override { err_ = ex; } int32_t readHelper(uint8_t* buf, size_t len, bool all) { if (!sock_->good()) { return 0; } readBuf_ = buf; readLen_ = len; sock_->setReadCB(this); while (!err_ && sock_->good() && readLen_ > 0) { eventBase_.loopOnce(); if (!all) { break; } } sock_->setReadCB(nullptr); if (err_.hasValue()) { throw err_.value(); } if (all && readLen_ > 0) { throw folly::AsyncSocketException( folly::AsyncSocketException::UNKNOWN, "eof"); } return len - readLen_; } };