Make StatusMatcher more flexible

PiperOrigin-RevId: 244879203
Change-Id: I5f7994130a898e84f041b18c0b5313d7e8b32780
This commit is contained in:
Wiktor Garbacz 2019-04-23 10:29:37 -07:00 committed by Copybara-Service
parent 726b1fb451
commit 6cbaaead8b
10 changed files with 87 additions and 73 deletions

View File

@ -23,61 +23,61 @@
namespace sapi {
sapi::Status RPCChannel::Call(const FuncCall& call, uint32_t tag, FuncRet* ret,
::sapi::Status RPCChannel::Call(const FuncCall& call, uint32_t tag, FuncRet* ret,
v::Type exp_type) {
absl::MutexLock lock(&mutex_);
if (!comms_->SendTLV(tag, sizeof(call),
reinterpret_cast<const uint8_t*>(&call))) {
return sapi::UnavailableError("Sending TLV value failed");
return ::sapi::UnavailableError("Sending TLV value failed");
}
SAPI_ASSIGN_OR_RETURN(auto fret, Return(exp_type));
*ret = fret;
return sapi::OkStatus();
return ::sapi::OkStatus();
}
sapi::StatusOr<FuncRet> RPCChannel::Return(v::Type exp_type) {
::sapi::StatusOr<FuncRet> RPCChannel::Return(v::Type exp_type) {
uint32_t tag;
uint64_t len;
FuncRet ret;
if (!comms_->RecvTLV(&tag, &len, &ret, sizeof(ret))) {
return sapi::UnavailableError("Receiving TLV value failed");
return ::sapi::UnavailableError("Receiving TLV value failed");
}
if (tag != comms::kMsgReturn) {
LOG(ERROR) << "tag != comms::kMsgReturn (" << absl::StrCat(absl::Hex(tag))
<< " != " << absl::StrCat(absl::Hex(comms::kMsgReturn)) << ")";
return sapi::UnavailableError("Received TLV has incorrect tag");
return ::sapi::UnavailableError("Received TLV has incorrect tag");
}
if (len != sizeof(FuncRet)) {
LOG(ERROR) << "len != sizeof(FuncReturn) (" << len
<< " != " << sizeof(FuncRet) << ")";
return sapi::UnavailableError("Received TLV has incorrect length");
return ::sapi::UnavailableError("Received TLV has incorrect length");
}
if (ret.ret_type != exp_type) {
LOG(ERROR) << "FuncRet->type != exp_type (" << ret.ret_type
<< " != " << exp_type << ")";
return sapi::UnavailableError("Received TLV has incorrect return type");
return ::sapi::UnavailableError("Received TLV has incorrect return type");
}
if (!ret.success) {
LOG(ERROR) << "FuncRet->success == false";
return sapi::UnavailableError("Function call failed");
return ::sapi::UnavailableError("Function call failed");
}
return ret;
}
sapi::Status RPCChannel::Allocate(size_t size, void** addr) {
::sapi::Status RPCChannel::Allocate(size_t size, void** addr) {
absl::MutexLock lock(&mutex_);
uint64_t sz = size;
if (!comms_->SendTLV(comms::kMsgAllocate, sizeof(sz),
reinterpret_cast<uint8_t*>(&sz))) {
return sapi::UnavailableError("Sending TLV value failed");
return ::sapi::UnavailableError("Sending TLV value failed");
}
SAPI_ASSIGN_OR_RETURN(auto fret, Return(v::Type::kPointer));
*addr = reinterpret_cast<void*>(fret.int_val);
return sapi::OkStatus();
return ::sapi::OkStatus();
}
sapi::Status RPCChannel::Reallocate(void* old_addr, size_t size,
::sapi::Status RPCChannel::Reallocate(void* old_addr, size_t size,
void** new_addr) {
absl::MutexLock lock(&mutex_);
comms::ReallocRequest req;
@ -86,54 +86,54 @@ sapi::Status RPCChannel::Reallocate(void* old_addr, size_t size,
if (!comms_->SendTLV(comms::kMsgReallocate, sizeof(comms::ReallocRequest),
reinterpret_cast<uint8_t*>(&req))) {
return sapi::UnavailableError("Sending TLV value failed");
return ::sapi::UnavailableError("Sending TLV value failed");
}
auto fret_or = Return(v::Type::kPointer);
if (!fret_or.ok()) {
*new_addr = nullptr;
return sapi::UnavailableError(
return ::sapi::UnavailableError(
absl::StrCat("Reallocate() failed on the remote side: ",
fret_or.status().message()));
}
auto fret = std::move(fret_or).ValueOrDie();
*new_addr = reinterpret_cast<void*>(fret.int_val);
return sapi::OkStatus();
return ::sapi::OkStatus();
}
sapi::Status RPCChannel::Free(void* addr) {
::sapi::Status RPCChannel::Free(void* addr) {
absl::MutexLock lock(&mutex_);
uint64_t remote = reinterpret_cast<uint64_t>(addr);
if (!comms_->SendTLV(comms::kMsgFree, sizeof(remote),
reinterpret_cast<uint8_t*>(&remote))) {
return sapi::UnavailableError("Sending TLV value failed");
return ::sapi::UnavailableError("Sending TLV value failed");
}
SAPI_ASSIGN_OR_RETURN(auto fret, Return(v::Type::kVoid));
if (!fret.success) {
return sapi::UnavailableError("Free() failed on the remote side");
return ::sapi::UnavailableError("Free() failed on the remote side");
}
return sapi::OkStatus();
return ::sapi::OkStatus();
}
sapi::Status RPCChannel::Symbol(const char* symname, void** addr) {
::sapi::Status RPCChannel::Symbol(const char* symname, void** addr) {
absl::MutexLock lock(&mutex_);
if (!comms_->SendTLV(comms::kMsgSymbol, strlen(symname) + 1,
reinterpret_cast<const uint8_t*>(symname))) {
return sapi::UnavailableError("Sending TLV value failed");
return ::sapi::UnavailableError("Sending TLV value failed");
}
SAPI_ASSIGN_OR_RETURN(auto fret, Return(v::Type::kPointer));
*addr = reinterpret_cast<void*>(fret.int_val);
return sapi::OkStatus();
return ::sapi::OkStatus();
}
sapi::Status RPCChannel::Exit() {
::sapi::Status RPCChannel::Exit() {
absl::MutexLock lock(&mutex_);
if (comms_->IsTerminated()) {
VLOG(2) << "Comms channel already terminated";
return sapi::OkStatus();
return ::sapi::OkStatus();
}
// Try the RPC exit sequence. But, the only thing that matters as a success
@ -146,62 +146,62 @@ sapi::Status RPCChannel::Exit() {
if (!comms_->IsTerminated()) {
LOG(ERROR) << "Comms channel not terminated in Exit()";
// TODO(hamacher): Better error code
return sapi::FailedPreconditionError(
return ::sapi::FailedPreconditionError(
"Comms channel not terminated in Exit()");
}
return sapi::OkStatus();
return ::sapi::OkStatus();
}
sapi::Status RPCChannel::SendFD(int local_fd, int* remote_fd) {
::sapi::Status RPCChannel::SendFD(int local_fd, int* remote_fd) {
absl::MutexLock lock(&mutex_);
bool unused = true;
if (!comms_->SendTLV(comms::kMsgSendFd, sizeof(unused),
reinterpret_cast<uint8_t*>(&unused))) {
return sapi::UnavailableError("Sending TLV value failed");
return ::sapi::UnavailableError("Sending TLV value failed");
}
if (!comms_->SendFD(local_fd)) {
return sapi::UnavailableError("Sending FD failed");
return ::sapi::UnavailableError("Sending FD failed");
}
SAPI_ASSIGN_OR_RETURN(auto fret, Return(v::Type::kInt));
if (!fret.success) {
return sapi::UnavailableError("SendFD failed on the remote side");
return ::sapi::UnavailableError("SendFD failed on the remote side");
}
*remote_fd = fret.int_val;
return sapi::OkStatus();
return ::sapi::OkStatus();
}
sapi::Status RPCChannel::RecvFD(int remote_fd, int* local_fd) {
::sapi::Status RPCChannel::RecvFD(int remote_fd, int* local_fd) {
absl::MutexLock lock(&mutex_);
if (!comms_->SendTLV(comms::kMsgRecvFd, sizeof(remote_fd),
reinterpret_cast<uint8_t*>(&remote_fd))) {
return sapi::UnavailableError("Sending TLV value failed");
return ::sapi::UnavailableError("Sending TLV value failed");
}
if (!comms_->RecvFD(local_fd)) {
return sapi::UnavailableError("Receving FD failed");
return ::sapi::UnavailableError("Receving FD failed");
}
SAPI_ASSIGN_OR_RETURN(auto fret, Return(v::Type::kVoid));
if (!fret.success) {
return sapi::UnavailableError("RecvFD failed on the remote side");
return ::sapi::UnavailableError("RecvFD failed on the remote side");
}
return sapi::OkStatus();
return ::sapi::OkStatus();
}
sapi::Status RPCChannel::Close(int remote_fd) {
::sapi::Status RPCChannel::Close(int remote_fd) {
absl::MutexLock lock(&mutex_);
if (!comms_->SendTLV(comms::kMsgClose, sizeof(remote_fd),
reinterpret_cast<uint8_t*>(&remote_fd))) {
return sapi::UnavailableError("Sending TLV value failed");
return ::sapi::UnavailableError("Sending TLV value failed");
}
SAPI_ASSIGN_OR_RETURN(auto fret, Return(v::Type::kVoid));
if (!fret.success) {
return sapi::UnavailableError("Close() failed on the remote side");
return ::sapi::UnavailableError("Close() failed on the remote side");
}
return sapi::OkStatus();
return ::sapi::OkStatus();
}
} // namespace sapi

View File

@ -20,9 +20,9 @@
#include "absl/synchronization/mutex.h"
#include "sandboxed_api/call.h"
#include "sandboxed_api/sandbox2/comms.h"
#include "sandboxed_api/var_type.h"
#include "sandboxed_api/util/status.h"
#include "sandboxed_api/util/statusor.h"
#include "sandboxed_api/var_type.h"
namespace sapi {
@ -33,38 +33,38 @@ class RPCChannel {
explicit RPCChannel(sandbox2::Comms* comms) : comms_(comms) {}
// Calls a function.
sapi::Status Call(const FuncCall& call, uint32_t tag, FuncRet* ret,
::sapi::Status Call(const FuncCall& call, uint32_t tag, FuncRet* ret,
v::Type exp_type);
// Allocates memory.
sapi::Status Allocate(size_t size, void** addr);
::sapi::Status Allocate(size_t size, void** addr);
// Reallocates memory.
sapi::Status Reallocate(void* old_addr, size_t size, void** new_addr);
::sapi::Status Reallocate(void* old_addr, size_t size, void** new_addr);
// Frees memory.
sapi::Status Free(void* addr);
::sapi::Status Free(void* addr);
// Returns address of a symbol.
sapi::Status Symbol(const char* symname, void** addr);
::sapi::Status Symbol(const char* symname, void** addr);
// Makes the remote part exit.
sapi::Status Exit();
::sapi::Status Exit();
// Transfers fd to sandboxee.
sapi::Status SendFD(int local_fd, int* remote_fd);
::sapi::Status SendFD(int local_fd, int* remote_fd);
// Retrieves fd from sandboxee.
sapi::Status RecvFD(int remote_fd, int* local_fd);
::sapi::Status RecvFD(int remote_fd, int* local_fd);
// Closes fd in sandboxee.
sapi::Status Close(int remote_fd);
::sapi::Status Close(int remote_fd);
sandbox2::Comms* comms() const { return comms_; }
private:
// Receives the result after a call.
sapi::StatusOr<FuncRet> Return(v::Type exp_type);
::sapi::StatusOr<FuncRet> Return(v::Type exp_type);
sandbox2::Comms* comms_; // Owned by sandbox2;
absl::Mutex mutex_;

View File

@ -526,6 +526,7 @@ cc_library(
"//sandboxed_api/sandbox2/util:strerror",
"//sandboxed_api/util:raw_logging",
"//sandboxed_api/util:status",
"//sandboxed_api/util:status_proto",
"//sandboxed_api/util:statusor",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",

View File

@ -24,9 +24,9 @@
#include <sys/socket.h>
#include <sys/uio.h>
#include <sys/un.h>
#include <syscall.h>
#include <unistd.h>
#include <cerrno>
#include <cinttypes>
#include <cstddef>

View File

@ -34,6 +34,7 @@
#include "absl/base/attributes.h"
#include "absl/synchronization/mutex.h"
#include "sandboxed_api/util/status.h"
#include "sandboxed_api/util/status.pb.h"
namespace proto2 {
class Message;

View File

@ -307,7 +307,7 @@ TEST_F(CommsTest, TestSendRecvStatusOK) {
};
auto b = [](Comms* comms) {
// Send a good status.
ASSERT_THAT(comms->SendStatus(sapi::OkStatus()), IsTrue());
ASSERT_THAT(comms->SendStatus(::sapi::OkStatus()), IsTrue());
};
HandleCommunication(sockname_, a, b);
}

View File

@ -33,17 +33,17 @@ namespace {
constexpr absl::string_view kMktempSuffix = "XXXXXX";
} // namespace
sapi::StatusOr<std::pair<std::string, int>> CreateNamedTempFile(
::sapi::StatusOr<std::pair<std::string, int>> CreateNamedTempFile(
absl::string_view prefix) {
std::string name_template = absl::StrCat(prefix, kMktempSuffix);
int fd = mkstemp(&name_template[0]);
if (fd < 0) {
return sapi::UnknownError(absl::StrCat("mkstemp():", StrError(errno)));
return ::sapi::UnknownError(absl::StrCat("mkstemp():", StrError(errno)));
}
return std::pair<std::string, int>{std::move(name_template), fd};
}
sapi::StatusOr<std::string> CreateNamedTempFileAndClose(
::sapi::StatusOr<std::string> CreateNamedTempFileAndClose(
absl::string_view prefix) {
auto result_or = CreateNamedTempFile(prefix);
if (result_or.ok()) {
@ -56,10 +56,10 @@ sapi::StatusOr<std::string> CreateNamedTempFileAndClose(
return result_or.status();
}
sapi::StatusOr<std::string> CreateTempDir(absl::string_view prefix) {
::sapi::StatusOr<std::string> CreateTempDir(absl::string_view prefix) {
std::string name_template = absl::StrCat(prefix, kMktempSuffix);
if (mkdtemp(&name_template[0]) == nullptr) {
return sapi::UnknownError(absl::StrCat("mkdtemp():", StrError(errno)));
return ::sapi::UnknownError(absl::StrCat("mkdtemp():", StrError(errno)));
}
return name_template;
}

View File

@ -23,18 +23,18 @@ namespace sandbox2 {
// Creates a temporary file under a path starting with prefix. File is not
// unlinked and its path is returned together with an open fd.
sapi::StatusOr<std::pair<std::string, int>> CreateNamedTempFile(
::sapi::StatusOr<std::pair<std::string, int>> CreateNamedTempFile(
absl::string_view prefix);
// Creates a temporary file under a path starting with prefix. File is not
// unlinked and its path is returned. FD of the created file is closed just
// after creation.
sapi::StatusOr<std::string> CreateNamedTempFileAndClose(
::sapi::StatusOr<std::string> CreateNamedTempFileAndClose(
absl::string_view prefix);
// Creates a temporary directory under a path starting with prefix.
// Returns the path of the created directory.
sapi::StatusOr<std::string> CreateTempDir(absl::string_view prefix);
::sapi::StatusOr<std::string> CreateTempDir(absl::string_view prefix);
} // namespace sandbox2

View File

@ -25,6 +25,7 @@
#include "sandboxed_api/util/status_matchers.h"
using sapi::IsOk;
using sapi::StatusIs;
using testing::Eq;
using testing::IsTrue;
using testing::Ne;

View File

@ -15,6 +15,8 @@
#ifndef SANDBOXED_API_UTIL_STATUS_MATCHERS_H_
#define SANDBOXED_API_UTIL_STATUS_MATCHERS_H_
#include <type_traits>
#include "gmock/gmock.h"
#include "absl/types/optional.h"
#include "sandboxed_api/util/status.h"
@ -59,9 +61,10 @@ class StatusIsMatcher {
StatusIsMatcher(Enum code, absl::optional<absl::string_view> message)
: code_{code}, message_{message} {}
template <typename StatusT>
bool MatchAndExplain(const StatusT& status,
template <typename T>
bool MatchAndExplain(const T& value,
::testing::MatchResultListener* listener) const {
auto status = GetStatus(value);
if (code_ != status.code()) {
*listener << "whose error code is generic::"
<< internal::CodeEnumToString(status.code());
@ -74,12 +77,6 @@ class StatusIsMatcher {
return true;
}
template <typename T>
bool MatchAndExplain(const StatusOr<T>& status_or,
::testing::MatchResultListener* listener) const {
return MatchAndExplain(status_or.status(), listener);
}
void DescribeTo(std::ostream* os) const {
*os << "has a status code that is generic::"
<< internal::CodeEnumToString(code_);
@ -98,6 +95,20 @@ class StatusIsMatcher {
}
private:
template <typename StatusT,
typename std::enable_if<
!std::is_void<decltype(std::declval<StatusT>().code())>::value,
int>::type = 0>
static const StatusT& GetStatus(const StatusT& status) {
return status;
}
template <typename StatusOrT,
typename StatusT = decltype(std::declval<StatusOrT>().status())>
static StatusT GetStatus(const StatusOrT& status_or) {
return status_or.status();
}
const Enum code_;
const absl::optional<std::string> message_;
};