This commit is contained in:
Thomas Fussell 2017-04-28 17:56:05 -04:00
parent be37df3c86
commit f5b5d67594
7 changed files with 396 additions and 288 deletions

View File

@ -293,5 +293,32 @@ std::vector<byte> string_to_bytes(const std::basic_string<T> &string)
return bytes; return bytes;
} }
template<typename T>
T read(std::istream &in)
{
T result;
in.read(reinterpret_cast<char *>(&result), sizeof(T));
return result;
}
template<typename T>
std::vector<T> read_vector(std::istream &in, std::size_t count)
{
std::vector<T> result(count, T());
in.read(reinterpret_cast<char *>(&result[0]), sizeof(T) * count);
return result;
}
template<typename T>
std::basic_string<T> read_string(std::istream &in, std::size_t count)
{
std::basic_string<T> result(count, T());
in.read(reinterpret_cast<char *>(&result[0]), sizeof(T) * count);
return result;
}
} // namespace detail } // namespace detail
} // namespace xlnt } // namespace xlnt

View File

@ -76,77 +76,246 @@ const directory_id End = -1;
namespace xlnt { namespace xlnt {
namespace detail { namespace detail {
compound_document::compound_document(std::vector<std::uint8_t> &data) /// <summary>
: writer_(new binary_writer<byte>(data)), /// Allows a std::vector to be read through a std::istream.
reader_(new binary_reader<byte>(data)) /// </summary>
class compound_document_istreambuf : public std::streambuf
{ {
header_.msat.fill(FreeSector); using int_type = std::streambuf::int_type;
writer_->offset(0);
writer_->write(header_);
header_.directory_start = allocate_sector(); public:
writer_->offset(0); compound_document_istreambuf(const std::string &filename)
writer_->write(header_); : data_(filename.begin(), filename.end()),
position_(0)
insert_entry("Root Entry", compound_document_entry::entry_type::RootStorage);
}
compound_document::compound_document(const std::vector<std::uint8_t> &data)
: writer_(nullptr),
reader_(new binary_reader<byte>(data))
{
header_ = reader_->read<compound_document_header>();
// read msat
auto current = header_.extra_msat_start;
for (auto i = std::size_t(0); i < header_.num_msat_sectors; ++i)
{ {
if (i < 109) }
compound_document_istreambuf(const compound_document_istreambuf &) = delete;
compound_document_istreambuf &operator=(const compound_document_istreambuf &) = delete;
private:
int_type underflow()
{
if (position_ == data_.size())
{ {
msat_.push_back(header_.msat[i]); return traits_type::eof();
}
return traits_type::to_int_type(static_cast<char>(data_[position_]));
}
int_type uflow()
{
if (position_ == data_.size())
{
return traits_type::eof();
}
return traits_type::to_int_type(static_cast<char>(data_[position_++]));
}
std::streamsize showmanyc()
{
if (position_ == data_.size())
{
return static_cast<std::streamsize>(-1);
}
return static_cast<std::streamsize>(data_.size() - position_);
}
std::streampos seekoff(std::streamoff off, std::ios_base::seekdir way, std::ios_base::openmode)
{
if (way == std::ios_base::beg)
{
position_ = 0;
}
else if (way == std::ios_base::end)
{
position_ = data_.size();
}
if (off < 0)
{
if (static_cast<std::size_t>(-off) > position_)
{
position_ = 0;
return static_cast<std::ptrdiff_t>(-1);
}
else
{
position_ -= static_cast<std::size_t>(-off);
}
}
else if (off > 0)
{
if (static_cast<std::size_t>(off) + position_ > data_.size())
{
position_ = data_.size();
return static_cast<std::ptrdiff_t>(-1);
}
else
{
position_ += static_cast<std::size_t>(off);
}
}
return static_cast<std::ptrdiff_t>(position_);
}
std::streampos seekpos(std::streampos sp, std::ios_base::openmode)
{
if (sp < 0)
{
position_ = 0;
}
else if (static_cast<std::size_t>(sp) > data_.size())
{
position_ = data_.size();
} }
else else
{ {
auto extra_msat_sector = std::vector<sector_id>(); position_ = static_cast<std::size_t>(sp);
auto extra_msat_sector_writer = binary_writer<sector_id>(extra_msat_sector);
read_sector(current, extra_msat_sector_writer);
std::copy(extra_msat_sector.begin(),
extra_msat_sector.end() - 1,
std::back_inserter(msat_));
current = extra_msat_sector.back();
} }
}
return static_cast<std::ptrdiff_t>(position_);
for (auto sat_sector_id : msat_)
{
if (sat_sector_id < 0) continue;
auto sat_sector = std::vector<sector_id>();
auto sat_sector_writer = binary_writer<sector_id>(sat_sector);
read_sector(sat_sector_id, sat_sector_writer);
std::copy(sat_sector.begin(),
sat_sector.end(),
std::back_inserter(sat_));
} }
for (auto ssat_sector_id : follow_chain(header_.ssat_start, sat_)) private:
std::vector<std::uint8_t> data_;
std::size_t position_;
};
/// <summary>
/// Allows a std::vector to be written through a std::ostream.
/// </summary>
class compound_document_ostreambuf : public std::streambuf
{
using int_type = std::streambuf::int_type;
public:
compound_document_ostreambuf(const std::string &filename)
: data_(filename.begin(), filename.end()),
position_(0)
{ {
auto ssat_sector = std::vector<sector_id>();
auto ssat_sector_writer = binary_writer<sector_id>(ssat_sector);
read_sector(ssat_sector_id, ssat_sector_writer);
std::copy(ssat_sector.begin(),
ssat_sector.end(),
std::back_inserter(ssat_));
} }
tree_initialize_parent_maps(); compound_document_ostreambuf(const compound_document_ostreambuf &) = delete;
compound_document_ostreambuf &operator=(const compound_document_ostreambuf &) = delete;
private:
int_type overflow(int_type c = traits_type::eof())
{
if (c != traits_type::eof())
{
data_.push_back(static_cast<std::uint8_t>(c));
position_ = data_.size() - 1;
}
return traits_type::to_int_type(static_cast<char>(data_[position_]));
}
std::streamsize xsputn(const char *s, std::streamsize n)
{
if (data_.empty())
{
data_.resize(static_cast<std::size_t>(n));
}
else
{
auto position_size = data_.size();
auto required_size = static_cast<std::size_t>(position_ + static_cast<std::size_t>(n));
data_.resize(std::max(position_size, required_size));
}
std::copy(s, s + n, data_.begin() + static_cast<std::ptrdiff_t>(position_));
position_ += static_cast<std::size_t>(n);
return n;
}
std::streampos seekoff(std::streamoff off, std::ios_base::seekdir way, std::ios_base::openmode)
{
if (way == std::ios_base::beg)
{
position_ = 0;
}
else if (way == std::ios_base::end)
{
position_ = data_.size();
}
if (off < 0)
{
if (static_cast<std::size_t>(-off) > position_)
{
position_ = 0;
return static_cast<std::ptrdiff_t>(-1);
}
else
{
position_ -= static_cast<std::size_t>(-off);
}
}
else if (off > 0)
{
if (static_cast<std::size_t>(off) + position_ > data_.size())
{
position_ = data_.size();
return static_cast<std::ptrdiff_t>(-1);
}
else
{
position_ += static_cast<std::size_t>(off);
}
}
return static_cast<std::ptrdiff_t>(position_);
}
std::streampos seekpos(std::streampos sp, std::ios_base::openmode)
{
if (sp < 0)
{
position_ = 0;
}
else if (static_cast<std::size_t>(sp) > data_.size())
{
position_ = data_.size();
}
else
{
position_ = static_cast<std::size_t>(sp);
}
return static_cast<std::ptrdiff_t>(position_);
}
private:
std::vector<std::uint8_t> data_;
std::size_t position_;
};
compound_document::compound_document(std::ostream &out)
: out_(&out),
stream_in_(nullptr),
stream_out_(nullptr)
{
write_header();
insert_entry("Root Entry", compound_document_entry::entry_type::RootStorage);
}
compound_document::compound_document(std::istream &in)
: in_(&in),
stream_in_(nullptr),
stream_out_(nullptr)
{
read_header();
read_msat();
read_sat();
read_ssat();
read_directory();
} }
compound_document::~compound_document() compound_document::~compound_document()
@ -163,84 +332,36 @@ std::size_t compound_document::short_sector_size()
return static_cast<std::size_t>(1) << header_.short_sector_size_power; return static_cast<std::size_t>(1) << header_.short_sector_size_power;
} }
std::vector<byte> compound_document::read_stream(const std::string &name) std::istream &compound_document::open_read_stream(const std::string &name)
{ {
const auto entry_id = find_entry(name, compound_document_entry::entry_type::UserStream); const auto entry_id = find_entry(name, compound_document_entry::entry_type::UserStream);
const auto &entry = entries_.at(entry_id); const auto &entry = entries_.at(entry_id);
auto stream_data = std::vector<byte>(); stream_in_buffer_.reset(new compound_document_istreambuf(name));
auto stream_data_writer = binary_writer<byte>(stream_data); stream_in_.rdbuf(stream_in_buffer_.get());
if (entry.size < header_.threshold) return stream_in_;
{
for (auto sector : follow_chain(entry.start, ssat_))
{
read_short_sector<byte>(sector, stream_data_writer);
}
}
else
{
for (auto sector : follow_chain(entry.start, sat_))
{
read_sector<byte>(sector, stream_data_writer);
}
}
stream_data.resize(entry.size);
return stream_data;
} }
void compound_document::write_stream(const std::string &name, const std::vector<std::uint8_t> &data) std::ostream &compound_document::open_write_stream(const std::string &name)
{ {
auto entry_id = contains_entry(name, compound_document_entry::entry_type::UserStream) auto entry_id = contains_entry(name, compound_document_entry::entry_type::UserStream)
? find_entry(name, compound_document_entry::entry_type::UserStream) ? find_entry(name, compound_document_entry::entry_type::UserStream)
: insert_entry(name, compound_document_entry::entry_type::UserStream); : insert_entry(name, compound_document_entry::entry_type::UserStream);
auto &entry = entries_.at(entry_id); auto &entry = entries_.at(entry_id);
entry.size = static_cast<std::uint32_t>(data.size());
auto stream_data_reader = binary_reader<byte>(data); stream_out_buffer_.reset(new compound_document_ostreambuf(name));
stream_out_.rdbuf(stream_out_buffer_.get());
if (entry.size < header_.threshold) return stream_out_;
{
const auto num_sectors = data.size() / short_sector_size()
+ (data.size() % short_sector_size() ? 1 : 0);
auto chain = allocate_short_sectors(num_sectors);
entry.start = chain.front();
for (auto sector : follow_chain(entry.start, ssat_))
{
write_short_sector(stream_data_reader, sector);
}
}
else
{
const auto num_sectors = data.size() / short_sector_size()
+ (data.size() % short_sector_size() ? 1 : 0);
auto chain = allocate_sectors(num_sectors);
entry.start = chain.front();
for (auto sector : follow_chain(entry.start, sat_))
{
write_sector(stream_data_reader, sector);
}
}
auto directory_chain = follow_chain(header_.directory_start, sat_);
const auto entries_per_sector = sector_size() / sizeof(compound_document_entry);
auto entry_directory_sector = directory_chain[entry_id / entries_per_sector];
writer_->offset(sector_data_start()
+ entry_directory_sector * sector_size()
+ (entry_id % entries_per_sector) * sizeof(compound_document_entry));
writer_->write(entry);
} }
template<typename T> template<typename T>
void compound_document::write_sector(binary_reader<T> &reader, sector_id id) void compound_document::write_sector(binary_reader<T> &reader, sector_id id)
{ {
writer_->offset(sector_data_start() + sector_size() * id); out_->seekp(sector_data_start() + sector_size() * id);
writer_->append(reader, std::min(sector_size(), reader.bytes()) / sizeof(T)); out_->write(reinterpret_cast<const char *>(reader.data() + reader.offset()),
std::min(sector_size(), reader.bytes() - reader.offset()));
} }
template<typename T> template<typename T>
@ -249,15 +370,18 @@ void compound_document::write_short_sector(binary_reader<T> &reader, sector_id i
auto chain = follow_chain(entries_[0].start, sat_); auto chain = follow_chain(entries_[0].start, sat_);
auto sector_id = chain[id / (sector_size() / short_sector_size())]; auto sector_id = chain[id / (sector_size() / short_sector_size())];
auto sector_offset = id % (sector_size() / short_sector_size()) * short_sector_size(); auto sector_offset = id % (sector_size() / short_sector_size()) * short_sector_size();
writer_->offset(sector_data_start() + sector_size() * sector_id + sector_offset); out_->seekp(sector_data_start() + sector_size() * sector_id + sector_offset);
writer_->append(reader, std::min(short_sector_size() / sizeof(T), reader.count() - reader.offset())); out_->write(reinterpret_cast<const char *>(reader.data() + reader.offset()),
std::min(short_sector_size(), reader.bytes() - reader.offset()));
} }
template<typename T> template<typename T>
void compound_document::read_sector(sector_id id, binary_writer<T> &writer) void compound_document::read_sector(sector_id id, binary_writer<T> &writer)
{ {
reader_->offset(sector_data_start() + sector_size() * id); in_->seekg(sector_data_start() + sector_size() * id);
writer.append(*reader_, sector_size()); std::vector<byte> sector(sector_size(), 0);
in_->read(reinterpret_cast<char *>(sector.data()), sector_size());
writer.append(sector);
} }
template<typename T> template<typename T>
@ -300,12 +424,7 @@ sector_id compound_document::allocate_sector()
auto next_free = sector_id(next_free_iter - sat_.begin()); auto next_free = sector_id(next_free_iter - sat_.begin());
sat_[next_free] = EndOfChain; sat_[next_free] = EndOfChain;
auto next_free_msat_index = next_free / sectors_per_sector;; write_sat();
auto sat_index = msat_[next_free_msat_index];
writer_->offset(sector_data_start()
+ (sat_index * sector_size())
+ (next_free % sectors_per_sector) * sizeof(sector_id));
writer_->write(EndOfChain);
auto empty_sector = std::vector<byte>(sector_size()); auto empty_sector = std::vector<byte>(sector_size());
auto empty_sector_reader = binary_reader<byte>(empty_sector); auto empty_sector_reader = binary_reader<byte>(empty_sector);
@ -387,8 +506,7 @@ sector_id compound_document::allocate_short_sector()
sat_[ssat_chain.back()] = new_ssat_sector_id; sat_[ssat_chain.back()] = new_ssat_sector_id;
} }
writer_->offset(0); write_header();
writer_->write(header_);
auto old_size = ssat_.size(); auto old_size = ssat_.size();
ssat_.resize(old_size + sectors_per_sector, FreeSector); ssat_.resize(old_size + sectors_per_sector, FreeSector);
@ -402,14 +520,8 @@ sector_id compound_document::allocate_short_sector()
auto next_free = sector_id(next_free_iter - ssat_.begin()); auto next_free = sector_id(next_free_iter - ssat_.begin());
ssat_[next_free] = EndOfChain; ssat_[next_free] = EndOfChain;
auto sat_chain = follow_chain(header_.ssat_start, sat_); write_ssat();
auto next_free_sat_chain_index = next_free / sectors_per_sector;
auto sat_index = sat_chain[next_free_sat_chain_index];
writer_->offset(sector_data_start()
+ (sat_index * sectors_per_sector) * sizeof(sector_id)
+ (next_free % sectors_per_sector) * sizeof(sector_id));
writer_->write(EndOfChain);
const auto short_sectors_per_sector = sector_size() / short_sector_size(); const auto short_sectors_per_sector = sector_size() / short_sector_size();
const auto required_container_sectors = std::size_t(next_free / short_sectors_per_sector + 1); const auto required_container_sectors = std::size_t(next_free / short_sectors_per_sector + 1);
@ -419,8 +531,7 @@ sector_id compound_document::allocate_short_sector()
if (entries_[0].start < 0) if (entries_[0].start < 0)
{ {
entries_[0].start = allocate_sector(); entries_[0].start = allocate_sector();
writer_->offset(sector_data_start() + header_.directory_start * sector_size()); write_entry(0);
writer_->write(entries_[0]);
} }
auto container_chain = follow_chain(entries_[0].start, sat_); auto container_chain = follow_chain(entries_[0].start, sat_);
@ -428,6 +539,7 @@ sector_id compound_document::allocate_short_sector()
if (required_container_sectors > container_chain.size()) if (required_container_sectors > container_chain.size())
{ {
sat_[container_chain.back()] = allocate_sector(); sat_[container_chain.back()] = allocate_sector();
write_sat();
} }
} }
@ -454,17 +566,13 @@ directory_id compound_document::next_empty_entry()
/ sizeof(compound_document_entry); / sizeof(compound_document_entry);
auto new_sector = allocate_sector(); auto new_sector = allocate_sector();
// TODO: connect chains here // TODO: connect chains here
writer_->offset(sector_data_start() + new_sector * sector_size());
reader_->offset(sector_data_start() + new_sector * sector_size());
for (auto i = std::size_t(0); i < entries_per_sector; ++i) for (auto i = std::size_t(0); i < entries_per_sector; ++i)
{ {
auto empty_entry = compound_document_entry(); auto empty_entry = compound_document_entry();
empty_entry.type = compound_document_entry::entry_type::Empty; empty_entry.type = compound_document_entry::entry_type::Empty;
entries_.push_back(empty_entry);
writer_->write(empty_entry); write_entry(entry_id + directory_id(i));
entries_.push_back(reader_->read<compound_document_entry>());
} }
return entry_id; return entry_id;
@ -480,14 +588,7 @@ directory_id compound_document::insert_entry(
entry.name(name); entry.name(name);
entry.type = type; entry.type = type;
//TODO: move this to a "write_entry" function write_entry(entry_id);
auto directory_chain = follow_chain(header_.directory_start, sat_);
const auto entries_per_sector = sector_size() / sizeof(compound_document_entry);
auto entry_directory_sector = directory_chain[entry_id / entries_per_sector];
writer_->offset(sector_data_start()
+ entry_directory_sector * sector_size()
+ (entry_id % entries_per_sector) * sizeof(compound_document_entry));
writer_->write(entry);
// TODO: parse path from name and use correct parent storage instead of 0 // TODO: parse path from name and use correct parent storage instead of 0
tree_insert(entry_id, 0); tree_insert(entry_id, 0);
@ -542,18 +643,17 @@ void compound_document::print_directory()
} }
} }
void compound_document::tree_initialize_parent_maps() void compound_document::read_directory()
{ {
const auto entries_per_sector = sector_size() / sizeof(compound_document_entry); const auto entries_per_sector = sector_size() / sizeof(compound_document_entry);
auto entry_id = directory_id(0); auto entry_id = directory_id(0);
for (auto sector : follow_chain(header_.directory_start, sat_)) for (auto sector : follow_chain(header_.directory_start, sat_))
{ {
reader_->offset(sector_data_start() + sector * sector_size());
for (auto i = std::size_t(0); i < entries_per_sector; ++i) for (auto i = std::size_t(0); i < entries_per_sector; ++i)
{ {
entries_.push_back(reader_->read<compound_document_entry>()); entries_.push_back(compound_document_entry());
read_entry(entry_id++);
} }
} }
@ -843,8 +943,8 @@ compound_document_entry::entry_color &compound_document::tree_color(directory_id
void compound_document::read_header() void compound_document::read_header()
{ {
reader_->offset(0); in_->seekg(0, std::ios::beg);
header_ = reader_->read<compound_document_header>(); in_->read(reinterpret_cast<char *>(&header_), sizeof(compound_document_header));
} }
void compound_document::read_msat() void compound_document::read_msat()
@ -905,14 +1005,14 @@ void compound_document::read_entry(directory_id id)
const auto offset = sector_size() * directory_sector const auto offset = sector_size() * directory_sector
+ ((id % entries_per_sector) * sizeof(compound_document_entry)); + ((id % entries_per_sector) * sizeof(compound_document_entry));
reader_->offset(offset); in_->seekg(sector_data_start() + offset, std::ios::beg);
entries_[id] = reader_->read<compound_document_entry>(); in_->read(reinterpret_cast<char *>(&entries_[id]), sizeof(compound_document_entry));
} }
void compound_document::write_header() void compound_document::write_header()
{ {
writer_->offset(0); out_->seekp(0, std::ios::beg);
writer_->write<compound_document_header>(header_); out_->write(reinterpret_cast<char *>(&header_), sizeof(compound_document_header));
} }
void compound_document::write_msat() void compound_document::write_msat()
@ -968,8 +1068,8 @@ void compound_document::write_entry(directory_id id)
const auto offset = sector_size() * directory_sector const auto offset = sector_size() * directory_sector
+ ((id % entries_per_sector) * sizeof(compound_document_entry)); + ((id % entries_per_sector) * sizeof(compound_document_entry));
writer_->offset(offset); out_->seekp(offset, std::ios::beg);
writer_->write<compound_document_entry>(entries_[id]); out_->write(reinterpret_cast<char *>(&entries_[id]), sizeof(compound_document_entry));
} }
} // namespace detail } // namespace detail

