diff --git a/sandboxed_api/sandbox2/BUILD.bazel b/sandboxed_api/sandbox2/BUILD.bazel index afb57d2..04d797c 100644 --- a/sandboxed_api/sandbox2/BUILD.bazel +++ b/sandboxed_api/sandbox2/BUILD.bazel @@ -835,6 +835,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":util", + "//sandboxed_api/util:fileops", "//sandboxed_api/util:raw_logging", "//sandboxed_api/util:status", "//sandboxed_api/util:status_cc_proto", diff --git a/sandboxed_api/sandbox2/CMakeLists.txt b/sandboxed_api/sandbox2/CMakeLists.txt index ccff2ea..4966d1d 100644 --- a/sandboxed_api/sandbox2/CMakeLists.txt +++ b/sandboxed_api/sandbox2/CMakeLists.txt @@ -763,6 +763,7 @@ target_link_libraries(sandbox2_comms PUBLIC absl::core_headers absl::status protobuf::libprotobuf + sapi::fileops sapi::status ) diff --git a/sandboxed_api/sandbox2/comms.cc b/sandboxed_api/sandbox2/comms.cc index da595d6..79cab93 100644 --- a/sandboxed_api/sandbox2/comms.cc +++ b/sandboxed_api/sandbox2/comms.cc @@ -43,11 +43,14 @@ #include "absl/status/statusor.h" #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "google/protobuf/message_lite.h" #include "sandboxed_api/sandbox2/util.h" +#include "sandboxed_api/util/fileops.h" #include "sandboxed_api/util/raw_logging.h" #include "sandboxed_api/util/status.h" #include "sandboxed_api/util/status.pb.h" +#include "sandboxed_api/util/status_macros.h" namespace sandbox2 { @@ -59,6 +62,9 @@ class PotentiallyBlockingRegion { }; namespace { + +using sapi::file_util::fileops::FDCloser; + bool IsFatalError(int saved_errno) { return saved_errno != EAGAIN && saved_errno != EWOULDBLOCK && saved_errno != EFAULT && saved_errno != EINTR && @@ -74,16 +80,47 @@ int GetDefaultCommsFd() { } return Comms::kSandbox2ClientCommsFD; } + +socklen_t CreateSockaddrUn(const std::string& socket_name, bool abstract_uds, + sockaddr_un* sun) { + sun->sun_family = AF_UNIX; + bzero(sun->sun_path, sizeof(sun->sun_path)); + socklen_t slen = sizeof(sun->sun_family) + strlen(socket_name.c_str()); + if (abstract_uds) { + // Create an 'abstract socket address' by specifying a leading null byte. + // The remainder of the path is used as a unique name, but no file is + // created on the filesystem. No need to NUL-terminate the string. See `man + // 7 unix` for further explanation. + strncpy(&sun->sun_path[1], socket_name.c_str(), sizeof(sun->sun_path) - 1); + // Len is complicated - it's essentially size of the path, plus initial + // NUL-byte, minus size of the sun.sun_family. + slen++; + } else { + // Create the socket address as it was passed from the constructor. + strncpy(&sun->sun_path[0], socket_name.c_str(), sizeof(sun->sun_path)); + } + + // This takes care of the socket address overflow. + if (slen > sizeof(sockaddr_un)) { + SAPI_RAW_LOG(ERROR, "Socket address is too long, will be truncated"); + slen = sizeof(sockaddr_un); + } + return slen; +} } // namespace Comms::Comms(const std::string& socket_name, bool abstract_uds) - : socket_name_(socket_name), abstract_uds_(abstract_uds) {} + : name_(socket_name), abstract_uds_(abstract_uds) {} -Comms::Comms(int fd) : connection_fd_(fd) { +Comms::Comms(int fd, absl::string_view name) : connection_fd_(fd) { // Generate a unique and meaningful socket name for this FD. // Note: getpid()/gettid() are non-blocking syscalls. - socket_name_ = absl::StrFormat("sandbox2::Comms:FD=%d/PID=%d/TID=%ld", fd, - getpid(), syscall(__NR_gettid)); + if (name.empty()) { + name_ = absl::StrFormat("sandbox2::Comms:FD=%d/PID=%d/TID=%ld", fd, + getpid(), syscall(__NR_gettid)); + } else { + name_ = std::string(name); + } // File descriptor is already connected. state_ = State::kConnected; @@ -94,7 +131,36 @@ Comms::Comms(Comms::DefaultConnectionTag) : Comms(GetDefaultCommsFd()) {} Comms::~Comms() { Terminate(); } int Comms::GetConnectionFD() const { - return connection_fd_; + return connection_fd_.get(); +} + +absl::StatusOr ListeningComms::Create( + absl::string_view socket_name, bool abstract_uds) { + ListeningComms comms(std::string(socket_name), abstract_uds); + SAPI_RETURN_IF_ERROR(comms.Listen()); + return comms; +} + +absl::Status ListeningComms::Listen() { + bind_fd_ = FDCloser(socket(AF_UNIX, SOCK_STREAM, 0)); // Non-blocking + if (bind_fd_.get() == -1) { + return absl::ErrnoToStatus(errno, "socket(AF_UNIX) failed"); + } + + sockaddr_un sus; + socklen_t slen = CreateSockaddrUn(socket_name_, abstract_uds_, &sus); + // bind() is non-blocking. + if (bind(bind_fd_.get(), reinterpret_cast(&sus), slen) == -1) { + return absl::ErrnoToStatus(errno, "bind failed"); + } + + // listen() non-blocking. + if (listen(bind_fd_.get(), 0) == -1) { + return absl::ErrnoToStatus(errno, "listen failed"); + } + + SAPI_RAW_VLOG(1, "Listening at: %s", socket_name_.c_str()); + return absl::OkStatus(); } bool Comms::Listen() { @@ -102,123 +168,98 @@ bool Comms::Listen() { return true; } - bind_fd_ = socket(AF_UNIX, SOCK_STREAM, 0); // Non-blocking - if (bind_fd_ == -1) { - SAPI_RAW_PLOG(ERROR, "socket(AF_UNIX)"); + absl::StatusOr listening_comms = + ListeningComms::Create(name_, abstract_uds_); + if (!listening_comms.ok()) { + SAPI_RAW_LOG(ERROR, "Listening failed: %s", + std::string(listening_comms.status().message()).c_str()); return false; } - - sockaddr_un sus; - socklen_t slen = CreateSockaddrUn(&sus); - // bind() is non-blocking. - if (bind(bind_fd_, reinterpret_cast(&sus), slen) == -1) { - SAPI_RAW_PLOG(ERROR, "bind(bind_fd)"); - - // Note: checking for EINTR on close() syscall is useless and possibly - // harmful, see https://lwn.net/Articles/576478/. - { - PotentiallyBlockingRegion region; - close(bind_fd_); - } - bind_fd_ = -1; - return false; - } - - // listen() non-blocking. - if (listen(bind_fd_, 0) == -1) { - SAPI_RAW_PLOG(ERROR, "listen(bind_fd)"); - { - PotentiallyBlockingRegion region; - close(bind_fd_); - } - bind_fd_ = -1; - return false; - } - - SAPI_RAW_VLOG(1, "Listening at: %s", socket_name_.c_str()); + listening_comms_ = + std::make_unique(*std::move(listening_comms)); return true; } +absl::StatusOr ListeningComms::Accept() { + sockaddr_un suc; + socklen_t len = sizeof(suc); + int connection_fd; + { + PotentiallyBlockingRegion region; + connection_fd = TEMP_FAILURE_RETRY( + accept(bind_fd_.get(), reinterpret_cast(&suc), &len)); + } + if (connection_fd == -1) { + return absl::ErrnoToStatus(errno, "accept failed"); + } + SAPI_RAW_VLOG(1, "Accepted connection at: %s, fd: %d", socket_name_.c_str(), + connection_fd); + return Comms(connection_fd, socket_name_); +} + bool Comms::Accept() { if (IsConnected()) { return true; } - sockaddr_un suc; - socklen_t len = sizeof(suc); - { - PotentiallyBlockingRegion region; - connection_fd_ = TEMP_FAILURE_RETRY( - accept(bind_fd_, reinterpret_cast(&suc), &len)); - } - if (connection_fd_ == -1) { - SAPI_RAW_PLOG(ERROR, "accept(bind_fd)"); - { - PotentiallyBlockingRegion region; - close(bind_fd_); - } - bind_fd_ = -1; + if (listening_comms_ == nullptr) { + SAPI_RAW_LOG(ERROR, "Comms::Listen needs to be called first"); return false; } - state_ = State::kConnected; - - SAPI_RAW_VLOG(1, "Accepted connection at: %s, fd: %d", socket_name_.c_str(), - connection_fd_); + absl::StatusOr comms = listening_comms_->Accept(); + if (!comms.ok()) { + SAPI_RAW_LOG(ERROR, "%s", std::string(comms.status().message()).c_str()); + return false; + } + *this = *std::move(comms); return true; } +absl::StatusOr Comms::Connect(const std::string& socket_name, + bool abstract_uds) { + FDCloser connection_fd(socket(AF_UNIX, SOCK_STREAM, 0)); // Non-blocking + if (connection_fd.get() == -1) { + return absl::ErrnoToStatus(errno, "socket(AF_UNIX)"); + } + + sockaddr_un suc; + socklen_t slen = CreateSockaddrUn(socket_name, abstract_uds, &suc); + int ret; + { + PotentiallyBlockingRegion region; + ret = TEMP_FAILURE_RETRY( + connect(connection_fd.get(), reinterpret_cast(&suc), slen)); + } + if (ret == -1) { + return absl::ErrnoToStatus(errno, "connect(connection_fd)"); + } + + SAPI_RAW_VLOG(1, "Connected to: %s, fd: %d", socket_name.c_str(), + connection_fd.get()); + return Comms(connection_fd.Release(), socket_name); +} + bool Comms::Connect() { if (IsConnected()) { return true; } - connection_fd_ = socket(AF_UNIX, SOCK_STREAM, 0); // Non-blocking - if (connection_fd_ == -1) { - SAPI_RAW_PLOG(ERROR, "socket(AF_UNIX)"); + absl::StatusOr connected = Connect(name_, abstract_uds_); + if (!connected.ok()) { + SAPI_RAW_LOG(ERROR, "%s", + std::string(connected.status().message()).c_str()); return false; } - - sockaddr_un suc; - socklen_t slen = CreateSockaddrUn(&suc); - int ret; - { - PotentiallyBlockingRegion region; - ret = TEMP_FAILURE_RETRY( - connect(connection_fd_, reinterpret_cast(&suc), slen)); - } - if (ret == -1) { - SAPI_RAW_PLOG(ERROR, "connect(connection_fd)"); - { - PotentiallyBlockingRegion region; - close(connection_fd_); - } - connection_fd_ = -1; - return false; - } - - state_ = State::kConnected; - - SAPI_RAW_VLOG(1, "Connected to: %s, fd: %d", socket_name_.c_str(), - connection_fd_); + *this = *std::move(connected); return true; } void Comms::Terminate() { - { - PotentiallyBlockingRegion region; + state_ = State::kTerminated; - state_ = State::kTerminated; - - if (bind_fd_ != -1) { - close(bind_fd_); - bind_fd_ = -1; - } - if (connection_fd_ != -1) { - close(connection_fd_); - connection_fd_ = -1; - } - } + connection_fd_.Close(); + listening_comms_.reset(); } bool Comms::SendTLV(uint32_t tag, size_t length, const void* value) { @@ -349,7 +390,7 @@ bool Comms::RecvFD(int* fd) { util::Syscall(__NR_recvmsg, fd, reinterpret_cast(&msg), 0)); }; ssize_t len; - len = op(connection_fd_); + len = op(connection_fd_.get()); if (len < 0) { if (IsFatalError(errno)) { Terminate(); @@ -432,7 +473,7 @@ bool Comms::SendFD(int fd) { util::Syscall(__NR_sendmsg, fd, reinterpret_cast(&msg), 0)); }; ssize_t len; - len = op(connection_fd_); + len = op(connection_fd_.get()); if (len == -1 && errno == EPIPE) { Terminate(); SAPI_RAW_LOG(ERROR, "sendmsg(SCM_RIGHTS): Peer disconnected"); @@ -458,10 +499,10 @@ bool Comms::RecvProtoBuf(google::protobuf::MessageLite* message) { std::vector bytes; if (!RecvTLV(&tag, &bytes)) { if (IsConnected()) { - SAPI_RAW_PLOG(ERROR, "RecvProtoBuf failed for (%s)", socket_name_); + SAPI_RAW_PLOG(ERROR, "RecvProtoBuf failed for (%s)", name_); } else { Terminate(); - SAPI_RAW_VLOG(2, "Connection terminated (%s)", socket_name_.c_str()); + SAPI_RAW_VLOG(2, "Connection terminated (%s)", name_.c_str()); } return false; } @@ -488,32 +529,6 @@ bool Comms::SendProtoBuf(const google::protobuf::MessageLite& message) { // All methods below are private, for internal use only. // ***************************************************************************** -socklen_t Comms::CreateSockaddrUn(sockaddr_un* sun) { - sun->sun_family = AF_UNIX; - bzero(sun->sun_path, sizeof(sun->sun_path)); - socklen_t slen = sizeof(sun->sun_family) + strlen(socket_name_.c_str()); - if (abstract_uds_) { - // Create an 'abstract socket address' by specifying a leading null byte. - // The remainder of the path is used as a unique name, but no file is - // created on the filesystem. No need to NUL-terminate the string. See `man - // 7 unix` for further explanation. - strncpy(&sun->sun_path[1], socket_name_.c_str(), sizeof(sun->sun_path) - 1); - // Len is complicated - it's essentially size of the path, plus initial - // NUL-byte, minus size of the sun.sun_family. - slen++; - } else { - // Create the socket address as it was passed from the constructor. - strncpy(&sun->sun_path[0], socket_name_.c_str(), sizeof(sun->sun_path)); - } - - // This takes care of the socket address overflow. - if (slen > sizeof(sockaddr_un)) { - SAPI_RAW_LOG(ERROR, "Socket address is too long, will be truncated"); - slen = sizeof(sockaddr_un); - } - return slen; -} - bool Comms::Send(const void* data, size_t len) { size_t total_sent = 0; const char* bytes = reinterpret_cast(data); @@ -523,7 +538,7 @@ bool Comms::Send(const void* data, size_t len) { }; while (total_sent < len) { ssize_t s; - s = op(connection_fd_); + s = op(connection_fd_.get()); if (s == -1 && errno == EPIPE) { Terminate(); // We do not expect the other end to disappear. @@ -557,7 +572,7 @@ bool Comms::Recv(void* data, size_t len) { }; while (total_recv < len) { ssize_t s; - s = op(connection_fd_); + s = op(connection_fd_.get()); if (s == -1) { SAPI_RAW_PLOG(ERROR, "read"); if (IsFatalError(errno)) { @@ -675,12 +690,11 @@ bool Comms::SendStatus(const absl::Status& status) { } void Comms::MoveToAnotherFd() { - SAPI_RAW_CHECK(connection_fd_ != -1, + SAPI_RAW_CHECK(connection_fd_.get() != -1, "Cannot move comms fd as it's not connected"); - int new_fd = dup(connection_fd_); - SAPI_RAW_CHECK(new_fd != -1, "Failed to move comms to another fd"); - close(connection_fd_); - connection_fd_ = new_fd; + FDCloser new_fd(dup(connection_fd_.get())); + SAPI_RAW_CHECK(new_fd.get() != -1, "Failed to move comms to another fd"); + connection_fd_.Swap(new_fd); } } // namespace sandbox2 diff --git a/sandboxed_api/sandbox2/comms.h b/sandboxed_api/sandbox2/comms.h index c9d22d0..39f29ef 100644 --- a/sandboxed_api/sandbox2/comms.h +++ b/sandboxed_api/sandbox2/comms.h @@ -41,6 +41,7 @@ #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "google/protobuf/message_lite.h" +#include "sandboxed_api/util/fileops.h" namespace proto2 { class Message; @@ -49,6 +50,7 @@ class Message; namespace sandbox2 { class Client; +class ListeningComms; class Comms { public: @@ -93,8 +95,13 @@ class Comms { static constexpr const char* kSandbox2CommsFDEnvVar = "SANDBOX2_COMMS_FD"; + static absl::StatusOr Connect(const std::string& socket_name, + bool abstract_uds = true); + // This object will have to be connected later on. // When not specified the constructor uses abstract unix domain sockets. + ABSL_DEPRECATED( + "Use ListeningComms or absl::StatusOr Connect() instead") explicit Comms(const std::string& socket_name, bool abstract_uds = true); Comms(Comms&& other) { *this = std::move(other); } @@ -112,7 +119,7 @@ class Comms { // Instantiates a pre-connected object. // Takes ownership over fd, which will be closed on object's destruction. - explicit Comms(int fd); + explicit Comms(int fd, absl::string_view name = ""); // Instantiates a pre-connected object using the default connection params. explicit Comms(DefaultConnectionTag); @@ -120,12 +127,17 @@ class Comms { ~Comms(); // Binds to an address and make it listen to connections. + ABSL_DEPRECATED("Use ListeningComms::Create() instead") bool Listen(); // Accepts the connection. + ABSL_DEPRECATED("Use ListeningComms::Accept() instead") bool Accept(); // Connects to a remote socket. + ABSL_DEPRECATED( + "Use absl::StatusOr Comms::Connect(std::string& socket_name, bool " + "abstract_uds) instead") bool Connect(); // Terminates all underlying file descriptors, and sets the status of the @@ -201,11 +213,11 @@ class Comms { return; } using std::swap; - swap(socket_name_, other.socket_name_); + swap(name_, other.name_); swap(abstract_uds_, other.abstract_uds_); swap(connection_fd_, other.connection_fd_); - swap(bind_fd_, other.bind_fd_); swap(state_, other.state_); + swap(listening_comms_, other.listening_comms_); } friend void swap(Comms& x, Comms& y) { return x.Swap(y); } @@ -221,10 +233,11 @@ class Comms { }; // Connection parameters. - std::string socket_name_; + std::string name_; bool abstract_uds_ = true; - int connection_fd_ = -1; - int bind_fd_ = -1; + sapi::file_util::fileops::FDCloser connection_fd_; + + std::unique_ptr listening_comms_; // State of the channel (enum), socket will have to be connected later on. State state_ = State::kUnconnected; @@ -239,9 +252,6 @@ class Comms { size_t len; }; - // Fills sockaddr_un struct with proper values. - socklen_t CreateSockaddrUn(sockaddr_un* sun); - // Moves the comms fd to an other free file descriptor. void MoveToAnotherFd(); @@ -270,6 +280,26 @@ class Comms { } }; +class ListeningComms { + public: + static absl::StatusOr Create(absl::string_view socket_name, + bool abstract_uds = true); + + ListeningComms(ListeningComms&& other) = default; + ListeningComms& operator=(ListeningComms&& other) = default; + ~ListeningComms() = default; + absl::StatusOr Accept(); + + private: + ListeningComms(absl::string_view socket_name, bool abstract_uds) + : socket_name_(socket_name), abstract_uds_(abstract_uds), bind_fd_(-1) {} + absl::Status Listen(); + + std::string socket_name_; + bool abstract_uds_; + sapi::file_util::fileops::FDCloser bind_fd_; +}; + } // namespace sandbox2 #endif // SANDBOXED_API_SANDBOX2_COMMS_H_