Refactor Comms to split out listening/connecting part

Deprecated APIs slated for removal after migration of internal
clients.

PiperOrigin-RevId: 566598245
Change-Id: I5d7b920f3a788d4eccc6e78f239b660ba903adcc
This commit is contained in:
Wiktor Garbacz 2023-09-19 05:13:40 -07:00 committed by Copybara-Service
parent d26262d82e
commit 1cf45be7df
4 changed files with 186 additions and 140 deletions

View File

@ -835,6 +835,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":util",
"//sandboxed_api/util:fileops",
"//sandboxed_api/util:raw_logging",
"//sandboxed_api/util:status",
"//sandboxed_api/util:status_cc_proto",

View File

@ -763,6 +763,7 @@ target_link_libraries(sandbox2_comms
PUBLIC absl::core_headers
absl::status
protobuf::libprotobuf
sapi::fileops
sapi::status
)

View File

@ -43,11 +43,14 @@
#include "absl/status/statusor.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/message_lite.h"
#include "sandboxed_api/sandbox2/util.h"
#include "sandboxed_api/util/fileops.h"
#include "sandboxed_api/util/raw_logging.h"
#include "sandboxed_api/util/status.h"
#include "sandboxed_api/util/status.pb.h"
#include "sandboxed_api/util/status_macros.h"
namespace sandbox2 {
@ -59,6 +62,9 @@ class PotentiallyBlockingRegion {
};
namespace {
using sapi::file_util::fileops::FDCloser;
bool IsFatalError(int saved_errno) {
return saved_errno != EAGAIN && saved_errno != EWOULDBLOCK &&
saved_errno != EFAULT && saved_errno != EINTR &&
@ -74,16 +80,47 @@ int GetDefaultCommsFd() {
}
return Comms::kSandbox2ClientCommsFD;
}
socklen_t CreateSockaddrUn(const std::string& socket_name, bool abstract_uds,
sockaddr_un* sun) {
sun->sun_family = AF_UNIX;
bzero(sun->sun_path, sizeof(sun->sun_path));
socklen_t slen = sizeof(sun->sun_family) + strlen(socket_name.c_str());
if (abstract_uds) {
// Create an 'abstract socket address' by specifying a leading null byte.
// The remainder of the path is used as a unique name, but no file is
// created on the filesystem. No need to NUL-terminate the string. See `man
// 7 unix` for further explanation.
strncpy(&sun->sun_path[1], socket_name.c_str(), sizeof(sun->sun_path) - 1);
// Len is complicated - it's essentially size of the path, plus initial
// NUL-byte, minus size of the sun.sun_family.
slen++;
} else {
// Create the socket address as it was passed from the constructor.
strncpy(&sun->sun_path[0], socket_name.c_str(), sizeof(sun->sun_path));
}
// This takes care of the socket address overflow.
if (slen > sizeof(sockaddr_un)) {
SAPI_RAW_LOG(ERROR, "Socket address is too long, will be truncated");
slen = sizeof(sockaddr_un);
}
return slen;
}
} // namespace
Comms::Comms(const std::string& socket_name, bool abstract_uds)
: socket_name_(socket_name), abstract_uds_(abstract_uds) {}
: name_(socket_name), abstract_uds_(abstract_uds) {}
Comms::Comms(int fd) : connection_fd_(fd) {
Comms::Comms(int fd, absl::string_view name) : connection_fd_(fd) {
// Generate a unique and meaningful socket name for this FD.
// Note: getpid()/gettid() are non-blocking syscalls.
socket_name_ = absl::StrFormat("sandbox2::Comms:FD=%d/PID=%d/TID=%ld", fd,
getpid(), syscall(__NR_gettid));
if (name.empty()) {
name_ = absl::StrFormat("sandbox2::Comms:FD=%d/PID=%d/TID=%ld", fd,
getpid(), syscall(__NR_gettid));
} else {
name_ = std::string(name);
}
// File descriptor is already connected.
state_ = State::kConnected;
@ -94,7 +131,36 @@ Comms::Comms(Comms::DefaultConnectionTag) : Comms(GetDefaultCommsFd()) {}
Comms::~Comms() { Terminate(); }
int Comms::GetConnectionFD() const {
return connection_fd_;
return connection_fd_.get();
}
absl::StatusOr<ListeningComms> ListeningComms::Create(
absl::string_view socket_name, bool abstract_uds) {
ListeningComms comms(std::string(socket_name), abstract_uds);
SAPI_RETURN_IF_ERROR(comms.Listen());
return comms;
}
absl::Status ListeningComms::Listen() {
bind_fd_ = FDCloser(socket(AF_UNIX, SOCK_STREAM, 0)); // Non-blocking
if (bind_fd_.get() == -1) {
return absl::ErrnoToStatus(errno, "socket(AF_UNIX) failed");
}
sockaddr_un sus;
socklen_t slen = CreateSockaddrUn(socket_name_, abstract_uds_, &sus);
// bind() is non-blocking.
if (bind(bind_fd_.get(), reinterpret_cast<sockaddr*>(&sus), slen) == -1) {
return absl::ErrnoToStatus(errno, "bind failed");
}
// listen() non-blocking.
if (listen(bind_fd_.get(), 0) == -1) {
return absl::ErrnoToStatus(errno, "listen failed");
}
SAPI_RAW_VLOG(1, "Listening at: %s", socket_name_.c_str());
return absl::OkStatus();
}
bool Comms::Listen() {
@ -102,123 +168,98 @@ bool Comms::Listen() {
return true;
}
bind_fd_ = socket(AF_UNIX, SOCK_STREAM, 0); // Non-blocking
if (bind_fd_ == -1) {
SAPI_RAW_PLOG(ERROR, "socket(AF_UNIX)");
absl::StatusOr<ListeningComms> listening_comms =
ListeningComms::Create(name_, abstract_uds_);
if (!listening_comms.ok()) {
SAPI_RAW_LOG(ERROR, "Listening failed: %s",
std::string(listening_comms.status().message()).c_str());
return false;
}
sockaddr_un sus;
socklen_t slen = CreateSockaddrUn(&sus);
// bind() is non-blocking.
if (bind(bind_fd_, reinterpret_cast<sockaddr*>(&sus), slen) == -1) {
SAPI_RAW_PLOG(ERROR, "bind(bind_fd)");
// Note: checking for EINTR on close() syscall is useless and possibly
// harmful, see https://lwn.net/Articles/576478/.
{
PotentiallyBlockingRegion region;
close(bind_fd_);
}
bind_fd_ = -1;
return false;
}
// listen() non-blocking.
if (listen(bind_fd_, 0) == -1) {
SAPI_RAW_PLOG(ERROR, "listen(bind_fd)");
{
PotentiallyBlockingRegion region;
close(bind_fd_);
}
bind_fd_ = -1;
return false;
}
SAPI_RAW_VLOG(1, "Listening at: %s", socket_name_.c_str());
listening_comms_ =
std::make_unique<ListeningComms>(*std::move(listening_comms));
return true;
}
absl::StatusOr<Comms> ListeningComms::Accept() {
sockaddr_un suc;
socklen_t len = sizeof(suc);
int connection_fd;
{
PotentiallyBlockingRegion region;
connection_fd = TEMP_FAILURE_RETRY(
accept(bind_fd_.get(), reinterpret_cast<sockaddr*>(&suc), &len));
}
if (connection_fd == -1) {
return absl::ErrnoToStatus(errno, "accept failed");
}
SAPI_RAW_VLOG(1, "Accepted connection at: %s, fd: %d", socket_name_.c_str(),
connection_fd);
return Comms(connection_fd, socket_name_);
}
bool Comms::Accept() {
if (IsConnected()) {
return true;
}
sockaddr_un suc;
socklen_t len = sizeof(suc);
{
PotentiallyBlockingRegion region;
connection_fd_ = TEMP_FAILURE_RETRY(
accept(bind_fd_, reinterpret_cast<sockaddr*>(&suc), &len));
}
if (connection_fd_ == -1) {
SAPI_RAW_PLOG(ERROR, "accept(bind_fd)");
{
PotentiallyBlockingRegion region;
close(bind_fd_);
}
bind_fd_ = -1;
if (listening_comms_ == nullptr) {
SAPI_RAW_LOG(ERROR, "Comms::Listen needs to be called first");
return false;
}
state_ = State::kConnected;
SAPI_RAW_VLOG(1, "Accepted connection at: %s, fd: %d", socket_name_.c_str(),
connection_fd_);
absl::StatusOr<Comms> comms = listening_comms_->Accept();
if (!comms.ok()) {
SAPI_RAW_LOG(ERROR, "%s", std::string(comms.status().message()).c_str());
return false;
}
*this = *std::move(comms);
return true;
}
absl::StatusOr<Comms> Comms::Connect(const std::string& socket_name,
bool abstract_uds) {
FDCloser connection_fd(socket(AF_UNIX, SOCK_STREAM, 0)); // Non-blocking
if (connection_fd.get() == -1) {
return absl::ErrnoToStatus(errno, "socket(AF_UNIX)");
}
sockaddr_un suc;
socklen_t slen = CreateSockaddrUn(socket_name, abstract_uds, &suc);
int ret;
{
PotentiallyBlockingRegion region;
ret = TEMP_FAILURE_RETRY(
connect(connection_fd.get(), reinterpret_cast<sockaddr*>(&suc), slen));
}
if (ret == -1) {
return absl::ErrnoToStatus(errno, "connect(connection_fd)");
}
SAPI_RAW_VLOG(1, "Connected to: %s, fd: %d", socket_name.c_str(),
connection_fd.get());
return Comms(connection_fd.Release(), socket_name);
}
bool Comms::Connect() {
if (IsConnected()) {
return true;
}
connection_fd_ = socket(AF_UNIX, SOCK_STREAM, 0); // Non-blocking
if (connection_fd_ == -1) {
SAPI_RAW_PLOG(ERROR, "socket(AF_UNIX)");
absl::StatusOr<Comms> connected = Connect(name_, abstract_uds_);
if (!connected.ok()) {
SAPI_RAW_LOG(ERROR, "%s",
std::string(connected.status().message()).c_str());
return false;
}
sockaddr_un suc;
socklen_t slen = CreateSockaddrUn(&suc);
int ret;
{
PotentiallyBlockingRegion region;
ret = TEMP_FAILURE_RETRY(
connect(connection_fd_, reinterpret_cast<sockaddr*>(&suc), slen));
}
if (ret == -1) {
SAPI_RAW_PLOG(ERROR, "connect(connection_fd)");
{
PotentiallyBlockingRegion region;
close(connection_fd_);
}
connection_fd_ = -1;
return false;
}
state_ = State::kConnected;
SAPI_RAW_VLOG(1, "Connected to: %s, fd: %d", socket_name_.c_str(),
connection_fd_);
*this = *std::move(connected);
return true;
}
void Comms::Terminate() {
{
PotentiallyBlockingRegion region;
state_ = State::kTerminated;
state_ = State::kTerminated;
if (bind_fd_ != -1) {
close(bind_fd_);
bind_fd_ = -1;
}
if (connection_fd_ != -1) {
close(connection_fd_);
connection_fd_ = -1;
}
}
connection_fd_.Close();
listening_comms_.reset();
}
bool Comms::SendTLV(uint32_t tag, size_t length, const void* value) {
@ -349,7 +390,7 @@ bool Comms::RecvFD(int* fd) {
util::Syscall(__NR_recvmsg, fd, reinterpret_cast<uintptr_t>(&msg), 0));
};
ssize_t len;
len = op(connection_fd_);
len = op(connection_fd_.get());
if (len < 0) {
if (IsFatalError(errno)) {
Terminate();
@ -432,7 +473,7 @@ bool Comms::SendFD(int fd) {
util::Syscall(__NR_sendmsg, fd, reinterpret_cast<uintptr_t>(&msg), 0));
};
ssize_t len;
len = op(connection_fd_);
len = op(connection_fd_.get());
if (len == -1 && errno == EPIPE) {
Terminate();
SAPI_RAW_LOG(ERROR, "sendmsg(SCM_RIGHTS): Peer disconnected");
@ -458,10 +499,10 @@ bool Comms::RecvProtoBuf(google::protobuf::MessageLite* message) {
std::vector<uint8_t> bytes;
if (!RecvTLV(&tag, &bytes)) {
if (IsConnected()) {
SAPI_RAW_PLOG(ERROR, "RecvProtoBuf failed for (%s)", socket_name_);
SAPI_RAW_PLOG(ERROR, "RecvProtoBuf failed for (%s)", name_);
} else {
Terminate();
SAPI_RAW_VLOG(2, "Connection terminated (%s)", socket_name_.c_str());
SAPI_RAW_VLOG(2, "Connection terminated (%s)", name_.c_str());
}
return false;
}
@ -488,32 +529,6 @@ bool Comms::SendProtoBuf(const google::protobuf::MessageLite& message) {
// All methods below are private, for internal use only.
// *****************************************************************************
socklen_t Comms::CreateSockaddrUn(sockaddr_un* sun) {
sun->sun_family = AF_UNIX;
bzero(sun->sun_path, sizeof(sun->sun_path));
socklen_t slen = sizeof(sun->sun_family) + strlen(socket_name_.c_str());
if (abstract_uds_) {
// Create an 'abstract socket address' by specifying a leading null byte.
// The remainder of the path is used as a unique name, but no file is
// created on the filesystem. No need to NUL-terminate the string. See `man
// 7 unix` for further explanation.
strncpy(&sun->sun_path[1], socket_name_.c_str(), sizeof(sun->sun_path) - 1);
// Len is complicated - it's essentially size of the path, plus initial
// NUL-byte, minus size of the sun.sun_family.
slen++;
} else {
// Create the socket address as it was passed from the constructor.
strncpy(&sun->sun_path[0], socket_name_.c_str(), sizeof(sun->sun_path));
}
// This takes care of the socket address overflow.
if (slen > sizeof(sockaddr_un)) {
SAPI_RAW_LOG(ERROR, "Socket address is too long, will be truncated");
slen = sizeof(sockaddr_un);
}
return slen;
}
bool Comms::Send(const void* data, size_t len) {
size_t total_sent = 0;
const char* bytes = reinterpret_cast<const char*>(data);
@ -523,7 +538,7 @@ bool Comms::Send(const void* data, size_t len) {
};
while (total_sent < len) {
ssize_t s;
s = op(connection_fd_);
s = op(connection_fd_.get());
if (s == -1 && errno == EPIPE) {
Terminate();
// We do not expect the other end to disappear.
@ -557,7 +572,7 @@ bool Comms::Recv(void* data, size_t len) {
};
while (total_recv < len) {
ssize_t s;
s = op(connection_fd_);
s = op(connection_fd_.get());
if (s == -1) {
SAPI_RAW_PLOG(ERROR, "read");
if (IsFatalError(errno)) {
@ -675,12 +690,11 @@ bool Comms::SendStatus(const absl::Status& status) {
}
void Comms::MoveToAnotherFd() {
SAPI_RAW_CHECK(connection_fd_ != -1,
SAPI_RAW_CHECK(connection_fd_.get() != -1,
"Cannot move comms fd as it's not connected");
int new_fd = dup(connection_fd_);
SAPI_RAW_CHECK(new_fd != -1, "Failed to move comms to another fd");
close(connection_fd_);
connection_fd_ = new_fd;
FDCloser new_fd(dup(connection_fd_.get()));
SAPI_RAW_CHECK(new_fd.get() != -1, "Failed to move comms to another fd");
connection_fd_.Swap(new_fd);
}
} // namespace sandbox2

View File

@ -41,6 +41,7 @@
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/message_lite.h"
#include "sandboxed_api/util/fileops.h"
namespace proto2 {
class Message;
@ -49,6 +50,7 @@ class Message;
namespace sandbox2 {
class Client;
class ListeningComms;
class Comms {
public:
@ -93,8 +95,13 @@ class Comms {
static constexpr const char* kSandbox2CommsFDEnvVar = "SANDBOX2_COMMS_FD";
static absl::StatusOr<Comms> Connect(const std::string& socket_name,
bool abstract_uds = true);
// This object will have to be connected later on.
// When not specified the constructor uses abstract unix domain sockets.
ABSL_DEPRECATED(
"Use ListeningComms or absl::StatusOr<Comms> Connect() instead")
explicit Comms(const std::string& socket_name, bool abstract_uds = true);
Comms(Comms&& other) { *this = std::move(other); }
@ -112,7 +119,7 @@ class Comms {
// Instantiates a pre-connected object.
// Takes ownership over fd, which will be closed on object's destruction.
explicit Comms(int fd);
explicit Comms(int fd, absl::string_view name = "");
// Instantiates a pre-connected object using the default connection params.
explicit Comms(DefaultConnectionTag);
@ -120,12 +127,17 @@ class Comms {
~Comms();
// Binds to an address and make it listen to connections.
ABSL_DEPRECATED("Use ListeningComms::Create() instead")
bool Listen();
// Accepts the connection.
ABSL_DEPRECATED("Use ListeningComms::Accept() instead")
bool Accept();
// Connects to a remote socket.
ABSL_DEPRECATED(
"Use absl::StatusOr<Comms> Comms::Connect(std::string& socket_name, bool "
"abstract_uds) instead")
bool Connect();
// Terminates all underlying file descriptors, and sets the status of the
@ -201,11 +213,11 @@ class Comms {
return;
}
using std::swap;
swap(socket_name_, other.socket_name_);
swap(name_, other.name_);
swap(abstract_uds_, other.abstract_uds_);
swap(connection_fd_, other.connection_fd_);
swap(bind_fd_, other.bind_fd_);
swap(state_, other.state_);
swap(listening_comms_, other.listening_comms_);
}
friend void swap(Comms& x, Comms& y) { return x.Swap(y); }
@ -221,10 +233,11 @@ class Comms {
};
// Connection parameters.
std::string socket_name_;
std::string name_;
bool abstract_uds_ = true;
int connection_fd_ = -1;
int bind_fd_ = -1;
sapi::file_util::fileops::FDCloser connection_fd_;
std::unique_ptr<ListeningComms> listening_comms_;
// State of the channel (enum), socket will have to be connected later on.
State state_ = State::kUnconnected;
@ -239,9 +252,6 @@ class Comms {
size_t len;
};
// Fills sockaddr_un struct with proper values.
socklen_t CreateSockaddrUn(sockaddr_un* sun);
// Moves the comms fd to an other free file descriptor.
void MoveToAnotherFd();
@ -270,6 +280,26 @@ class Comms {
}
};
class ListeningComms {
public:
static absl::StatusOr<ListeningComms> Create(absl::string_view socket_name,
bool abstract_uds = true);
ListeningComms(ListeningComms&& other) = default;
ListeningComms& operator=(ListeningComms&& other) = default;
~ListeningComms() = default;
absl::StatusOr<Comms> Accept();
private:
ListeningComms(absl::string_view socket_name, bool abstract_uds)
: socket_name_(socket_name), abstract_uds_(abstract_uds), bind_fd_(-1) {}
absl::Status Listen();
std::string socket_name_;
bool abstract_uds_;
sapi::file_util::fileops::FDCloser bind_fd_;
};
} // namespace sandbox2
#endif // SANDBOXED_API_SANDBOX2_COMMS_H_