From 0aa7183502c27a88497fa813d009b76c15ee9134 Mon Sep 17 00:00:00 2001 From: Christian Blichmann Date: Fri, 13 Sep 2019 02:28:09 -0700 Subject: [PATCH] Refactor the tests and strings example a bit PiperOrigin-RevId: 268865491 Change-Id: Ie16e5f17e2eb22e25821c34edf0068cb81bcc2fe --- sandboxed_api/call.h | 1 - sandboxed_api/client.cc | 19 +-- .../examples/stringop/main_stringop.cc | 152 +++++++++--------- sandboxed_api/proto_helper.h | 27 ++-- sandboxed_api/sapi_test.cc | 5 +- sandboxed_api/transaction.h | 1 + sandboxed_api/util/status.h | 1 + sandboxed_api/var_proto.h | 41 +++-- 8 files changed, 133 insertions(+), 114 deletions(-) diff --git a/sandboxed_api/call.h b/sandboxed_api/call.h index 1da3fd7..a13311f 100644 --- a/sandboxed_api/call.h +++ b/sandboxed_api/call.h @@ -28,7 +28,6 @@ struct ReallocRequest { }; // Types of TAGs used with Comms channel. -// TODO(cblichmann): Mark these as "inline" once we're on C++17. // Call: constexpr uint32_t kMsgCall = 0x101; constexpr uint32_t kMsgAllocate = 0x102; diff --git a/sandboxed_api/client.cc b/sandboxed_api/client.cc index bc5e318..3036ded 100644 --- a/sandboxed_api/client.cc +++ b/sandboxed_api/client.cc @@ -45,7 +45,7 @@ namespace sapi { namespace { // Guess the FFI type on the basis of data size and float/non-float/bool. -static ffi_type* GetFFIType(size_t size, v::Type type) { +ffi_type* GetFFIType(size_t size, v::Type type) { switch (type) { case v::Type::kVoid: return &ffi_type_void; @@ -94,11 +94,11 @@ class FunctionCallPreparer { explicit FunctionCallPreparer(const FuncCall& call) { CHECK(call.argc <= FuncCall::kArgsMax) << "Number of arguments of a sandbox call exceeds limits."; - for (size_t i = 0; i < call.argc; i++) { + for (int i = 0; i < call.argc; ++i) { arg_types_[i] = GetFFIType(call.arg_size[i], call.arg_type[i]); } ret_type_ = GetFFIType(call.ret_size, call.ret_type); - for (size_t i = 0; i < call.argc; i++) { + for (int i = 0; i < call.argc; ++i) { if (call.arg_type[i] == v::Type::kPointer && call.aux_type[i] == v::Type::kProto) { // Deserialize protobuf stored in the LenValueStruct and keep a @@ -107,13 +107,10 @@ class FunctionCallPreparer { // This will also make sure that the protobuf is freed afterwards. arg_values_[i] = GetDeserializedProto( reinterpret_cast(call.args[i].arg_int)); + } else if (call.arg_type[i] == v::Type::kFloat) { + arg_values_[i] = reinterpret_cast(&call.args[i].arg_float); } else { - if (call.arg_type[i] == v::Type::kFloat) { - arg_values_[i] = - reinterpret_cast(&call.args[i].arg_float); - } else { - arg_values_[i] = reinterpret_cast(&call.args[i].arg_int); - } + arg_values_[i] = reinterpret_cast(&call.args[i].arg_int); } } } @@ -125,8 +122,8 @@ class FunctionCallPreparer { // There is no way to figure out whether the protobuf structure has // changed or not, so we always serialize the protobuf again and replace // the LenValStruct content. - std::vector serialized = SerializeProto(*proto); - // Reallocate the LV memory to match it's length. + std::vector serialized = SerializeProto(*proto).ValueOrDie(); + // Reallocate the LV memory to match its length. if (lvs->size != serialized.size()) { void* newdata = realloc(lvs->data, serialized.size()); if (!newdata) { diff --git a/sandboxed_api/examples/stringop/main_stringop.cc b/sandboxed_api/examples/stringop/main_stringop.cc index 2501474..943d408 100644 --- a/sandboxed_api/examples/stringop/main_stringop.cc +++ b/sandboxed_api/examples/stringop/main_stringop.cc @@ -30,25 +30,32 @@ #include "sandboxed_api/vars.h" #include "sandboxed_api/util/canonical_errors.h" #include "sandboxed_api/util/status.h" +#include "sandboxed_api/util/status_macros.h" using ::sapi::IsOk; +using ::testing::Eq; +using ::testing::Ne; +using ::testing::SizeIs; +using ::testing::StrEq; namespace { -// Tests using the simple transaction (and function pointers): +// Tests using a simple transaction (and function pointers): TEST(StringopTest, ProtobufStringDuplication) { sapi::BasicTransaction st(absl::make_unique()); EXPECT_THAT(st.Run([](sapi::Sandbox* sandbox) -> sapi::Status { - StringopApi f(sandbox); + StringopApi api(sandbox); stringop::StringDuplication proto; proto.set_input("Hello"); sapi::v::Proto pp(proto); - SAPI_ASSIGN_OR_RETURN(int v, f.pb_duplicate_string(pp.PtrBoth())); - TRANSACTION_FAIL_IF_NOT(v, "pb_duplicate_string failed"); - auto pb_result = pp.GetProtoCopy(); - TRANSACTION_FAIL_IF_NOT(pb_result, "Could not deserialize pb result"); - LOG(INFO) << "Result PB: " << pb_result->DebugString(); - TRANSACTION_FAIL_IF_NOT(pb_result->output() == "HelloHello", + { + SAPI_ASSIGN_OR_RETURN(int return_value, api.pb_duplicate_string(pp.PtrBoth())); + TRANSACTION_FAIL_IF_NOT(return_value, "pb_duplicate_string() failed"); + } + + SAPI_ASSIGN_OR_RETURN(auto pb_result, pp.GetMessage()); + LOG(INFO) << "Result PB: " << pb_result.DebugString(); + TRANSACTION_FAIL_IF_NOT(pb_result.output() == "HelloHello", "Incorrect output"); return sapi::OkStatus(); }), @@ -56,84 +63,71 @@ TEST(StringopTest, ProtobufStringDuplication) { } TEST(StringopTest, ProtobufStringReversal) { - sapi::BasicTransaction st(absl::make_unique()); - EXPECT_THAT(st.Run([](sapi::Sandbox* sandbox) -> sapi::Status { - StringopApi f(sandbox); - stringop::StringReverse proto; - proto.set_input("Hello"); - sapi::v::Proto pp(proto); - SAPI_ASSIGN_OR_RETURN(int v, f.pb_reverse_string(pp.PtrBoth())); - TRANSACTION_FAIL_IF_NOT(v, "pb_reverse_string failed"); - auto pb_result = pp.GetProtoCopy(); - TRANSACTION_FAIL_IF_NOT(pb_result, "Could not deserialize pb result"); - LOG(INFO) << "Result PB: " << pb_result->DebugString(); - TRANSACTION_FAIL_IF_NOT(pb_result->output() == "olleH", "Incorrect output"); - return sapi::OkStatus(); - }), - IsOk()); + StringopSapiSandbox sandbox; + ASSERT_THAT(sandbox.Init(), IsOk()); + StringopApi api(&sandbox); + + stringop::StringReverse proto; + proto.set_input("Hello"); + sapi::v::Proto pp(proto); + SAPI_ASSERT_OK_AND_ASSIGN(int return_value, api.pb_reverse_string(pp.PtrBoth())); + EXPECT_THAT(return_value, Ne(0)) << "pb_reverse_string() failed"; + + SAPI_ASSERT_OK_AND_ASSIGN(auto pb_result, pp.GetMessage()); + LOG(INFO) << "Result PB: " << pb_result.DebugString(); + EXPECT_THAT(pb_result.output(), StrEq("olleH")); } -// Tests using raw dynamic buffers. TEST(StringopTest, RawStringDuplication) { - sapi::BasicTransaction st(absl::make_unique()); - EXPECT_THAT(st.Run([](sapi::Sandbox* sandbox) -> sapi::Status { - StringopApi f(sandbox); - sapi::v::LenVal param("0123456789", 10); - SAPI_ASSIGN_OR_RETURN(int return_value, f.duplicate_string(param.PtrBoth())); - TRANSACTION_FAIL_IF_NOT(return_value == 1, - "duplicate_string() returned incorrect value"); - TRANSACTION_FAIL_IF_NOT(param.GetDataSize() == 20, - "duplicate_string() did not return enough data"); - absl::string_view data(reinterpret_cast(param.GetData()), - param.GetDataSize()); - TRANSACTION_FAIL_IF_NOT( - data == "01234567890123456789", - "duplicate_string() did not return the expected data"); - return sapi::OkStatus(); - }), - IsOk()); + StringopSapiSandbox sandbox; + ASSERT_THAT(sandbox.Init(), IsOk()); + StringopApi api(&sandbox); + + sapi::v::LenVal param("0123456789", 10); + SAPI_ASSERT_OK_AND_ASSIGN(int return_value, api.duplicate_string(param.PtrBoth())); + EXPECT_THAT(return_value, Eq(1)) << "duplicate_string() failed"; + + absl::string_view data(reinterpret_cast(param.GetData()), + param.GetDataSize()); + EXPECT_THAT(data, SizeIs(20)) + << "duplicate_string() did not return enough data"; + EXPECT_THAT(std::string(data), StrEq("01234567890123456789")); } TEST(StringopTest, RawStringReversal) { - sapi::BasicTransaction st(absl::make_unique()); - EXPECT_THAT(st.Run([](sapi::Sandbox* sandbox) -> sapi::Status { - StringopApi f(sandbox); - sapi::v::LenVal param("0123456789", 10); - { - SAPI_ASSIGN_OR_RETURN(int return_value, f.reverse_string(param.PtrBoth())); - TRANSACTION_FAIL_IF_NOT(return_value == 1, - "reverse_string() returned incorrect value"); - TRANSACTION_FAIL_IF_NOT(param.GetDataSize() == 10, - "reverse_string() did not return enough data"); - absl::string_view data(reinterpret_cast(param.GetData()), + StringopSapiSandbox sandbox; + ASSERT_THAT(sandbox.Init(), IsOk()); + StringopApi api(&sandbox); + + sapi::v::LenVal param("0123456789", 10); + { + SAPI_ASSERT_OK_AND_ASSIGN(int return_value, api.reverse_string(param.PtrBoth())); + EXPECT_THAT(return_value, Eq(1)) + << "reverse_string() returned incorrect value"; + absl::string_view data(reinterpret_cast(param.GetData()), + param.GetDataSize()); + EXPECT_THAT(param.GetDataSize(), Eq(10)) + << "reverse_string() did not return enough data"; + EXPECT_THAT(std::string(data), StrEq("9876543210")) + << "reverse_string() did not return the expected data"; + } + { + // Let's call it again with different data as argument, reusing the + // existing LenVal object. + EXPECT_THAT(param.ResizeData(sandbox.GetRpcChannel(), 16), IsOk()); + memcpy(param.GetData() + 10, "ABCDEF", 6); + absl::string_view data(reinterpret_cast(param.GetData()), + param.GetDataSize()); + EXPECT_THAT(data, SizeIs(16)) << "Resize did not behave correctly"; + EXPECT_THAT(std::string(data), StrEq("9876543210ABCDEF")); + + SAPI_ASSERT_OK_AND_ASSIGN(int return_value, api.reverse_string(param.PtrBoth())); + EXPECT_THAT(return_value, Eq(1)) + << "reverse_string() returned incorrect value"; + data = absl::string_view(reinterpret_cast(param.GetData()), param.GetDataSize()); - TRANSACTION_FAIL_IF_NOT( - data == "9876543210", - "reverse_string() did not return the expected data"); - } - { - // Let's call it again with different data as argument, reusing the - // existing LenVal object. - SAPI_RETURN_IF_ERROR(param.ResizeData(sandbox->GetRpcChannel(), 16)); - memcpy(param.GetData() + 10, "ABCDEF", 6); - TRANSACTION_FAIL_IF_NOT(param.GetDataSize() == 16, - "Resize did not behave correctly"); - absl::string_view data(reinterpret_cast(param.GetData()), - param.GetDataSize()); - TRANSACTION_FAIL_IF_NOT(data == "9876543210ABCDEF", - "Data not as expected"); - SAPI_ASSIGN_OR_RETURN(int return_value, f.reverse_string(param.PtrBoth())); - TRANSACTION_FAIL_IF_NOT(return_value == 1, - "reverse_string() returned incorrect value"); - data = absl::string_view(reinterpret_cast(param.GetData()), - param.GetDataSize()); - TRANSACTION_FAIL_IF_NOT( - data == "FEDCBA0123456789", - "reverse_string() did not return the expected data"); - } - return sapi::OkStatus(); - }), - IsOk()); + EXPECT_THAT(std::string(data), StrEq("FEDCBA0123456789")); + } } } // namespace diff --git a/sandboxed_api/proto_helper.h b/sandboxed_api/proto_helper.h index 4c95975..b29d64d 100644 --- a/sandboxed_api/proto_helper.h +++ b/sandboxed_api/proto_helper.h @@ -18,40 +18,49 @@ #define SANDBOXED_API_PROTO_HELPER_H_ #include +#include #include -#include #include "sandboxed_api/proto_arg.pb.h" +#include "sandboxed_api/util/canonical_errors.h" +#include "sandboxed_api/util/status.h" +#include "sandboxed_api/util/statusor.h" namespace sapi { template -std::vector SerializeProto(const T& proto) { +sapi::StatusOr> SerializeProto(const T& proto) { + static_assert(std::is_base_of::value, + "Template argument must be a proto message"); // Wrap protobuf in a envelope so that we know the name of the protobuf // structure when deserializing in the sandboxee. ProtoArg proto_arg; proto_arg.set_protobuf_data(proto.SerializeAsString()); proto_arg.set_full_name(proto.GetDescriptor()->full_name()); - std::vector serialized_proto(proto_arg.ByteSizeLong()); + std::vector serialized_proto(proto_arg.ByteSizeLong()); if (!proto_arg.SerializeToArray(serialized_proto.data(), serialized_proto.size())) { - LOG(ERROR) << "Unable to serialize array"; + return sapi::InternalError("Unable to serialize proto to array"); } - return serialized_proto; } template -bool DeserializeProto(T* result, const char* data, size_t len) { +sapi::StatusOr DeserializeProto(const char* data, size_t len) { + static_assert(std::is_base_of::value, + "Template argument must be a proto message"); ProtoArg envelope; if (!envelope.ParseFromArray(data, len)) { - LOG(ERROR) << "Unable to deserialize envelope"; - return false; + return sapi::InternalError("Unable to parse proto from array"); } auto pb_data = envelope.protobuf_data(); - return result->ParseFromArray(pb_data.c_str(), pb_data.size()); + T result; + if (!result.ParseFromArray(pb_data.data(), pb_data.size())) { + return sapi::InternalError("Unable to parse proto from envelope data"); + } + return result; } } // namespace sapi diff --git a/sandboxed_api/sapi_test.cc b/sandboxed_api/sapi_test.cc index 43b625a..f3d1b01 100644 --- a/sandboxed_api/sapi_test.cc +++ b/sandboxed_api/sapi_test.cc @@ -53,9 +53,8 @@ sapi::Status InvokeStringReversal(Sandbox* sandbox) { v::Proto pp(proto); SAPI_ASSIGN_OR_RETURN(int return_code, api.pb_reverse_string(pp.PtrBoth())); TRANSACTION_FAIL_IF_NOT(return_code != 0, "pb_reverse_string failed"); - std::unique_ptr pb_result = pp.GetProtoCopy(); - TRANSACTION_FAIL_IF_NOT(pb_result, "Could not deserialize pb result"); - TRANSACTION_FAIL_IF_NOT(pb_result->output() == "olleH", "Incorrect output"); + SAPI_ASSIGN_OR_RETURN(auto pb_result, pp.GetMessage()); + TRANSACTION_FAIL_IF_NOT(pb_result.output() == "olleH", "Incorrect output"); return sapi::OkStatus(); } diff --git a/sandboxed_api/transaction.h b/sandboxed_api/transaction.h index e3199e9..e8669e6 100644 --- a/sandboxed_api/transaction.h +++ b/sandboxed_api/transaction.h @@ -149,6 +149,7 @@ class Transaction : public TransactionBase { // Callback style transactions: class BasicTransaction final : public TransactionBase { + private: using InitFunction = std::function; using FinishFunction = std::function; diff --git a/sandboxed_api/util/status.h b/sandboxed_api/util/status.h index 8a2aa21..0410f65 100644 --- a/sandboxed_api/util/status.h +++ b/sandboxed_api/util/status.h @@ -23,6 +23,7 @@ #include #include +#include "absl/base/attributes.h" #include "absl/meta/type_traits.h" #include "absl/strings/string_view.h" #include "sandboxed_api/util/status.pb.h" diff --git a/sandboxed_api/var_proto.h b/sandboxed_api/var_proto.h index 52c034b..17aaa33 100644 --- a/sandboxed_api/var_proto.h +++ b/sandboxed_api/var_proto.h @@ -18,12 +18,16 @@ #define SANDBOXED_API_VAR_PROTO_H_ #include +#include +#include +#include "absl/base/macros.h" #include "absl/memory/memory.h" #include "sandboxed_api/proto_helper.h" #include "sandboxed_api/var_lenval.h" #include "sandboxed_api/var_pointable.h" #include "sandboxed_api/var_ptr.h" +#include "sandboxed_api/util/status_macros.h" namespace sapi { namespace v { @@ -31,7 +35,17 @@ namespace v { template class Proto : public Pointable, public Var { public: - explicit Proto(const T& proto) : wrapped_var_(SerializeProto(proto)) {} + static_assert(std::is_base_of::value, + "Template argument must be a proto message"); + + ABSL_DEPRECATED("Use Proto<>::FromMessage() instead") + explicit Proto(const T& proto) + : wrapped_var_(SerializeProto(proto).ValueOrDie()) {} + + static sapi::StatusOr> FromMessage(const T& proto) { + SAPI_ASSIGN_OR_RETURN(std::vector len_val, SerializeProto(proto)); + return Proto(len_val); + } size_t GetSize() const final { return wrapped_var_.GetSize(); } Type GetType() const final { return Type::kProto; } @@ -46,18 +60,21 @@ class Proto : public Pointable, public Var { void* GetLocal() const override { return wrapped_var_.GetLocal(); } // Returns a copy of the stored protobuf object. - std::unique_ptr GetProtoCopy() const { - auto res = absl::make_unique(); - if (!res || - !DeserializeProto(res.get(), - reinterpret_cast(wrapped_var_.GetData()), - wrapped_var_.GetDataSize())) { - res.reset(); - } - return res; + sapi::StatusOr GetMessage() const { + return DeserializeProto( + reinterpret_cast(wrapped_var_.GetData()), + wrapped_var_.GetDataSize()); } - void SetRemote(void* remote) override { + ABSL_DEPRECATED("Use GetMessage() instead") + std::unique_ptr GetProtoCopy() const { + if (auto result_or = GetMessage(); result_or.ok()) { + return absl::make_unique(std::move(result_or).ValueOrDie()); + } + return nullptr; + } + + void SetRemote(void* /* remote */) override { // We do not support that much indirection (pointer to a pointer to a // protobuf) as it is unlikely that this is wanted behavior. If you expect // this to work, please get in touch with us. @@ -85,6 +102,8 @@ class Proto : public Pointable, public Var { } private: + explicit Proto(std::vector data) : wrapped_var_(data) {} + // The management of reading/writing the data to the sandboxee is handled by // the LenVal class. LenVal wrapped_var_;