diff --git a/sandboxed_api/sandbox2/client.cc b/sandboxed_api/sandbox2/client.cc index bb5cb83..9347792 100644 --- a/sandboxed_api/sandbox2/client.cc +++ b/sandboxed_api/sandbox2/client.cc @@ -34,6 +34,7 @@ #include "absl/base/attributes.h" #include "absl/base/macros.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -69,8 +70,8 @@ std::string Client::GetFdMapEnvVar() const { absl::StrJoin(fd_map_, ",", absl::PairFormatter(","))); } -void Client::PrepareEnvironment() { - SetUpIPC(); +void Client::PrepareEnvironment(std::vector* preserve_fds) { + SetUpIPC(preserve_fds); SetUpCwd(); } @@ -123,7 +124,7 @@ void Client::SetUpCwd() { } } -void Client::SetUpIPC() { +void Client::SetUpIPC(std::vector* preserve_fds) { uint32_t num_of_fd_pairs; SAPI_RAW_CHECK(comms_->RecvUint32(&num_of_fd_pairs), "receiving number of fd pairs"); @@ -131,6 +132,13 @@ void Client::SetUpIPC() { SAPI_RAW_VLOG(1, "Will receive %d file descriptor pairs", num_of_fd_pairs); + absl::flat_hash_map preserve_fds_map; + if (preserve_fds) { + for (int& fd : *preserve_fds) { + preserve_fds_map.emplace(fd, &fd); + } + } + for (uint32_t i = 0; i < num_of_fd_pairs; ++i) { int32_t requested_fd; int32_t fd; @@ -140,6 +148,27 @@ void Client::SetUpIPC() { SAPI_RAW_CHECK(comms_->RecvFD(&fd), "receiving current fd"); SAPI_RAW_CHECK(comms_->RecvString(&name), "receiving name string"); + if (auto it = preserve_fds_map.find(requested_fd); + it != preserve_fds_map.end()) { + int old_fd = it->first; + int new_fd = dup(old_fd); + SAPI_RAW_PCHECK(new_fd != -1, "Failed to duplicate preserved fd=%d", + old_fd); + SAPI_RAW_LOG(INFO, "Moved preserved fd=%d to %d", old_fd, new_fd); + close(old_fd); + int* pfd = it->second; + *pfd = new_fd; + preserve_fds_map.erase(it); + preserve_fds_map.emplace(new_fd, pfd); + } + + if (requested_fd == comms_->GetConnectionFD()) { + comms_->MoveToAnotherFd(); + SAPI_RAW_LOG(INFO, + "Trying to map over comms fd (%d). Remapped comms to %d", + requested_fd, comms_->GetConnectionFD()); + } + if (requested_fd != -1 && fd != requested_fd) { if (requested_fd > STDERR_FILENO && fcntl(requested_fd, F_GETFD) != -1) { // Dup2 will silently close the FD if one is already at requested_fd. diff --git a/sandboxed_api/sandbox2/client.h b/sandboxed_api/sandbox2/client.h index 31e9b96..bee2685 100644 --- a/sandboxed_api/sandbox2/client.h +++ b/sandboxed_api/sandbox2/client.h @@ -89,7 +89,10 @@ class Client { std::string GetFdMapEnvVar() const; // Sets up communication channels with the sandbox. - void SetUpIPC(); + // preserve_fds contains file descriptors that should be kept open and alive. + // The FD numbers might be changed if needed and are updated in the vector. + // preserve_fds can be a nullptr, equivallent to an empty vector. + void SetUpIPC(std::vector* preserve_fds); // Sets up the current working directory. void SetUpCwd(); @@ -100,7 +103,7 @@ class Client { // Applies sandbox-bpf policy, have limits applied on us, and become ptrace'd. void ApplyPolicyAndBecomeTracee(); - void PrepareEnvironment(); + void PrepareEnvironment(std::vector* preserve_fds = nullptr); void EnableSandbox(); }; diff --git a/sandboxed_api/sandbox2/comms.cc b/sandboxed_api/sandbox2/comms.cc index 5047500..08132ed 100644 --- a/sandboxed_api/sandbox2/comms.cc +++ b/sandboxed_api/sandbox2/comms.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -63,6 +64,16 @@ bool IsFatalError(int saved_errno) { saved_errno != EFAULT && saved_errno != EINTR && saved_errno != EINVAL && saved_errno != ENOMEM; } + +int GetDefaultCommsFd() { + if (const char* var = getenv(Comms::kSandbox2CommsFDEnvVar); var) { + int fd; + SAPI_RAW_CHECK(absl::SimpleAtoi(var, &fd), "cannot parse comms fd var"); + unsetenv(Comms::kSandbox2CommsFDEnvVar); + return fd; + } + return Comms::kSandbox2ClientCommsFD; +} } // namespace Comms::Comms(const std::string& socket_name) : socket_name_(socket_name) {} @@ -77,6 +88,8 @@ Comms::Comms(int fd) : connection_fd_(fd) { state_ = State::kConnected; } +Comms::Comms(Comms::DefaultConnectionTag) : Comms(GetDefaultCommsFd()) {} + Comms::~Comms() { Terminate(); } int Comms::GetConnectionFD() const { @@ -649,4 +662,13 @@ bool Comms::SendStatus(const absl::Status& status) { return SendProtoBuf(proto); } +void Comms::MoveToAnotherFd() { + SAPI_RAW_CHECK(connection_fd_ != -1, + "Cannot move comms fd as it's not connected"); + int new_fd = dup(connection_fd_); + SAPI_RAW_CHECK(new_fd != -1, "Failed to move comms to another fd"); + close(connection_fd_); + connection_fd_ = new_fd; +} + } // namespace sandbox2 diff --git a/sandboxed_api/sandbox2/comms.h b/sandboxed_api/sandbox2/comms.h index 0ecd36a..cb2f8ea 100644 --- a/sandboxed_api/sandbox2/comms.h +++ b/sandboxed_api/sandbox2/comms.h @@ -33,6 +33,7 @@ #include "absl/base/attributes.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "sandboxed_api/util/status.pb.h" @@ -42,6 +43,8 @@ class Message; namespace sandbox2 { +class Client; + class Comms { public: struct DefaultConnectionTag {}; @@ -77,6 +80,8 @@ class Comms { static constexpr DefaultConnectionTag kDefaultConnection = {}; + static constexpr const char* kSandbox2CommsFDEnvVar = "SANDBOX2_COMMS_FD"; + // This object will have to be connected later on. explicit Comms(const std::string& socket_name); @@ -88,7 +93,7 @@ class Comms { explicit Comms(int fd); // Instantiates a pre-connected object using the default connection params. - explicit Comms(DefaultConnectionTag) : Comms(kSandbox2ClientCommsFD) {} + explicit Comms(DefaultConnectionTag); ~Comms(); @@ -170,6 +175,8 @@ class Comms { bool SendStatus(const absl::Status& status); private: + friend class Client; + // State of the channel enum class State { kUnconnected = 0, @@ -201,6 +208,9 @@ class Comms { // Fills sockaddr_un struct with proper values. socklen_t CreateSockaddrUn(sockaddr_un* sun); + // Moves the comms fd to an other free file descriptor. + void MoveToAnotherFd(); + // Support for EINTR and size completion. bool Send(const void* data, size_t len); bool Recv(void* data, size_t len); diff --git a/sandboxed_api/sandbox2/forkserver.cc b/sandboxed_api/sandbox2/forkserver.cc index 0976b34..d37f139 100644 --- a/sandboxed_api/sandbox2/forkserver.cc +++ b/sandboxed_api/sandbox2/forkserver.cc @@ -352,8 +352,19 @@ void ForkServer::LaunchChild(const ForkRequest& request, int execve_fd, // The following client calls are basically SandboxMeHere. We split it so // that we can set up the envp after we received the file descriptors but // before we enable the syscall filter. - c.PrepareEnvironment(); + std::vector preserved_fds; + if (request.mode() == FORKSERVER_FORK_EXECVE_SANDBOX) { + preserved_fds.push_back(execve_fd); + } + c.PrepareEnvironment(&preserved_fds); + if (request.mode() == FORKSERVER_FORK_EXECVE_SANDBOX) { + execve_fd = preserved_fds[0]; + } + if (client_comms.GetConnectionFD() != Comms::kSandbox2ClientCommsFD) { + envs.push_back(absl::StrCat(Comms::kSandbox2CommsFDEnvVar, "=", + client_comms.GetConnectionFD())); + } envs.push_back(c.GetFdMapEnvVar()); // Convert args and envs before enabling sandbox (it'll allocate which might // be blocked). diff --git a/sandboxed_api/sandbox2/ipc_test.cc b/sandboxed_api/sandbox2/ipc_test.cc index ca578a4..a2684d2 100644 --- a/sandboxed_api/sandbox2/ipc_test.cc +++ b/sandboxed_api/sandbox2/ipc_test.cc @@ -36,14 +36,18 @@ using ::sapi::GetTestSourcePath; constexpr int kPreferredIpcFd = 812; +class IPCTest : public testing::Test, + public testing::WithParamInterface {}; + // This test verifies that mapping fds by name works if the sandbox is enabled // before execve. -TEST(IPCTest, MapFDByNamePreExecve) { +TEST_P(IPCTest, MapFDByNamePreExecve) { SKIP_SANITIZERS_AND_COVERAGE; + const int fd = GetParam(); const std::string path = GetTestSourcePath("sandbox2/testcases/ipc"); - std::vector args = {path, "1", std::to_string(kPreferredIpcFd)}; + std::vector args = {path, "1", std::to_string(fd)}; auto executor = absl::make_unique(path, args); - Comms comms(executor->ipc()->ReceiveFd(kPreferredIpcFd, "ipc_test")); + Comms comms(executor->ipc()->ReceiveFd(fd, "ipc_test")); SAPI_ASSERT_OK_AND_ASSIGN(auto policy, PolicyBuilder() @@ -57,9 +61,14 @@ TEST(IPCTest, MapFDByNamePreExecve) { ASSERT_TRUE(comms.SendString("hello")); std::string resp; + ASSERT_TRUE(s2.comms()->RecvString(&resp)); + ASSERT_EQ(resp, "start"); + ASSERT_TRUE(s2.comms()->SendString("started")); ASSERT_TRUE(comms.RecvString(&resp)); - ASSERT_EQ(resp, "world"); + ASSERT_TRUE(s2.comms()->RecvString(&resp)); + ASSERT_EQ(resp, "finish"); + ASSERT_TRUE(s2.comms()->SendString("finished")); auto result = s2.AwaitResult(); @@ -69,13 +78,14 @@ TEST(IPCTest, MapFDByNamePreExecve) { // This test verifies that mapping fds by name works if SandboxMeHere() is // called by the sandboxee. -TEST(IPCTest, MapFDByNamePostExecve) { +TEST_P(IPCTest, MapFDByNamePostExecve) { SKIP_SANITIZERS_AND_COVERAGE; + const int fd = GetParam(); const std::string path = GetTestSourcePath("sandbox2/testcases/ipc"); - std::vector args = {path, "2", std::to_string(kPreferredIpcFd)}; + std::vector args = {path, "2", std::to_string(fd)}; auto executor = absl::make_unique(path, args); executor->set_enable_sandbox_before_exec(false); - Comms comms(executor->ipc()->ReceiveFd(kPreferredIpcFd, "ipc_test")); + Comms comms(executor->ipc()->ReceiveFd(fd, "ipc_test")); SAPI_ASSERT_OK_AND_ASSIGN(auto policy, PolicyBuilder() @@ -89,9 +99,14 @@ TEST(IPCTest, MapFDByNamePostExecve) { ASSERT_TRUE(comms.SendString("hello")); std::string resp; + ASSERT_TRUE(s2.comms()->RecvString(&resp)); + ASSERT_EQ(resp, "start"); + ASSERT_TRUE(s2.comms()->SendString("started")); ASSERT_TRUE(comms.RecvString(&resp)); - ASSERT_EQ(resp, "world"); + ASSERT_TRUE(s2.comms()->RecvString(&resp)); + ASSERT_EQ(resp, "finish"); + ASSERT_TRUE(s2.comms()->SendString("finished")); auto result = s2.AwaitResult(); @@ -119,5 +134,11 @@ TEST(IPCTest, NoMappedFDsPreExecve) { ASSERT_EQ(result.reason_code(), 0); } +INSTANTIATE_TEST_SUITE_P(NormalFds, IPCTest, testing::Values(kPreferredIpcFd)); + +INSTANTIATE_TEST_SUITE_P(RestrictedFds, IPCTest, + testing::Values(Comms::kSandbox2ClientCommsFD, + Comms::kSandbox2TargetExecFD)); + } // namespace } // namespace sandbox2 diff --git a/sandboxed_api/sandbox2/testcases/ipc.cc b/sandboxed_api/sandbox2/testcases/ipc.cc index d54c26d..fa91f03 100644 --- a/sandboxed_api/sandbox2/testcases/ipc.cc +++ b/sandboxed_api/sandbox2/testcases/ipc.cc @@ -62,9 +62,28 @@ int main(int argc, char* argv[]) { return EXIT_FAILURE; } sandbox2::Comms comms(fd); - std::string hello; - if (!comms.RecvString(&hello)) { - fputs("error on comms.RecvString(&hello)", stderr); + std::string resp; + if (!default_comms.SendString("start")) { + fputs("error on default_comms.RecvString(\"start\")", stderr); + return EXIT_FAILURE; + } + if (!default_comms.RecvString(&resp)) { + fputs("error on default_comms.RecvString(&resp)", stderr); + return EXIT_FAILURE; + } + if (resp != "started") { + fprintf(stderr, "unexpected response \"%s\" (expected \"started\")\n", + resp.c_str()); + return EXIT_FAILURE; + } + + if (!comms.RecvString(&resp)) { + fputs("error on comms.RecvString(&resp)", stderr); + return EXIT_FAILURE; + } + if (resp != "hello") { + fprintf(stderr, "unexpected response \"%s\" (expected \"hello\")\n", + resp.c_str()); return EXIT_FAILURE; } @@ -72,6 +91,19 @@ int main(int argc, char* argv[]) { fputs("error on comms.SendString(\"world\")", stderr); return EXIT_FAILURE; } + if (!default_comms.SendString("finish")) { + fputs("error on default_comms.RecvString(\"finish\")", stderr); + return EXIT_FAILURE; + } + if (!default_comms.RecvString(&resp)) { + fputs("error on default_comms.RecvString(&resp)", stderr); + return EXIT_FAILURE; + } + if (resp != "finished") { + fprintf(stderr, "unexpected response \"%s\" (expected \"finished\")\n", + resp.c_str()); + return EXIT_FAILURE; + } printf("OK: All tests went OK\n"); return EXIT_SUCCESS;