Refactor the tests and strings example a bit

PiperOrigin-RevId: 268865491
Change-Id: Ie16e5f17e2eb22e25821c34edf0068cb81bcc2fe
This commit is contained in:
Christian Blichmann 2019-09-13 02:28:09 -07:00 committed by Copybara-Service
parent d6ca9d9564
commit 0aa7183502
8 changed files with 133 additions and 114 deletions

View File

@ -28,7 +28,6 @@ struct ReallocRequest {
}; };
// Types of TAGs used with Comms channel. // Types of TAGs used with Comms channel.
// TODO(cblichmann): Mark these as "inline" once we're on C++17.
// Call: // Call:
constexpr uint32_t kMsgCall = 0x101; constexpr uint32_t kMsgCall = 0x101;
constexpr uint32_t kMsgAllocate = 0x102; constexpr uint32_t kMsgAllocate = 0x102;

View File

@ -45,7 +45,7 @@ namespace sapi {
namespace { namespace {
// Guess the FFI type on the basis of data size and float/non-float/bool. // Guess the FFI type on the basis of data size and float/non-float/bool.
static ffi_type* GetFFIType(size_t size, v::Type type) { ffi_type* GetFFIType(size_t size, v::Type type) {
switch (type) { switch (type) {
case v::Type::kVoid: case v::Type::kVoid:
return &ffi_type_void; return &ffi_type_void;
@ -94,11 +94,11 @@ class FunctionCallPreparer {
explicit FunctionCallPreparer(const FuncCall& call) { explicit FunctionCallPreparer(const FuncCall& call) {
CHECK(call.argc <= FuncCall::kArgsMax) CHECK(call.argc <= FuncCall::kArgsMax)
<< "Number of arguments of a sandbox call exceeds limits."; << "Number of arguments of a sandbox call exceeds limits.";
for (size_t i = 0; i < call.argc; i++) { for (int i = 0; i < call.argc; ++i) {
arg_types_[i] = GetFFIType(call.arg_size[i], call.arg_type[i]); arg_types_[i] = GetFFIType(call.arg_size[i], call.arg_type[i]);
} }
ret_type_ = GetFFIType(call.ret_size, call.ret_type); ret_type_ = GetFFIType(call.ret_size, call.ret_type);
for (size_t i = 0; i < call.argc; i++) { for (int i = 0; i < call.argc; ++i) {
if (call.arg_type[i] == v::Type::kPointer && if (call.arg_type[i] == v::Type::kPointer &&
call.aux_type[i] == v::Type::kProto) { call.aux_type[i] == v::Type::kProto) {
// Deserialize protobuf stored in the LenValueStruct and keep a // Deserialize protobuf stored in the LenValueStruct and keep a
@ -107,13 +107,10 @@ class FunctionCallPreparer {
// This will also make sure that the protobuf is freed afterwards. // This will also make sure that the protobuf is freed afterwards.
arg_values_[i] = GetDeserializedProto( arg_values_[i] = GetDeserializedProto(
reinterpret_cast<LenValStruct*>(call.args[i].arg_int)); reinterpret_cast<LenValStruct*>(call.args[i].arg_int));
} else if (call.arg_type[i] == v::Type::kFloat) {
arg_values_[i] = reinterpret_cast<const void*>(&call.args[i].arg_float);
} else { } else {
if (call.arg_type[i] == v::Type::kFloat) { arg_values_[i] = reinterpret_cast<const void*>(&call.args[i].arg_int);
arg_values_[i] =
reinterpret_cast<const void*>(&call.args[i].arg_float);
} else {
arg_values_[i] = reinterpret_cast<const void*>(&call.args[i].arg_int);
}
} }
} }
} }
@ -125,8 +122,8 @@ class FunctionCallPreparer {
// There is no way to figure out whether the protobuf structure has // There is no way to figure out whether the protobuf structure has
// changed or not, so we always serialize the protobuf again and replace // changed or not, so we always serialize the protobuf again and replace
// the LenValStruct content. // the LenValStruct content.
std::vector<uint8_t> serialized = SerializeProto(*proto); std::vector<uint8_t> serialized = SerializeProto(*proto).ValueOrDie();
// Reallocate the LV memory to match it's length. // Reallocate the LV memory to match its length.
if (lvs->size != serialized.size()) { if (lvs->size != serialized.size()) {
void* newdata = realloc(lvs->data, serialized.size()); void* newdata = realloc(lvs->data, serialized.size());
if (!newdata) { if (!newdata) {

View File

@ -30,25 +30,32 @@
#include "sandboxed_api/vars.h" #include "sandboxed_api/vars.h"
#include "sandboxed_api/util/canonical_errors.h" #include "sandboxed_api/util/canonical_errors.h"
#include "sandboxed_api/util/status.h" #include "sandboxed_api/util/status.h"
#include "sandboxed_api/util/status_macros.h"
using ::sapi::IsOk; using ::sapi::IsOk;
using ::testing::Eq;
using ::testing::Ne;
using ::testing::SizeIs;
using ::testing::StrEq;
namespace { namespace {
// Tests using the simple transaction (and function pointers): // Tests using a simple transaction (and function pointers):
TEST(StringopTest, ProtobufStringDuplication) { TEST(StringopTest, ProtobufStringDuplication) {
sapi::BasicTransaction st(absl::make_unique<StringopSapiSandbox>()); sapi::BasicTransaction st(absl::make_unique<StringopSapiSandbox>());
EXPECT_THAT(st.Run([](sapi::Sandbox* sandbox) -> sapi::Status { EXPECT_THAT(st.Run([](sapi::Sandbox* sandbox) -> sapi::Status {
StringopApi f(sandbox); StringopApi api(sandbox);
stringop::StringDuplication proto; stringop::StringDuplication proto;
proto.set_input("Hello"); proto.set_input("Hello");
sapi::v::Proto<stringop::StringDuplication> pp(proto); sapi::v::Proto<stringop::StringDuplication> pp(proto);
SAPI_ASSIGN_OR_RETURN(int v, f.pb_duplicate_string(pp.PtrBoth())); {
TRANSACTION_FAIL_IF_NOT(v, "pb_duplicate_string failed"); SAPI_ASSIGN_OR_RETURN(int return_value, api.pb_duplicate_string(pp.PtrBoth()));
auto pb_result = pp.GetProtoCopy(); TRANSACTION_FAIL_IF_NOT(return_value, "pb_duplicate_string() failed");
TRANSACTION_FAIL_IF_NOT(pb_result, "Could not deserialize pb result"); }
LOG(INFO) << "Result PB: " << pb_result->DebugString();
TRANSACTION_FAIL_IF_NOT(pb_result->output() == "HelloHello", SAPI_ASSIGN_OR_RETURN(auto pb_result, pp.GetMessage());
LOG(INFO) << "Result PB: " << pb_result.DebugString();
TRANSACTION_FAIL_IF_NOT(pb_result.output() == "HelloHello",
"Incorrect output"); "Incorrect output");
return sapi::OkStatus(); return sapi::OkStatus();
}), }),
@ -56,84 +63,71 @@ TEST(StringopTest, ProtobufStringDuplication) {
} }
TEST(StringopTest, ProtobufStringReversal) { TEST(StringopTest, ProtobufStringReversal) {
sapi::BasicTransaction st(absl::make_unique<StringopSapiSandbox>()); StringopSapiSandbox sandbox;
EXPECT_THAT(st.Run([](sapi::Sandbox* sandbox) -> sapi::Status { ASSERT_THAT(sandbox.Init(), IsOk());
StringopApi f(sandbox); StringopApi api(&sandbox);
stringop::StringReverse proto;
proto.set_input("Hello"); stringop::StringReverse proto;
sapi::v::Proto<stringop::StringReverse> pp(proto); proto.set_input("Hello");
SAPI_ASSIGN_OR_RETURN(int v, f.pb_reverse_string(pp.PtrBoth())); sapi::v::Proto<stringop::StringReverse> pp(proto);
TRANSACTION_FAIL_IF_NOT(v, "pb_reverse_string failed"); SAPI_ASSERT_OK_AND_ASSIGN(int return_value, api.pb_reverse_string(pp.PtrBoth()));
auto pb_result = pp.GetProtoCopy(); EXPECT_THAT(return_value, Ne(0)) << "pb_reverse_string() failed";
TRANSACTION_FAIL_IF_NOT(pb_result, "Could not deserialize pb result");
LOG(INFO) << "Result PB: " << pb_result->DebugString(); SAPI_ASSERT_OK_AND_ASSIGN(auto pb_result, pp.GetMessage());
TRANSACTION_FAIL_IF_NOT(pb_result->output() == "olleH", "Incorrect output"); LOG(INFO) << "Result PB: " << pb_result.DebugString();
return sapi::OkStatus(); EXPECT_THAT(pb_result.output(), StrEq("olleH"));
}),
IsOk());
} }
// Tests using raw dynamic buffers.
TEST(StringopTest, RawStringDuplication) { TEST(StringopTest, RawStringDuplication) {
sapi::BasicTransaction st(absl::make_unique<StringopSapiSandbox>()); StringopSapiSandbox sandbox;
EXPECT_THAT(st.Run([](sapi::Sandbox* sandbox) -> sapi::Status { ASSERT_THAT(sandbox.Init(), IsOk());
StringopApi f(sandbox); StringopApi api(&sandbox);
sapi::v::LenVal param("0123456789", 10);
SAPI_ASSIGN_OR_RETURN(int return_value, f.duplicate_string(param.PtrBoth())); sapi::v::LenVal param("0123456789", 10);
TRANSACTION_FAIL_IF_NOT(return_value == 1, SAPI_ASSERT_OK_AND_ASSIGN(int return_value, api.duplicate_string(param.PtrBoth()));
"duplicate_string() returned incorrect value"); EXPECT_THAT(return_value, Eq(1)) << "duplicate_string() failed";
TRANSACTION_FAIL_IF_NOT(param.GetDataSize() == 20,
"duplicate_string() did not return enough data"); absl::string_view data(reinterpret_cast<const char*>(param.GetData()),
absl::string_view data(reinterpret_cast<const char*>(param.GetData()), param.GetDataSize());
param.GetDataSize()); EXPECT_THAT(data, SizeIs(20))
TRANSACTION_FAIL_IF_NOT( << "duplicate_string() did not return enough data";
data == "01234567890123456789", EXPECT_THAT(std::string(data), StrEq("01234567890123456789"));
"duplicate_string() did not return the expected data");
return sapi::OkStatus();
}),
IsOk());
} }
TEST(StringopTest, RawStringReversal) { TEST(StringopTest, RawStringReversal) {
sapi::BasicTransaction st(absl::make_unique<StringopSapiSandbox>()); StringopSapiSandbox sandbox;
EXPECT_THAT(st.Run([](sapi::Sandbox* sandbox) -> sapi::Status { ASSERT_THAT(sandbox.Init(), IsOk());
StringopApi f(sandbox); StringopApi api(&sandbox);
sapi::v::LenVal param("0123456789", 10);
{ sapi::v::LenVal param("0123456789", 10);
SAPI_ASSIGN_OR_RETURN(int return_value, f.reverse_string(param.PtrBoth())); {
TRANSACTION_FAIL_IF_NOT(return_value == 1, SAPI_ASSERT_OK_AND_ASSIGN(int return_value, api.reverse_string(param.PtrBoth()));
"reverse_string() returned incorrect value"); EXPECT_THAT(return_value, Eq(1))
TRANSACTION_FAIL_IF_NOT(param.GetDataSize() == 10, << "reverse_string() returned incorrect value";
"reverse_string() did not return enough data"); absl::string_view data(reinterpret_cast<const char*>(param.GetData()),
absl::string_view data(reinterpret_cast<const char*>(param.GetData()), param.GetDataSize());
EXPECT_THAT(param.GetDataSize(), Eq(10))
<< "reverse_string() did not return enough data";
EXPECT_THAT(std::string(data), StrEq("9876543210"))
<< "reverse_string() did not return the expected data";
}
{
// Let's call it again with different data as argument, reusing the
// existing LenVal object.
EXPECT_THAT(param.ResizeData(sandbox.GetRpcChannel(), 16), IsOk());
memcpy(param.GetData() + 10, "ABCDEF", 6);
absl::string_view data(reinterpret_cast<const char*>(param.GetData()),
param.GetDataSize());
EXPECT_THAT(data, SizeIs(16)) << "Resize did not behave correctly";
EXPECT_THAT(std::string(data), StrEq("9876543210ABCDEF"));
SAPI_ASSERT_OK_AND_ASSIGN(int return_value, api.reverse_string(param.PtrBoth()));
EXPECT_THAT(return_value, Eq(1))
<< "reverse_string() returned incorrect value";
data = absl::string_view(reinterpret_cast<const char*>(param.GetData()),
param.GetDataSize()); param.GetDataSize());
TRANSACTION_FAIL_IF_NOT( EXPECT_THAT(std::string(data), StrEq("FEDCBA0123456789"));
data == "9876543210", }
"reverse_string() did not return the expected data");
}
{
// Let's call it again with different data as argument, reusing the
// existing LenVal object.
SAPI_RETURN_IF_ERROR(param.ResizeData(sandbox->GetRpcChannel(), 16));
memcpy(param.GetData() + 10, "ABCDEF", 6);
TRANSACTION_FAIL_IF_NOT(param.GetDataSize() == 16,
"Resize did not behave correctly");
absl::string_view data(reinterpret_cast<const char*>(param.GetData()),
param.GetDataSize());
TRANSACTION_FAIL_IF_NOT(data == "9876543210ABCDEF",
"Data not as expected");
SAPI_ASSIGN_OR_RETURN(int return_value, f.reverse_string(param.PtrBoth()));
TRANSACTION_FAIL_IF_NOT(return_value == 1,
"reverse_string() returned incorrect value");
data = absl::string_view(reinterpret_cast<const char*>(param.GetData()),
param.GetDataSize());
TRANSACTION_FAIL_IF_NOT(
data == "FEDCBA0123456789",
"reverse_string() did not return the expected data");
}
return sapi::OkStatus();
}),
IsOk());
} }
} // namespace } // namespace

View File

@ -18,40 +18,49 @@
#define SANDBOXED_API_PROTO_HELPER_H_ #define SANDBOXED_API_PROTO_HELPER_H_
#include <cinttypes> #include <cinttypes>
#include <type_traits>
#include <vector> #include <vector>
#include <glog/logging.h>
#include "sandboxed_api/proto_arg.pb.h" #include "sandboxed_api/proto_arg.pb.h"
#include "sandboxed_api/util/canonical_errors.h"
#include "sandboxed_api/util/status.h"
#include "sandboxed_api/util/statusor.h"
namespace sapi { namespace sapi {
template <typename T> template <typename T>
std::vector<uint8_t> SerializeProto(const T& proto) { sapi::StatusOr<std::vector<uint8_t>> SerializeProto(const T& proto) {
static_assert(std::is_base_of<google::protobuf::Message, T>::value,
"Template argument must be a proto message");
// Wrap protobuf in a envelope so that we know the name of the protobuf // Wrap protobuf in a envelope so that we know the name of the protobuf
// structure when deserializing in the sandboxee. // structure when deserializing in the sandboxee.
ProtoArg proto_arg; ProtoArg proto_arg;
proto_arg.set_protobuf_data(proto.SerializeAsString()); proto_arg.set_protobuf_data(proto.SerializeAsString());
proto_arg.set_full_name(proto.GetDescriptor()->full_name()); proto_arg.set_full_name(proto.GetDescriptor()->full_name());
std::vector<uint8_t> serialized_proto(proto_arg.ByteSizeLong());
std::vector<uint8_t> serialized_proto(proto_arg.ByteSizeLong());
if (!proto_arg.SerializeToArray(serialized_proto.data(), if (!proto_arg.SerializeToArray(serialized_proto.data(),
serialized_proto.size())) { serialized_proto.size())) {
LOG(ERROR) << "Unable to serialize array"; return sapi::InternalError("Unable to serialize proto to array");
} }
return serialized_proto; return serialized_proto;
} }
template <typename T> template <typename T>
bool DeserializeProto(T* result, const char* data, size_t len) { sapi::StatusOr<T> DeserializeProto(const char* data, size_t len) {
static_assert(std::is_base_of<google::protobuf::Message, T>::value,
"Template argument must be a proto message");
ProtoArg envelope; ProtoArg envelope;
if (!envelope.ParseFromArray(data, len)) { if (!envelope.ParseFromArray(data, len)) {
LOG(ERROR) << "Unable to deserialize envelope"; return sapi::InternalError("Unable to parse proto from array");
return false;
} }
auto pb_data = envelope.protobuf_data(); auto pb_data = envelope.protobuf_data();
return result->ParseFromArray(pb_data.c_str(), pb_data.size()); T result;
if (!result.ParseFromArray(pb_data.data(), pb_data.size())) {
return sapi::InternalError("Unable to parse proto from envelope data");
}
return result;
} }
} // namespace sapi } // namespace sapi

View File

@ -53,9 +53,8 @@ sapi::Status InvokeStringReversal(Sandbox* sandbox) {
v::Proto<stringop::StringReverse> pp(proto); v::Proto<stringop::StringReverse> pp(proto);
SAPI_ASSIGN_OR_RETURN(int return_code, api.pb_reverse_string(pp.PtrBoth())); SAPI_ASSIGN_OR_RETURN(int return_code, api.pb_reverse_string(pp.PtrBoth()));
TRANSACTION_FAIL_IF_NOT(return_code != 0, "pb_reverse_string failed"); TRANSACTION_FAIL_IF_NOT(return_code != 0, "pb_reverse_string failed");
std::unique_ptr<stringop::StringReverse> pb_result = pp.GetProtoCopy(); SAPI_ASSIGN_OR_RETURN(auto pb_result, pp.GetMessage());
TRANSACTION_FAIL_IF_NOT(pb_result, "Could not deserialize pb result"); TRANSACTION_FAIL_IF_NOT(pb_result.output() == "olleH", "Incorrect output");
TRANSACTION_FAIL_IF_NOT(pb_result->output() == "olleH", "Incorrect output");
return sapi::OkStatus(); return sapi::OkStatus();
} }

View File

@ -149,6 +149,7 @@ class Transaction : public TransactionBase {
// Callback style transactions: // Callback style transactions:
class BasicTransaction final : public TransactionBase { class BasicTransaction final : public TransactionBase {
private:
using InitFunction = std::function<sapi::Status(Sandbox*)>; using InitFunction = std::function<sapi::Status(Sandbox*)>;
using FinishFunction = std::function<sapi::Status(Sandbox*)>; using FinishFunction = std::function<sapi::Status(Sandbox*)>;

View File

@ -23,6 +23,7 @@
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include "absl/base/attributes.h"
#include "absl/meta/type_traits.h" #include "absl/meta/type_traits.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "sandboxed_api/util/status.pb.h" #include "sandboxed_api/util/status.pb.h"

View File

@ -18,12 +18,16 @@
#define SANDBOXED_API_VAR_PROTO_H_ #define SANDBOXED_API_VAR_PROTO_H_
#include <cinttypes> #include <cinttypes>
#include <cstdint>
#include <vector>
#include "absl/base/macros.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "sandboxed_api/proto_helper.h" #include "sandboxed_api/proto_helper.h"
#include "sandboxed_api/var_lenval.h" #include "sandboxed_api/var_lenval.h"
#include "sandboxed_api/var_pointable.h" #include "sandboxed_api/var_pointable.h"
#include "sandboxed_api/var_ptr.h" #include "sandboxed_api/var_ptr.h"
#include "sandboxed_api/util/status_macros.h"
namespace sapi { namespace sapi {
namespace v { namespace v {
@ -31,7 +35,17 @@ namespace v {
template <typename T> template <typename T>
class Proto : public Pointable, public Var { class Proto : public Pointable, public Var {
public: public:
explicit Proto(const T& proto) : wrapped_var_(SerializeProto(proto)) {} static_assert(std::is_base_of<google::protobuf::Message, T>::value,
"Template argument must be a proto message");
ABSL_DEPRECATED("Use Proto<>::FromMessage() instead")
explicit Proto(const T& proto)
: wrapped_var_(SerializeProto(proto).ValueOrDie()) {}
static sapi::StatusOr<Proto<T>> FromMessage(const T& proto) {
SAPI_ASSIGN_OR_RETURN(std::vector<uint8_t> len_val, SerializeProto(proto));
return Proto(len_val);
}
size_t GetSize() const final { return wrapped_var_.GetSize(); } size_t GetSize() const final { return wrapped_var_.GetSize(); }
Type GetType() const final { return Type::kProto; } Type GetType() const final { return Type::kProto; }
@ -46,18 +60,21 @@ class Proto : public Pointable, public Var {
void* GetLocal() const override { return wrapped_var_.GetLocal(); } void* GetLocal() const override { return wrapped_var_.GetLocal(); }
// Returns a copy of the stored protobuf object. // Returns a copy of the stored protobuf object.
std::unique_ptr<T> GetProtoCopy() const { sapi::StatusOr<T> GetMessage() const {
auto res = absl::make_unique<T>(); return DeserializeProto<T>(
if (!res || reinterpret_cast<const char*>(wrapped_var_.GetData()),
!DeserializeProto(res.get(), wrapped_var_.GetDataSize());
reinterpret_cast<const char*>(wrapped_var_.GetData()),
wrapped_var_.GetDataSize())) {
res.reset();
}
return res;
} }
void SetRemote(void* remote) override { ABSL_DEPRECATED("Use GetMessage() instead")
std::unique_ptr<T> GetProtoCopy() const {
if (auto result_or = GetMessage(); result_or.ok()) {
return absl::make_unique<T>(std::move(result_or).ValueOrDie());
}
return nullptr;
}
void SetRemote(void* /* remote */) override {
// We do not support that much indirection (pointer to a pointer to a // We do not support that much indirection (pointer to a pointer to a
// protobuf) as it is unlikely that this is wanted behavior. If you expect // protobuf) as it is unlikely that this is wanted behavior. If you expect
// this to work, please get in touch with us. // this to work, please get in touch with us.
@ -85,6 +102,8 @@ class Proto : public Pointable, public Var {
} }
private: private:
explicit Proto(std::vector<uint8_t> data) : wrapped_var_(data) {}
// The management of reading/writing the data to the sandboxee is handled by // The management of reading/writing the data to the sandboxee is handled by
// the LenVal class. // the LenVal class.
LenVal wrapped_var_; LenVal wrapped_var_;