View File

@ -113,15 +113,20 @@ struct compound_document_entry
std::uint32_t ignore2; std::uint32_t ignore2;
}; };
class compound_document_istreambuf;
class compound_document_ostreambuf;
class compound_document class compound_document
{ {
public: public:
compound_document(std::vector<std::uint8_t> &data); compound_document(std::istream &in);
compound_document(const std::vector<std::uint8_t> &data); compound_document(std::ostream &out);
~compound_document(); ~compound_document();
std::vector<std::uint8_t> read_stream(const std::string &filename); void close();
void write_stream(const std::string &filename, const std::vector<std::uint8_t> &data);
std::istream &open_read_stream(const std::string &filename);
std::ostream &open_write_stream(const std::string &filename);
private: private:
template<typename T> template<typename T>
@ -139,6 +144,7 @@ private:
void read_sat(); void read_sat();
void read_ssat(); void read_ssat();
void read_entry(directory_id id); void read_entry(directory_id id);
void read_directory();
void write_header(); void write_header();
void write_msat(); void write_msat();
@ -172,7 +178,6 @@ private:
compound_document_entry::entry_type type); compound_document_entry::entry_type type);
// Red black tree helper functions // Red black tree helper functions
void tree_initialize_parent_maps();
void tree_insert(directory_id new_id, directory_id storage_id); void tree_insert(directory_id new_id, directory_id storage_id);
void tree_insert_fixup(directory_id x); void tree_insert_fixup(directory_id x);
std::string tree_path(directory_id id); std::string tree_path(directory_id id);
@ -186,9 +191,6 @@ private:
std::string tree_key(directory_id id); std::string tree_key(directory_id id);
compound_document_entry::entry_color &tree_color(directory_id id); compound_document_entry::entry_color &tree_color(directory_id id);
std::unique_ptr<binary_writer<byte>> writer_;
std::unique_ptr<binary_reader<byte>> reader_;
compound_document_header header_; compound_document_header header_;
sector_chain msat_; sector_chain msat_;
sector_chain sat_; sector_chain sat_;
@ -197,6 +199,14 @@ private:
std::unordered_map<directory_id, directory_id> parent_storage_; std::unordered_map<directory_id, directory_id> parent_storage_;
std::unordered_map<directory_id, directory_id> parent_; std::unordered_map<directory_id, directory_id> parent_;
std::istream *in_;
std::ostream *out_;
std::unique_ptr<compound_document_istreambuf> stream_in_buffer_;
std::istream stream_in_;
std::unique_ptr<compound_document_ostreambuf> stream_out_buffer_;
std::ostream stream_out_;
}; };
} // namespace detail } // namespace detail

