From 6cbaaead8b53421454b533dc9b13098a3c19fc52 Mon Sep 17 00:00:00 2001 From: Wiktor Garbacz Date: Tue, 23 Apr 2019 10:29:37 -0700 Subject: [PATCH] Make StatusMatcher more flexible PiperOrigin-RevId: 244879203 Change-Id: I5f7994130a898e84f041b18c0b5313d7e8b32780 --- sandboxed_api/rpcchannel.cc | 86 +++++++++---------- sandboxed_api/rpcchannel.h | 24 +++--- sandboxed_api/sandbox2/BUILD.bazel | 1 + sandboxed_api/sandbox2/comms.cc | 2 +- sandboxed_api/sandbox2/comms.h | 1 + sandboxed_api/sandbox2/comms_test.cc | 2 +- sandboxed_api/sandbox2/util/temp_file.cc | 10 +-- sandboxed_api/sandbox2/util/temp_file.h | 6 +- sandboxed_api/sandbox2/util/temp_file_test.cc | 1 + sandboxed_api/util/status_matchers.h | 27 ++++-- 10 files changed, 87 insertions(+), 73 deletions(-) diff --git a/sandboxed_api/rpcchannel.cc b/sandboxed_api/rpcchannel.cc index fb3fe92..6081f10 100644 --- a/sandboxed_api/rpcchannel.cc +++ b/sandboxed_api/rpcchannel.cc @@ -23,62 +23,62 @@ namespace sapi { -sapi::Status RPCChannel::Call(const FuncCall& call, uint32_t tag, FuncRet* ret, - v::Type exp_type) { +::sapi::Status RPCChannel::Call(const FuncCall& call, uint32_t tag, FuncRet* ret, + v::Type exp_type) { absl::MutexLock lock(&mutex_); if (!comms_->SendTLV(tag, sizeof(call), reinterpret_cast(&call))) { - return sapi::UnavailableError("Sending TLV value failed"); + return ::sapi::UnavailableError("Sending TLV value failed"); } SAPI_ASSIGN_OR_RETURN(auto fret, Return(exp_type)); *ret = fret; - return sapi::OkStatus(); + return ::sapi::OkStatus(); } -sapi::StatusOr RPCChannel::Return(v::Type exp_type) { +::sapi::StatusOr RPCChannel::Return(v::Type exp_type) { uint32_t tag; uint64_t len; FuncRet ret; if (!comms_->RecvTLV(&tag, &len, &ret, sizeof(ret))) { - return sapi::UnavailableError("Receiving TLV value failed"); + return ::sapi::UnavailableError("Receiving TLV value failed"); } if (tag != comms::kMsgReturn) { LOG(ERROR) << "tag != comms::kMsgReturn (" << absl::StrCat(absl::Hex(tag)) << " != " << absl::StrCat(absl::Hex(comms::kMsgReturn)) << ")"; - return sapi::UnavailableError("Received TLV has incorrect tag"); + return ::sapi::UnavailableError("Received TLV has incorrect tag"); } if (len != sizeof(FuncRet)) { LOG(ERROR) << "len != sizeof(FuncReturn) (" << len << " != " << sizeof(FuncRet) << ")"; - return sapi::UnavailableError("Received TLV has incorrect length"); + return ::sapi::UnavailableError("Received TLV has incorrect length"); } if (ret.ret_type != exp_type) { LOG(ERROR) << "FuncRet->type != exp_type (" << ret.ret_type << " != " << exp_type << ")"; - return sapi::UnavailableError("Received TLV has incorrect return type"); + return ::sapi::UnavailableError("Received TLV has incorrect return type"); } if (!ret.success) { LOG(ERROR) << "FuncRet->success == false"; - return sapi::UnavailableError("Function call failed"); + return ::sapi::UnavailableError("Function call failed"); } return ret; } -sapi::Status RPCChannel::Allocate(size_t size, void** addr) { +::sapi::Status RPCChannel::Allocate(size_t size, void** addr) { absl::MutexLock lock(&mutex_); uint64_t sz = size; if (!comms_->SendTLV(comms::kMsgAllocate, sizeof(sz), reinterpret_cast(&sz))) { - return sapi::UnavailableError("Sending TLV value failed"); + return ::sapi::UnavailableError("Sending TLV value failed"); } SAPI_ASSIGN_OR_RETURN(auto fret, Return(v::Type::kPointer)); *addr = reinterpret_cast(fret.int_val); - return sapi::OkStatus(); + return ::sapi::OkStatus(); } -sapi::Status RPCChannel::Reallocate(void* old_addr, size_t size, - void** new_addr) { +::sapi::Status RPCChannel::Reallocate(void* old_addr, size_t size, + void** new_addr) { absl::MutexLock lock(&mutex_); comms::ReallocRequest req; req.old_addr = reinterpret_cast(old_addr); @@ -86,54 +86,54 @@ sapi::Status RPCChannel::Reallocate(void* old_addr, size_t size, if (!comms_->SendTLV(comms::kMsgReallocate, sizeof(comms::ReallocRequest), reinterpret_cast(&req))) { - return sapi::UnavailableError("Sending TLV value failed"); + return ::sapi::UnavailableError("Sending TLV value failed"); } auto fret_or = Return(v::Type::kPointer); if (!fret_or.ok()) { *new_addr = nullptr; - return sapi::UnavailableError( + return ::sapi::UnavailableError( absl::StrCat("Reallocate() failed on the remote side: ", fret_or.status().message())); } auto fret = std::move(fret_or).ValueOrDie(); *new_addr = reinterpret_cast(fret.int_val); - return sapi::OkStatus(); + return ::sapi::OkStatus(); } -sapi::Status RPCChannel::Free(void* addr) { +::sapi::Status RPCChannel::Free(void* addr) { absl::MutexLock lock(&mutex_); uint64_t remote = reinterpret_cast(addr); if (!comms_->SendTLV(comms::kMsgFree, sizeof(remote), reinterpret_cast(&remote))) { - return sapi::UnavailableError("Sending TLV value failed"); + return ::sapi::UnavailableError("Sending TLV value failed"); } SAPI_ASSIGN_OR_RETURN(auto fret, Return(v::Type::kVoid)); if (!fret.success) { - return sapi::UnavailableError("Free() failed on the remote side"); + return ::sapi::UnavailableError("Free() failed on the remote side"); } - return sapi::OkStatus(); + return ::sapi::OkStatus(); } -sapi::Status RPCChannel::Symbol(const char* symname, void** addr) { +::sapi::Status RPCChannel::Symbol(const char* symname, void** addr) { absl::MutexLock lock(&mutex_); if (!comms_->SendTLV(comms::kMsgSymbol, strlen(symname) + 1, reinterpret_cast(symname))) { - return sapi::UnavailableError("Sending TLV value failed"); + return ::sapi::UnavailableError("Sending TLV value failed"); } SAPI_ASSIGN_OR_RETURN(auto fret, Return(v::Type::kPointer)); *addr = reinterpret_cast(fret.int_val); - return sapi::OkStatus(); + return ::sapi::OkStatus(); } -sapi::Status RPCChannel::Exit() { +::sapi::Status RPCChannel::Exit() { absl::MutexLock lock(&mutex_); if (comms_->IsTerminated()) { VLOG(2) << "Comms channel already terminated"; - return sapi::OkStatus(); + return ::sapi::OkStatus(); } // Try the RPC exit sequence. But, the only thing that matters as a success @@ -146,62 +146,62 @@ sapi::Status RPCChannel::Exit() { if (!comms_->IsTerminated()) { LOG(ERROR) << "Comms channel not terminated in Exit()"; // TODO(hamacher): Better error code - return sapi::FailedPreconditionError( + return ::sapi::FailedPreconditionError( "Comms channel not terminated in Exit()"); } - return sapi::OkStatus(); + return ::sapi::OkStatus(); } -sapi::Status RPCChannel::SendFD(int local_fd, int* remote_fd) { +::sapi::Status RPCChannel::SendFD(int local_fd, int* remote_fd) { absl::MutexLock lock(&mutex_); bool unused = true; if (!comms_->SendTLV(comms::kMsgSendFd, sizeof(unused), reinterpret_cast(&unused))) { - return sapi::UnavailableError("Sending TLV value failed"); + return ::sapi::UnavailableError("Sending TLV value failed"); } if (!comms_->SendFD(local_fd)) { - return sapi::UnavailableError("Sending FD failed"); + return ::sapi::UnavailableError("Sending FD failed"); } SAPI_ASSIGN_OR_RETURN(auto fret, Return(v::Type::kInt)); if (!fret.success) { - return sapi::UnavailableError("SendFD failed on the remote side"); + return ::sapi::UnavailableError("SendFD failed on the remote side"); } *remote_fd = fret.int_val; - return sapi::OkStatus(); + return ::sapi::OkStatus(); } -sapi::Status RPCChannel::RecvFD(int remote_fd, int* local_fd) { +::sapi::Status RPCChannel::RecvFD(int remote_fd, int* local_fd) { absl::MutexLock lock(&mutex_); if (!comms_->SendTLV(comms::kMsgRecvFd, sizeof(remote_fd), reinterpret_cast(&remote_fd))) { - return sapi::UnavailableError("Sending TLV value failed"); + return ::sapi::UnavailableError("Sending TLV value failed"); } if (!comms_->RecvFD(local_fd)) { - return sapi::UnavailableError("Receving FD failed"); + return ::sapi::UnavailableError("Receving FD failed"); } SAPI_ASSIGN_OR_RETURN(auto fret, Return(v::Type::kVoid)); if (!fret.success) { - return sapi::UnavailableError("RecvFD failed on the remote side"); + return ::sapi::UnavailableError("RecvFD failed on the remote side"); } - return sapi::OkStatus(); + return ::sapi::OkStatus(); } -sapi::Status RPCChannel::Close(int remote_fd) { +::sapi::Status RPCChannel::Close(int remote_fd) { absl::MutexLock lock(&mutex_); if (!comms_->SendTLV(comms::kMsgClose, sizeof(remote_fd), reinterpret_cast(&remote_fd))) { - return sapi::UnavailableError("Sending TLV value failed"); + return ::sapi::UnavailableError("Sending TLV value failed"); } SAPI_ASSIGN_OR_RETURN(auto fret, Return(v::Type::kVoid)); if (!fret.success) { - return sapi::UnavailableError("Close() failed on the remote side"); + return ::sapi::UnavailableError("Close() failed on the remote side"); } - return sapi::OkStatus(); + return ::sapi::OkStatus(); } } // namespace sapi diff --git a/sandboxed_api/rpcchannel.h b/sandboxed_api/rpcchannel.h index 4465d4b..a49b329 100644 --- a/sandboxed_api/rpcchannel.h +++ b/sandboxed_api/rpcchannel.h @@ -20,9 +20,9 @@ #include "absl/synchronization/mutex.h" #include "sandboxed_api/call.h" #include "sandboxed_api/sandbox2/comms.h" +#include "sandboxed_api/var_type.h" #include "sandboxed_api/util/status.h" #include "sandboxed_api/util/statusor.h" -#include "sandboxed_api/var_type.h" namespace sapi { @@ -33,38 +33,38 @@ class RPCChannel { explicit RPCChannel(sandbox2::Comms* comms) : comms_(comms) {} // Calls a function. - sapi::Status Call(const FuncCall& call, uint32_t tag, FuncRet* ret, - v::Type exp_type); + ::sapi::Status Call(const FuncCall& call, uint32_t tag, FuncRet* ret, + v::Type exp_type); // Allocates memory. - sapi::Status Allocate(size_t size, void** addr); + ::sapi::Status Allocate(size_t size, void** addr); // Reallocates memory. - sapi::Status Reallocate(void* old_addr, size_t size, void** new_addr); + ::sapi::Status Reallocate(void* old_addr, size_t size, void** new_addr); // Frees memory. - sapi::Status Free(void* addr); + ::sapi::Status Free(void* addr); // Returns address of a symbol. - sapi::Status Symbol(const char* symname, void** addr); + ::sapi::Status Symbol(const char* symname, void** addr); // Makes the remote part exit. - sapi::Status Exit(); + ::sapi::Status Exit(); // Transfers fd to sandboxee. - sapi::Status SendFD(int local_fd, int* remote_fd); + ::sapi::Status SendFD(int local_fd, int* remote_fd); // Retrieves fd from sandboxee. - sapi::Status RecvFD(int remote_fd, int* local_fd); + ::sapi::Status RecvFD(int remote_fd, int* local_fd); // Closes fd in sandboxee. - sapi::Status Close(int remote_fd); + ::sapi::Status Close(int remote_fd); sandbox2::Comms* comms() const { return comms_; } private: // Receives the result after a call. - sapi::StatusOr Return(v::Type exp_type); + ::sapi::StatusOr Return(v::Type exp_type); sandbox2::Comms* comms_; // Owned by sandbox2; absl::Mutex mutex_; diff --git a/sandboxed_api/sandbox2/BUILD.bazel b/sandboxed_api/sandbox2/BUILD.bazel index 2271b46..6f1cc1a 100644 --- a/sandboxed_api/sandbox2/BUILD.bazel +++ b/sandboxed_api/sandbox2/BUILD.bazel @@ -526,6 +526,7 @@ cc_library( "//sandboxed_api/sandbox2/util:strerror", "//sandboxed_api/util:raw_logging", "//sandboxed_api/util:status", + "//sandboxed_api/util:status_proto", "//sandboxed_api/util:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", diff --git a/sandboxed_api/sandbox2/comms.cc b/sandboxed_api/sandbox2/comms.cc index d88412f..b2ff8b3 100644 --- a/sandboxed_api/sandbox2/comms.cc +++ b/sandboxed_api/sandbox2/comms.cc @@ -24,9 +24,9 @@ #include #include #include - #include #include + #include #include #include diff --git a/sandboxed_api/sandbox2/comms.h b/sandboxed_api/sandbox2/comms.h index 5cf9696..7238c30 100644 --- a/sandboxed_api/sandbox2/comms.h +++ b/sandboxed_api/sandbox2/comms.h @@ -34,6 +34,7 @@ #include "absl/base/attributes.h" #include "absl/synchronization/mutex.h" #include "sandboxed_api/util/status.h" +#include "sandboxed_api/util/status.pb.h" namespace proto2 { class Message; diff --git a/sandboxed_api/sandbox2/comms_test.cc b/sandboxed_api/sandbox2/comms_test.cc index 1ec1b46..a6058c1 100644 --- a/sandboxed_api/sandbox2/comms_test.cc +++ b/sandboxed_api/sandbox2/comms_test.cc @@ -307,7 +307,7 @@ TEST_F(CommsTest, TestSendRecvStatusOK) { }; auto b = [](Comms* comms) { // Send a good status. - ASSERT_THAT(comms->SendStatus(sapi::OkStatus()), IsTrue()); + ASSERT_THAT(comms->SendStatus(::sapi::OkStatus()), IsTrue()); }; HandleCommunication(sockname_, a, b); } diff --git a/sandboxed_api/sandbox2/util/temp_file.cc b/sandboxed_api/sandbox2/util/temp_file.cc index 0b226ed..063e1e3 100644 --- a/sandboxed_api/sandbox2/util/temp_file.cc +++ b/sandboxed_api/sandbox2/util/temp_file.cc @@ -33,17 +33,17 @@ namespace { constexpr absl::string_view kMktempSuffix = "XXXXXX"; } // namespace -sapi::StatusOr> CreateNamedTempFile( +::sapi::StatusOr> CreateNamedTempFile( absl::string_view prefix) { std::string name_template = absl::StrCat(prefix, kMktempSuffix); int fd = mkstemp(&name_template[0]); if (fd < 0) { - return sapi::UnknownError(absl::StrCat("mkstemp():", StrError(errno))); + return ::sapi::UnknownError(absl::StrCat("mkstemp():", StrError(errno))); } return std::pair{std::move(name_template), fd}; } -sapi::StatusOr CreateNamedTempFileAndClose( +::sapi::StatusOr CreateNamedTempFileAndClose( absl::string_view prefix) { auto result_or = CreateNamedTempFile(prefix); if (result_or.ok()) { @@ -56,10 +56,10 @@ sapi::StatusOr CreateNamedTempFileAndClose( return result_or.status(); } -sapi::StatusOr CreateTempDir(absl::string_view prefix) { +::sapi::StatusOr CreateTempDir(absl::string_view prefix) { std::string name_template = absl::StrCat(prefix, kMktempSuffix); if (mkdtemp(&name_template[0]) == nullptr) { - return sapi::UnknownError(absl::StrCat("mkdtemp():", StrError(errno))); + return ::sapi::UnknownError(absl::StrCat("mkdtemp():", StrError(errno))); } return name_template; } diff --git a/sandboxed_api/sandbox2/util/temp_file.h b/sandboxed_api/sandbox2/util/temp_file.h index 708708b..d1156a4 100644 --- a/sandboxed_api/sandbox2/util/temp_file.h +++ b/sandboxed_api/sandbox2/util/temp_file.h @@ -23,18 +23,18 @@ namespace sandbox2 { // Creates a temporary file under a path starting with prefix. File is not // unlinked and its path is returned together with an open fd. -sapi::StatusOr> CreateNamedTempFile( +::sapi::StatusOr> CreateNamedTempFile( absl::string_view prefix); // Creates a temporary file under a path starting with prefix. File is not // unlinked and its path is returned. FD of the created file is closed just // after creation. -sapi::StatusOr CreateNamedTempFileAndClose( +::sapi::StatusOr CreateNamedTempFileAndClose( absl::string_view prefix); // Creates a temporary directory under a path starting with prefix. // Returns the path of the created directory. -sapi::StatusOr CreateTempDir(absl::string_view prefix); +::sapi::StatusOr CreateTempDir(absl::string_view prefix); } // namespace sandbox2 diff --git a/sandboxed_api/sandbox2/util/temp_file_test.cc b/sandboxed_api/sandbox2/util/temp_file_test.cc index 0e38d06..6a05074 100644 --- a/sandboxed_api/sandbox2/util/temp_file_test.cc +++ b/sandboxed_api/sandbox2/util/temp_file_test.cc @@ -25,6 +25,7 @@ #include "sandboxed_api/util/status_matchers.h" using sapi::IsOk; +using sapi::StatusIs; using testing::Eq; using testing::IsTrue; using testing::Ne; diff --git a/sandboxed_api/util/status_matchers.h b/sandboxed_api/util/status_matchers.h index d933768..e7b89cf 100644 --- a/sandboxed_api/util/status_matchers.h +++ b/sandboxed_api/util/status_matchers.h @@ -15,6 +15,8 @@ #ifndef SANDBOXED_API_UTIL_STATUS_MATCHERS_H_ #define SANDBOXED_API_UTIL_STATUS_MATCHERS_H_ +#include + #include "gmock/gmock.h" #include "absl/types/optional.h" #include "sandboxed_api/util/status.h" @@ -59,9 +61,10 @@ class StatusIsMatcher { StatusIsMatcher(Enum code, absl::optional message) : code_{code}, message_{message} {} - template - bool MatchAndExplain(const StatusT& status, + template + bool MatchAndExplain(const T& value, ::testing::MatchResultListener* listener) const { + auto status = GetStatus(value); if (code_ != status.code()) { *listener << "whose error code is generic::" << internal::CodeEnumToString(status.code()); @@ -74,12 +77,6 @@ class StatusIsMatcher { return true; } - template - bool MatchAndExplain(const StatusOr& status_or, - ::testing::MatchResultListener* listener) const { - return MatchAndExplain(status_or.status(), listener); - } - void DescribeTo(std::ostream* os) const { *os << "has a status code that is generic::" << internal::CodeEnumToString(code_); @@ -98,6 +95,20 @@ class StatusIsMatcher { } private: + template ().code())>::value, + int>::type = 0> + static const StatusT& GetStatus(const StatusT& status) { + return status; + } + + template ().status())> + static StatusT GetStatus(const StatusOrT& status_or) { + return status_or.status(); + } + const Enum code_; const absl::optional message_; };