ZStandard: introduce a wrapper

The goal is to use a file descriptor as an input for ZStandard
library. Thanks to that we shouldn't send a chunk of memory over
expensive protocol.
reviewable/pr108/r6
Mariusz Zaborski 2022-02-02 15:45:58 -05:00
parent dc03c38df1
commit 5c154af744
9 changed files with 640 additions and 23 deletions

View File

@ -34,6 +34,8 @@ FetchContent_Declare(libzstd
FetchContent_MakeAvailable(libzstd)
set(libzstd_INCLUDE_DIR "${libzstd_SOURCE_DIR}/lib")
add_subdirectory(wrapper)
add_sapi_library(
sapi_zstd
@ -70,10 +72,16 @@ add_sapi_library(
ZSTD_getFrameContentSize
ZSTD_compress_fd
ZSTD_compressStream_fd
ZSTD_decompress_fd
ZSTD_decompressStream_fd
INPUTS
${libzstd_INCLUDE_DIR}/zstd.h
wrapper/wrapper_zstd.h
LIBRARY libzstd_static
LIBRARY wrapper_zstd
LIBRARY_NAME Zstd
NAMESPACE ""
)

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fcntl.h>
#include <unistd.h>
#include <cstdlib>
@ -25,10 +26,58 @@
#include "contrib/zstd/sandboxed.h"
#include "contrib/zstd/utils/utils_zstd.h"
ABSL_FLAG(bool, stream, false, "stream data to sandbox");
ABSL_FLAG(bool, decompress, false, "decompress");
ABSL_FLAG(bool, memory_mode, false, "in memory operations");
ABSL_FLAG(uint32_t, level, 0, "compression level");
absl::Status Stream(ZstdApi& api, std::string infile_s, std::string outfile_s) {
std::ifstream infile(infile_s, std::ios::binary);
if (!infile.is_open()) {
return absl::UnavailableError(absl::StrCat("Unable to open ", infile_s));
}
std::ofstream outfile(outfile_s, std::ios::binary);
if (!outfile.is_open()) {
return absl::UnavailableError(absl::StrCat("Unable to open ", outfile_s));
}
if (absl::GetFlag(FLAGS_memory_mode) && absl::GetFlag(FLAGS_decompress)) {
return DecompressInMemory(api, infile, outfile);
}
if (absl::GetFlag(FLAGS_memory_mode) && !absl::GetFlag(FLAGS_decompress)) {
return CompressInMemory(api, infile, outfile, absl::GetFlag(FLAGS_level));
}
if (!absl::GetFlag(FLAGS_memory_mode) && absl::GetFlag(FLAGS_decompress)) {
return DecompressStream(api, infile, outfile);
}
return CompressStream(api, infile, outfile, absl::GetFlag(FLAGS_level));
}
absl::Status FileDescriptor(ZstdApi& api, std::string infile_s,
std::string outfile_s) {
sapi::v::Fd infd(open(infile_s.c_str(), O_RDONLY));
if (infd.GetValue() < 0) {
return absl::UnavailableError(absl::StrCat("Unable to open ", infile_s));
}
sapi::v::Fd outfd(open(outfile_s.c_str(), O_WRONLY | O_CREAT));
if (outfd.GetValue() < 0) {
return absl::UnavailableError(absl::StrCat("Unable to open ", outfile_s));
}
if (absl::GetFlag(FLAGS_memory_mode) && absl::GetFlag(FLAGS_decompress)) {
return DecompressInMemoryFD(api, infd, outfd);
}
if (absl::GetFlag(FLAGS_memory_mode) && !absl::GetFlag(FLAGS_decompress)) {
return CompressInMemoryFD(api, infd, outfd, absl::GetFlag(FLAGS_level));
}
if (!absl::GetFlag(FLAGS_memory_mode) && absl::GetFlag(FLAGS_decompress)) {
return DecompressStreamFD(api, infd, outfd);
}
return CompressStreamFD(api, infd, outfd, absl::GetFlag(FLAGS_level));
}
int main(int argc, char* argv[]) {
std::string prog_name(argv[0]);
google::InitGoogleLogging(argv[0]);
@ -39,17 +88,6 @@ int main(int argc, char* argv[]) {
return EXIT_FAILURE;
}
std::ifstream infile(args[1], std::ios::binary);
if (!infile.is_open()) {
std::cerr << "Unable to open " << args[1] << std::endl;
return EXIT_FAILURE;
}
std::ofstream outfile(args[2], std::ios::binary);
if (!outfile.is_open()) {
std::cerr << "Unable to open " << args[2] << std::endl;
return EXIT_FAILURE;
}
ZstdSapiSandbox sandbox;
if (!sandbox.Init().ok()) {
std::cerr << "Unable to start sandbox\n";
@ -59,16 +97,10 @@ int main(int argc, char* argv[]) {
ZstdApi api(&sandbox);
absl::Status status;
if (absl::GetFlag(FLAGS_memory_mode) && absl::GetFlag(FLAGS_decompress)) {
status = DecompressInMemory(api, infile, outfile);
} else if (absl::GetFlag(FLAGS_memory_mode) &&
!absl::GetFlag(FLAGS_decompress)) {
status = CompressInMemory(api, infile, outfile, absl::GetFlag(FLAGS_level));
} else if (!absl::GetFlag(FLAGS_memory_mode) &&
absl::GetFlag(FLAGS_decompress)) {
status = DecompressStream(api, infile, outfile);
if (absl::GetFlag(FLAGS_stream)) {
status = Stream(api, argv[1], argv[2]);
} else {
status = CompressStream(api, infile, outfile, absl::GetFlag(FLAGS_level));
status = FileDescriptor(api, argv[1], argv[2]);
}
if (!status.ok()) {

View File

@ -27,10 +27,14 @@ class ZstdSapiSandbox : public ZstdSandbox {
std::unique_ptr<sandbox2::Policy> ModifyPolicy(
sandbox2::PolicyBuilder*) override {
return sandbox2::PolicyBuilder()
.AllowDynamicStartup()
.AllowRead()
.AllowWrite()
.AllowSystemMalloc()
.AllowExit()
.AllowSyscalls({
__NR_recvmsg
})
.BuildOrDie();
}
};

View File

@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fcntl.h>
#include <unistd.h>
#include <fstream>
#include <string>
@ -278,8 +281,239 @@ TEST(SandboxTest, CheckCompressAndDecompressStream) {
ASSERT_TRUE(outfile.is_open());
status = DecompressStream(api, inmiddle, outfile);
ASSERT_THAT(status, IsOk()) << "Unable to decompress";
ASSERT_TRUE(CompareFiles(infile_s, outfile_s));
}
TEST(SandboxTest, CheckCompressInMemoryFD) {
ZstdSapiSandbox sandbox;
ASSERT_THAT(sandbox.Init(), IsOk()) << "Couldn't initialize Sandboxed API";
ZstdApi api = ZstdApi(&sandbox);
std::string infile_s = GetTestFilePath("text");
absl::StatusOr<std::string> path =
sapi::CreateNamedTempFileAndClose("out.zstd");
ASSERT_THAT(path, IsOk()) << "Could not create temp output file";
std::string outfile_s =
sapi::file::JoinPath(sapi::file_util::fileops::GetCWD(), *path);
sapi::v::Fd infd(open(infile_s.c_str(), O_RDONLY));
ASSERT_GE(infd.GetValue(), 0);
sapi::v::Fd outfd(open(outfile_s.c_str(), O_WRONLY));
ASSERT_GE(outfd.GetValue(), 0);
absl::Status status = CompressInMemoryFD(api, infd, outfd, 0);
ASSERT_THAT(status, IsOk()) << "Unable to compress file in memory";
off_t inpos = lseek(infd.GetValue(), 0, SEEK_END);
EXPECT_GE(inpos, 0);
off_t outpos = lseek(outfd.GetValue(), 0, SEEK_END);
EXPECT_GE(outpos, 0);
EXPECT_LT(outpos, inpos);
}
TEST(SandboxTest, CheckDecompressInMemoryFD) {
ZstdSapiSandbox sandbox;
ASSERT_THAT(sandbox.Init(), IsOk()) << "Couldn't initialize Sandboxed API";
ZstdApi api = ZstdApi(&sandbox);
std::string infile_s = GetTestFilePath("text.blob.zstd");
sapi::v::Fd infd(open(infile_s.c_str(), O_RDONLY));
ASSERT_GE(infd.GetValue(), 0);
absl::StatusOr<std::string> path = sapi::CreateNamedTempFileAndClose("out");
ASSERT_THAT(path, IsOk()) << "Could not create temp output file";
std::string outfile_s =
sapi::file::JoinPath(sapi::file_util::fileops::GetCWD(), *path);
sapi::v::Fd outfd(open(outfile_s.c_str(), O_WRONLY));
ASSERT_GE(outfd.GetValue(), 0);
absl::Status status = DecompressInMemoryFD(api, infd, outfd);
ASSERT_THAT(status, IsOk()) << "Unable to compress file in memory";
off_t inpos = lseek(infd.GetValue(), 0, SEEK_END);
EXPECT_GE(inpos, 0);
off_t outpos = lseek(outfd.GetValue(), 0, SEEK_END);
EXPECT_GE(outpos, 0);
EXPECT_GT(outpos, inpos);
ASSERT_TRUE(CompareFiles(GetTestFilePath("text"), outfile_s));
}
TEST(SandboxTest, CheckCompressAndDecompressInMemoryFD) {
ZstdSapiSandbox sandbox;
absl::Status status;
int ret;
ASSERT_THAT(sandbox.Init(), IsOk()) << "Couldn't initialize Sandboxed API";
ZstdApi api = ZstdApi(&sandbox);
std::string infile_s = GetTestFilePath("text");
absl::StatusOr<std::string> path_middle =
sapi::CreateNamedTempFileAndClose("middle.zstd");
ASSERT_THAT(path_middle, IsOk()) << "Could not create temp output file";
std::string middle_s =
sapi::file::JoinPath(sapi::file_util::fileops::GetCWD(), *path_middle);
absl::StatusOr<std::string> path = sapi::CreateNamedTempFileAndClose("out");
ASSERT_THAT(path, IsOk()) << "Could not create temp output file";
std::string outfile_s =
sapi::file::JoinPath(sapi::file_util::fileops::GetCWD(), *path);
sapi::v::Fd infd(open(infile_s.c_str(), O_RDONLY));
ASSERT_GE(infd.GetValue(), 0);
sapi::v::Fd outmiddlefd(open(middle_s.c_str(), O_WRONLY));
ASSERT_GE(outmiddlefd.GetValue(), 0);
status = CompressInMemoryFD(api, infd, outmiddlefd, 0);
ASSERT_THAT(status, IsOk()) << "Unable to compress file in memory";
off_t inpos = lseek(infd.GetValue(), 0, SEEK_END);
EXPECT_GE(inpos, 0);
off_t outpos = lseek(outmiddlefd.GetValue(), 0, SEEK_END);
EXPECT_GE(outpos, 0);
EXPECT_LT(outpos, inpos);
infd.CloseLocalFd();
outmiddlefd.CloseLocalFd();
sapi::v::Fd inmiddlefd(open(middle_s.c_str(), O_RDONLY));
ASSERT_GE(inmiddlefd.GetValue(), 0);
sapi::v::Fd outfd(open(outfile_s.c_str(), O_WRONLY));
ASSERT_GE(outfd.GetValue(), 0);
status = DecompressInMemoryFD(api, inmiddlefd, outfd);
ASSERT_THAT(status, IsOk()) << "Unable to decompress file in memory";
outfd.CloseLocalFd();
inmiddlefd.CloseLocalFd();
ASSERT_TRUE(CompareFiles(infile_s, outfile_s));
}
TEST(SandboxTest, CheckCompressStreamFD) {
absl::Status status;
ZstdSapiSandbox sandbox;
ASSERT_THAT(sandbox.Init(), IsOk()) << "Couldn't initialize Sandboxed API";
ZstdApi api = ZstdApi(&sandbox);
std::string infile_s = GetTestFilePath("text");
absl::StatusOr<std::string> path =
sapi::CreateNamedTempFileAndClose("out.zstd");
ASSERT_THAT(path, IsOk()) << "Could not create temp output file";
std::string outfile_s =
sapi::file::JoinPath(sapi::file_util::fileops::GetCWD(), *path);
sapi::v::Fd infd(open(infile_s.c_str(), O_RDONLY));
ASSERT_GE(infd.GetValue(), 0);
sapi::v::Fd outfd(open(outfile_s.c_str(), O_WRONLY));
ASSERT_GE(outfd.GetValue(), 0);
status = CompressStreamFD(api, infd, outfd, 0);
ASSERT_THAT(status, IsOk()) << "Unable to compress stream";
off_t inpos = lseek(infd.GetValue(), 0, SEEK_END);
EXPECT_GE(inpos, 0);
off_t outpos = lseek(outfd.GetValue(), 0, SEEK_END);
EXPECT_GE(outpos, 0);
EXPECT_LT(outpos, inpos);
}
TEST(SandboxTest, CheckDecompressStreamFD) {
absl::Status status;
ZstdSapiSandbox sandbox;
ASSERT_THAT(sandbox.Init(), IsOk()) << "Couldn't initialize Sandboxed API";
ZstdApi api = ZstdApi(&sandbox);
std::string infile_s = GetTestFilePath("text.stream.zstd");
absl::StatusOr<std::string> path = sapi::CreateNamedTempFileAndClose("out");
ASSERT_THAT(path, IsOk()) << "Could not create temp output file";
std::string outfile_s =
sapi::file::JoinPath(sapi::file_util::fileops::GetCWD(), *path);
sapi::v::Fd infd(open(infile_s.c_str(), O_RDONLY));
ASSERT_GE(infd.GetValue(), 0);
sapi::v::Fd outfd(open(outfile_s.c_str(), O_WRONLY));
ASSERT_GE(outfd.GetValue(), 0);
status = DecompressStreamFD(api, infd, outfd);
ASSERT_THAT(status, IsOk()) << "Unable to decompress stream";
off_t inpos = lseek(infd.GetValue(), 0, SEEK_END);
EXPECT_GE(inpos, 0);
off_t outpos = lseek(outfd.GetValue(), 0, SEEK_END);
EXPECT_GE(outpos, 0);
EXPECT_GT(outpos, inpos);
ASSERT_TRUE(CompareFiles(GetTestFilePath("text"), outfile_s));
}
TEST(SandboxTest, CheckCompressAndDecompressStreamFD) {
ZstdSapiSandbox sandbox;
absl::Status status;
int ret;
ASSERT_THAT(sandbox.Init(), IsOk()) << "Couldn't initialize Sandboxed API";
ZstdApi api = ZstdApi(&sandbox);
std::string infile_s = GetTestFilePath("text");
absl::StatusOr<std::string> path_middle =
sapi::CreateNamedTempFileAndClose("middle.zstd");
ASSERT_THAT(path_middle, IsOk()) << "Could not create temp output file";
std::string middle_s =
sapi::file::JoinPath(sapi::file_util::fileops::GetCWD(), *path_middle);
absl::StatusOr<std::string> path = sapi::CreateNamedTempFileAndClose("out");
ASSERT_THAT(path, IsOk()) << "Could not create temp output file";
std::string outfile_s =
sapi::file::JoinPath(sapi::file_util::fileops::GetCWD(), *path);
sapi::v::Fd infd(open(infile_s.c_str(), O_RDONLY));
ASSERT_GE(infd.GetValue(), 0);
sapi::v::Fd outmiddlefd(open(middle_s.c_str(), O_WRONLY));
ASSERT_GE(outmiddlefd.GetValue(), 0);
status = CompressStreamFD(api, infd, outmiddlefd, 0);
ASSERT_THAT(status, IsOk()) << "Unable to compress stream";
off_t inpos = lseek(infd.GetValue(), 0, SEEK_END);
EXPECT_GE(inpos, 0);
off_t outmiddlepos = lseek(outmiddlefd.GetValue(), 0, SEEK_END);
EXPECT_GE(outmiddlepos, 0);
EXPECT_LT(outmiddlepos, inpos);
infd.CloseLocalFd();
outmiddlefd.CloseLocalFd();
sapi::v::Fd inmiddlefd(open(middle_s.c_str(), O_RDONLY));
ASSERT_GE(inmiddlefd.GetValue(), 0);
sapi::v::Fd outfd(open(outfile_s.c_str(), O_WRONLY));
ASSERT_GE(outfd.GetValue(), 0);
status = DecompressStreamFD(api, inmiddlefd, outfd);
ASSERT_THAT(status, IsOk()) << "Unable to decompress stream";
ASSERT_TRUE(CompareFiles(infile_s, outfile_s));
}

View File

@ -110,7 +110,7 @@ absl::Status CompressStream(ZstdApi& api, std::ifstream& in_stream,
}
// Create Zstd context.
SAPI_ASSIGN_OR_RETURN(ZSTD_CCtx * cctx, api.ZSTD_createCCtx());
SAPI_ASSIGN_OR_RETURN(ZSTD_CCtx* cctx, api.ZSTD_createCCtx());
sapi::v::RemotePtr rcctx(cctx);
SAPI_ASSIGN_OR_RETURN(iserr, api.ZSTD_CCtx_setParameter(
@ -196,7 +196,7 @@ absl::Status DecompressStream(ZstdApi& api, std::ifstream& in_stream,
}
// Create Zstd context.
SAPI_ASSIGN_OR_RETURN(ZSTD_DCtx * dctx, api.ZSTD_createDCtx());
SAPI_ASSIGN_OR_RETURN(ZSTD_DCtx* dctx, api.ZSTD_createDCtx());
sapi::v::RemotePtr rdctx(dctx);
// Decompress.
@ -241,3 +241,95 @@ absl::Status DecompressStream(ZstdApi& api, std::ifstream& in_stream,
return absl::OkStatus();
}
absl::Status CompressInMemoryFD(ZstdApi& api, sapi::v::Fd& infd,
sapi::v::Fd& outfd, int level) {
SAPI_RETURN_IF_ERROR(api.GetSandbox()->TransferToSandboxee(&infd));
SAPI_RETURN_IF_ERROR(api.GetSandbox()->TransferToSandboxee(&outfd));
SAPI_ASSIGN_OR_RETURN(
int iserr,
api.ZSTD_compress_fd(infd.GetRemoteFd(), outfd.GetRemoteFd(), 0));
SAPI_ASSIGN_OR_RETURN(iserr, api.ZSTD_isError(iserr))
if (iserr) {
return absl::UnavailableError("Unable to compress file");
}
infd.CloseRemoteFd(api.GetSandbox()->rpc_channel()).IgnoreError();
outfd.CloseRemoteFd(api.GetSandbox()->rpc_channel()).IgnoreError();
return absl::OkStatus();
}
absl::Status DecompressInMemoryFD(ZstdApi& api, sapi::v::Fd& infd,
sapi::v::Fd& outfd) {
SAPI_RETURN_IF_ERROR(api.GetSandbox()->TransferToSandboxee(&infd));
SAPI_RETURN_IF_ERROR(api.GetSandbox()->TransferToSandboxee(&outfd));
SAPI_ASSIGN_OR_RETURN(int iserr, api.ZSTD_decompress_fd(infd.GetRemoteFd(),
outfd.GetRemoteFd()));
SAPI_ASSIGN_OR_RETURN(iserr, api.ZSTD_isError(iserr))
if (iserr) {
return absl::UnavailableError("Unable to compress file");
}
infd.CloseRemoteFd(api.GetSandbox()->rpc_channel()).IgnoreError();
outfd.CloseRemoteFd(api.GetSandbox()->rpc_channel()).IgnoreError();
return absl::OkStatus();
}
absl::Status CompressStreamFD(ZstdApi& api, sapi::v::Fd& infd,
sapi::v::Fd& outfd, int level) {
SAPI_ASSIGN_OR_RETURN(ZSTD_CCtx* cctx, api.ZSTD_createCCtx());
sapi::v::RemotePtr rcctx(cctx);
int iserr;
SAPI_ASSIGN_OR_RETURN(iserr, api.ZSTD_CCtx_setParameter(
&rcctx, ZSTD_c_compressionLevel, level));
SAPI_ASSIGN_OR_RETURN(iserr, api.ZSTD_isError(iserr));
if (iserr) {
return absl::UnavailableError("Unable to set parameter l");
}
SAPI_ASSIGN_OR_RETURN(
iserr, api.ZSTD_CCtx_setParameter(&rcctx, ZSTD_c_checksumFlag, 1));
SAPI_ASSIGN_OR_RETURN(iserr, api.ZSTD_isError(iserr));
if (iserr) {
return absl::UnavailableError("Unable to set parameter c");
}
SAPI_RETURN_IF_ERROR(api.GetSandbox()->TransferToSandboxee(&infd));
SAPI_RETURN_IF_ERROR(api.GetSandbox()->TransferToSandboxee(&outfd));
SAPI_ASSIGN_OR_RETURN(iserr,
api.ZSTD_compressStream_fd(&rcctx, infd.GetRemoteFd(),
outfd.GetRemoteFd()));
if (iserr) {
return absl::UnavailableError("Unable to compress");
}
infd.CloseRemoteFd(api.GetSandbox()->rpc_channel()).IgnoreError();
outfd.CloseRemoteFd(api.GetSandbox()->rpc_channel()).IgnoreError();
return absl::OkStatus();
}
absl::Status DecompressStreamFD(ZstdApi& api, sapi::v::Fd& infd,
sapi::v::Fd& outfd) {
SAPI_ASSIGN_OR_RETURN(ZSTD_DCtx* dctx, api.ZSTD_createDCtx());
sapi::v::RemotePtr rdctx(dctx);
SAPI_RETURN_IF_ERROR(api.GetSandbox()->TransferToSandboxee(&infd));
SAPI_RETURN_IF_ERROR(api.GetSandbox()->TransferToSandboxee(&outfd));
SAPI_ASSIGN_OR_RETURN(int iserr,
api.ZSTD_decompressStream_fd(&rdctx, infd.GetRemoteFd(),
outfd.GetRemoteFd()));
if (iserr) {
return absl::UnavailableError("Unable to decompress");
}
infd.CloseRemoteFd(api.GetSandbox()->rpc_channel()).IgnoreError();
outfd.CloseRemoteFd(api.GetSandbox()->rpc_channel()).IgnoreError();
return absl::OkStatus();
}

View File

@ -24,10 +24,18 @@ absl::Status CompressInMemory(ZstdApi& api, std::ifstream& in_stream,
std::ofstream& out_stream, int level);
absl::Status DecompressInMemory(ZstdApi& api, std::ifstream& in_stream,
std::ofstream& out_stream);
absl::Status CompressInMemoryFD(ZstdApi& api, sapi::v::Fd& infd,
sapi::v::Fd& outfd, int level);
absl::Status DecompressInMemoryFD(ZstdApi& api, sapi::v::Fd& infd,
sapi::v::Fd& outfd);
absl::Status CompressStream(ZstdApi& api, std::ifstream& in_stream,
std::ofstream& out_stream, int level);
absl::Status DecompressStream(ZstdApi& api, std::ifstream& in_stream,
std::ofstream& out_stream);
absl::Status CompressStreamFD(ZstdApi& api, sapi::v::Fd& infd,
sapi::v::Fd& outfd, int level);
absl::Status DecompressStreamFD(ZstdApi& api, sapi::v::Fd& infd,
sapi::v::Fd& outfd);
#endif // CONTRIB_ZSTD_UTILS_UTILS_ZSTD_H_

View File

@ -0,0 +1,27 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
add_library(
wrapper_zstd STATIC
wrapper_zstd.cc
)
target_link_libraries(wrapper_zstd PUBLIC
libzstd_static
)
target_include_directories(wrapper_zstd PUBLIC
${libzstd_INCLUDE_DIR}
)

View File

@ -0,0 +1,184 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "wrapper_zstd.h"
#include <errno.h>
#include <fcntl.h>
#include <unistd.h>
#include <cstdio>
#include <cstdlib>
#include <iostream>
#include <memory>
#include "zstd.h"
static constexpr size_t kFileMaxSize = 1024 * 1024 * 1024; // 1GB
off_t FDGetSize(int fd) {
off_t size = lseek(fd, 0, SEEK_END);
if (size < 0) {
return -1;
}
if (lseek(fd, 0, SEEK_SET) < 0) {
return -1;
}
return size;
}
int ZSTD_compress_fd(int fdin, int fdout, int level) {
off_t sizein = FDGetSize(fdin);
if (sizein <= 0) {
return -1;
}
size_t sizeout = ZSTD_compressBound(sizein);
auto bufin = std::make_unique<int8_t[]>(sizein);
auto bufout = std::make_unique<int8_t[]>(sizeout);
if (read(fdin, bufin.get(), sizein) != sizein) {
return -1;
}
int retsize =
ZSTD_compress(bufout.get(), sizeout, bufin.get(), sizein, level);
if (ZSTD_isError(retsize)) {
return -1;
}
if (write(fdout, bufout.get(), retsize) != retsize) {
return -1;
}
return 0;
}
int ZSTD_compressStream_fd(ZSTD_CCtx* cctx, int fdin, int fdout) {
size_t sizein = ZSTD_CStreamInSize();
size_t sizeout = ZSTD_CStreamOutSize();
auto bufin = std::make_unique<int8_t[]>(sizein);
auto bufout = std::make_unique<int8_t[]>(sizeout);
ssize_t size;
while ((size = read(fdin, bufin.get(), sizein)) > 0) {
ZSTD_inBuffer_s struct_in;
struct_in.src = bufin.get();
struct_in.pos = 0;
struct_in.size = size;
ZSTD_EndDirective mode = ZSTD_e_continue;
if (size < sizein) {
mode = ZSTD_e_end;
}
bool isdone = false;
while (!isdone) {
ZSTD_outBuffer_s struct_out;
struct_out.dst = bufout.get();
struct_out.pos = 0;
struct_out.size = sizeout;
size_t remaining =
ZSTD_compressStream2(cctx, &struct_out, &struct_in, mode);
if (ZSTD_isError(remaining)) {
return -1;
}
if (write(fdout, bufout.get(), struct_out.pos) != struct_out.pos) {
return -1;
}
if (mode == ZSTD_e_continue) {
isdone = (struct_in.pos == size);
} else {
isdone = (remaining == 0);
}
}
}
if (size != 0) {
return -1;
}
return 0;
}
int ZSTD_decompress_fd(int fdin, int fdout) {
off_t sizein = FDGetSize(fdin);
if (sizein <= 0) {
return -1;
}
auto bufin = std::make_unique<int8_t[]>(sizein);
if (read(fdin, bufin.get(), sizein) != sizein) {
return -1;
}
size_t sizeout = ZSTD_getFrameContentSize(bufin.get(), sizein);
if (ZSTD_isError(sizeout) || sizeout > kFileMaxSize) {
return -1;
}
auto bufout = std::make_unique<int8_t[]>(sizeout);
size_t desize = ZSTD_decompress(bufout.get(), sizeout, bufin.get(), sizein);
if (ZSTD_isError(desize) || desize != sizeout) {
return -1;
}
if (write(fdout, bufout.get(), sizeout) != sizeout) {
return -1;
}
return 0;
}
int ZSTD_decompressStream_fd(ZSTD_DCtx* dctx, int fdin, int fdout) {
size_t sizein = ZSTD_CStreamInSize();
size_t sizeout = ZSTD_CStreamOutSize();
auto bufin = std::make_unique<int8_t[]>(sizein);
auto bufout = std::make_unique<int8_t[]>(sizeout);
ssize_t size;
while ((size = read(fdin, bufin.get(), sizein)) > 0) {
ZSTD_inBuffer_s struct_in;
struct_in.src = bufin.get();
struct_in.pos = 0;
struct_in.size = size;
while (struct_in.pos < size) {
ZSTD_outBuffer_s struct_out;
struct_out.dst = bufout.get();
struct_out.pos = 0;
struct_out.size = sizeout;
size_t ret = ZSTD_decompressStream(dctx, &struct_out, &struct_in);
if (ZSTD_isError(ret)) {
return -1;
}
if (write(fdout, bufout.get(), struct_out.pos) != struct_out.pos) {
return -1;
}
}
}
if (size != 0) {
return -1;
}
return 0;
}

View File

@ -0,0 +1,28 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CONTRIB_ZSTD_WRAPPER_WRAPPER_ZSTD_H_
#define CONTRIB_ZSTD_WRAPPER_WRAPPER_ZSTD_H_
#include "zstd.h"
extern "C" {
int ZSTD_compress_fd(int fdin, int fdout, int level);
int ZSTD_compressStream_fd(ZSTD_CCtx* cctx, int fdin, int fdout);
int ZSTD_decompress_fd(int fdin, int fdout);
int ZSTD_decompressStream_fd(ZSTD_DCtx* dctx, int fdin, int fdout);
};
#endif // CONTRIB_ZSTD_WRAPPER_WRAPPER_ZSTD_H_