View File

@ -42,18 +42,25 @@
namespace { namespace {
using xlnt::detail::byte; using xlnt::detail::byte;
using xlnt::detail::binary_reader; using xlnt::detail::read;
using xlnt::detail::encryption_info; using xlnt::detail::encryption_info;
std::vector<std::uint8_t> decrypt_xlsx_standard( std::vector<std::uint8_t> decrypt_xlsx_standard(
encryption_info info, encryption_info info,
const std::vector<std::uint8_t> &encrypted_package) std::istream &encrypted_package_stream)
{ {
const auto key = info.calculate_key(); const auto key = info.calculate_key();
auto reader = binary_reader<byte>(encrypted_package); auto encrypted_package = std::vector<byte>(
auto decrypted_size = reader.read<std::uint64_t>(); std::istreambuf_iterator<char>(encrypted_package_stream),
auto decrypted = xlnt::detail::aes_ecb_decrypt(encrypted_package, key, reader.offset()); std::istreambuf_iterator<char>());
auto decrypted_size = read<std::uint64_t>(encrypted_package_stream);
auto decrypted = xlnt::detail::aes_ecb_decrypt(
encrypted_package,
key,
encrypted_package_stream.tellg());
decrypted.resize(static_cast<std::size_t>(decrypted_size)); decrypted.resize(static_cast<std::size_t>(decrypted_size));
return decrypted; return decrypted;
@ -61,10 +68,8 @@ std::vector<std::uint8_t> decrypt_xlsx_standard(
std::vector<std::uint8_t> decrypt_xlsx_agile( std::vector<std::uint8_t> decrypt_xlsx_agile(
const encryption_info &info, const encryption_info &info,
const std::vector<std::uint8_t> &encrypted_package) std::istream &encrypted_package_stream)
{ {
static const auto segment_length = std::size_t(4096);
const auto key = info.calculate_key(); const auto key = info.calculate_key();
auto salt_size = info.agile.key_data.salt_size; auto salt_size = info.agile.key_data.salt_size;
@ -72,23 +77,20 @@ std::vector<std::uint8_t> decrypt_xlsx_agile(
salt_with_block_key.resize(salt_size + sizeof(std::uint32_t), 0); salt_with_block_key.resize(salt_size + sizeof(std::uint32_t), 0);
auto &segment = *reinterpret_cast<std::uint32_t *>(salt_with_block_key.data() + salt_size); auto &segment = *reinterpret_cast<std::uint32_t *>(salt_with_block_key.data() + salt_size);
auto total_size = static_cast<std::size_t>(*reinterpret_cast<const std::uint64_t *>(encrypted_package.data())); auto total_size = read<std::uint64_t>(encrypted_package_stream);
std::vector<std::uint8_t> encrypted_segment(segment_length, 0); std::vector<std::uint8_t> encrypted_segment(4096, 0);
std::vector<std::uint8_t> decrypted_package; std::vector<std::uint8_t> decrypted_package;
decrypted_package.reserve(encrypted_package.size() - 8);
for (std::size_t i = 8; i < encrypted_package.size(); i += segment_length) while (encrypted_package_stream)
{ {
auto iv = hash(info.agile.key_encryptor.hash, salt_with_block_key); auto iv = hash(info.agile.key_encryptor.hash, salt_with_block_key);
iv.resize(16); iv.resize(16);
auto segment_begin = encrypted_package.begin() + static_cast<std::ptrdiff_t>(i); encrypted_package_stream.read(
auto current_segment_length = std::min(segment_length, encrypted_package.size() - i); reinterpret_cast<char *>(encrypted_segment.data()),
auto segment_end = encrypted_package.begin() + static_cast<std::ptrdiff_t>(i + current_segment_length); encrypted_segment.size());
encrypted_segment.assign(segment_begin, segment_end);
auto decrypted_segment = xlnt::detail::aes_cbc_decrypt(encrypted_segment, key, iv); auto decrypted_segment = xlnt::detail::aes_cbc_decrypt(encrypted_segment, key, iv);
decrypted_segment.resize(current_segment_length);
decrypted_package.insert( decrypted_package.insert(
decrypted_package.end(), decrypted_package.end(),
@ -103,21 +105,15 @@ std::vector<std::uint8_t> decrypt_xlsx_agile(
return decrypted_package; return decrypted_package;
} }
encryption_info::standard_encryption_info read_standard_encryption_info(const std::vector<std::uint8_t> &info_bytes) encryption_info::standard_encryption_info read_standard_encryption_info(std::istream &info_stream)
{ {
encryption_info::standard_encryption_info result; encryption_info::standard_encryption_info result;
auto reader = binary_reader<byte>(info_bytes); auto header_length = read<std::uint32_t>(info_stream);
auto index_at_start = info_stream.tellg();
// skip version info /*auto skip_flags = */ read<std::uint32_t>(info_stream);
reader.read<std::uint32_t>(); /*auto size_extra = */ read<std::uint32_t>(info_stream);
reader.read<std::uint32_t>(); auto alg_id = read<std::uint32_t>(info_stream);
auto header_length = reader.read<std::uint32_t>();
auto index_at_start = reader.offset();
/*auto skip_flags = */ reader.read<std::uint32_t>();
/*auto size_extra = */ reader.read<std::uint32_t>();
auto alg_id = reader.read<std::uint32_t>();
if (alg_id == 0 || alg_id == 0x0000660E || alg_id == 0x0000660F || alg_id == 0x00006610) if (alg_id == 0 || alg_id == 0x0000660E || alg_id == 0x0000660F || alg_id == 0x00006610)
{ {
@ -128,68 +124,50 @@ encryption_info::standard_encryption_info read_standard_encryption_info(const st
throw xlnt::exception("invalid cipher algorithm"); throw xlnt::exception("invalid cipher algorithm");
} }
auto alg_id_hash = reader.read<std::uint32_t>(); auto alg_id_hash = read<std::uint32_t>(info_stream);
if (alg_id_hash != 0x00008004 && alg_id_hash == 0) if (alg_id_hash != 0x00008004 && alg_id_hash == 0)
{ {
throw xlnt::exception("invalid hash algorithm"); throw xlnt::exception("invalid hash algorithm");
} }
result.key_bits = reader.read<std::uint32_t>(); result.key_bits = read<std::uint32_t>(info_stream);
result.key_bytes = result.key_bits / 8; result.key_bytes = result.key_bits / 8;
auto provider_type = reader.read<std::uint32_t>(); auto provider_type = read<std::uint32_t>(info_stream);
if (provider_type != 0 && provider_type != 0x00000018) if (provider_type != 0 && provider_type != 0x00000018)
{ {
throw xlnt::exception("invalid provider type"); throw xlnt::exception("invalid provider type");
} }
reader.read<std::uint32_t>(); // reserved 1 read<std::uint32_t>(info_stream); // reserved 1
if (reader.read<std::uint32_t>() != 0) // reserved 2 if (read<std::uint32_t>(info_stream) != 0) // reserved 2
{ {
throw xlnt::exception("invalid header"); throw xlnt::exception("invalid header");
} }
const auto csp_name_length = header_length - (reader.offset() - index_at_start); const auto csp_name_length = header_length - (info_stream.tellg() - index_at_start);
std::vector<std::uint16_t> csp_name_wide( auto csp_name = xlnt::detail::read_string<char16_t>(info_stream, csp_name_length);
reinterpret_cast<const std::uint16_t *>(&*(info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset()))), if (csp_name != u"Microsoft Enhanced RSA and AES Cryptographic Provider (Prototype)"
reinterpret_cast<const std::uint16_t *>( && csp_name != u"Microsoft Enhanced RSA and AES Cryptographic Provider")
&*(info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset() + csp_name_length))));
std::string csp_name(csp_name_wide.begin(), csp_name_wide.end() - 1); // without trailing null
if (csp_name != "Microsoft Enhanced RSA and AES Cryptographic Provider (Prototype)"
&& csp_name != "Microsoft Enhanced RSA and AES Cryptographic Provider")
{ {
throw xlnt::exception("invalid cryptographic provider"); throw xlnt::exception("invalid cryptographic provider");
} }
reader.offset(reader.offset() + csp_name_length); info_stream.seekg(csp_name_length);
const auto salt_size = reader.read<std::uint32_t>(); const auto salt_size = read<std::uint32_t>(info_stream);
result.salt = std::vector<std::uint8_t>( result.salt = xlnt::detail::read_vector<byte>(info_stream, salt_size);
info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset()),
info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset() + salt_size));
reader.offset(reader.offset() + salt_size);
static const auto verifier_size = std::size_t(16); static const auto verifier_size = std::size_t(16);
result.encrypted_verifier = std::vector<std::uint8_t>( result.encrypted_verifier = xlnt::detail::read_vector<byte>(info_stream, verifier_size);
info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset()),
info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset() + verifier_size));
reader.offset(reader.offset() + verifier_size);
/*const auto verifier_hash_size = */reader.read<std::uint32_t>(); /*const auto verifier_hash_size = */read<std::uint32_t>(info_stream);
const auto encrypted_verifier_hash_size = std::size_t(32); const auto encrypted_verifier_hash_size = std::size_t(32);
result.encrypted_verifier_hash = std::vector<std::uint8_t>( result.encrypted_verifier_hash = xlnt::detail::read_vector<byte>(info_stream, encrypted_verifier_hash_size);
info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset()),
info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset() + encrypted_verifier_hash_size));
reader.offset(reader.offset() + encrypted_verifier_hash_size);
if (reader.offset() != info_bytes.size())
{
throw xlnt::exception("extra data after encryption info");
}
return result; return result;
} }
encryption_info::agile_encryption_info read_agile_encryption_info(const std::vector<std::uint8_t> &info_bytes) encryption_info::agile_encryption_info read_agile_encryption_info(std::istream &info_stream)
{ {
using xlnt::detail::decode_base64; using xlnt::detail::decode_base64;
@ -199,8 +177,10 @@ encryption_info::agile_encryption_info read_agile_encryption_info(const std::vec
encryption_info::agile_encryption_info result; encryption_info::agile_encryption_info result;
auto header_size = std::size_t(8); auto xml_string = std::string(
xml::parser parser(info_bytes.data() + header_size, info_bytes.size() - header_size, "EncryptionInfo"); std::istreambuf_iterator<char>(info_stream),
std::istreambuf_iterator<char>());
xml::parser parser(xml_string.data(), xml_string.size(), "EncryptionInfo");
parser.next_expect(xml::parser::event_type::start_element, xmlns, "encryption"); parser.next_expect(xml::parser::event_type::start_element, xmlns, "encryption");
@ -269,17 +249,15 @@ encryption_info::agile_encryption_info read_agile_encryption_info(const std::vec
return result; return result;
} }
encryption_info read_encryption_info(const std::vector<std::uint8_t> &info_bytes, const std::u16string &password) encryption_info read_encryption_info(std::istream &info_stream, const std::u16string &password)
{ {
encryption_info info; encryption_info info;
info.password = password; info.password = password;
auto reader = binary_reader<byte>(info_bytes); auto version_major = read<std::uint16_t>(info_stream);
auto version_minor = read<std::uint16_t>(info_stream);
auto version_major = reader.read<std::uint16_t>(); auto encryption_flags = read<std::uint32_t>(info_stream);
auto version_minor = reader.read<std::uint16_t>();
auto encryption_flags = reader.read<std::uint32_t>();
info.is_agile = version_major == 4 && version_minor == 4; info.is_agile = version_major == 4 && version_minor == 4;
@ -290,7 +268,7 @@ encryption_info read_encryption_info(const std::vector<std::uint8_t> &info_bytes
throw xlnt::exception("bad header"); throw xlnt::exception("bad header");
} }
info.agile = read_agile_encryption_info(info_bytes); info.agile = read_agile_encryption_info(info_stream);
} }
else else
{ {
@ -315,7 +293,7 @@ encryption_info read_encryption_info(const std::vector<std::uint8_t> &info_bytes
throw xlnt::exception("not an OOXML document"); throw xlnt::exception("not an OOXML document");
} }
info.standard = read_standard_encryption_info(info_bytes); info.standard = read_standard_encryption_info(info_stream);
} }
return info; return info;
@ -330,15 +308,16 @@ std::vector<std::uint8_t> decrypt_xlsx(
throw xlnt::exception("empty file"); throw xlnt::exception("empty file");
} }
xlnt::detail::compound_document document(bytes); xlnt::detail::vector_istreambuf buffer(bytes);
std::istream stream(&buffer);
xlnt::detail::compound_document document(stream);
auto encryption_info = read_encryption_info( auto &encryption_info_stream = document.open_read_stream("EncryptionInfo");
document.read_stream("EncryptionInfo"), password); auto encryption_info = read_encryption_info(encryption_info_stream, password);
auto encrypted_package = document.read_stream("EncryptedPackage");
return encryption_info.is_agile return encryption_info.is_agile
? decrypt_xlsx_agile(encryption_info, encrypted_package) ? decrypt_xlsx_agile(encryption_info, document.open_read_stream("EncryptedPackage"))
: decrypt_xlsx_standard(encryption_info, encrypted_package); : decrypt_xlsx_standard(encryption_info, document.open_read_stream("EncryptedPackage"));
} }
} // namespace } // namespace

