Allow shutting down the global forkserver

PiperOrigin-RevId: 345198374
Change-Id: I3b5c49f6e5abb76d2b0a57078ffeb0609e0be008
This commit is contained in:
Wiktor Garbacz 2020-12-02 03:05:06 -08:00 committed by Copybara-Service
parent 6587e571f1
commit da64459e3f
4 changed files with 158 additions and 24 deletions

View File

@ -233,8 +233,10 @@ cc_library(
"//sandboxed_api/sandbox2/util:strerror", "//sandboxed_api/sandbox2/util:strerror",
"//sandboxed_api/util:flags", "//sandboxed_api/util:flags",
"//sandboxed_api/util:raw_logging", "//sandboxed_api/util:raw_logging",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_glog//:glog", "@com_google_glog//:glog",
], ],
) )

View File

@ -242,7 +242,9 @@ target_link_libraries(sandbox2_global_forkserver
sapi::embed_file sapi::embed_file
sapi::flags sapi::flags
sapi::raw_logging sapi::raw_logging
PUBLIC sandbox2::comms PUBLIC absl::core_headers
absl::synchronization
sandbox2::comms
sandbox2::fork_client sandbox2::fork_client
sandbox2::forkserver_proto sandbox2::forkserver_proto
) )

View File

@ -22,12 +22,20 @@
#include <syscall.h> #include <syscall.h>
#include <unistd.h> #include <unistd.h>
#include <bitset>
#include <csignal> #include <csignal>
#include <cstdlib> #include <cstdlib>
#include <string>
#include <vector>
#include <glog/logging.h> #include <glog/logging.h>
#include "sandboxed_api/util/flag.h" #include "sandboxed_api/util/flag.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "sandboxed_api/embed_file.h" #include "sandboxed_api/embed_file.h"
#include "sandboxed_api/sandbox2/comms.h" #include "sandboxed_api/sandbox2/comms.h"
#include "sandboxed_api/sandbox2/forkserver_bin_embed.h" #include "sandboxed_api/sandbox2/forkserver_bin_embed.h"
@ -36,13 +44,110 @@
#include "sandboxed_api/sandbox2/util/strerror.h" #include "sandboxed_api/sandbox2/util/strerror.h"
#include "sandboxed_api/util/raw_logging.h" #include "sandboxed_api/util/raw_logging.h"
ABSL_FLAG(bool, sandbox2_start_forkserver, true, namespace sandbox2 {
"Start Sandbox2 Forkserver process"); namespace {
enum class GlobalForkserverStartMode {
kOnDemand,
// MUST be the last element
kNumGlobalForkserverStartModes,
};
class GlobalForkserverStartModeSet {
public:
static constexpr size_t kSize = static_cast<size_t>(
GlobalForkserverStartMode::kNumGlobalForkserverStartModes);
GlobalForkserverStartModeSet() {}
explicit GlobalForkserverStartModeSet(GlobalForkserverStartMode value) {
value_[static_cast<size_t>(value)] = true;
}
GlobalForkserverStartModeSet& operator|=(GlobalForkserverStartMode value) {
value_[static_cast<size_t>(value)] = true;
return *this;
}
GlobalForkserverStartModeSet operator|(
GlobalForkserverStartMode value) const {
GlobalForkserverStartModeSet rv(*this);
rv |= value;
return rv;
}
bool contains(GlobalForkserverStartMode value) const {
return value_[static_cast<size_t>(value)];
}
bool empty() { return value_.none(); }
private:
std::bitset<kSize> value_;
};
bool AbslParseFlag(absl::string_view text, GlobalForkserverStartModeSet* out,
std::string* error) {
*out = {};
if (text == "never") {
return true;
}
for (absl::string_view mode : absl::StrSplit(text, ',')) {
mode = absl::StripAsciiWhitespace(mode);
if (mode == "ondemand") {
*out |= GlobalForkserverStartMode::kOnDemand;
} else {
*error = absl::StrCat("Invalid forkserver start mode: ", mode);
return false;
}
}
return true;
}
std::string ToString(GlobalForkserverStartMode mode) {
switch (mode) {
case GlobalForkserverStartMode::kOnDemand:
return "ondemand";
default:
return "unknown";
}
}
std::string AbslUnparseFlag(GlobalForkserverStartModeSet in) {
std::vector<std::string> str_modes;
for (size_t i = 0; i < GlobalForkserverStartModeSet::kSize; ++i) {
auto mode = static_cast<GlobalForkserverStartMode>(i);
if (in.contains(mode)) {
str_modes.push_back(ToString(mode));
}
}
if (str_modes.empty()) {
return "never";
}
return absl::StrJoin(str_modes, ",");
}
bool ValidateStartMode(const char*, const std::string& flag) {
GlobalForkserverStartModeSet unused;
std::string error;
if (!AbslParseFlag(flag, &unused, &error)) {
SAPI_RAW_LOG(ERROR, "%s", error);
return false;
}
return true;
}
} // namespace
} // namespace sandbox2
ABSL_FLAG(string, sandbox2_forkserver_start_mode, "ondemand",
"When Sandbox2 Forkserver process should be started");
DEFINE_validator(sandbox2_forkserver_start_mode, &sandbox2::ValidateStartMode);
namespace sandbox2 { namespace sandbox2 {
namespace { namespace {
GlobalForkserverStartModeSet GetForkserverStartMode() {
GlobalForkserverStartModeSet rv;
std::string error;
CHECK(AbslParseFlag(absl::GetFlag(FLAGS_sandbox2_forkserver_start_mode), &rv,
&error));
return rv;
}
std::unique_ptr<GlobalForkClient> StartGlobalForkServer() { std::unique_ptr<GlobalForkClient> StartGlobalForkServer() {
if (getenv(kForkServerDisableEnv)) { if (getenv(kForkServerDisableEnv)) {
SAPI_RAW_VLOG(1, SAPI_RAW_VLOG(1,
@ -52,7 +157,7 @@ std::unique_ptr<GlobalForkClient> StartGlobalForkServer() {
return {}; return {};
} }
if (!absl::GetFlag(FLAGS_sandbox2_start_forkserver)) { if (GetForkserverStartMode().empty()) {
SAPI_RAW_VLOG( SAPI_RAW_VLOG(
1, "Start of the Global Fork-Server prevented by commandline flag"); 1, "Start of the Global Fork-Server prevented by commandline flag");
return {}; return {};
@ -90,34 +195,49 @@ std::unique_ptr<GlobalForkClient> StartGlobalForkServer() {
return absl::make_unique<GlobalForkClient>(sv[1], pid); return absl::make_unique<GlobalForkClient>(sv[1], pid);
} }
GlobalForkClient* GetGlobalForkClient() {
static GlobalForkClient* global_fork_client =
StartGlobalForkServer().release();
return global_fork_client;
}
} // namespace } // namespace
void GlobalForkClient::EnsureStarted() { GetGlobalForkClient(); } absl::Mutex GlobalForkClient::instance_mutex_(absl::kConstInit);
GlobalForkClient* GlobalForkClient::instance_ = nullptr;
void GlobalForkClient::EnsureStarted() {
absl::MutexLock lock(&GlobalForkClient::instance_mutex_);
EnsureStartedLocked(
GetForkserverStartMode().contains(GlobalForkserverStartMode::kOnDemand));
}
void GlobalForkClient::Shutdown() {
absl::MutexLock lock(&GlobalForkClient::instance_mutex_);
delete instance_;
instance_ = nullptr;
}
void GlobalForkClient::EnsureStartedLocked(bool start_if_needed) {
if (!instance_ && start_if_needed) {
instance_ = StartGlobalForkServer().release();
}
SAPI_RAW_CHECK(instance_ != nullptr, "global fork client not initialized");
}
pid_t GlobalForkClient::SendRequest(const ForkRequest& request, int exec_fd, pid_t GlobalForkClient::SendRequest(const ForkRequest& request, int exec_fd,
int comms_fd, int user_ns_fd, int comms_fd, int user_ns_fd,
pid_t* init_pid) { pid_t* init_pid) {
GlobalForkClient* global_fork_client = GetGlobalForkClient(); absl::MutexLock lock(&GlobalForkClient::instance_mutex_);
SAPI_RAW_CHECK(global_fork_client != nullptr, EnsureStartedLocked(
"global fork client not initialized"); GetForkserverStartMode().contains(GlobalForkserverStartMode::kOnDemand));
pid_t pid = global_fork_client->fork_client_.SendRequest( pid_t pid = instance_->fork_client_.SendRequest(request, exec_fd, comms_fd,
request, exec_fd, comms_fd, user_ns_fd, init_pid); user_ns_fd, init_pid);
if (global_fork_client->comms_.IsTerminated()) { if (instance_->comms_.IsTerminated()) {
LOG(ERROR) << "Global forkserver connection terminated"; LOG(ERROR) << "Global forkserver connection terminated";
} }
return pid; return pid;
} }
pid_t GlobalForkClient::GetPid() { pid_t GlobalForkClient::GetPid() {
GlobalForkClient* global_fork_client = GetGlobalForkClient(); absl::MutexLock lock(&instance_mutex_);
SAPI_RAW_CHECK(global_fork_client != nullptr, EnsureStartedLocked(
"global fork client not initialized"); GetForkserverStartMode().contains(GlobalForkserverStartMode::kOnDemand));
return global_fork_client->fork_client_.pid(); SAPI_RAW_CHECK(instance_ != nullptr, "global fork client not initialized");
return instance_->fork_client_.pid();
} }
} // namespace sandbox2 } // namespace sandbox2

View File

@ -20,6 +20,8 @@
#include <sys/types.h> #include <sys/types.h>
#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
#include "sandboxed_api/sandbox2/comms.h" #include "sandboxed_api/sandbox2/comms.h"
#include "sandboxed_api/sandbox2/fork_client.h" #include "sandboxed_api/sandbox2/fork_client.h"
#include "sandboxed_api/sandbox2/forkserver.pb.h" #include "sandboxed_api/sandbox2/forkserver.pb.h"
@ -33,12 +35,20 @@ class GlobalForkClient {
static pid_t SendRequest(const ForkRequest& request, int exec_fd, static pid_t SendRequest(const ForkRequest& request, int exec_fd,
int comms_fd, int user_ns_fd = -1, int comms_fd, int user_ns_fd = -1,
pid_t* init_pid = nullptr); pid_t* init_pid = nullptr)
static pid_t GetPid(); ABSL_LOCKS_EXCLUDED(instance_mutex_);
static pid_t GetPid() ABSL_LOCKS_EXCLUDED(instance_mutex_);
static void EnsureStarted(); static void EnsureStarted() ABSL_LOCKS_EXCLUDED(instance_mutex_);
static void Shutdown() ABSL_LOCKS_EXCLUDED(instance_mutex_);
private: private:
static absl::Mutex instance_mutex_;
static GlobalForkClient* instance_ ABSL_GUARDED_BY(instance_mutex_);
static void EnsureStartedLocked(bool start_if_needed = true)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(instance_mutex_);
Comms comms_; Comms comms_;
ForkClient fork_client_; ForkClient fork_client_;
}; };