Refactor Comms to split out listening/connecting part

Deprecated APIs slated for removal after migration of internal
clients.

PiperOrigin-RevId: 566598245
Change-Id: I5d7b920f3a788d4eccc6e78f239b660ba903adcc
This commit is contained in:
Wiktor Garbacz 2023-09-19 05:13:40 -07:00 committed by Copybara-Service
parent d26262d82e
commit 1cf45be7df
4 changed files with 186 additions and 140 deletions

View File

@ -835,6 +835,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":util", ":util",
"//sandboxed_api/util:fileops",
"//sandboxed_api/util:raw_logging", "//sandboxed_api/util:raw_logging",
"//sandboxed_api/util:status", "//sandboxed_api/util:status",
"//sandboxed_api/util:status_cc_proto", "//sandboxed_api/util:status_cc_proto",

View File

@ -763,6 +763,7 @@ target_link_libraries(sandbox2_comms
PUBLIC absl::core_headers PUBLIC absl::core_headers
absl::status absl::status
protobuf::libprotobuf protobuf::libprotobuf
sapi::fileops
sapi::status sapi::status
) )

View File

@ -43,11 +43,14 @@
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/message_lite.h" #include "google/protobuf/message_lite.h"
#include "sandboxed_api/sandbox2/util.h" #include "sandboxed_api/sandbox2/util.h"
#include "sandboxed_api/util/fileops.h"
#include "sandboxed_api/util/raw_logging.h" #include "sandboxed_api/util/raw_logging.h"
#include "sandboxed_api/util/status.h" #include "sandboxed_api/util/status.h"
#include "sandboxed_api/util/status.pb.h" #include "sandboxed_api/util/status.pb.h"
#include "sandboxed_api/util/status_macros.h"
namespace sandbox2 { namespace sandbox2 {
@ -59,6 +62,9 @@ class PotentiallyBlockingRegion {
}; };
namespace { namespace {
using sapi::file_util::fileops::FDCloser;
bool IsFatalError(int saved_errno) { bool IsFatalError(int saved_errno) {
return saved_errno != EAGAIN && saved_errno != EWOULDBLOCK && return saved_errno != EAGAIN && saved_errno != EWOULDBLOCK &&
saved_errno != EFAULT && saved_errno != EINTR && saved_errno != EFAULT && saved_errno != EINTR &&
@ -74,16 +80,47 @@ int GetDefaultCommsFd() {
} }
return Comms::kSandbox2ClientCommsFD; return Comms::kSandbox2ClientCommsFD;
} }
socklen_t CreateSockaddrUn(const std::string& socket_name, bool abstract_uds,
sockaddr_un* sun) {
sun->sun_family = AF_UNIX;
bzero(sun->sun_path, sizeof(sun->sun_path));
socklen_t slen = sizeof(sun->sun_family) + strlen(socket_name.c_str());
if (abstract_uds) {
// Create an 'abstract socket address' by specifying a leading null byte.
// The remainder of the path is used as a unique name, but no file is
// created on the filesystem. No need to NUL-terminate the string. See `man
// 7 unix` for further explanation.
strncpy(&sun->sun_path[1], socket_name.c_str(), sizeof(sun->sun_path) - 1);
// Len is complicated - it's essentially size of the path, plus initial
// NUL-byte, minus size of the sun.sun_family.
slen++;
} else {
// Create the socket address as it was passed from the constructor.
strncpy(&sun->sun_path[0], socket_name.c_str(), sizeof(sun->sun_path));
}
// This takes care of the socket address overflow.
if (slen > sizeof(sockaddr_un)) {
SAPI_RAW_LOG(ERROR, "Socket address is too long, will be truncated");
slen = sizeof(sockaddr_un);
}
return slen;
}
} // namespace } // namespace
Comms::Comms(const std::string& socket_name, bool abstract_uds) Comms::Comms(const std::string& socket_name, bool abstract_uds)
: socket_name_(socket_name), abstract_uds_(abstract_uds) {} : name_(socket_name), abstract_uds_(abstract_uds) {}
Comms::Comms(int fd) : connection_fd_(fd) { Comms::Comms(int fd, absl::string_view name) : connection_fd_(fd) {
// Generate a unique and meaningful socket name for this FD. // Generate a unique and meaningful socket name for this FD.
// Note: getpid()/gettid() are non-blocking syscalls. // Note: getpid()/gettid() are non-blocking syscalls.
socket_name_ = absl::StrFormat("sandbox2::Comms:FD=%d/PID=%d/TID=%ld", fd, if (name.empty()) {
name_ = absl::StrFormat("sandbox2::Comms:FD=%d/PID=%d/TID=%ld", fd,
getpid(), syscall(__NR_gettid)); getpid(), syscall(__NR_gettid));
} else {
name_ = std::string(name);
}
// File descriptor is already connected. // File descriptor is already connected.
state_ = State::kConnected; state_ = State::kConnected;
@ -94,7 +131,36 @@ Comms::Comms(Comms::DefaultConnectionTag) : Comms(GetDefaultCommsFd()) {}
Comms::~Comms() { Terminate(); } Comms::~Comms() { Terminate(); }
int Comms::GetConnectionFD() const { int Comms::GetConnectionFD() const {
return connection_fd_; return connection_fd_.get();
}
absl::StatusOr<ListeningComms> ListeningComms::Create(
absl::string_view socket_name, bool abstract_uds) {
ListeningComms comms(std::string(socket_name), abstract_uds);
SAPI_RETURN_IF_ERROR(comms.Listen());
return comms;
}
absl::Status ListeningComms::Listen() {
bind_fd_ = FDCloser(socket(AF_UNIX, SOCK_STREAM, 0)); // Non-blocking
if (bind_fd_.get() == -1) {
return absl::ErrnoToStatus(errno, "socket(AF_UNIX) failed");
}
sockaddr_un sus;
socklen_t slen = CreateSockaddrUn(socket_name_, abstract_uds_, &sus);
// bind() is non-blocking.
if (bind(bind_fd_.get(), reinterpret_cast<sockaddr*>(&sus), slen) == -1) {
return absl::ErrnoToStatus(errno, "bind failed");
}
// listen() non-blocking.
if (listen(bind_fd_.get(), 0) == -1) {
return absl::ErrnoToStatus(errno, "listen failed");
}
SAPI_RAW_VLOG(1, "Listening at: %s", socket_name_.c_str());
return absl::OkStatus();
} }
bool Comms::Listen() { bool Comms::Listen() {
@ -102,123 +168,98 @@ bool Comms::Listen() {
return true; return true;
} }
bind_fd_ = socket(AF_UNIX, SOCK_STREAM, 0); // Non-blocking absl::StatusOr<ListeningComms> listening_comms =
if (bind_fd_ == -1) { ListeningComms::Create(name_, abstract_uds_);
SAPI_RAW_PLOG(ERROR, "socket(AF_UNIX)"); if (!listening_comms.ok()) {
SAPI_RAW_LOG(ERROR, "Listening failed: %s",
std::string(listening_comms.status().message()).c_str());
return false; return false;
} }
listening_comms_ =
sockaddr_un sus; std::make_unique<ListeningComms>(*std::move(listening_comms));
socklen_t slen = CreateSockaddrUn(&sus);
// bind() is non-blocking.
if (bind(bind_fd_, reinterpret_cast<sockaddr*>(&sus), slen) == -1) {
SAPI_RAW_PLOG(ERROR, "bind(bind_fd)");
// Note: checking for EINTR on close() syscall is useless and possibly
// harmful, see https://lwn.net/Articles/576478/.
{
PotentiallyBlockingRegion region;
close(bind_fd_);
}
bind_fd_ = -1;
return false;
}
// listen() non-blocking.
if (listen(bind_fd_, 0) == -1) {
SAPI_RAW_PLOG(ERROR, "listen(bind_fd)");
{
PotentiallyBlockingRegion region;
close(bind_fd_);
}
bind_fd_ = -1;
return false;
}
SAPI_RAW_VLOG(1, "Listening at: %s", socket_name_.c_str());
return true; return true;
} }
absl::StatusOr<Comms> ListeningComms::Accept() {
sockaddr_un suc;
socklen_t len = sizeof(suc);
int connection_fd;
{
PotentiallyBlockingRegion region;
connection_fd = TEMP_FAILURE_RETRY(
accept(bind_fd_.get(), reinterpret_cast<sockaddr*>(&suc), &len));
}
if (connection_fd == -1) {
return absl::ErrnoToStatus(errno, "accept failed");
}
SAPI_RAW_VLOG(1, "Accepted connection at: %s, fd: %d", socket_name_.c_str(),
connection_fd);
return Comms(connection_fd, socket_name_);
}
bool Comms::Accept() { bool Comms::Accept() {
if (IsConnected()) { if (IsConnected()) {
return true; return true;
} }
sockaddr_un suc; if (listening_comms_ == nullptr) {
socklen_t len = sizeof(suc); SAPI_RAW_LOG(ERROR, "Comms::Listen needs to be called first");
{
PotentiallyBlockingRegion region;
connection_fd_ = TEMP_FAILURE_RETRY(
accept(bind_fd_, reinterpret_cast<sockaddr*>(&suc), &len));
}
if (connection_fd_ == -1) {
SAPI_RAW_PLOG(ERROR, "accept(bind_fd)");
{
PotentiallyBlockingRegion region;
close(bind_fd_);
}
bind_fd_ = -1;
return false; return false;
} }
state_ = State::kConnected; absl::StatusOr<Comms> comms = listening_comms_->Accept();
if (!comms.ok()) {
SAPI_RAW_VLOG(1, "Accepted connection at: %s, fd: %d", socket_name_.c_str(), SAPI_RAW_LOG(ERROR, "%s", std::string(comms.status().message()).c_str());
connection_fd_); return false;
}
*this = *std::move(comms);
return true; return true;
} }
absl::StatusOr<Comms> Comms::Connect(const std::string& socket_name,
bool abstract_uds) {
FDCloser connection_fd(socket(AF_UNIX, SOCK_STREAM, 0)); // Non-blocking
if (connection_fd.get() == -1) {
return absl::ErrnoToStatus(errno, "socket(AF_UNIX)");
}
sockaddr_un suc;
socklen_t slen = CreateSockaddrUn(socket_name, abstract_uds, &suc);
int ret;
{
PotentiallyBlockingRegion region;
ret = TEMP_FAILURE_RETRY(
connect(connection_fd.get(), reinterpret_cast<sockaddr*>(&suc), slen));
}
if (ret == -1) {
return absl::ErrnoToStatus(errno, "connect(connection_fd)");
}
SAPI_RAW_VLOG(1, "Connected to: %s, fd: %d", socket_name.c_str(),
connection_fd.get());
return Comms(connection_fd.Release(), socket_name);
}
bool Comms::Connect() { bool Comms::Connect() {
if (IsConnected()) { if (IsConnected()) {
return true; return true;
} }
connection_fd_ = socket(AF_UNIX, SOCK_STREAM, 0); // Non-blocking absl::StatusOr<Comms> connected = Connect(name_, abstract_uds_);
if (connection_fd_ == -1) { if (!connected.ok()) {
SAPI_RAW_PLOG(ERROR, "socket(AF_UNIX)"); SAPI_RAW_LOG(ERROR, "%s",
std::string(connected.status().message()).c_str());
return false; return false;
} }
*this = *std::move(connected);
sockaddr_un suc;
socklen_t slen = CreateSockaddrUn(&suc);
int ret;
{
PotentiallyBlockingRegion region;
ret = TEMP_FAILURE_RETRY(
connect(connection_fd_, reinterpret_cast<sockaddr*>(&suc), slen));
}
if (ret == -1) {
SAPI_RAW_PLOG(ERROR, "connect(connection_fd)");
{
PotentiallyBlockingRegion region;
close(connection_fd_);
}
connection_fd_ = -1;
return false;
}
state_ = State::kConnected;
SAPI_RAW_VLOG(1, "Connected to: %s, fd: %d", socket_name_.c_str(),
connection_fd_);
return true; return true;
} }
void Comms::Terminate() { void Comms::Terminate() {
{
PotentiallyBlockingRegion region;
state_ = State::kTerminated; state_ = State::kTerminated;
if (bind_fd_ != -1) { connection_fd_.Close();
close(bind_fd_); listening_comms_.reset();
bind_fd_ = -1;
}
if (connection_fd_ != -1) {
close(connection_fd_);
connection_fd_ = -1;
}
}
} }
bool Comms::SendTLV(uint32_t tag, size_t length, const void* value) { bool Comms::SendTLV(uint32_t tag, size_t length, const void* value) {
@ -349,7 +390,7 @@ bool Comms::RecvFD(int* fd) {
util::Syscall(__NR_recvmsg, fd, reinterpret_cast<uintptr_t>(&msg), 0)); util::Syscall(__NR_recvmsg, fd, reinterpret_cast<uintptr_t>(&msg), 0));
}; };
ssize_t len; ssize_t len;
len = op(connection_fd_); len = op(connection_fd_.get());
if (len < 0) { if (len < 0) {
if (IsFatalError(errno)) { if (IsFatalError(errno)) {
Terminate(); Terminate();
@ -432,7 +473,7 @@ bool Comms::SendFD(int fd) {
util::Syscall(__NR_sendmsg, fd, reinterpret_cast<uintptr_t>(&msg), 0)); util::Syscall(__NR_sendmsg, fd, reinterpret_cast<uintptr_t>(&msg), 0));
}; };
ssize_t len; ssize_t len;
len = op(connection_fd_); len = op(connection_fd_.get());
if (len == -1 && errno == EPIPE) { if (len == -1 && errno == EPIPE) {
Terminate(); Terminate();
SAPI_RAW_LOG(ERROR, "sendmsg(SCM_RIGHTS): Peer disconnected"); SAPI_RAW_LOG(ERROR, "sendmsg(SCM_RIGHTS): Peer disconnected");
@ -458,10 +499,10 @@ bool Comms::RecvProtoBuf(google::protobuf::MessageLite* message) {
std::vector<uint8_t> bytes; std::vector<uint8_t> bytes;
if (!RecvTLV(&tag, &bytes)) { if (!RecvTLV(&tag, &bytes)) {
if (IsConnected()) { if (IsConnected()) {
SAPI_RAW_PLOG(ERROR, "RecvProtoBuf failed for (%s)", socket_name_); SAPI_RAW_PLOG(ERROR, "RecvProtoBuf failed for (%s)", name_);
} else { } else {
Terminate(); Terminate();
SAPI_RAW_VLOG(2, "Connection terminated (%s)", socket_name_.c_str()); SAPI_RAW_VLOG(2, "Connection terminated (%s)", name_.c_str());
} }
return false; return false;
} }
@ -488,32 +529,6 @@ bool Comms::SendProtoBuf(const google::protobuf::MessageLite& message) {
// All methods below are private, for internal use only. // All methods below are private, for internal use only.
// ***************************************************************************** // *****************************************************************************
socklen_t Comms::CreateSockaddrUn(sockaddr_un* sun) {
sun->sun_family = AF_UNIX;
bzero(sun->sun_path, sizeof(sun->sun_path));
socklen_t slen = sizeof(sun->sun_family) + strlen(socket_name_.c_str());
if (abstract_uds_) {
// Create an 'abstract socket address' by specifying a leading null byte.
// The remainder of the path is used as a unique name, but no file is
// created on the filesystem. No need to NUL-terminate the string. See `man
// 7 unix` for further explanation.
strncpy(&sun->sun_path[1], socket_name_.c_str(), sizeof(sun->sun_path) - 1);
// Len is complicated - it's essentially size of the path, plus initial
// NUL-byte, minus size of the sun.sun_family.
slen++;
} else {
// Create the socket address as it was passed from the constructor.
strncpy(&sun->sun_path[0], socket_name_.c_str(), sizeof(sun->sun_path));
}
// This takes care of the socket address overflow.
if (slen > sizeof(sockaddr_un)) {
SAPI_RAW_LOG(ERROR, "Socket address is too long, will be truncated");
slen = sizeof(sockaddr_un);
}
return slen;
}
bool Comms::Send(const void* data, size_t len) { bool Comms::Send(const void* data, size_t len) {
size_t total_sent = 0; size_t total_sent = 0;
const char* bytes = reinterpret_cast<const char*>(data); const char* bytes = reinterpret_cast<const char*>(data);
@ -523,7 +538,7 @@ bool Comms::Send(const void* data, size_t len) {
}; };
while (total_sent < len) { while (total_sent < len) {
ssize_t s; ssize_t s;
s = op(connection_fd_); s = op(connection_fd_.get());
if (s == -1 && errno == EPIPE) { if (s == -1 && errno == EPIPE) {
Terminate(); Terminate();
// We do not expect the other end to disappear. // We do not expect the other end to disappear.
@ -557,7 +572,7 @@ bool Comms::Recv(void* data, size_t len) {
}; };
while (total_recv < len) { while (total_recv < len) {
ssize_t s; ssize_t s;
s = op(connection_fd_); s = op(connection_fd_.get());
if (s == -1) { if (s == -1) {
SAPI_RAW_PLOG(ERROR, "read"); SAPI_RAW_PLOG(ERROR, "read");
if (IsFatalError(errno)) { if (IsFatalError(errno)) {
@ -675,12 +690,11 @@ bool Comms::SendStatus(const absl::Status& status) {
} }
void Comms::MoveToAnotherFd() { void Comms::MoveToAnotherFd() {
SAPI_RAW_CHECK(connection_fd_ != -1, SAPI_RAW_CHECK(connection_fd_.get() != -1,
"Cannot move comms fd as it's not connected"); "Cannot move comms fd as it's not connected");
int new_fd = dup(connection_fd_); FDCloser new_fd(dup(connection_fd_.get()));
SAPI_RAW_CHECK(new_fd != -1, "Failed to move comms to another fd"); SAPI_RAW_CHECK(new_fd.get() != -1, "Failed to move comms to another fd");
close(connection_fd_); connection_fd_.Swap(new_fd);
connection_fd_ = new_fd;
} }
} // namespace sandbox2 } // namespace sandbox2

View File

@ -41,6 +41,7 @@
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "google/protobuf/message_lite.h" #include "google/protobuf/message_lite.h"
#include "sandboxed_api/util/fileops.h"
namespace proto2 { namespace proto2 {
class Message; class Message;
@ -49,6 +50,7 @@ class Message;
namespace sandbox2 { namespace sandbox2 {
class Client; class Client;
class ListeningComms;
class Comms { class Comms {
public: public:
@ -93,8 +95,13 @@ class Comms {
static constexpr const char* kSandbox2CommsFDEnvVar = "SANDBOX2_COMMS_FD"; static constexpr const char* kSandbox2CommsFDEnvVar = "SANDBOX2_COMMS_FD";
static absl::StatusOr<Comms> Connect(const std::string& socket_name,
bool abstract_uds = true);
// This object will have to be connected later on. // This object will have to be connected later on.
// When not specified the constructor uses abstract unix domain sockets. // When not specified the constructor uses abstract unix domain sockets.
ABSL_DEPRECATED(
"Use ListeningComms or absl::StatusOr<Comms> Connect() instead")
explicit Comms(const std::string& socket_name, bool abstract_uds = true); explicit Comms(const std::string& socket_name, bool abstract_uds = true);
Comms(Comms&& other) { *this = std::move(other); } Comms(Comms&& other) { *this = std::move(other); }
@ -112,7 +119,7 @@ class Comms {
// Instantiates a pre-connected object. // Instantiates a pre-connected object.
// Takes ownership over fd, which will be closed on object's destruction. // Takes ownership over fd, which will be closed on object's destruction.
explicit Comms(int fd); explicit Comms(int fd, absl::string_view name = "");
// Instantiates a pre-connected object using the default connection params. // Instantiates a pre-connected object using the default connection params.
explicit Comms(DefaultConnectionTag); explicit Comms(DefaultConnectionTag);
@ -120,12 +127,17 @@ class Comms {
~Comms(); ~Comms();
// Binds to an address and make it listen to connections. // Binds to an address and make it listen to connections.
ABSL_DEPRECATED("Use ListeningComms::Create() instead")
bool Listen(); bool Listen();
// Accepts the connection. // Accepts the connection.
ABSL_DEPRECATED("Use ListeningComms::Accept() instead")
bool Accept(); bool Accept();
// Connects to a remote socket. // Connects to a remote socket.
ABSL_DEPRECATED(
"Use absl::StatusOr<Comms> Comms::Connect(std::string& socket_name, bool "
"abstract_uds) instead")
bool Connect(); bool Connect();
// Terminates all underlying file descriptors, and sets the status of the // Terminates all underlying file descriptors, and sets the status of the
@ -201,11 +213,11 @@ class Comms {
return; return;
} }
using std::swap; using std::swap;
swap(socket_name_, other.socket_name_); swap(name_, other.name_);
swap(abstract_uds_, other.abstract_uds_); swap(abstract_uds_, other.abstract_uds_);
swap(connection_fd_, other.connection_fd_); swap(connection_fd_, other.connection_fd_);
swap(bind_fd_, other.bind_fd_);
swap(state_, other.state_); swap(state_, other.state_);
swap(listening_comms_, other.listening_comms_);
} }
friend void swap(Comms& x, Comms& y) { return x.Swap(y); } friend void swap(Comms& x, Comms& y) { return x.Swap(y); }
@ -221,10 +233,11 @@ class Comms {
}; };
// Connection parameters. // Connection parameters.
std::string socket_name_; std::string name_;
bool abstract_uds_ = true; bool abstract_uds_ = true;
int connection_fd_ = -1; sapi::file_util::fileops::FDCloser connection_fd_;
int bind_fd_ = -1;
std::unique_ptr<ListeningComms> listening_comms_;
// State of the channel (enum), socket will have to be connected later on. // State of the channel (enum), socket will have to be connected later on.
State state_ = State::kUnconnected; State state_ = State::kUnconnected;
@ -239,9 +252,6 @@ class Comms {
size_t len; size_t len;
}; };
// Fills sockaddr_un struct with proper values.
socklen_t CreateSockaddrUn(sockaddr_un* sun);
// Moves the comms fd to an other free file descriptor. // Moves the comms fd to an other free file descriptor.
void MoveToAnotherFd(); void MoveToAnotherFd();
@ -270,6 +280,26 @@ class Comms {
} }
}; };
class ListeningComms {
public:
static absl::StatusOr<ListeningComms> Create(absl::string_view socket_name,
bool abstract_uds = true);
ListeningComms(ListeningComms&& other) = default;
ListeningComms& operator=(ListeningComms&& other) = default;
~ListeningComms() = default;
absl::StatusOr<Comms> Accept();
private:
ListeningComms(absl::string_view socket_name, bool abstract_uds)
: socket_name_(socket_name), abstract_uds_(abstract_uds), bind_fd_(-1) {}
absl::Status Listen();
std::string socket_name_;
bool abstract_uds_;
sapi::file_util::fileops::FDCloser bind_fd_;
};
} // namespace sandbox2 } // namespace sandbox2
#endif // SANDBOXED_API_SANDBOX2_COMMS_H_ #endif // SANDBOXED_API_SANDBOX2_COMMS_H_