diff --git a/sandboxed_api/sandbox2/comms.cc b/sandboxed_api/sandbox2/comms.cc index a4845c8..5195a9d 100644 --- a/sandboxed_api/sandbox2/comms.cc +++ b/sandboxed_api/sandbox2/comms.cc @@ -265,17 +265,16 @@ bool Comms::SendTLV(uint32_t tag, uint64_t length, const uint8_t* bytes) { } bool Comms::RecvString(std::string* v) { - TLV tlv; - if (!RecvTLV(&tlv)) { + uint32_t tag; + if (!RecvTLV(&tag, v)) { return false; } - if (tlv.tag != kTagString) { + if (tag != kTagString) { SAPI_RAW_LOG(ERROR, "Expected (kTagString == 0x%x), got: 0x%x", kTagString, - tlv.tag); + tag); return false; } - v->assign(reinterpret_cast(tlv.value.data()), tlv.value.size()); return true; } @@ -285,16 +284,16 @@ bool Comms::SendString(const std::string& v) { } bool Comms::RecvBytes(std::vector* buffer) { - TLV tlv; - if (!RecvTLV(&tlv)) { + uint32_t tag; + if (!RecvTLV(&tag, buffer)) { return false; } - if (tlv.tag != kTagBytes) { + if (tag != kTagBytes) { + buffer->clear(); SAPI_RAW_LOG(ERROR, "Expected (kTagBytes == 0x%x), got: 0x%u", kTagBytes, - tlv.tag); + tag); return false; } - buffer->swap(tlv.value); return true; } @@ -462,8 +461,9 @@ bool Comms::SendFD(int fd) { } bool Comms::RecvProtoBuf(google::protobuf::Message* message) { - TLV tlv; - if (!RecvTLV(&tlv)) { + uint32_t tag; + std::vector bytes; + if (!RecvTLV(&tag, &bytes)) { if (IsConnected()) { SAPI_RAW_PLOG(ERROR, "RecvProtoBuf failed for (%s)", socket_name_); } else { @@ -473,11 +473,11 @@ bool Comms::RecvProtoBuf(google::protobuf::Message* message) { return false; } - if (tlv.tag != kTagProto2) { - SAPI_RAW_LOG(ERROR, "Expected tag: 0x%x, got: 0x%u", kTagProto2, tlv.tag); + if (tag != kTagProto2) { + SAPI_RAW_LOG(ERROR, "Expected tag: 0x%x, got: 0x%u", kTagProto2, tag); return false; } - return message->ParseFromArray(tlv.value.data(), tlv.value.size()); + return message->ParseFromArray(bytes.data(), bytes.size()); } bool Comms::SendProtoBuf(const google::protobuf::Message& message) { @@ -599,18 +599,16 @@ bool Comms::RecvTL(uint32_t* tag, uint64_t* length) { return true; } -bool Comms::RecvTLV(TLV* tlv) { - absl::MutexLock lock(&tlv_recv_transmission_mutex_); - uint64_t length; - if (!RecvTL(&tlv->tag, &length)) { - return false; - } - - tlv->value.resize(length); - return length == 0 || Recv(tlv->value.data(), length); +bool Comms::RecvTLV(uint32_t* tag, std::vector* value) { + return RecvTLVGeneric(tag, value); } -bool Comms::RecvTLV(uint32_t* tag, std::vector* value) { +bool Comms::RecvTLV(uint32_t* tag, std::string* value) { + return RecvTLVGeneric(tag, value); +} + +template +bool Comms::RecvTLVGeneric(uint32_t* tag, T* value) { absl::MutexLock lock(&tlv_recv_transmission_mutex_); uint64_t length; if (!RecvTL(tag, &length)) { @@ -618,7 +616,7 @@ bool Comms::RecvTLV(uint32_t* tag, std::vector* value) { } value->resize(length); - return length == 0 || Recv(value->data(), length); + return length == 0 || Recv(reinterpret_cast(value->data()), length); } bool Comms::RecvTLV(uint32_t* tag, uint64_t* length, void* buffer, diff --git a/sandboxed_api/sandbox2/comms.h b/sandboxed_api/sandbox2/comms.h index c4ab84e..7863288 100644 --- a/sandboxed_api/sandbox2/comms.h +++ b/sandboxed_api/sandbox2/comms.h @@ -109,6 +109,9 @@ class Comms { // Receive a TLV structure, the memory for the value will be allocated // by std::vector. bool RecvTLV(uint32_t* tag, std::vector* value); + // Receive a TLV structure, the memory for the value will be allocated + // by std::string. + bool RecvTLV(uint32_t* tag, std::string* value); // Receives a TLV value into a specified buffer without allocating memory. bool RecvTLV(uint32_t* tag, uint64_t* length, void* buffer, uint64_t buffer_size); @@ -174,12 +177,6 @@ class Comms { // State of the channel (enum), socket will have to be connected later on. State state_ = State::kUnconnected; - // TLV structure used to pass messages around. - struct TLV { - uint32_t tag; - std::vector value; - }; - // Special struct for passing credentials or FDs. Different from the one above // as it inlines the value. This is important as the data is transmitted using // sendmsg/recvmsg instead of send/recv. @@ -201,8 +198,9 @@ class Comms { bool RecvTL(uint32_t* tag, uint64_t* length) ABSL_EXCLUSIVE_LOCKS_REQUIRED(tlv_recv_transmission_mutex_); - // Receives whole TLV structure, allocates memory for the data. - bool RecvTLV(TLV* tlv); + // T has to be a ContiguousContainer + template + bool RecvTLVGeneric(uint32_t* tag, T* value); // Receives arbitrary integers. bool RecvInt(void* buffer, uint64_t len, uint32_t tag);