Replace use of deprecated sandbox2::Comms functions

PiperOrigin-RevId: 566863078
Change-Id: Ida96eb8046ff96bdd41cec4a1427073ae43930d9
This commit is contained in:
Wiktor Garbacz 2023-09-19 23:54:18 -07:00 committed by Copybara-Service
parent 227daf4a42
commit 9a985f91a7

View File

@ -20,8 +20,8 @@
#include <sys/types.h> #include <sys/types.h>
#include <unistd.h> #include <unistd.h>
#include <cstdint>
#include <cstring> #include <cstring>
#include <ctime>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
@ -34,7 +34,6 @@
#include "absl/log/check.h" #include "absl/log/check.h"
#include "absl/log/log.h" #include "absl/log/log.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "sandboxed_api/sandbox2/comms_test.pb.h" #include "sandboxed_api/sandbox2/comms_test.pb.h"
#include "sandboxed_api/util/status_matchers.h" #include "sandboxed_api/util/status_matchers.h"
@ -49,47 +48,6 @@ namespace sandbox2 {
using CommunicationHandler = std::function<void(Comms* comms)>; using CommunicationHandler = std::function<void(Comms* comms)>;
class CommsTest : public ::testing::TestWithParam<bool> {
void SetUp() override {
use_abstract_socket_ = GetParam();
timespec ts1, ts2;
CHECK_NE(clock_gettime(CLOCK_REALTIME, &ts1), -1);
CHECK_NE(clock_gettime(CLOCK_REALTIME, &ts2), -1);
// If the test does not use an abstract socket, create the socket in the
// '/tmp' directory. The reason to put it in tmp is that we want to
// guarantee that the sockname_ does not go over the limit of 108 (107 char
// + '/0').
if (!use_abstract_socket_) {
sockname_ = "/tmp/";
}
absl::StrAppend(&sockname_, "comms-test-",
static_cast<uint32_t>(ts1.tv_sec), "-",
static_cast<uint32_t>(ts1.tv_nsec));
absl::StrAppend(&sockname_, static_cast<uint32_t>(ts2.tv_sec), "-",
static_cast<uint32_t>(ts2.tv_nsec));
LOG(INFO) << "Sockname: " << sockname_;
CHECK_LT(sockname_.size(), 108);
// Comms channel using a descriptor (initialized with a file descriptor).
int sv[2];
CHECK_NE(socketpair(AF_UNIX, SOCK_STREAM, 0, sv), -1);
fd_server_ = sv[0];
fd_client_ = sv[1];
LOG(INFO) << "FD(client): " << fd_client_ << ", FD(server): " << fd_server_;
}
void TearDown() override {
close(fd_server_);
close(fd_client_);
}
protected:
std::string sockname_ = "";
bool use_abstract_socket_ = true;
int fd_client_;
int fd_server_;
};
constexpr char kProtoStr[] = "ABCD"; constexpr char kProtoStr[] = "ABCD";
static absl::string_view NullTestString() { static absl::string_view NullTestString() {
static constexpr char kHelperStr[] = "test\0\n\r\t\x01\x02"; static constexpr char kHelperStr[] = "test\0\n\r\t\x01\x02";
@ -98,30 +56,24 @@ static absl::string_view NullTestString() {
// Helper function that handles the communication between the two handler // Helper function that handles the communication between the two handler
// functions. // functions.
void HandleCommunication(const std::string& socketname, void HandleCommunication(const CommunicationHandler& a,
bool use_abstract_socket,
const CommunicationHandler& a,
const CommunicationHandler& b) { const CommunicationHandler& b) {
Comms comms(socketname, use_abstract_socket); int sv[2];
comms.Listen(); CHECK_NE(socketpair(AF_UNIX, SOCK_STREAM, 0, sv), -1);
Comms comms(sv[0]);
// Start handler a. // Start handler a.
std::thread remote([&socketname, &a, use_abstract_socket]() { std::thread remote([sv, &a]() {
Comms my_comms(socketname, use_abstract_socket); Comms my_comms(sv[1]);
CHECK(my_comms.Connect());
a(&my_comms); a(&my_comms);
}); });
// Accept connection and run handler b. // Accept connection and run handler b.
CHECK(comms.Accept());
b(&comms); b(&comms);
remote.join(); remote.join();
} }
INSTANTIATE_TEST_SUITE_P(Comms, CommsTest, ::testing::Bool(), TEST(CommsTest, TestSendRecv8) {
::testing::PrintToStringParamName());
TEST_P(CommsTest, TestSendRecv8) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// Send Uint8. // Send Uint8.
ASSERT_THAT(comms->SendUint8(192), IsTrue()); ASSERT_THAT(comms->SendUint8(192), IsTrue());
@ -140,10 +92,10 @@ TEST_P(CommsTest, TestSendRecv8) {
// Send Int8. // Send Int8.
ASSERT_THAT(comms->SendInt8(-7), IsTrue()); ASSERT_THAT(comms->SendInt8(-7), IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendRecv16) { TEST(CommsTest, TestSendRecv16) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// Send Uint16. // Send Uint16.
ASSERT_THAT(comms->SendUint16(40001), IsTrue()); ASSERT_THAT(comms->SendUint16(40001), IsTrue());
@ -162,10 +114,10 @@ TEST_P(CommsTest, TestSendRecv16) {
// Send Int16. // Send Int16.
ASSERT_THAT(comms->SendInt16(-22050), IsTrue()); ASSERT_THAT(comms->SendInt16(-22050), IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendRecv32) { TEST(CommsTest, TestSendRecv32) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// SendUint32. // SendUint32.
ASSERT_THAT(comms->SendUint32(3221225472UL), IsTrue()); ASSERT_THAT(comms->SendUint32(3221225472UL), IsTrue());
@ -184,10 +136,10 @@ TEST_P(CommsTest, TestSendRecv32) {
// Send Int32. // Send Int32.
ASSERT_THAT(comms->SendInt32(-1073741824), IsTrue()); ASSERT_THAT(comms->SendInt32(-1073741824), IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendRecv64) { TEST(CommsTest, TestSendRecv64) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// SendUint64. // SendUint64.
ASSERT_THAT(comms->SendUint64(1099511627776ULL), IsTrue()); ASSERT_THAT(comms->SendUint64(1099511627776ULL), IsTrue());
@ -206,10 +158,10 @@ TEST_P(CommsTest, TestSendRecv64) {
// Send Int64. // Send Int64.
ASSERT_THAT(comms->SendInt64(-1099511627776LL), IsTrue()); ASSERT_THAT(comms->SendInt64(-1099511627776LL), IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestTypeMismatch) { TEST(CommsTest, TestTypeMismatch) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
uint8_t tmpu8; uint8_t tmpu8;
// Receive Int8 (but Uint8 expected). // Receive Int8 (but Uint8 expected).
@ -219,10 +171,10 @@ TEST_P(CommsTest, TestTypeMismatch) {
// Send Int8 (but Uint8 expected). // Send Int8 (but Uint8 expected).
ASSERT_THAT(comms->SendInt8(-93), IsTrue()); ASSERT_THAT(comms->SendInt8(-93), IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendRecvString) { TEST(CommsTest, TestSendRecvString) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
std::string tmps; std::string tmps;
ASSERT_THAT(comms->RecvString(&tmps), IsTrue()); ASSERT_THAT(comms->RecvString(&tmps), IsTrue());
@ -232,10 +184,10 @@ TEST_P(CommsTest, TestSendRecvString) {
auto b = [](Comms* comms) { auto b = [](Comms* comms) {
ASSERT_THAT(comms->SendString(std::string(NullTestString())), IsTrue()); ASSERT_THAT(comms->SendString(std::string(NullTestString())), IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendRecvArray) { TEST(CommsTest, TestSendRecvArray) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// Receive 1M bytes. // Receive 1M bytes.
std::vector<uint8_t> buffer; std::vector<uint8_t> buffer;
@ -248,10 +200,10 @@ TEST_P(CommsTest, TestSendRecvArray) {
memset(buffer.data(), 0, buffer.size()); memset(buffer.data(), 0, buffer.size());
ASSERT_THAT(comms->SendBytes(buffer), IsTrue()); ASSERT_THAT(comms->SendBytes(buffer), IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendRecvFD) { TEST(CommsTest, TestSendRecvFD) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// Receive FD and test it. // Receive FD and test it.
int fd = -1; int fd = -1;
@ -263,10 +215,10 @@ TEST_P(CommsTest, TestSendRecvFD) {
// Send our STDERR to the thread. // Send our STDERR to the thread.
ASSERT_THAT(comms->SendFD(STDERR_FILENO), IsTrue()); ASSERT_THAT(comms->SendFD(STDERR_FILENO), IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendRecvEmptyTLV) { TEST(CommsTest, TestSendRecvEmptyTLV) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// Receive TLV without a value. // Receive TLV without a value.
uint32_t tag; uint32_t tag;
@ -279,10 +231,10 @@ TEST_P(CommsTest, TestSendRecvEmptyTLV) {
// Send TLV without a value. // Send TLV without a value.
ASSERT_THAT(comms->SendTLV(0x00DEADBE, 0, nullptr), IsTrue()); ASSERT_THAT(comms->SendTLV(0x00DEADBE, 0, nullptr), IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendRecvEmptyTLV2) { TEST(CommsTest, TestSendRecvEmptyTLV2) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// Receive TLV without a value. // Receive TLV without a value.
uint32_t tag; uint32_t tag;
@ -295,10 +247,10 @@ TEST_P(CommsTest, TestSendRecvEmptyTLV2) {
// Send TLV without a value. // Send TLV without a value.
ASSERT_THAT(comms->SendTLV(0x00DEADBE, 0, nullptr), IsTrue()); ASSERT_THAT(comms->SendTLV(0x00DEADBE, 0, nullptr), IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendRecvProto) { TEST(CommsTest, TestSendRecvProto) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// Receive a ProtoBuf. // Receive a ProtoBuf.
std::unique_ptr<CommsTestMsg> comms_msg(new CommsTestMsg()); std::unique_ptr<CommsTestMsg> comms_msg(new CommsTestMsg());
@ -313,10 +265,10 @@ TEST_P(CommsTest, TestSendRecvProto) {
ASSERT_THAT(comms_msg->value_size(), Eq(1)); ASSERT_THAT(comms_msg->value_size(), Eq(1));
ASSERT_THAT(comms->SendProtoBuf(*comms_msg), IsTrue()); ASSERT_THAT(comms->SendProtoBuf(*comms_msg), IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendRecvStatusOK) { TEST(CommsTest, TestSendRecvStatusOK) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// Receive a good status. // Receive a good status.
absl::Status status; absl::Status status;
@ -327,10 +279,10 @@ TEST_P(CommsTest, TestSendRecvStatusOK) {
// Send a good status. // Send a good status.
ASSERT_THAT(comms->SendStatus(absl::OkStatus()), IsTrue()); ASSERT_THAT(comms->SendStatus(absl::OkStatus()), IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendRecvStatusFailing) { TEST(CommsTest, TestSendRecvStatusFailing) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// Receive a failing status. // Receive a failing status.
absl::Status status; absl::Status status;
@ -344,10 +296,10 @@ TEST_P(CommsTest, TestSendRecvStatusFailing) {
absl::Status{absl::StatusCode::kInternal, "something odd"}), absl::Status{absl::StatusCode::kInternal, "something odd"}),
IsTrue()); IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestUsesDistinctBuffers) { TEST(CommsTest, TestUsesDistinctBuffers) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// Receive 1M bytes. // Receive 1M bytes.
std::vector<uint8_t> buffer1, buffer2; std::vector<uint8_t> buffer1, buffer2;
@ -370,10 +322,10 @@ TEST_P(CommsTest, TestUsesDistinctBuffers) {
ASSERT_THAT(comms->SendBytes(buf.data(), buf.size()), IsTrue()); ASSERT_THAT(comms->SendBytes(buf.data(), buf.size()), IsTrue());
ASSERT_THAT(comms->SendBytes(buf.data(), buf.size()), IsTrue()); ASSERT_THAT(comms->SendBytes(buf.data(), buf.size()), IsTrue());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendRecvCredentials) { TEST(CommsTest, TestSendRecvCredentials) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// Check credentials. // Check credentials.
pid_t pid; pid_t pid;
@ -387,10 +339,10 @@ TEST_P(CommsTest, TestSendRecvCredentials) {
auto b = [](Comms* comms) { auto b = [](Comms* comms) {
// Nothing to do here. // Nothing to do here.
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendTooMuchData) { TEST(CommsTest, TestSendTooMuchData) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
// Nothing to do here. // Nothing to do here.
}; };
@ -399,10 +351,10 @@ TEST_P(CommsTest, TestSendTooMuchData) {
ASSERT_THAT(comms->SendBytes(nullptr, comms->GetMaxMsgSize() + 1), ASSERT_THAT(comms->SendBytes(nullptr, comms->GetMaxMsgSize() + 1),
IsFalse()); IsFalse());
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
TEST_P(CommsTest, TestSendRecvBytes) { TEST(CommsTest, TestSendRecvBytes) {
auto a = [](Comms* comms) { auto a = [](Comms* comms) {
std::vector<uint8_t> buffer; std::vector<uint8_t> buffer;
ASSERT_THAT(comms->RecvBytes(&buffer), IsTrue()); ASSERT_THAT(comms->RecvBytes(&buffer), IsTrue());
@ -416,12 +368,12 @@ TEST_P(CommsTest, TestSendRecvBytes) {
ASSERT_THAT(comms->RecvBytes(&response), IsTrue()); ASSERT_THAT(comms->RecvBytes(&response), IsTrue());
EXPECT_THAT(request, Eq(response)); EXPECT_THAT(request, Eq(response));
}; };
HandleCommunication(sockname_, use_abstract_socket_, a, b); HandleCommunication(a, b);
} }
// We cannot test this in the Client or Server tests, as the endpoint needs to // We cannot test this in the Client or Server tests, as the endpoint needs to
// be unconnected. // be unconnected.
TEST_P(CommsTest, TestMsgSize) { TEST(CommsTest, TestMsgSize) {
// There will be no actual connection to this socket. // There will be no actual connection to this socket.
const std::string socket_name = "sandbox2_comms_msg_size_test"; const std::string socket_name = "sandbox2_comms_msg_size_test";
Comms c(socket_name); Comms c(socket_name);