Remove mutexes from Comms

It was never fully thread-safe.
e.g. calling SendProtoBuf concurrently from 2 threads
could result in a data race.
Also not all users need the thread-safety thus it's better left off to be done externally by the ones that require it.

PiperOrigin-RevId: 562548941
Change-Id: Ie32dfca366be9e0c32841e55b688907f4f5f7704
This commit is contained in:
Wiktor Garbacz 2023-09-04 07:00:19 -07:00 committed by Copybara-Service
parent 197f03aa5b
commit 3ea315858d
5 changed files with 2 additions and 108 deletions

View File

@ -843,7 +843,6 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_protobuf//:protobuf",
],
)

View File

@ -762,7 +762,6 @@ target_link_libraries(sandbox2_comms
sapi::status_proto
PUBLIC absl::core_headers
absl::status
absl::synchronization
protobuf::libprotobuf
sapi::status
)

View File

@ -43,7 +43,6 @@
#include "absl/status/statusor.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "google/protobuf/message_lite.h"
#include "sandboxed_api/sandbox2/util.h"
#include "sandboxed_api/util/raw_logging.h"
@ -250,7 +249,6 @@ bool Comms::SendTLV(uint32_t tag, size_t length, const void* value) {
.len = length,
};
absl::MutexLock lock(&tlv_send_transmission_mutex_);
if (length + sizeof(tl) > kSendTLVTempBufferSize) {
if (!Send(&tl, sizeof(tl))) {
return false;
@ -615,7 +613,6 @@ bool Comms::RecvTLV(uint32_t* tag, std::string* value) {
template <typename T>
bool Comms::RecvTLVGeneric(uint32_t* tag, T* value) {
absl::MutexLock lock(&tlv_recv_transmission_mutex_);
size_t length;
if (!RecvTL(tag, &length)) {
return false;
@ -627,8 +624,6 @@ bool Comms::RecvTLVGeneric(uint32_t* tag, T* value) {
bool Comms::RecvTLV(uint32_t* tag, size_t* length, void* buffer,
size_t buffer_size) {
absl::MutexLock lock(&tlv_recv_transmission_mutex_);
if (!RecvTL(tag, length)) {
return false;
}

View File

@ -36,11 +36,10 @@
#include <vector>
#include "absl/base/attributes.h"
#include "absl/base/thread_annotations.h"
#include "absl/log/die_if_null.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/message_lite.h"
namespace proto2 {
@ -227,11 +226,6 @@ class Comms {
int connection_fd_ = -1;
int bind_fd_ = -1;
// Mutex making sure that we serialize TLV messages (which consist out of
// three different calls to send / receive).
absl::Mutex tlv_send_transmission_mutex_;
absl::Mutex tlv_recv_transmission_mutex_;
// State of the channel (enum), socket will have to be connected later on.
State state_ = State::kUnconnected;
@ -257,8 +251,7 @@ class Comms {
// Receives tag and length. Assumes that the `tlv_transmission_mutex_` mutex
// is locked.
bool RecvTL(uint32_t* tag, size_t* length)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(tlv_recv_transmission_mutex_);
bool RecvTL(uint32_t* tag, size_t* length);
// T has to be a ContiguousContainer
template <typename T>

View File

@ -419,98 +419,6 @@ TEST_P(CommsTest, TestSendRecvBytes) {
HandleCommunication(sockname_, use_abstract_socket_, a, b);
}
class SenderThread {
public:
SenderThread(Comms* comms, size_t rounds) : comms_(comms), rounds_(rounds) {}
void operator()() {
for (size_t i = 0; i < rounds_; i++) {
ASSERT_THAT(
comms_->SendBytes(reinterpret_cast<const uint8_t*>("Test"), 4),
IsTrue());
}
}
private:
Comms* comms_;
size_t rounds_;
};
class ReceiverThread {
public:
ReceiverThread(Comms* comms, size_t rounds)
: comms_(comms), rounds_(rounds) {}
void operator()() {
for (size_t i = 0; i < rounds_; i++) {
std::vector<uint8_t> buffer;
EXPECT_THAT(comms_->RecvBytes(&buffer), IsTrue());
EXPECT_THAT(buffer.size(), Eq(4));
}
}
private:
Comms* comms_;
size_t rounds_;
};
TEST_P(CommsTest, TestMultipleThreads) {
// The comms object should be thread safe, this testcase covers this.
constexpr size_t kNumThreads = 20;
constexpr size_t kNumRoundsPerThread = 50;
constexpr size_t kNumRounds = kNumThreads * kNumRoundsPerThread;
Comms c(sockname_);
c.Listen();
// Start the client thread.
std::string socketname = sockname_;
std::thread ct([&socketname]() {
Comms comms(socketname);
CHECK(comms.Connect());
std::vector<uint8_t> buffer;
// Receive N_ROUND times. We keep the local buffer and send it back
// later to increase our A/MSAN coverage.
for (size_t i = 0; i < kNumRounds; i++) {
ASSERT_THAT(comms.RecvBytes(&buffer), IsTrue());
}
for (size_t i = 0; i < kNumRounds; i++) {
ASSERT_THAT(comms.SendBytes(buffer), IsTrue());
}
});
// Accept connection.
ASSERT_THAT(c.Accept(), IsTrue());
// Start sender threads.
{
std::thread sender_threads[kNumThreads];
for (size_t i = 0; i < kNumThreads; i++) {
sender_threads[i] = std::thread(SenderThread(&c, kNumRoundsPerThread));
}
// Join threads.
for (size_t i = 0; i < kNumThreads; i++) {
sender_threads[i].join();
}
}
// Start receiver threads.
{
std::thread receiver_threads[kNumThreads];
for (size_t i = 0; i < kNumThreads; i++) {
receiver_threads[i] =
std::thread(ReceiverThread(&c, kNumRoundsPerThread));
}
// Join threads.
for (size_t i = 0; i < kNumThreads; i++) {
receiver_threads[i].join();
}
}
ct.join();
}
// We cannot test this in the Client or Server tests, as the endpoint needs to
// be unconnected.
TEST_P(CommsTest, TestMsgSize) {