View File

@ -105,8 +105,9 @@ encryption_info generate_encryption_info(const std::u16string &/*password*/)
return result; return result;
} }
std::vector<std::uint8_t> write_agile_encryption_info( void write_agile_encryption_info(
const encryption_info &info) const encryption_info &info,
std::ostream &info_stream)
{ {
static const auto &xmlns = xlnt::constants::ns("encryption"); static const auto &xmlns = xlnt::constants::ns("encryption");
static const auto &xmlns_p = xlnt::constants::ns("encryption-password"); static const auto &xmlns_p = xlnt::constants::ns("encryption-password");
@ -166,10 +167,10 @@ std::vector<std::uint8_t> write_agile_encryption_info(
serializer.end_element(xmlns, "encryption"); serializer.end_element(xmlns, "encryption");
return encryption_info; info_stream.write(reinterpret_cast<char *>(encryption_info.data()), encryption_info.size());
} }
std::vector<std::uint8_t> write_standard_encryption_info(const encryption_info &info) void write_standard_encryption_info(const encryption_info &info, std::ostream &info_stream)
{ {
auto result = std::vector<std::uint8_t>(); auto result = std::vector<std::uint8_t>();
auto writer = xlnt::detail::binary_writer<std::uint8_t>(result); auto writer = xlnt::detail::binary_writer<std::uint8_t>(result);
@ -205,15 +206,15 @@ std::vector<std::uint8_t> write_standard_encryption_info(const encryption_info &
writer.write(std::uint32_t(20)); writer.write(std::uint32_t(20));
writer.append(info.standard.encrypted_verifier_hash); writer.append(info.standard.encrypted_verifier_hash);
return result; info_stream.write(reinterpret_cast<char *>(result.data()), result.size());
} }
std::vector<std::uint8_t> encrypt_xlsx_agile( void encrypt_xlsx_agile(
const encryption_info &info, const encryption_info &info,
const std::vector<std::uint8_t> &plaintext) std::ostream &plaintext)
{ {
auto key = info.calculate_key(); auto key = info.calculate_key();
/*
auto padded = plaintext; auto padded = plaintext;
padded.resize((plaintext.size() / 16 + (plaintext.size() % 16 == 0 ? 0 : 1)) * 16); padded.resize((plaintext.size() / 16 + (plaintext.size() % 16 == 0 ? 0 : 1)) * 16);
auto ciphertext = xlnt::detail::aes_ecb_encrypt(padded, key); auto ciphertext = xlnt::detail::aes_ecb_encrypt(padded, key);
@ -221,16 +222,15 @@ std::vector<std::uint8_t> encrypt_xlsx_agile(
ciphertext.insert(ciphertext.begin(), ciphertext.insert(ciphertext.begin(),
reinterpret_cast<const std::uint8_t *>(&length), reinterpret_cast<const std::uint8_t *>(&length),
reinterpret_cast<const std::uint8_t *>(&length + sizeof(std::uint64_t))); reinterpret_cast<const std::uint8_t *>(&length + sizeof(std::uint64_t)));
*/
return ciphertext;
} }
std::vector<std::uint8_t> encrypt_xlsx_standard( void encrypt_xlsx_standard(
const encryption_info &info, const encryption_info &info,
const std::vector<std::uint8_t> &plaintext) std::ostream &plaintext)
{ {
auto key = info.calculate_key(); auto key = info.calculate_key();
/*
auto padded = plaintext; auto padded = plaintext;
padded.resize((plaintext.size() / 16 + (plaintext.size() % 16 == 0 ? 0 : 1)) * 16); padded.resize((plaintext.size() / 16 + (plaintext.size() % 16 == 0 ? 0 : 1)) * 16);
auto ciphertext = xlnt::detail::aes_ecb_encrypt(padded, key); auto ciphertext = xlnt::detail::aes_ecb_encrypt(padded, key);
@ -238,8 +238,7 @@ std::vector<std::uint8_t> encrypt_xlsx_standard(
ciphertext.insert(ciphertext.begin(), ciphertext.insert(ciphertext.begin(),
reinterpret_cast<const std::uint8_t *>(&length), reinterpret_cast<const std::uint8_t *>(&length),
reinterpret_cast<const std::uint8_t *>(&length + sizeof(std::uint64_t))); reinterpret_cast<const std::uint8_t *>(&length + sizeof(std::uint64_t)));
*/
return ciphertext;
} }
std::vector<std::uint8_t> encrypt_xlsx( std::vector<std::uint8_t> encrypt_xlsx(
@ -250,14 +249,20 @@ std::vector<std::uint8_t> encrypt_xlsx(
encryption_info.password = u"secret"; encryption_info.password = u"secret";
auto ciphertext = std::vector<std::uint8_t>(); auto ciphertext = std::vector<std::uint8_t>();
xlnt::detail::compound_document document(ciphertext); xlnt::detail::vector_ostreambuf buffer(ciphertext);
std::ostream stream(&buffer);
xlnt::detail::compound_document document(stream);
document.write_stream("EncryptionInfo", encryption_info.is_agile if (encryption_info.is_agile)
? write_agile_encryption_info(encryption_info) {
: write_standard_encryption_info(encryption_info)); write_agile_encryption_info(encryption_info, document.open_write_stream("/EncryptionInfo"));
document.write_stream("EncryptedPackage", encryption_info.is_agile encrypt_xlsx_agile(encryption_info, document.open_write_stream("/EncryptedPackage"));
? encrypt_xlsx_agile(encryption_info, plaintext) }
: encrypt_xlsx_standard(encryption_info, plaintext)); else
{
write_standard_encryption_info(encryption_info, document.open_write_stream("/EncryptionInfo"));
encrypt_xlsx_standard(encryption_info, document.open_write_stream("/EncryptedPackage"));
}
return ciphertext; return ciphertext;
} }

View File

@ -197,7 +197,7 @@ public:
strm.avail_in = 0; strm.avail_in = 0;
strm.next_in = Z_NULL; strm.next_in = Z_NULL;
setg(in.data(), in.data(), in.data()); setg(in.data(), in.data(), in.data() + buffer_size);
setp(0, 0); setp(0, 0);
// skip the header // skip the header

View File

@ -73,19 +73,6 @@ void print_summary()
int main() int main()
{ {
std::ifstream file("C:/Users/Thomas/Development/xlnt/tests/data/6_encrypted_libre.xlsx", std::ios::binary);
const auto bytes2 = xlnt::detail::to_vector(file);
xlnt::detail::compound_document doc2(bytes2);
auto info = doc2.read_stream("/EncryptionInfo");
std::vector<std::uint8_t> bytes;
xlnt::detail::compound_document doc(bytes);
doc.write_stream("aaa", std::vector<std::uint8_t>(4095, 'a'));
doc.write_stream("bbb", std::vector<std::uint8_t>(4095, 'b'));
doc.write_stream("ccc", std::vector<std::uint8_t>(4095, 'c'));
std::ofstream file2("cd.xlsx", std::ios::binary);
xlnt::detail::to_stream(bytes, file2);
// cell // cell
run_tests<cell_test_suite>(); run_tests<cell_test_suite>();
run_tests<index_types_test_suite>(); run_tests<index_types_test_suite>();