From 9f2ba9d6a14a838496063d284de95749d4d24dda Mon Sep 17 00:00:00 2001 From: Sandboxed API Team Date: Thu, 23 Mar 2023 07:33:49 -0700 Subject: [PATCH] Comms constructor for non abstract sockets Allows to create a Comms with unix domain sockets that are not abstract. This allows to use Comms to talk across network namespaces PiperOrigin-RevId: 518854724 Change-Id: I4fd65466bba9512f448b73bde367f38a0fbb584d --- sandboxed_api/sandbox2/comms.cc | 27 ++++--- sandboxed_api/sandbox2/comms.h | 10 ++- sandboxed_api/sandbox2/comms_test.cc | 112 +++++++++++++++------------ 3 files changed, 88 insertions(+), 61 deletions(-) diff --git a/sandboxed_api/sandbox2/comms.cc b/sandboxed_api/sandbox2/comms.cc index 663fcdf..0b17b7c 100644 --- a/sandboxed_api/sandbox2/comms.cc +++ b/sandboxed_api/sandbox2/comms.cc @@ -77,7 +77,8 @@ int GetDefaultCommsFd() { } } // namespace -Comms::Comms(const std::string& socket_name) : socket_name_(socket_name) {} +Comms::Comms(const std::string& socket_name, bool abstract_uds) + : socket_name_(socket_name), abstract_uds_(abstract_uds) {} Comms::Comms(int fd) : connection_fd_(fd) { // Generate a unique and meaningful socket name for this FD. @@ -485,16 +486,24 @@ bool Comms::SendProtoBuf(const google::protobuf::MessageLite& message) { socklen_t Comms::CreateSockaddrUn(sockaddr_un* sun) { sun->sun_family = AF_UNIX; bzero(sun->sun_path, sizeof(sun->sun_path)); - // Create an 'abstract socket address' by specifying a leading null byte. The - // remainder of the path is used as a unique name, but no file is created on - // the filesystem. No need to NUL-terminate the string. - // See `man 7 unix` for further explanation. - strncpy(&sun->sun_path[1], socket_name_.c_str(), sizeof(sun->sun_path) - 1); + socklen_t slen = sizeof(sun->sun_family) + strlen(socket_name_.c_str()); + if (abstract_uds_) { + // Create an 'abstract socket address' by specifying a leading null byte. + // The remainder of the path is used as a unique name, but no file is + // created on the filesystem. No need to NUL-terminate the string. See `man + // 7 unix` for further explanation. + strncpy(&sun->sun_path[1], socket_name_.c_str(), sizeof(sun->sun_path) - 1); + // Len is complicated - it's essentially size of the path, plus initial + // NUL-byte, minus size of the sun.sun_family. + slen++; + } else { + // Create the socket address as it was passed from the constructor. + strncpy(&sun->sun_path[0], socket_name_.c_str(), sizeof(sun->sun_path)); + } - // Len is complicated - it's essentially size of the path, plus initial - // NUL-byte, minus size of the sun.sun_family. - socklen_t slen = sizeof(sun->sun_family) + strlen(socket_name_.c_str()) + 1; + // This takes care of the socket address overflow. if (slen > sizeof(sockaddr_un)) { + SAPI_RAW_LOG(ERROR, "Socket address is too long, will be truncated"); slen = sizeof(sockaddr_un); } return slen; diff --git a/sandboxed_api/sandbox2/comms.h b/sandboxed_api/sandbox2/comms.h index cb2f8ea..8211519 100644 --- a/sandboxed_api/sandbox2/comms.h +++ b/sandboxed_api/sandbox2/comms.h @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -// The sandbox2::Comms class uses AF_UNIX sockets in the abstract namespace -// (man 7 unix) to send pieces of data between processes. It uses the TLV -// encoding and provides some useful helpers. +// The sandbox2::Comms class uses AF_UNIX sockets (man 7 unix) to send pieces of +// data between processes. It uses the TLV encoding and provides some useful +// helpers. // // The endianess is platform-specific, but as it can be used over abstract // sockets only, that's not a problem. Is some poor soul decides to rewrite it @@ -83,7 +83,8 @@ class Comms { static constexpr const char* kSandbox2CommsFDEnvVar = "SANDBOX2_COMMS_FD"; // This object will have to be connected later on. - explicit Comms(const std::string& socket_name); + // When not specified the constructor uses abstract unix domain sockets. + explicit Comms(const std::string& socket_name, bool abstract_uds = true); Comms(const Comms&) = delete; Comms& operator=(const Comms&) = delete; @@ -186,6 +187,7 @@ class Comms { // Connection parameters. std::string socket_name_; + bool abstract_uds_ = true; int connection_fd_ = -1; int bind_fd_ = -1; diff --git a/sandboxed_api/sandbox2/comms_test.cc b/sandboxed_api/sandbox2/comms_test.cc index b68c897..93d4490 100644 --- a/sandboxed_api/sandbox2/comms_test.cc +++ b/sandboxed_api/sandbox2/comms_test.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include // NOLINT(build/c++11) #include @@ -31,6 +32,7 @@ #include "absl/container/fixed_array.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "sandboxed_api/sandbox2/comms_test.pb.h" #include "sandboxed_api/util/status_matchers.h" @@ -45,18 +47,27 @@ namespace sandbox2 { using CommunicationHandler = std::function; -class CommsTest : public ::testing::Test { +class CommsTest : public ::testing::TestWithParam { void SetUp() override { - // Comms channel using an abstract socket namespace (initialized with socket - // name). + use_abstract_socket_ = GetParam(); + timespec ts1, ts2; CHECK_NE(clock_gettime(CLOCK_REALTIME, &ts1), -1); CHECK_NE(clock_gettime(CLOCK_REALTIME, &ts2), -1); - snprintf( - sockname_, sizeof(sockname_), "comms-test-%u-%u-%u-%u", - static_cast(ts1.tv_sec), static_cast(ts1.tv_nsec), - static_cast(ts2.tv_sec), static_cast(ts2.tv_nsec)); + // If the test does not use an abstract socket, create the socket in the + // '/tmp' directory. The reason to put it in tmp is that we want to + // guarantee that the sockname_ does not go over the limit of 108 (107 char + // + '/0'). + if (!use_abstract_socket_) { + sockname_ = "/tmp/"; + } + absl::StrAppend(&sockname_, "comms-test-", + static_cast(ts1.tv_sec), "-", + static_cast(ts1.tv_nsec)); + absl::StrAppend(&sockname_, static_cast(ts2.tv_sec), "-", + static_cast(ts2.tv_nsec)); LOG(INFO) << "Sockname: " << sockname_; + CHECK_LT(sockname_.size(), 108); // Comms channel using a descriptor (initialized with a file descriptor). int sv[2]; @@ -71,13 +82,14 @@ class CommsTest : public ::testing::Test { } protected: - char sockname_[256]; + std::string sockname_ = ""; + bool use_abstract_socket_ = true; int fd_client_; int fd_server_; }; constexpr char kProtoStr[] = "ABCD"; -static const absl::string_view NullTestString() { +static absl::string_view NullTestString() { static constexpr char kHelperStr[] = "test\0\n\r\t\x01\x02"; return absl::string_view(kHelperStr, sizeof(kHelperStr) - 1); } @@ -85,14 +97,15 @@ static const absl::string_view NullTestString() { // Helper function that handles the communication between the two handler // functions. void HandleCommunication(const std::string& socketname, + bool use_abstract_socket, const CommunicationHandler& a, const CommunicationHandler& b) { - Comms comms(socketname); + Comms comms(socketname, use_abstract_socket); comms.Listen(); // Start handler a. - std::thread remote([&socketname, &a]() { - Comms my_comms(socketname); + std::thread remote([&socketname, &a, use_abstract_socket]() { + Comms my_comms(socketname, use_abstract_socket); CHECK(my_comms.Connect()); a(&my_comms); }); @@ -103,7 +116,10 @@ void HandleCommunication(const std::string& socketname, remote.join(); } -TEST_F(CommsTest, TestSendRecv8) { +INSTANTIATE_TEST_SUITE_P(Comms, CommsTest, ::testing::Bool(), + ::testing::PrintToStringParamName()); + +TEST_P(CommsTest, TestSendRecv8) { auto a = [](Comms* comms) { // Send Uint8. ASSERT_THAT(comms->SendUint8(192), IsTrue()); @@ -122,10 +138,10 @@ TEST_F(CommsTest, TestSendRecv8) { // Send Int8. ASSERT_THAT(comms->SendInt8(-7), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendRecv16) { +TEST_P(CommsTest, TestSendRecv16) { auto a = [](Comms* comms) { // Send Uint16. ASSERT_THAT(comms->SendUint16(40001), IsTrue()); @@ -144,10 +160,10 @@ TEST_F(CommsTest, TestSendRecv16) { // Send Int16. ASSERT_THAT(comms->SendInt16(-22050), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendRecv32) { +TEST_P(CommsTest, TestSendRecv32) { auto a = [](Comms* comms) { // SendUint32. ASSERT_THAT(comms->SendUint32(3221225472UL), IsTrue()); @@ -166,10 +182,10 @@ TEST_F(CommsTest, TestSendRecv32) { // Send Int32. ASSERT_THAT(comms->SendInt32(-1073741824), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendRecv64) { +TEST_P(CommsTest, TestSendRecv64) { auto a = [](Comms* comms) { // SendUint64. ASSERT_THAT(comms->SendUint64(1099511627776ULL), IsTrue()); @@ -188,10 +204,10 @@ TEST_F(CommsTest, TestSendRecv64) { // Send Int64. ASSERT_THAT(comms->SendInt64(-1099511627776LL), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestTypeMismatch) { +TEST_P(CommsTest, TestTypeMismatch) { auto a = [](Comms* comms) { uint8_t tmpu8; // Receive Int8 (but Uint8 expected). @@ -201,10 +217,10 @@ TEST_F(CommsTest, TestTypeMismatch) { // Send Int8 (but Uint8 expected). ASSERT_THAT(comms->SendInt8(-93), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendRecvString) { +TEST_P(CommsTest, TestSendRecvString) { auto a = [](Comms* comms) { std::string tmps; ASSERT_THAT(comms->RecvString(&tmps), IsTrue()); @@ -214,10 +230,10 @@ TEST_F(CommsTest, TestSendRecvString) { auto b = [](Comms* comms) { ASSERT_THAT(comms->SendString(std::string(NullTestString())), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendRecvArray) { +TEST_P(CommsTest, TestSendRecvArray) { auto a = [](Comms* comms) { // Receive 1M bytes. std::vector buffer; @@ -230,10 +246,10 @@ TEST_F(CommsTest, TestSendRecvArray) { memset(buffer.data(), 0, buffer.size()); ASSERT_THAT(comms->SendBytes(buffer), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendRecvFD) { +TEST_P(CommsTest, TestSendRecvFD) { auto a = [](Comms* comms) { // Receive FD and test it. int fd = -1; @@ -245,10 +261,10 @@ TEST_F(CommsTest, TestSendRecvFD) { // Send our STDERR to the thread. ASSERT_THAT(comms->SendFD(STDERR_FILENO), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendRecvEmptyTLV) { +TEST_P(CommsTest, TestSendRecvEmptyTLV) { auto a = [](Comms* comms) { // Receive TLV without a value. uint32_t tag; @@ -261,10 +277,10 @@ TEST_F(CommsTest, TestSendRecvEmptyTLV) { // Send TLV without a value. ASSERT_THAT(comms->SendTLV(0x00DEADBE, 0, nullptr), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendRecvEmptyTLV2) { +TEST_P(CommsTest, TestSendRecvEmptyTLV2) { auto a = [](Comms* comms) { // Receive TLV without a value. uint32_t tag; @@ -277,10 +293,10 @@ TEST_F(CommsTest, TestSendRecvEmptyTLV2) { // Send TLV without a value. ASSERT_THAT(comms->SendTLV(0x00DEADBE, 0, nullptr), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendRecvProto) { +TEST_P(CommsTest, TestSendRecvProto) { auto a = [](Comms* comms) { // Receive a ProtoBuf. std::unique_ptr comms_msg(new CommsTestMsg()); @@ -295,10 +311,10 @@ TEST_F(CommsTest, TestSendRecvProto) { ASSERT_THAT(comms_msg->value_size(), Eq(1)); ASSERT_THAT(comms->SendProtoBuf(*comms_msg), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendRecvStatusOK) { +TEST_P(CommsTest, TestSendRecvStatusOK) { auto a = [](Comms* comms) { // Receive a good status. absl::Status status; @@ -309,10 +325,10 @@ TEST_F(CommsTest, TestSendRecvStatusOK) { // Send a good status. ASSERT_THAT(comms->SendStatus(absl::OkStatus()), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendRecvStatusFailing) { +TEST_P(CommsTest, TestSendRecvStatusFailing) { auto a = [](Comms* comms) { // Receive a failing status. absl::Status status; @@ -326,10 +342,10 @@ TEST_F(CommsTest, TestSendRecvStatusFailing) { absl::Status{absl::StatusCode::kInternal, "something odd"}), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestUsesDistinctBuffers) { +TEST_P(CommsTest, TestUsesDistinctBuffers) { auto a = [](Comms* comms) { // Receive 1M bytes. std::vector buffer1, buffer2; @@ -352,10 +368,10 @@ TEST_F(CommsTest, TestUsesDistinctBuffers) { ASSERT_THAT(comms->SendBytes(buf.data(), buf.size()), IsTrue()); ASSERT_THAT(comms->SendBytes(buf.data(), buf.size()), IsTrue()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendRecvCredentials) { +TEST_P(CommsTest, TestSendRecvCredentials) { auto a = [](Comms* comms) { // Check credentials. pid_t pid; @@ -369,10 +385,10 @@ TEST_F(CommsTest, TestSendRecvCredentials) { auto b = [](Comms* comms) { // Nothing to do here. }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendTooMuchData) { +TEST_P(CommsTest, TestSendTooMuchData) { auto a = [](Comms* comms) { // Nothing to do here. }; @@ -381,10 +397,10 @@ TEST_F(CommsTest, TestSendTooMuchData) { ASSERT_THAT(comms->SendBytes(nullptr, comms->GetMaxMsgSize() + 1), IsFalse()); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } -TEST_F(CommsTest, TestSendRecvBytes) { +TEST_P(CommsTest, TestSendRecvBytes) { auto a = [](Comms* comms) { std::vector buffer; ASSERT_THAT(comms->RecvBytes(&buffer), IsTrue()); @@ -398,7 +414,7 @@ TEST_F(CommsTest, TestSendRecvBytes) { ASSERT_THAT(comms->RecvBytes(&response), IsTrue()); EXPECT_THAT(request, Eq(response)); }; - HandleCommunication(sockname_, a, b); + HandleCommunication(sockname_, use_abstract_socket_, a, b); } class SenderThread { @@ -434,7 +450,7 @@ class ReceiverThread { size_t rounds_; }; -TEST_F(CommsTest, TestMultipleThreads) { +TEST_P(CommsTest, TestMultipleThreads) { // The comms object should be thread safe, this testcase covers this. constexpr size_t kNumThreads = 20; constexpr size_t kNumRoundsPerThread = 50; @@ -495,7 +511,7 @@ TEST_F(CommsTest, TestMultipleThreads) { // We cannot test this in the Client or Server tests, as the endpoint needs to // be unconnected. -TEST_F(CommsTest, TestMsgSize) { +TEST_P(CommsTest, TestMsgSize) { // There will be no actual connection to this socket. const std::string socket_name = "sandbox2_comms_msg_size_test"; Comms c(socket_name);