Refactor the tests and strings example a bit

PiperOrigin-RevId: 268865491
Change-Id: Ie16e5f17e2eb22e25821c34edf0068cb81bcc2fe
This commit is contained in:
Christian Blichmann 2019-09-13 02:28:09 -07:00 committed by Copybara-Service
parent d6ca9d9564
commit 0aa7183502
8 changed files with 133 additions and 114 deletions

View File

@ -28,7 +28,6 @@ struct ReallocRequest {
};
// Types of TAGs used with Comms channel.
// TODO(cblichmann): Mark these as "inline" once we're on C++17.
// Call:
constexpr uint32_t kMsgCall = 0x101;
constexpr uint32_t kMsgAllocate = 0x102;

View File

@ -45,7 +45,7 @@ namespace sapi {
namespace {
// Guess the FFI type on the basis of data size and float/non-float/bool.
static ffi_type* GetFFIType(size_t size, v::Type type) {
ffi_type* GetFFIType(size_t size, v::Type type) {
switch (type) {
case v::Type::kVoid:
return &ffi_type_void;
@ -94,11 +94,11 @@ class FunctionCallPreparer {
explicit FunctionCallPreparer(const FuncCall& call) {
CHECK(call.argc <= FuncCall::kArgsMax)
<< "Number of arguments of a sandbox call exceeds limits.";
for (size_t i = 0; i < call.argc; i++) {
for (int i = 0; i < call.argc; ++i) {
arg_types_[i] = GetFFIType(call.arg_size[i], call.arg_type[i]);
}
ret_type_ = GetFFIType(call.ret_size, call.ret_type);
for (size_t i = 0; i < call.argc; i++) {
for (int i = 0; i < call.argc; ++i) {
if (call.arg_type[i] == v::Type::kPointer &&
call.aux_type[i] == v::Type::kProto) {
// Deserialize protobuf stored in the LenValueStruct and keep a
@ -107,16 +107,13 @@ class FunctionCallPreparer {
// This will also make sure that the protobuf is freed afterwards.
arg_values_[i] = GetDeserializedProto(
reinterpret_cast<LenValStruct*>(call.args[i].arg_int));
} else {
if (call.arg_type[i] == v::Type::kFloat) {
arg_values_[i] =
reinterpret_cast<const void*>(&call.args[i].arg_float);
} else if (call.arg_type[i] == v::Type::kFloat) {
arg_values_[i] = reinterpret_cast<const void*>(&call.args[i].arg_float);
} else {
arg_values_[i] = reinterpret_cast<const void*>(&call.args[i].arg_int);
}
}
}
}
~FunctionCallPreparer() {
for (const auto& idx_proto : protos_to_be_destroyed_) {
@ -125,8 +122,8 @@ class FunctionCallPreparer {
// There is no way to figure out whether the protobuf structure has
// changed or not, so we always serialize the protobuf again and replace
// the LenValStruct content.
std::vector<uint8_t> serialized = SerializeProto(*proto);
// Reallocate the LV memory to match it's length.
std::vector<uint8_t> serialized = SerializeProto(*proto).ValueOrDie();
// Reallocate the LV memory to match its length.
if (lvs->size != serialized.size()) {
void* newdata = realloc(lvs->data, serialized.size());
if (!newdata) {

View File

@ -30,25 +30,32 @@
#include "sandboxed_api/vars.h"
#include "sandboxed_api/util/canonical_errors.h"
#include "sandboxed_api/util/status.h"
#include "sandboxed_api/util/status_macros.h"
using ::sapi::IsOk;
using ::testing::Eq;
using ::testing::Ne;
using ::testing::SizeIs;
using ::testing::StrEq;
namespace {
// Tests using the simple transaction (and function pointers):
// Tests using a simple transaction (and function pointers):
TEST(StringopTest, ProtobufStringDuplication) {
sapi::BasicTransaction st(absl::make_unique<StringopSapiSandbox>());
EXPECT_THAT(st.Run([](sapi::Sandbox* sandbox) -> sapi::Status {
StringopApi f(sandbox);
StringopApi api(sandbox);
stringop::StringDuplication proto;
proto.set_input("Hello");
sapi::v::Proto<stringop::StringDuplication> pp(proto);
SAPI_ASSIGN_OR_RETURN(int v, f.pb_duplicate_string(pp.PtrBoth()));
TRANSACTION_FAIL_IF_NOT(v, "pb_duplicate_string failed");
auto pb_result = pp.GetProtoCopy();
TRANSACTION_FAIL_IF_NOT(pb_result, "Could not deserialize pb result");
LOG(INFO) << "Result PB: " << pb_result->DebugString();
TRANSACTION_FAIL_IF_NOT(pb_result->output() == "HelloHello",
{
SAPI_ASSIGN_OR_RETURN(int return_value, api.pb_duplicate_string(pp.PtrBoth()));
TRANSACTION_FAIL_IF_NOT(return_value, "pb_duplicate_string() failed");
}
SAPI_ASSIGN_OR_RETURN(auto pb_result, pp.GetMessage());
LOG(INFO) << "Result PB: " << pb_result.DebugString();
TRANSACTION_FAIL_IF_NOT(pb_result.output() == "HelloHello",
"Incorrect output");
return sapi::OkStatus();
}),
@ -56,84 +63,71 @@ TEST(StringopTest, ProtobufStringDuplication) {
}
TEST(StringopTest, ProtobufStringReversal) {
sapi::BasicTransaction st(absl::make_unique<StringopSapiSandbox>());
EXPECT_THAT(st.Run([](sapi::Sandbox* sandbox) -> sapi::Status {
StringopApi f(sandbox);
StringopSapiSandbox sandbox;
ASSERT_THAT(sandbox.Init(), IsOk());
StringopApi api(&sandbox);
stringop::StringReverse proto;
proto.set_input("Hello");
sapi::v::Proto<stringop::StringReverse> pp(proto);
SAPI_ASSIGN_OR_RETURN(int v, f.pb_reverse_string(pp.PtrBoth()));
TRANSACTION_FAIL_IF_NOT(v, "pb_reverse_string failed");
auto pb_result = pp.GetProtoCopy();
TRANSACTION_FAIL_IF_NOT(pb_result, "Could not deserialize pb result");
LOG(INFO) << "Result PB: " << pb_result->DebugString();
TRANSACTION_FAIL_IF_NOT(pb_result->output() == "olleH", "Incorrect output");
return sapi::OkStatus();
}),
IsOk());
SAPI_ASSERT_OK_AND_ASSIGN(int return_value, api.pb_reverse_string(pp.PtrBoth()));
EXPECT_THAT(return_value, Ne(0)) << "pb_reverse_string() failed";
SAPI_ASSERT_OK_AND_ASSIGN(auto pb_result, pp.GetMessage());
LOG(INFO) << "Result PB: " << pb_result.DebugString();
EXPECT_THAT(pb_result.output(), StrEq("olleH"));
}
// Tests using raw dynamic buffers.
TEST(StringopTest, RawStringDuplication) {
sapi::BasicTransaction st(absl::make_unique<StringopSapiSandbox>());
EXPECT_THAT(st.Run([](sapi::Sandbox* sandbox) -> sapi::Status {
StringopApi f(sandbox);
StringopSapiSandbox sandbox;
ASSERT_THAT(sandbox.Init(), IsOk());
StringopApi api(&sandbox);
sapi::v::LenVal param("0123456789", 10);
SAPI_ASSIGN_OR_RETURN(int return_value, f.duplicate_string(param.PtrBoth()));
TRANSACTION_FAIL_IF_NOT(return_value == 1,
"duplicate_string() returned incorrect value");
TRANSACTION_FAIL_IF_NOT(param.GetDataSize() == 20,
"duplicate_string() did not return enough data");
SAPI_ASSERT_OK_AND_ASSIGN(int return_value, api.duplicate_string(param.PtrBoth()));
EXPECT_THAT(return_value, Eq(1)) << "duplicate_string() failed";
absl::string_view data(reinterpret_cast<const char*>(param.GetData()),
param.GetDataSize());
TRANSACTION_FAIL_IF_NOT(
data == "01234567890123456789",
"duplicate_string() did not return the expected data");
return sapi::OkStatus();
}),
IsOk());
EXPECT_THAT(data, SizeIs(20))
<< "duplicate_string() did not return enough data";
EXPECT_THAT(std::string(data), StrEq("01234567890123456789"));
}
TEST(StringopTest, RawStringReversal) {
sapi::BasicTransaction st(absl::make_unique<StringopSapiSandbox>());
EXPECT_THAT(st.Run([](sapi::Sandbox* sandbox) -> sapi::Status {
StringopApi f(sandbox);
StringopSapiSandbox sandbox;
ASSERT_THAT(sandbox.Init(), IsOk());
StringopApi api(&sandbox);
sapi::v::LenVal param("0123456789", 10);
{
SAPI_ASSIGN_OR_RETURN(int return_value, f.reverse_string(param.PtrBoth()));
TRANSACTION_FAIL_IF_NOT(return_value == 1,
"reverse_string() returned incorrect value");
TRANSACTION_FAIL_IF_NOT(param.GetDataSize() == 10,
"reverse_string() did not return enough data");
SAPI_ASSERT_OK_AND_ASSIGN(int return_value, api.reverse_string(param.PtrBoth()));
EXPECT_THAT(return_value, Eq(1))
<< "reverse_string() returned incorrect value";
absl::string_view data(reinterpret_cast<const char*>(param.GetData()),
param.GetDataSize());
TRANSACTION_FAIL_IF_NOT(
data == "9876543210",
"reverse_string() did not return the expected data");
EXPECT_THAT(param.GetDataSize(), Eq(10))
<< "reverse_string() did not return enough data";
EXPECT_THAT(std::string(data), StrEq("9876543210"))
<< "reverse_string() did not return the expected data";
}
{
// Let's call it again with different data as argument, reusing the
// existing LenVal object.
SAPI_RETURN_IF_ERROR(param.ResizeData(sandbox->GetRpcChannel(), 16));
EXPECT_THAT(param.ResizeData(sandbox.GetRpcChannel(), 16), IsOk());
memcpy(param.GetData() + 10, "ABCDEF", 6);
TRANSACTION_FAIL_IF_NOT(param.GetDataSize() == 16,
"Resize did not behave correctly");
absl::string_view data(reinterpret_cast<const char*>(param.GetData()),
param.GetDataSize());
TRANSACTION_FAIL_IF_NOT(data == "9876543210ABCDEF",
"Data not as expected");
SAPI_ASSIGN_OR_RETURN(int return_value, f.reverse_string(param.PtrBoth()));
TRANSACTION_FAIL_IF_NOT(return_value == 1,
"reverse_string() returned incorrect value");
EXPECT_THAT(data, SizeIs(16)) << "Resize did not behave correctly";
EXPECT_THAT(std::string(data), StrEq("9876543210ABCDEF"));
SAPI_ASSERT_OK_AND_ASSIGN(int return_value, api.reverse_string(param.PtrBoth()));
EXPECT_THAT(return_value, Eq(1))
<< "reverse_string() returned incorrect value";
data = absl::string_view(reinterpret_cast<const char*>(param.GetData()),
param.GetDataSize());
TRANSACTION_FAIL_IF_NOT(
data == "FEDCBA0123456789",
"reverse_string() did not return the expected data");
EXPECT_THAT(std::string(data), StrEq("FEDCBA0123456789"));
}
return sapi::OkStatus();
}),
IsOk());
}
} // namespace

View File

@ -18,40 +18,49 @@
#define SANDBOXED_API_PROTO_HELPER_H_
#include <cinttypes>
#include <type_traits>
#include <vector>
#include <glog/logging.h>
#include "sandboxed_api/proto_arg.pb.h"
#include "sandboxed_api/util/canonical_errors.h"
#include "sandboxed_api/util/status.h"
#include "sandboxed_api/util/statusor.h"
namespace sapi {
template <typename T>
std::vector<uint8_t> SerializeProto(const T& proto) {
sapi::StatusOr<std::vector<uint8_t>> SerializeProto(const T& proto) {
static_assert(std::is_base_of<google::protobuf::Message, T>::value,
"Template argument must be a proto message");
// Wrap protobuf in a envelope so that we know the name of the protobuf
// structure when deserializing in the sandboxee.
ProtoArg proto_arg;
proto_arg.set_protobuf_data(proto.SerializeAsString());
proto_arg.set_full_name(proto.GetDescriptor()->full_name());
std::vector<uint8_t> serialized_proto(proto_arg.ByteSizeLong());
std::vector<uint8_t> serialized_proto(proto_arg.ByteSizeLong());
if (!proto_arg.SerializeToArray(serialized_proto.data(),
serialized_proto.size())) {
LOG(ERROR) << "Unable to serialize array";
return sapi::InternalError("Unable to serialize proto to array");
}
return serialized_proto;
}
template <typename T>
bool DeserializeProto(T* result, const char* data, size_t len) {
sapi::StatusOr<T> DeserializeProto(const char* data, size_t len) {
static_assert(std::is_base_of<google::protobuf::Message, T>::value,
"Template argument must be a proto message");
ProtoArg envelope;
if (!envelope.ParseFromArray(data, len)) {
LOG(ERROR) << "Unable to deserialize envelope";
return false;
return sapi::InternalError("Unable to parse proto from array");
}
auto pb_data = envelope.protobuf_data();
return result->ParseFromArray(pb_data.c_str(), pb_data.size());
T result;
if (!result.ParseFromArray(pb_data.data(), pb_data.size())) {
return sapi::InternalError("Unable to parse proto from envelope data");
}
return result;
}
} // namespace sapi

View File

@ -53,9 +53,8 @@ sapi::Status InvokeStringReversal(Sandbox* sandbox) {
v::Proto<stringop::StringReverse> pp(proto);
SAPI_ASSIGN_OR_RETURN(int return_code, api.pb_reverse_string(pp.PtrBoth()));
TRANSACTION_FAIL_IF_NOT(return_code != 0, "pb_reverse_string failed");
std::unique_ptr<stringop::StringReverse> pb_result = pp.GetProtoCopy();
TRANSACTION_FAIL_IF_NOT(pb_result, "Could not deserialize pb result");
TRANSACTION_FAIL_IF_NOT(pb_result->output() == "olleH", "Incorrect output");
SAPI_ASSIGN_OR_RETURN(auto pb_result, pp.GetMessage());
TRANSACTION_FAIL_IF_NOT(pb_result.output() == "olleH", "Incorrect output");
return sapi::OkStatus();
}

View File

@ -149,6 +149,7 @@ class Transaction : public TransactionBase {
// Callback style transactions:
class BasicTransaction final : public TransactionBase {
private:
using InitFunction = std::function<sapi::Status(Sandbox*)>;
using FinishFunction = std::function<sapi::Status(Sandbox*)>;

View File

@ -23,6 +23,7 @@
#include <string>
#include <type_traits>
#include "absl/base/attributes.h"
#include "absl/meta/type_traits.h"
#include "absl/strings/string_view.h"
#include "sandboxed_api/util/status.pb.h"

View File

@ -18,12 +18,16 @@
#define SANDBOXED_API_VAR_PROTO_H_
#include <cinttypes>
#include <cstdint>
#include <vector>
#include "absl/base/macros.h"
#include "absl/memory/memory.h"
#include "sandboxed_api/proto_helper.h"
#include "sandboxed_api/var_lenval.h"
#include "sandboxed_api/var_pointable.h"
#include "sandboxed_api/var_ptr.h"
#include "sandboxed_api/util/status_macros.h"
namespace sapi {
namespace v {
@ -31,7 +35,17 @@ namespace v {
template <typename T>
class Proto : public Pointable, public Var {
public:
explicit Proto(const T& proto) : wrapped_var_(SerializeProto(proto)) {}
static_assert(std::is_base_of<google::protobuf::Message, T>::value,
"Template argument must be a proto message");
ABSL_DEPRECATED("Use Proto<>::FromMessage() instead")
explicit Proto(const T& proto)
: wrapped_var_(SerializeProto(proto).ValueOrDie()) {}
static sapi::StatusOr<Proto<T>> FromMessage(const T& proto) {
SAPI_ASSIGN_OR_RETURN(std::vector<uint8_t> len_val, SerializeProto(proto));
return Proto(len_val);
}
size_t GetSize() const final { return wrapped_var_.GetSize(); }
Type GetType() const final { return Type::kProto; }
@ -46,18 +60,21 @@ class Proto : public Pointable, public Var {
void* GetLocal() const override { return wrapped_var_.GetLocal(); }
// Returns a copy of the stored protobuf object.
std::unique_ptr<T> GetProtoCopy() const {
auto res = absl::make_unique<T>();
if (!res ||
!DeserializeProto(res.get(),
sapi::StatusOr<T> GetMessage() const {
return DeserializeProto<T>(
reinterpret_cast<const char*>(wrapped_var_.GetData()),
wrapped_var_.GetDataSize())) {
res.reset();
}
return res;
wrapped_var_.GetDataSize());
}
void SetRemote(void* remote) override {
ABSL_DEPRECATED("Use GetMessage() instead")
std::unique_ptr<T> GetProtoCopy() const {
if (auto result_or = GetMessage(); result_or.ok()) {
return absl::make_unique<T>(std::move(result_or).ValueOrDie());
}
return nullptr;
}
void SetRemote(void* /* remote */) override {
// We do not support that much indirection (pointer to a pointer to a
// protobuf) as it is unlikely that this is wanted behavior. If you expect
// this to work, please get in touch with us.
@ -85,6 +102,8 @@ class Proto : public Pointable, public Var {
}
private:
explicit Proto(std::vector<uint8_t> data) : wrapped_var_(data) {}
// The management of reading/writing the data to the sandboxee is handled by
// the LenVal class.
LenVal wrapped_var_;