mirror of
https://github.com/google/sandboxed-api.git
synced 2024-03-22 13:11:30 +08:00
Remove mutexes from Comms
It was never fully thread-safe. e.g. calling SendProtoBuf concurrently from 2 threads could result in a data race. Also not all users need the thread-safety thus it's better left off to be done externally by the ones that require it. PiperOrigin-RevId: 562548941 Change-Id: Ie32dfca366be9e0c32841e55b688907f4f5f7704
This commit is contained in:
parent
197f03aa5b
commit
3ea315858d
|
@ -843,7 +843,6 @@ cc_library(
|
|||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_protobuf//:protobuf",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -762,7 +762,6 @@ target_link_libraries(sandbox2_comms
|
|||
sapi::status_proto
|
||||
PUBLIC absl::core_headers
|
||||
absl::status
|
||||
absl::synchronization
|
||||
protobuf::libprotobuf
|
||||
sapi::status
|
||||
)
|
||||
|
|
|
@ -43,7 +43,6 @@
|
|||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "google/protobuf/message_lite.h"
|
||||
#include "sandboxed_api/sandbox2/util.h"
|
||||
#include "sandboxed_api/util/raw_logging.h"
|
||||
|
@ -250,7 +249,6 @@ bool Comms::SendTLV(uint32_t tag, size_t length, const void* value) {
|
|||
.len = length,
|
||||
};
|
||||
|
||||
absl::MutexLock lock(&tlv_send_transmission_mutex_);
|
||||
if (length + sizeof(tl) > kSendTLVTempBufferSize) {
|
||||
if (!Send(&tl, sizeof(tl))) {
|
||||
return false;
|
||||
|
@ -615,7 +613,6 @@ bool Comms::RecvTLV(uint32_t* tag, std::string* value) {
|
|||
|
||||
template <typename T>
|
||||
bool Comms::RecvTLVGeneric(uint32_t* tag, T* value) {
|
||||
absl::MutexLock lock(&tlv_recv_transmission_mutex_);
|
||||
size_t length;
|
||||
if (!RecvTL(tag, &length)) {
|
||||
return false;
|
||||
|
@ -627,8 +624,6 @@ bool Comms::RecvTLVGeneric(uint32_t* tag, T* value) {
|
|||
|
||||
bool Comms::RecvTLV(uint32_t* tag, size_t* length, void* buffer,
|
||||
size_t buffer_size) {
|
||||
absl::MutexLock lock(&tlv_recv_transmission_mutex_);
|
||||
|
||||
if (!RecvTL(tag, length)) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -36,11 +36,10 @@
|
|||
#include <vector>
|
||||
|
||||
#include "absl/base/attributes.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/log/die_if_null.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "google/protobuf/message_lite.h"
|
||||
|
||||
namespace proto2 {
|
||||
|
@ -227,11 +226,6 @@ class Comms {
|
|||
int connection_fd_ = -1;
|
||||
int bind_fd_ = -1;
|
||||
|
||||
// Mutex making sure that we serialize TLV messages (which consist out of
|
||||
// three different calls to send / receive).
|
||||
absl::Mutex tlv_send_transmission_mutex_;
|
||||
absl::Mutex tlv_recv_transmission_mutex_;
|
||||
|
||||
// State of the channel (enum), socket will have to be connected later on.
|
||||
State state_ = State::kUnconnected;
|
||||
|
||||
|
@ -257,8 +251,7 @@ class Comms {
|
|||
|
||||
// Receives tag and length. Assumes that the `tlv_transmission_mutex_` mutex
|
||||
// is locked.
|
||||
bool RecvTL(uint32_t* tag, size_t* length)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(tlv_recv_transmission_mutex_);
|
||||
bool RecvTL(uint32_t* tag, size_t* length);
|
||||
|
||||
// T has to be a ContiguousContainer
|
||||
template <typename T>
|
||||
|
|
|
@ -419,98 +419,6 @@ TEST_P(CommsTest, TestSendRecvBytes) {
|
|||
HandleCommunication(sockname_, use_abstract_socket_, a, b);
|
||||
}
|
||||
|
||||
class SenderThread {
|
||||
public:
|
||||
SenderThread(Comms* comms, size_t rounds) : comms_(comms), rounds_(rounds) {}
|
||||
void operator()() {
|
||||
for (size_t i = 0; i < rounds_; i++) {
|
||||
ASSERT_THAT(
|
||||
comms_->SendBytes(reinterpret_cast<const uint8_t*>("Test"), 4),
|
||||
IsTrue());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Comms* comms_;
|
||||
size_t rounds_;
|
||||
};
|
||||
|
||||
class ReceiverThread {
|
||||
public:
|
||||
ReceiverThread(Comms* comms, size_t rounds)
|
||||
: comms_(comms), rounds_(rounds) {}
|
||||
void operator()() {
|
||||
for (size_t i = 0; i < rounds_; i++) {
|
||||
std::vector<uint8_t> buffer;
|
||||
EXPECT_THAT(comms_->RecvBytes(&buffer), IsTrue());
|
||||
EXPECT_THAT(buffer.size(), Eq(4));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Comms* comms_;
|
||||
size_t rounds_;
|
||||
};
|
||||
|
||||
TEST_P(CommsTest, TestMultipleThreads) {
|
||||
// The comms object should be thread safe, this testcase covers this.
|
||||
constexpr size_t kNumThreads = 20;
|
||||
constexpr size_t kNumRoundsPerThread = 50;
|
||||
constexpr size_t kNumRounds = kNumThreads * kNumRoundsPerThread;
|
||||
Comms c(sockname_);
|
||||
c.Listen();
|
||||
|
||||
// Start the client thread.
|
||||
std::string socketname = sockname_;
|
||||
std::thread ct([&socketname]() {
|
||||
Comms comms(socketname);
|
||||
CHECK(comms.Connect());
|
||||
std::vector<uint8_t> buffer;
|
||||
|
||||
// Receive N_ROUND times. We keep the local buffer and send it back
|
||||
// later to increase our A/MSAN coverage.
|
||||
for (size_t i = 0; i < kNumRounds; i++) {
|
||||
ASSERT_THAT(comms.RecvBytes(&buffer), IsTrue());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < kNumRounds; i++) {
|
||||
ASSERT_THAT(comms.SendBytes(buffer), IsTrue());
|
||||
}
|
||||
});
|
||||
|
||||
// Accept connection.
|
||||
ASSERT_THAT(c.Accept(), IsTrue());
|
||||
|
||||
// Start sender threads.
|
||||
{
|
||||
std::thread sender_threads[kNumThreads];
|
||||
for (size_t i = 0; i < kNumThreads; i++) {
|
||||
sender_threads[i] = std::thread(SenderThread(&c, kNumRoundsPerThread));
|
||||
}
|
||||
|
||||
// Join threads.
|
||||
for (size_t i = 0; i < kNumThreads; i++) {
|
||||
sender_threads[i].join();
|
||||
}
|
||||
}
|
||||
|
||||
// Start receiver threads.
|
||||
{
|
||||
std::thread receiver_threads[kNumThreads];
|
||||
for (size_t i = 0; i < kNumThreads; i++) {
|
||||
receiver_threads[i] =
|
||||
std::thread(ReceiverThread(&c, kNumRoundsPerThread));
|
||||
}
|
||||
|
||||
// Join threads.
|
||||
for (size_t i = 0; i < kNumThreads; i++) {
|
||||
receiver_threads[i].join();
|
||||
}
|
||||
}
|
||||
|
||||
ct.join();
|
||||
}
|
||||
|
||||
// We cannot test this in the Client or Server tests, as the endpoint needs to
|
||||
// be unconnected.
|
||||
TEST_P(CommsTest, TestMsgSize) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user