diff --git a/sandboxed_api/sandbox2/BUILD.bazel b/sandboxed_api/sandbox2/BUILD.bazel index 3141262..3e3671a 100644 --- a/sandboxed_api/sandbox2/BUILD.bazel +++ b/sandboxed_api/sandbox2/BUILD.bazel @@ -233,8 +233,10 @@ cc_library( "//sandboxed_api/sandbox2/util:strerror", "//sandboxed_api/util:flags", "//sandboxed_api/util:raw_logging", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_glog//:glog", ], ) diff --git a/sandboxed_api/sandbox2/CMakeLists.txt b/sandboxed_api/sandbox2/CMakeLists.txt index 7bfc02d..1bdf207 100644 --- a/sandboxed_api/sandbox2/CMakeLists.txt +++ b/sandboxed_api/sandbox2/CMakeLists.txt @@ -242,7 +242,9 @@ target_link_libraries(sandbox2_global_forkserver sapi::embed_file sapi::flags sapi::raw_logging - PUBLIC sandbox2::comms + PUBLIC absl::core_headers + absl::synchronization + sandbox2::comms sandbox2::fork_client sandbox2::forkserver_proto ) diff --git a/sandboxed_api/sandbox2/global_forkclient.cc b/sandboxed_api/sandbox2/global_forkclient.cc index 9348a50..0373a6b 100644 --- a/sandboxed_api/sandbox2/global_forkclient.cc +++ b/sandboxed_api/sandbox2/global_forkclient.cc @@ -22,12 +22,20 @@ #include #include +#include #include #include +#include +#include #include #include "sandboxed_api/util/flag.h" #include "absl/memory/memory.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "sandboxed_api/embed_file.h" #include "sandboxed_api/sandbox2/comms.h" #include "sandboxed_api/sandbox2/forkserver_bin_embed.h" @@ -36,13 +44,110 @@ #include "sandboxed_api/sandbox2/util/strerror.h" #include "sandboxed_api/util/raw_logging.h" -ABSL_FLAG(bool, sandbox2_start_forkserver, true, - "Start Sandbox2 Forkserver process"); +namespace sandbox2 { +namespace { +enum class GlobalForkserverStartMode { + kOnDemand, + // MUST be the last element + kNumGlobalForkserverStartModes, +}; + +class GlobalForkserverStartModeSet { + public: + static constexpr size_t kSize = static_cast( + GlobalForkserverStartMode::kNumGlobalForkserverStartModes); + + GlobalForkserverStartModeSet() {} + explicit GlobalForkserverStartModeSet(GlobalForkserverStartMode value) { + value_[static_cast(value)] = true; + } + GlobalForkserverStartModeSet& operator|=(GlobalForkserverStartMode value) { + value_[static_cast(value)] = true; + return *this; + } + GlobalForkserverStartModeSet operator|( + GlobalForkserverStartMode value) const { + GlobalForkserverStartModeSet rv(*this); + rv |= value; + return rv; + } + bool contains(GlobalForkserverStartMode value) const { + return value_[static_cast(value)]; + } + bool empty() { return value_.none(); } + + private: + std::bitset value_; +}; + +bool AbslParseFlag(absl::string_view text, GlobalForkserverStartModeSet* out, + std::string* error) { + *out = {}; + if (text == "never") { + return true; + } + for (absl::string_view mode : absl::StrSplit(text, ',')) { + mode = absl::StripAsciiWhitespace(mode); + if (mode == "ondemand") { + *out |= GlobalForkserverStartMode::kOnDemand; + } else { + *error = absl::StrCat("Invalid forkserver start mode: ", mode); + return false; + } + } + return true; +} + +std::string ToString(GlobalForkserverStartMode mode) { + switch (mode) { + case GlobalForkserverStartMode::kOnDemand: + return "ondemand"; + default: + return "unknown"; + } +} + +std::string AbslUnparseFlag(GlobalForkserverStartModeSet in) { + std::vector str_modes; + for (size_t i = 0; i < GlobalForkserverStartModeSet::kSize; ++i) { + auto mode = static_cast(i); + if (in.contains(mode)) { + str_modes.push_back(ToString(mode)); + } + } + if (str_modes.empty()) { + return "never"; + } + return absl::StrJoin(str_modes, ","); +} +bool ValidateStartMode(const char*, const std::string& flag) { + GlobalForkserverStartModeSet unused; + std::string error; + if (!AbslParseFlag(flag, &unused, &error)) { + SAPI_RAW_LOG(ERROR, "%s", error); + return false; + } + return true; +} +} // namespace +} // namespace sandbox2 + +ABSL_FLAG(string, sandbox2_forkserver_start_mode, "ondemand", + "When Sandbox2 Forkserver process should be started"); +DEFINE_validator(sandbox2_forkserver_start_mode, &sandbox2::ValidateStartMode); namespace sandbox2 { namespace { +GlobalForkserverStartModeSet GetForkserverStartMode() { + GlobalForkserverStartModeSet rv; + std::string error; + CHECK(AbslParseFlag(absl::GetFlag(FLAGS_sandbox2_forkserver_start_mode), &rv, + &error)); + return rv; +} + std::unique_ptr StartGlobalForkServer() { if (getenv(kForkServerDisableEnv)) { SAPI_RAW_VLOG(1, @@ -52,7 +157,7 @@ std::unique_ptr StartGlobalForkServer() { return {}; } - if (!absl::GetFlag(FLAGS_sandbox2_start_forkserver)) { + if (GetForkserverStartMode().empty()) { SAPI_RAW_VLOG( 1, "Start of the Global Fork-Server prevented by commandline flag"); return {}; @@ -90,34 +195,49 @@ std::unique_ptr StartGlobalForkServer() { return absl::make_unique(sv[1], pid); } -GlobalForkClient* GetGlobalForkClient() { - static GlobalForkClient* global_fork_client = - StartGlobalForkServer().release(); - return global_fork_client; -} - } // namespace -void GlobalForkClient::EnsureStarted() { GetGlobalForkClient(); } +absl::Mutex GlobalForkClient::instance_mutex_(absl::kConstInit); +GlobalForkClient* GlobalForkClient::instance_ = nullptr; + +void GlobalForkClient::EnsureStarted() { + absl::MutexLock lock(&GlobalForkClient::instance_mutex_); + EnsureStartedLocked( + GetForkserverStartMode().contains(GlobalForkserverStartMode::kOnDemand)); +} + +void GlobalForkClient::Shutdown() { + absl::MutexLock lock(&GlobalForkClient::instance_mutex_); + delete instance_; + instance_ = nullptr; +} + +void GlobalForkClient::EnsureStartedLocked(bool start_if_needed) { + if (!instance_ && start_if_needed) { + instance_ = StartGlobalForkServer().release(); + } + SAPI_RAW_CHECK(instance_ != nullptr, "global fork client not initialized"); +} pid_t GlobalForkClient::SendRequest(const ForkRequest& request, int exec_fd, int comms_fd, int user_ns_fd, pid_t* init_pid) { - GlobalForkClient* global_fork_client = GetGlobalForkClient(); - SAPI_RAW_CHECK(global_fork_client != nullptr, - "global fork client not initialized"); - pid_t pid = global_fork_client->fork_client_.SendRequest( - request, exec_fd, comms_fd, user_ns_fd, init_pid); - if (global_fork_client->comms_.IsTerminated()) { + absl::MutexLock lock(&GlobalForkClient::instance_mutex_); + EnsureStartedLocked( + GetForkserverStartMode().contains(GlobalForkserverStartMode::kOnDemand)); + pid_t pid = instance_->fork_client_.SendRequest(request, exec_fd, comms_fd, + user_ns_fd, init_pid); + if (instance_->comms_.IsTerminated()) { LOG(ERROR) << "Global forkserver connection terminated"; } return pid; } pid_t GlobalForkClient::GetPid() { - GlobalForkClient* global_fork_client = GetGlobalForkClient(); - SAPI_RAW_CHECK(global_fork_client != nullptr, - "global fork client not initialized"); - return global_fork_client->fork_client_.pid(); + absl::MutexLock lock(&instance_mutex_); + EnsureStartedLocked( + GetForkserverStartMode().contains(GlobalForkserverStartMode::kOnDemand)); + SAPI_RAW_CHECK(instance_ != nullptr, "global fork client not initialized"); + return instance_->fork_client_.pid(); } } // namespace sandbox2 diff --git a/sandboxed_api/sandbox2/global_forkclient.h b/sandboxed_api/sandbox2/global_forkclient.h index b7f22b4..2aff3dd 100644 --- a/sandboxed_api/sandbox2/global_forkclient.h +++ b/sandboxed_api/sandbox2/global_forkclient.h @@ -20,6 +20,8 @@ #include +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" #include "sandboxed_api/sandbox2/comms.h" #include "sandboxed_api/sandbox2/fork_client.h" #include "sandboxed_api/sandbox2/forkserver.pb.h" @@ -33,12 +35,20 @@ class GlobalForkClient { static pid_t SendRequest(const ForkRequest& request, int exec_fd, int comms_fd, int user_ns_fd = -1, - pid_t* init_pid = nullptr); - static pid_t GetPid(); + pid_t* init_pid = nullptr) + ABSL_LOCKS_EXCLUDED(instance_mutex_); + static pid_t GetPid() ABSL_LOCKS_EXCLUDED(instance_mutex_); - static void EnsureStarted(); + static void EnsureStarted() ABSL_LOCKS_EXCLUDED(instance_mutex_); + static void Shutdown() ABSL_LOCKS_EXCLUDED(instance_mutex_); private: + static absl::Mutex instance_mutex_; + static GlobalForkClient* instance_ ABSL_GUARDED_BY(instance_mutex_); + + static void EnsureStartedLocked(bool start_if_needed = true) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(instance_mutex_); + Comms comms_; ForkClient fork_client_; };