Sandbox2: Graciously handle mapping over Comms/Exec fds

Try to move the affected FDs transparently to avoid conflict.

PiperOrigin-RevId: 480105375
Change-Id: I0cd093fce120505d1cd4a1d081b3c0e63bf0210a
This commit is contained in:
Wiktor Garbacz 2022-10-10 09:38:14 -07:00 committed by Copybara-Service
parent b9c2830ebc
commit cb8efdc270
7 changed files with 146 additions and 18 deletions

View File

@ -34,6 +34,7 @@
#include "absl/base/attributes.h" #include "absl/base/attributes.h"
#include "absl/base/macros.h" #include "absl/base/macros.h"
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
@ -69,8 +70,8 @@ std::string Client::GetFdMapEnvVar() const {
absl::StrJoin(fd_map_, ",", absl::PairFormatter(","))); absl::StrJoin(fd_map_, ",", absl::PairFormatter(",")));
} }
void Client::PrepareEnvironment() { void Client::PrepareEnvironment(std::vector<int>* preserve_fds) {
SetUpIPC(); SetUpIPC(preserve_fds);
SetUpCwd(); SetUpCwd();
} }
@ -123,7 +124,7 @@ void Client::SetUpCwd() {
} }
} }
void Client::SetUpIPC() { void Client::SetUpIPC(std::vector<int>* preserve_fds) {
uint32_t num_of_fd_pairs; uint32_t num_of_fd_pairs;
SAPI_RAW_CHECK(comms_->RecvUint32(&num_of_fd_pairs), SAPI_RAW_CHECK(comms_->RecvUint32(&num_of_fd_pairs),
"receiving number of fd pairs"); "receiving number of fd pairs");
@ -131,6 +132,13 @@ void Client::SetUpIPC() {
SAPI_RAW_VLOG(1, "Will receive %d file descriptor pairs", num_of_fd_pairs); SAPI_RAW_VLOG(1, "Will receive %d file descriptor pairs", num_of_fd_pairs);
absl::flat_hash_map<int, int*> preserve_fds_map;
if (preserve_fds) {
for (int& fd : *preserve_fds) {
preserve_fds_map.emplace(fd, &fd);
}
}
for (uint32_t i = 0; i < num_of_fd_pairs; ++i) { for (uint32_t i = 0; i < num_of_fd_pairs; ++i) {
int32_t requested_fd; int32_t requested_fd;
int32_t fd; int32_t fd;
@ -140,6 +148,27 @@ void Client::SetUpIPC() {
SAPI_RAW_CHECK(comms_->RecvFD(&fd), "receiving current fd"); SAPI_RAW_CHECK(comms_->RecvFD(&fd), "receiving current fd");
SAPI_RAW_CHECK(comms_->RecvString(&name), "receiving name string"); SAPI_RAW_CHECK(comms_->RecvString(&name), "receiving name string");
if (auto it = preserve_fds_map.find(requested_fd);
it != preserve_fds_map.end()) {
int old_fd = it->first;
int new_fd = dup(old_fd);
SAPI_RAW_PCHECK(new_fd != -1, "Failed to duplicate preserved fd=%d",
old_fd);
SAPI_RAW_LOG(INFO, "Moved preserved fd=%d to %d", old_fd, new_fd);
close(old_fd);
int* pfd = it->second;
*pfd = new_fd;
preserve_fds_map.erase(it);
preserve_fds_map.emplace(new_fd, pfd);
}
if (requested_fd == comms_->GetConnectionFD()) {
comms_->MoveToAnotherFd();
SAPI_RAW_LOG(INFO,
"Trying to map over comms fd (%d). Remapped comms to %d",
requested_fd, comms_->GetConnectionFD());
}
if (requested_fd != -1 && fd != requested_fd) { if (requested_fd != -1 && fd != requested_fd) {
if (requested_fd > STDERR_FILENO && fcntl(requested_fd, F_GETFD) != -1) { if (requested_fd > STDERR_FILENO && fcntl(requested_fd, F_GETFD) != -1) {
// Dup2 will silently close the FD if one is already at requested_fd. // Dup2 will silently close the FD if one is already at requested_fd.

View File

@ -89,7 +89,10 @@ class Client {
std::string GetFdMapEnvVar() const; std::string GetFdMapEnvVar() const;
// Sets up communication channels with the sandbox. // Sets up communication channels with the sandbox.
void SetUpIPC(); // preserve_fds contains file descriptors that should be kept open and alive.
// The FD numbers might be changed if needed and are updated in the vector.
// preserve_fds can be a nullptr, equivallent to an empty vector.
void SetUpIPC(std::vector<int>* preserve_fds);
// Sets up the current working directory. // Sets up the current working directory.
void SetUpCwd(); void SetUpCwd();
@ -100,7 +103,7 @@ class Client {
// Applies sandbox-bpf policy, have limits applied on us, and become ptrace'd. // Applies sandbox-bpf policy, have limits applied on us, and become ptrace'd.
void ApplyPolicyAndBecomeTracee(); void ApplyPolicyAndBecomeTracee();
void PrepareEnvironment(); void PrepareEnvironment(std::vector<int>* preserve_fds = nullptr);
void EnableSandbox(); void EnableSandbox();
}; };

View File

@ -31,6 +31,7 @@
#include <cerrno> #include <cerrno>
#include <cinttypes> #include <cinttypes>
#include <cstddef> #include <cstddef>
#include <cstdlib>
#include <cstring> #include <cstring>
#include <functional> #include <functional>
@ -63,6 +64,16 @@ bool IsFatalError(int saved_errno) {
saved_errno != EFAULT && saved_errno != EINTR && saved_errno != EFAULT && saved_errno != EINTR &&
saved_errno != EINVAL && saved_errno != ENOMEM; saved_errno != EINVAL && saved_errno != ENOMEM;
} }
int GetDefaultCommsFd() {
if (const char* var = getenv(Comms::kSandbox2CommsFDEnvVar); var) {
int fd;
SAPI_RAW_CHECK(absl::SimpleAtoi(var, &fd), "cannot parse comms fd var");
unsetenv(Comms::kSandbox2CommsFDEnvVar);
return fd;
}
return Comms::kSandbox2ClientCommsFD;
}
} // namespace } // namespace
Comms::Comms(const std::string& socket_name) : socket_name_(socket_name) {} Comms::Comms(const std::string& socket_name) : socket_name_(socket_name) {}
@ -77,6 +88,8 @@ Comms::Comms(int fd) : connection_fd_(fd) {
state_ = State::kConnected; state_ = State::kConnected;
} }
Comms::Comms(Comms::DefaultConnectionTag) : Comms(GetDefaultCommsFd()) {}
Comms::~Comms() { Terminate(); } Comms::~Comms() { Terminate(); }
int Comms::GetConnectionFD() const { int Comms::GetConnectionFD() const {
@ -649,4 +662,13 @@ bool Comms::SendStatus(const absl::Status& status) {
return SendProtoBuf(proto); return SendProtoBuf(proto);
} }
void Comms::MoveToAnotherFd() {
SAPI_RAW_CHECK(connection_fd_ != -1,
"Cannot move comms fd as it's not connected");
int new_fd = dup(connection_fd_);
SAPI_RAW_CHECK(new_fd != -1, "Failed to move comms to another fd");
close(connection_fd_);
connection_fd_ = new_fd;
}
} // namespace sandbox2 } // namespace sandbox2

View File

@ -33,6 +33,7 @@
#include "absl/base/attributes.h" #include "absl/base/attributes.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "sandboxed_api/util/status.pb.h" #include "sandboxed_api/util/status.pb.h"
@ -42,6 +43,8 @@ class Message;
namespace sandbox2 { namespace sandbox2 {
class Client;
class Comms { class Comms {
public: public:
struct DefaultConnectionTag {}; struct DefaultConnectionTag {};
@ -77,6 +80,8 @@ class Comms {
static constexpr DefaultConnectionTag kDefaultConnection = {}; static constexpr DefaultConnectionTag kDefaultConnection = {};
static constexpr const char* kSandbox2CommsFDEnvVar = "SANDBOX2_COMMS_FD";
// This object will have to be connected later on. // This object will have to be connected later on.
explicit Comms(const std::string& socket_name); explicit Comms(const std::string& socket_name);
@ -88,7 +93,7 @@ class Comms {
explicit Comms(int fd); explicit Comms(int fd);
// Instantiates a pre-connected object using the default connection params. // Instantiates a pre-connected object using the default connection params.
explicit Comms(DefaultConnectionTag) : Comms(kSandbox2ClientCommsFD) {} explicit Comms(DefaultConnectionTag);
~Comms(); ~Comms();
@ -170,6 +175,8 @@ class Comms {
bool SendStatus(const absl::Status& status); bool SendStatus(const absl::Status& status);
private: private:
friend class Client;
// State of the channel // State of the channel
enum class State { enum class State {
kUnconnected = 0, kUnconnected = 0,
@ -201,6 +208,9 @@ class Comms {
// Fills sockaddr_un struct with proper values. // Fills sockaddr_un struct with proper values.
socklen_t CreateSockaddrUn(sockaddr_un* sun); socklen_t CreateSockaddrUn(sockaddr_un* sun);
// Moves the comms fd to an other free file descriptor.
void MoveToAnotherFd();
// Support for EINTR and size completion. // Support for EINTR and size completion.
bool Send(const void* data, size_t len); bool Send(const void* data, size_t len);
bool Recv(void* data, size_t len); bool Recv(void* data, size_t len);

View File

@ -352,8 +352,19 @@ void ForkServer::LaunchChild(const ForkRequest& request, int execve_fd,
// The following client calls are basically SandboxMeHere. We split it so // The following client calls are basically SandboxMeHere. We split it so
// that we can set up the envp after we received the file descriptors but // that we can set up the envp after we received the file descriptors but
// before we enable the syscall filter. // before we enable the syscall filter.
c.PrepareEnvironment(); std::vector<int> preserved_fds;
if (request.mode() == FORKSERVER_FORK_EXECVE_SANDBOX) {
preserved_fds.push_back(execve_fd);
}
c.PrepareEnvironment(&preserved_fds);
if (request.mode() == FORKSERVER_FORK_EXECVE_SANDBOX) {
execve_fd = preserved_fds[0];
}
if (client_comms.GetConnectionFD() != Comms::kSandbox2ClientCommsFD) {
envs.push_back(absl::StrCat(Comms::kSandbox2CommsFDEnvVar, "=",
client_comms.GetConnectionFD()));
}
envs.push_back(c.GetFdMapEnvVar()); envs.push_back(c.GetFdMapEnvVar());
// Convert args and envs before enabling sandbox (it'll allocate which might // Convert args and envs before enabling sandbox (it'll allocate which might
// be blocked). // be blocked).

View File

@ -36,14 +36,18 @@ using ::sapi::GetTestSourcePath;
constexpr int kPreferredIpcFd = 812; constexpr int kPreferredIpcFd = 812;
class IPCTest : public testing::Test,
public testing::WithParamInterface<int> {};
// This test verifies that mapping fds by name works if the sandbox is enabled // This test verifies that mapping fds by name works if the sandbox is enabled
// before execve. // before execve.
TEST(IPCTest, MapFDByNamePreExecve) { TEST_P(IPCTest, MapFDByNamePreExecve) {
SKIP_SANITIZERS_AND_COVERAGE; SKIP_SANITIZERS_AND_COVERAGE;
const int fd = GetParam();
const std::string path = GetTestSourcePath("sandbox2/testcases/ipc"); const std::string path = GetTestSourcePath("sandbox2/testcases/ipc");
std::vector<std::string> args = {path, "1", std::to_string(kPreferredIpcFd)}; std::vector<std::string> args = {path, "1", std::to_string(fd)};
auto executor = absl::make_unique<Executor>(path, args); auto executor = absl::make_unique<Executor>(path, args);
Comms comms(executor->ipc()->ReceiveFd(kPreferredIpcFd, "ipc_test")); Comms comms(executor->ipc()->ReceiveFd(fd, "ipc_test"));
SAPI_ASSERT_OK_AND_ASSIGN(auto policy, SAPI_ASSERT_OK_AND_ASSIGN(auto policy,
PolicyBuilder() PolicyBuilder()
@ -57,9 +61,14 @@ TEST(IPCTest, MapFDByNamePreExecve) {
ASSERT_TRUE(comms.SendString("hello")); ASSERT_TRUE(comms.SendString("hello"));
std::string resp; std::string resp;
ASSERT_TRUE(s2.comms()->RecvString(&resp));
ASSERT_EQ(resp, "start");
ASSERT_TRUE(s2.comms()->SendString("started"));
ASSERT_TRUE(comms.RecvString(&resp)); ASSERT_TRUE(comms.RecvString(&resp));
ASSERT_EQ(resp, "world"); ASSERT_EQ(resp, "world");
ASSERT_TRUE(s2.comms()->RecvString(&resp));
ASSERT_EQ(resp, "finish");
ASSERT_TRUE(s2.comms()->SendString("finished"));
auto result = s2.AwaitResult(); auto result = s2.AwaitResult();
@ -69,13 +78,14 @@ TEST(IPCTest, MapFDByNamePreExecve) {
// This test verifies that mapping fds by name works if SandboxMeHere() is // This test verifies that mapping fds by name works if SandboxMeHere() is
// called by the sandboxee. // called by the sandboxee.
TEST(IPCTest, MapFDByNamePostExecve) { TEST_P(IPCTest, MapFDByNamePostExecve) {
SKIP_SANITIZERS_AND_COVERAGE; SKIP_SANITIZERS_AND_COVERAGE;
const int fd = GetParam();
const std::string path = GetTestSourcePath("sandbox2/testcases/ipc"); const std::string path = GetTestSourcePath("sandbox2/testcases/ipc");
std::vector<std::string> args = {path, "2", std::to_string(kPreferredIpcFd)}; std::vector<std::string> args = {path, "2", std::to_string(fd)};
auto executor = absl::make_unique<Executor>(path, args); auto executor = absl::make_unique<Executor>(path, args);
executor->set_enable_sandbox_before_exec(false); executor->set_enable_sandbox_before_exec(false);
Comms comms(executor->ipc()->ReceiveFd(kPreferredIpcFd, "ipc_test")); Comms comms(executor->ipc()->ReceiveFd(fd, "ipc_test"));
SAPI_ASSERT_OK_AND_ASSIGN(auto policy, SAPI_ASSERT_OK_AND_ASSIGN(auto policy,
PolicyBuilder() PolicyBuilder()
@ -89,9 +99,14 @@ TEST(IPCTest, MapFDByNamePostExecve) {
ASSERT_TRUE(comms.SendString("hello")); ASSERT_TRUE(comms.SendString("hello"));
std::string resp; std::string resp;
ASSERT_TRUE(s2.comms()->RecvString(&resp));
ASSERT_EQ(resp, "start");
ASSERT_TRUE(s2.comms()->SendString("started"));
ASSERT_TRUE(comms.RecvString(&resp)); ASSERT_TRUE(comms.RecvString(&resp));
ASSERT_EQ(resp, "world"); ASSERT_EQ(resp, "world");
ASSERT_TRUE(s2.comms()->RecvString(&resp));
ASSERT_EQ(resp, "finish");
ASSERT_TRUE(s2.comms()->SendString("finished"));
auto result = s2.AwaitResult(); auto result = s2.AwaitResult();
@ -119,5 +134,11 @@ TEST(IPCTest, NoMappedFDsPreExecve) {
ASSERT_EQ(result.reason_code(), 0); ASSERT_EQ(result.reason_code(), 0);
} }
INSTANTIATE_TEST_SUITE_P(NormalFds, IPCTest, testing::Values(kPreferredIpcFd));
INSTANTIATE_TEST_SUITE_P(RestrictedFds, IPCTest,
testing::Values(Comms::kSandbox2ClientCommsFD,
Comms::kSandbox2TargetExecFD));
} // namespace } // namespace
} // namespace sandbox2 } // namespace sandbox2

View File

@ -62,9 +62,28 @@ int main(int argc, char* argv[]) {
return EXIT_FAILURE; return EXIT_FAILURE;
} }
sandbox2::Comms comms(fd); sandbox2::Comms comms(fd);
std::string hello; std::string resp;
if (!comms.RecvString(&hello)) { if (!default_comms.SendString("start")) {
fputs("error on comms.RecvString(&hello)", stderr); fputs("error on default_comms.RecvString(\"start\")", stderr);
return EXIT_FAILURE;
}
if (!default_comms.RecvString(&resp)) {
fputs("error on default_comms.RecvString(&resp)", stderr);
return EXIT_FAILURE;
}
if (resp != "started") {
fprintf(stderr, "unexpected response \"%s\" (expected \"started\")\n",
resp.c_str());
return EXIT_FAILURE;
}
if (!comms.RecvString(&resp)) {
fputs("error on comms.RecvString(&resp)", stderr);
return EXIT_FAILURE;
}
if (resp != "hello") {
fprintf(stderr, "unexpected response \"%s\" (expected \"hello\")\n",
resp.c_str());
return EXIT_FAILURE; return EXIT_FAILURE;
} }
@ -72,6 +91,19 @@ int main(int argc, char* argv[]) {
fputs("error on comms.SendString(\"world\")", stderr); fputs("error on comms.SendString(\"world\")", stderr);
return EXIT_FAILURE; return EXIT_FAILURE;
} }
if (!default_comms.SendString("finish")) {
fputs("error on default_comms.RecvString(\"finish\")", stderr);
return EXIT_FAILURE;
}
if (!default_comms.RecvString(&resp)) {
fputs("error on default_comms.RecvString(&resp)", stderr);
return EXIT_FAILURE;
}
if (resp != "finished") {
fprintf(stderr, "unexpected response \"%s\" (expected \"finished\")\n",
resp.c_str());
return EXIT_FAILURE;
}
printf("OK: All tests went OK\n"); printf("OK: All tests went OK\n");
return EXIT_SUCCESS; return EXIT_SUCCESS;