diff --git a/source/detail/binary.hpp b/source/detail/binary.hpp index 5d85cb0a..26f06e49 100644 --- a/source/detail/binary.hpp +++ b/source/detail/binary.hpp @@ -293,5 +293,32 @@ std::vector string_to_bytes(const std::basic_string &string) return bytes; } +template +T read(std::istream &in) +{ + T result; + in.read(reinterpret_cast(&result), sizeof(T)); + + return result; +} + +template +std::vector read_vector(std::istream &in, std::size_t count) +{ + std::vector result(count, T()); + in.read(reinterpret_cast(&result[0]), sizeof(T) * count); + + return result; +} + +template +std::basic_string read_string(std::istream &in, std::size_t count) +{ + std::basic_string result(count, T()); + in.read(reinterpret_cast(&result[0]), sizeof(T) * count); + + return result; +} + } // namespace detail } // namespace xlnt diff --git a/source/detail/cryptography/compound_document.cpp b/source/detail/cryptography/compound_document.cpp index ce476d2b..eb94cbdf 100644 --- a/source/detail/cryptography/compound_document.cpp +++ b/source/detail/cryptography/compound_document.cpp @@ -76,77 +76,246 @@ const directory_id End = -1; namespace xlnt { namespace detail { -compound_document::compound_document(std::vector &data) - : writer_(new binary_writer(data)), - reader_(new binary_reader(data)) +/// +/// Allows a std::vector to be read through a std::istream. +/// +class compound_document_istreambuf : public std::streambuf { - header_.msat.fill(FreeSector); - writer_->offset(0); - writer_->write(header_); + using int_type = std::streambuf::int_type; - header_.directory_start = allocate_sector(); - writer_->offset(0); - writer_->write(header_); - - insert_entry("Root Entry", compound_document_entry::entry_type::RootStorage); -} - -compound_document::compound_document(const std::vector &data) - : writer_(nullptr), - reader_(new binary_reader(data)) -{ - header_ = reader_->read(); - - // read msat - auto current = header_.extra_msat_start; - - for (auto i = std::size_t(0); i < header_.num_msat_sectors; ++i) +public: + compound_document_istreambuf(const std::string &filename) + : data_(filename.begin(), filename.end()), + position_(0) { - if (i < 109) + } + + compound_document_istreambuf(const compound_document_istreambuf &) = delete; + compound_document_istreambuf &operator=(const compound_document_istreambuf &) = delete; + +private: + int_type underflow() + { + if (position_ == data_.size()) { - msat_.push_back(header_.msat[i]); + return traits_type::eof(); + } + + return traits_type::to_int_type(static_cast(data_[position_])); + } + + int_type uflow() + { + if (position_ == data_.size()) + { + return traits_type::eof(); + } + + return traits_type::to_int_type(static_cast(data_[position_++])); + } + + std::streamsize showmanyc() + { + if (position_ == data_.size()) + { + return static_cast(-1); + } + + return static_cast(data_.size() - position_); + } + + std::streampos seekoff(std::streamoff off, std::ios_base::seekdir way, std::ios_base::openmode) + { + if (way == std::ios_base::beg) + { + position_ = 0; + } + else if (way == std::ios_base::end) + { + position_ = data_.size(); + } + + if (off < 0) + { + if (static_cast(-off) > position_) + { + position_ = 0; + return static_cast(-1); + } + else + { + position_ -= static_cast(-off); + } + } + else if (off > 0) + { + if (static_cast(off) + position_ > data_.size()) + { + position_ = data_.size(); + return static_cast(-1); + } + else + { + position_ += static_cast(off); + } + } + + return static_cast(position_); + } + + std::streampos seekpos(std::streampos sp, std::ios_base::openmode) + { + if (sp < 0) + { + position_ = 0; + } + else if (static_cast(sp) > data_.size()) + { + position_ = data_.size(); } else { - auto extra_msat_sector = std::vector(); - auto extra_msat_sector_writer = binary_writer(extra_msat_sector); - - read_sector(current, extra_msat_sector_writer); - - std::copy(extra_msat_sector.begin(), - extra_msat_sector.end() - 1, - std::back_inserter(msat_)); - current = extra_msat_sector.back(); + position_ = static_cast(sp); } - } - - for (auto sat_sector_id : msat_) - { - if (sat_sector_id < 0) continue; - - auto sat_sector = std::vector(); - auto sat_sector_writer = binary_writer(sat_sector); - - read_sector(sat_sector_id, sat_sector_writer); - - std::copy(sat_sector.begin(), - sat_sector.end(), - std::back_inserter(sat_)); + + return static_cast(position_); } - for (auto ssat_sector_id : follow_chain(header_.ssat_start, sat_)) +private: + std::vector data_; + std::size_t position_; +}; + +/// +/// Allows a std::vector to be written through a std::ostream. +/// +class compound_document_ostreambuf : public std::streambuf +{ + using int_type = std::streambuf::int_type; + +public: + compound_document_ostreambuf(const std::string &filename) + : data_(filename.begin(), filename.end()), + position_(0) { - auto ssat_sector = std::vector(); - auto ssat_sector_writer = binary_writer(ssat_sector); - - read_sector(ssat_sector_id, ssat_sector_writer); - - std::copy(ssat_sector.begin(), - ssat_sector.end(), - std::back_inserter(ssat_)); } - tree_initialize_parent_maps(); + compound_document_ostreambuf(const compound_document_ostreambuf &) = delete; + compound_document_ostreambuf &operator=(const compound_document_ostreambuf &) = delete; + +private: + int_type overflow(int_type c = traits_type::eof()) + { + if (c != traits_type::eof()) + { + data_.push_back(static_cast(c)); + position_ = data_.size() - 1; + } + + return traits_type::to_int_type(static_cast(data_[position_])); + } + + std::streamsize xsputn(const char *s, std::streamsize n) + { + if (data_.empty()) + { + data_.resize(static_cast(n)); + } + else + { + auto position_size = data_.size(); + auto required_size = static_cast(position_ + static_cast(n)); + data_.resize(std::max(position_size, required_size)); + } + + std::copy(s, s + n, data_.begin() + static_cast(position_)); + position_ += static_cast(n); + + return n; + } + + std::streampos seekoff(std::streamoff off, std::ios_base::seekdir way, std::ios_base::openmode) + { + if (way == std::ios_base::beg) + { + position_ = 0; + } + else if (way == std::ios_base::end) + { + position_ = data_.size(); + } + + if (off < 0) + { + if (static_cast(-off) > position_) + { + position_ = 0; + return static_cast(-1); + } + else + { + position_ -= static_cast(-off); + } + } + else if (off > 0) + { + if (static_cast(off) + position_ > data_.size()) + { + position_ = data_.size(); + return static_cast(-1); + } + else + { + position_ += static_cast(off); + } + } + + return static_cast(position_); + } + + std::streampos seekpos(std::streampos sp, std::ios_base::openmode) + { + if (sp < 0) + { + position_ = 0; + } + else if (static_cast(sp) > data_.size()) + { + position_ = data_.size(); + } + else + { + position_ = static_cast(sp); + } + + return static_cast(position_); + } + +private: + std::vector data_; + std::size_t position_; +}; + + +compound_document::compound_document(std::ostream &out) + : out_(&out), + stream_in_(nullptr), + stream_out_(nullptr) +{ + write_header(); + insert_entry("Root Entry", compound_document_entry::entry_type::RootStorage); +} + +compound_document::compound_document(std::istream &in) + : in_(&in), + stream_in_(nullptr), + stream_out_(nullptr) +{ + read_header(); + read_msat(); + read_sat(); + read_ssat(); + read_directory(); } compound_document::~compound_document() @@ -163,84 +332,36 @@ std::size_t compound_document::short_sector_size() return static_cast(1) << header_.short_sector_size_power; } -std::vector compound_document::read_stream(const std::string &name) +std::istream &compound_document::open_read_stream(const std::string &name) { const auto entry_id = find_entry(name, compound_document_entry::entry_type::UserStream); const auto &entry = entries_.at(entry_id); - auto stream_data = std::vector(); - auto stream_data_writer = binary_writer(stream_data); + stream_in_buffer_.reset(new compound_document_istreambuf(name)); + stream_in_.rdbuf(stream_in_buffer_.get()); - if (entry.size < header_.threshold) - { - for (auto sector : follow_chain(entry.start, ssat_)) - { - read_short_sector(sector, stream_data_writer); - } - } - else - { - for (auto sector : follow_chain(entry.start, sat_)) - { - read_sector(sector, stream_data_writer); - } - } - stream_data.resize(entry.size); - - return stream_data; + return stream_in_; } -void compound_document::write_stream(const std::string &name, const std::vector &data) +std::ostream &compound_document::open_write_stream(const std::string &name) { auto entry_id = contains_entry(name, compound_document_entry::entry_type::UserStream) ? find_entry(name, compound_document_entry::entry_type::UserStream) : insert_entry(name, compound_document_entry::entry_type::UserStream); auto &entry = entries_.at(entry_id); - entry.size = static_cast(data.size()); - auto stream_data_reader = binary_reader(data); + stream_out_buffer_.reset(new compound_document_ostreambuf(name)); + stream_out_.rdbuf(stream_out_buffer_.get()); - if (entry.size < header_.threshold) - { - const auto num_sectors = data.size() / short_sector_size() - + (data.size() % short_sector_size() ? 1 : 0); - - auto chain = allocate_short_sectors(num_sectors); - entry.start = chain.front(); - - for (auto sector : follow_chain(entry.start, ssat_)) - { - write_short_sector(stream_data_reader, sector); - } - } - else - { - const auto num_sectors = data.size() / short_sector_size() - + (data.size() % short_sector_size() ? 1 : 0); - - auto chain = allocate_sectors(num_sectors); - entry.start = chain.front(); - - for (auto sector : follow_chain(entry.start, sat_)) - { - write_sector(stream_data_reader, sector); - } - } - - auto directory_chain = follow_chain(header_.directory_start, sat_); - const auto entries_per_sector = sector_size() / sizeof(compound_document_entry); - auto entry_directory_sector = directory_chain[entry_id / entries_per_sector]; - writer_->offset(sector_data_start() - + entry_directory_sector * sector_size() - + (entry_id % entries_per_sector) * sizeof(compound_document_entry)); - writer_->write(entry); + return stream_out_; } template void compound_document::write_sector(binary_reader &reader, sector_id id) { - writer_->offset(sector_data_start() + sector_size() * id); - writer_->append(reader, std::min(sector_size(), reader.bytes()) / sizeof(T)); + out_->seekp(sector_data_start() + sector_size() * id); + out_->write(reinterpret_cast(reader.data() + reader.offset()), + std::min(sector_size(), reader.bytes() - reader.offset())); } template @@ -249,15 +370,18 @@ void compound_document::write_short_sector(binary_reader &reader, sector_id i auto chain = follow_chain(entries_[0].start, sat_); auto sector_id = chain[id / (sector_size() / short_sector_size())]; auto sector_offset = id % (sector_size() / short_sector_size()) * short_sector_size(); - writer_->offset(sector_data_start() + sector_size() * sector_id + sector_offset); - writer_->append(reader, std::min(short_sector_size() / sizeof(T), reader.count() - reader.offset())); + out_->seekp(sector_data_start() + sector_size() * sector_id + sector_offset); + out_->write(reinterpret_cast(reader.data() + reader.offset()), + std::min(short_sector_size(), reader.bytes() - reader.offset())); } template void compound_document::read_sector(sector_id id, binary_writer &writer) { - reader_->offset(sector_data_start() + sector_size() * id); - writer.append(*reader_, sector_size()); + in_->seekg(sector_data_start() + sector_size() * id); + std::vector sector(sector_size(), 0); + in_->read(reinterpret_cast(sector.data()), sector_size()); + writer.append(sector); } template @@ -300,12 +424,7 @@ sector_id compound_document::allocate_sector() auto next_free = sector_id(next_free_iter - sat_.begin()); sat_[next_free] = EndOfChain; - auto next_free_msat_index = next_free / sectors_per_sector;; - auto sat_index = msat_[next_free_msat_index]; - writer_->offset(sector_data_start() - + (sat_index * sector_size()) - + (next_free % sectors_per_sector) * sizeof(sector_id)); - writer_->write(EndOfChain); + write_sat(); auto empty_sector = std::vector(sector_size()); auto empty_sector_reader = binary_reader(empty_sector); @@ -387,8 +506,7 @@ sector_id compound_document::allocate_short_sector() sat_[ssat_chain.back()] = new_ssat_sector_id; } - writer_->offset(0); - writer_->write(header_); + write_header(); auto old_size = ssat_.size(); ssat_.resize(old_size + sectors_per_sector, FreeSector); @@ -402,14 +520,8 @@ sector_id compound_document::allocate_short_sector() auto next_free = sector_id(next_free_iter - ssat_.begin()); ssat_[next_free] = EndOfChain; - - auto sat_chain = follow_chain(header_.ssat_start, sat_); - auto next_free_sat_chain_index = next_free / sectors_per_sector; - auto sat_index = sat_chain[next_free_sat_chain_index]; - writer_->offset(sector_data_start() - + (sat_index * sectors_per_sector) * sizeof(sector_id) - + (next_free % sectors_per_sector) * sizeof(sector_id)); - writer_->write(EndOfChain); + + write_ssat(); const auto short_sectors_per_sector = sector_size() / short_sector_size(); const auto required_container_sectors = std::size_t(next_free / short_sectors_per_sector + 1); @@ -419,8 +531,7 @@ sector_id compound_document::allocate_short_sector() if (entries_[0].start < 0) { entries_[0].start = allocate_sector(); - writer_->offset(sector_data_start() + header_.directory_start * sector_size()); - writer_->write(entries_[0]); + write_entry(0); } auto container_chain = follow_chain(entries_[0].start, sat_); @@ -428,6 +539,7 @@ sector_id compound_document::allocate_short_sector() if (required_container_sectors > container_chain.size()) { sat_[container_chain.back()] = allocate_sector(); + write_sat(); } } @@ -454,17 +566,13 @@ directory_id compound_document::next_empty_entry() / sizeof(compound_document_entry); auto new_sector = allocate_sector(); // TODO: connect chains here - writer_->offset(sector_data_start() + new_sector * sector_size()); - reader_->offset(sector_data_start() + new_sector * sector_size()); for (auto i = std::size_t(0); i < entries_per_sector; ++i) { auto empty_entry = compound_document_entry(); empty_entry.type = compound_document_entry::entry_type::Empty; - - writer_->write(empty_entry); - - entries_.push_back(reader_->read()); + entries_.push_back(empty_entry); + write_entry(entry_id + directory_id(i)); } return entry_id; @@ -480,14 +588,7 @@ directory_id compound_document::insert_entry( entry.name(name); entry.type = type; - //TODO: move this to a "write_entry" function - auto directory_chain = follow_chain(header_.directory_start, sat_); - const auto entries_per_sector = sector_size() / sizeof(compound_document_entry); - auto entry_directory_sector = directory_chain[entry_id / entries_per_sector]; - writer_->offset(sector_data_start() - + entry_directory_sector * sector_size() - + (entry_id % entries_per_sector) * sizeof(compound_document_entry)); - writer_->write(entry); + write_entry(entry_id); // TODO: parse path from name and use correct parent storage instead of 0 tree_insert(entry_id, 0); @@ -542,18 +643,17 @@ void compound_document::print_directory() } } -void compound_document::tree_initialize_parent_maps() +void compound_document::read_directory() { const auto entries_per_sector = sector_size() / sizeof(compound_document_entry); auto entry_id = directory_id(0); for (auto sector : follow_chain(header_.directory_start, sat_)) { - reader_->offset(sector_data_start() + sector * sector_size()); - for (auto i = std::size_t(0); i < entries_per_sector; ++i) { - entries_.push_back(reader_->read()); + entries_.push_back(compound_document_entry()); + read_entry(entry_id++); } } @@ -843,8 +943,8 @@ compound_document_entry::entry_color &compound_document::tree_color(directory_id void compound_document::read_header() { - reader_->offset(0); - header_ = reader_->read(); + in_->seekg(0, std::ios::beg); + in_->read(reinterpret_cast(&header_), sizeof(compound_document_header)); } void compound_document::read_msat() @@ -905,14 +1005,14 @@ void compound_document::read_entry(directory_id id) const auto offset = sector_size() * directory_sector + ((id % entries_per_sector) * sizeof(compound_document_entry)); - reader_->offset(offset); - entries_[id] = reader_->read(); + in_->seekg(sector_data_start() + offset, std::ios::beg); + in_->read(reinterpret_cast(&entries_[id]), sizeof(compound_document_entry)); } void compound_document::write_header() { - writer_->offset(0); - writer_->write(header_); + out_->seekp(0, std::ios::beg); + out_->write(reinterpret_cast(&header_), sizeof(compound_document_header)); } void compound_document::write_msat() @@ -968,8 +1068,8 @@ void compound_document::write_entry(directory_id id) const auto offset = sector_size() * directory_sector + ((id % entries_per_sector) * sizeof(compound_document_entry)); - writer_->offset(offset); - writer_->write(entries_[id]); + out_->seekp(offset, std::ios::beg); + out_->write(reinterpret_cast(&entries_[id]), sizeof(compound_document_entry)); } } // namespace detail diff --git a/source/detail/cryptography/compound_document.hpp b/source/detail/cryptography/compound_document.hpp index 69748604..9f003b58 100644 --- a/source/detail/cryptography/compound_document.hpp +++ b/source/detail/cryptography/compound_document.hpp @@ -113,15 +113,20 @@ struct compound_document_entry std::uint32_t ignore2; }; +class compound_document_istreambuf; +class compound_document_ostreambuf; + class compound_document { public: - compound_document(std::vector &data); - compound_document(const std::vector &data); + compound_document(std::istream &in); + compound_document(std::ostream &out); ~compound_document(); - std::vector read_stream(const std::string &filename); - void write_stream(const std::string &filename, const std::vector &data); + void close(); + + std::istream &open_read_stream(const std::string &filename); + std::ostream &open_write_stream(const std::string &filename); private: template @@ -139,6 +144,7 @@ private: void read_sat(); void read_ssat(); void read_entry(directory_id id); + void read_directory(); void write_header(); void write_msat(); @@ -172,7 +178,6 @@ private: compound_document_entry::entry_type type); // Red black tree helper functions - void tree_initialize_parent_maps(); void tree_insert(directory_id new_id, directory_id storage_id); void tree_insert_fixup(directory_id x); std::string tree_path(directory_id id); @@ -186,9 +191,6 @@ private: std::string tree_key(directory_id id); compound_document_entry::entry_color &tree_color(directory_id id); - std::unique_ptr> writer_; - std::unique_ptr> reader_; - compound_document_header header_; sector_chain msat_; sector_chain sat_; @@ -197,6 +199,14 @@ private: std::unordered_map parent_storage_; std::unordered_map parent_; + + std::istream *in_; + std::ostream *out_; + + std::unique_ptr stream_in_buffer_; + std::istream stream_in_; + std::unique_ptr stream_out_buffer_; + std::ostream stream_out_; }; } // namespace detail diff --git a/source/detail/cryptography/xlsx_crypto_consumer.cpp b/source/detail/cryptography/xlsx_crypto_consumer.cpp index 7f8f4ea4..71eee198 100644 --- a/source/detail/cryptography/xlsx_crypto_consumer.cpp +++ b/source/detail/cryptography/xlsx_crypto_consumer.cpp @@ -42,18 +42,25 @@ namespace { using xlnt::detail::byte; -using xlnt::detail::binary_reader; +using xlnt::detail::read; using xlnt::detail::encryption_info; std::vector decrypt_xlsx_standard( encryption_info info, - const std::vector &encrypted_package) + std::istream &encrypted_package_stream) { const auto key = info.calculate_key(); - auto reader = binary_reader(encrypted_package); - auto decrypted_size = reader.read(); - auto decrypted = xlnt::detail::aes_ecb_decrypt(encrypted_package, key, reader.offset()); + auto encrypted_package = std::vector( + std::istreambuf_iterator(encrypted_package_stream), + std::istreambuf_iterator()); + auto decrypted_size = read(encrypted_package_stream); + + auto decrypted = xlnt::detail::aes_ecb_decrypt( + encrypted_package, + key, + encrypted_package_stream.tellg()); + decrypted.resize(static_cast(decrypted_size)); return decrypted; @@ -61,10 +68,8 @@ std::vector decrypt_xlsx_standard( std::vector decrypt_xlsx_agile( const encryption_info &info, - const std::vector &encrypted_package) + std::istream &encrypted_package_stream) { - static const auto segment_length = std::size_t(4096); - const auto key = info.calculate_key(); auto salt_size = info.agile.key_data.salt_size; @@ -72,23 +77,20 @@ std::vector decrypt_xlsx_agile( salt_with_block_key.resize(salt_size + sizeof(std::uint32_t), 0); auto &segment = *reinterpret_cast(salt_with_block_key.data() + salt_size); - auto total_size = static_cast(*reinterpret_cast(encrypted_package.data())); + auto total_size = read(encrypted_package_stream); - std::vector encrypted_segment(segment_length, 0); + std::vector encrypted_segment(4096, 0); std::vector decrypted_package; - decrypted_package.reserve(encrypted_package.size() - 8); - for (std::size_t i = 8; i < encrypted_package.size(); i += segment_length) + while (encrypted_package_stream) { auto iv = hash(info.agile.key_encryptor.hash, salt_with_block_key); iv.resize(16); - auto segment_begin = encrypted_package.begin() + static_cast(i); - auto current_segment_length = std::min(segment_length, encrypted_package.size() - i); - auto segment_end = encrypted_package.begin() + static_cast(i + current_segment_length); - encrypted_segment.assign(segment_begin, segment_end); + encrypted_package_stream.read( + reinterpret_cast(encrypted_segment.data()), + encrypted_segment.size()); auto decrypted_segment = xlnt::detail::aes_cbc_decrypt(encrypted_segment, key, iv); - decrypted_segment.resize(current_segment_length); decrypted_package.insert( decrypted_package.end(), @@ -103,21 +105,15 @@ std::vector decrypt_xlsx_agile( return decrypted_package; } -encryption_info::standard_encryption_info read_standard_encryption_info(const std::vector &info_bytes) +encryption_info::standard_encryption_info read_standard_encryption_info(std::istream &info_stream) { encryption_info::standard_encryption_info result; - auto reader = binary_reader(info_bytes); - - // skip version info - reader.read(); - reader.read(); - - auto header_length = reader.read(); - auto index_at_start = reader.offset(); - /*auto skip_flags = */ reader.read(); - /*auto size_extra = */ reader.read(); - auto alg_id = reader.read(); + auto header_length = read(info_stream); + auto index_at_start = info_stream.tellg(); + /*auto skip_flags = */ read(info_stream); + /*auto size_extra = */ read(info_stream); + auto alg_id = read(info_stream); if (alg_id == 0 || alg_id == 0x0000660E || alg_id == 0x0000660F || alg_id == 0x00006610) { @@ -128,68 +124,50 @@ encryption_info::standard_encryption_info read_standard_encryption_info(const st throw xlnt::exception("invalid cipher algorithm"); } - auto alg_id_hash = reader.read(); + auto alg_id_hash = read(info_stream); if (alg_id_hash != 0x00008004 && alg_id_hash == 0) { throw xlnt::exception("invalid hash algorithm"); } - result.key_bits = reader.read(); + result.key_bits = read(info_stream); result.key_bytes = result.key_bits / 8; - auto provider_type = reader.read(); + auto provider_type = read(info_stream); if (provider_type != 0 && provider_type != 0x00000018) { throw xlnt::exception("invalid provider type"); } - reader.read(); // reserved 1 - if (reader.read() != 0) // reserved 2 + read(info_stream); // reserved 1 + if (read(info_stream) != 0) // reserved 2 { throw xlnt::exception("invalid header"); } - const auto csp_name_length = header_length - (reader.offset() - index_at_start); - std::vector csp_name_wide( - reinterpret_cast(&*(info_bytes.begin() + static_cast(reader.offset()))), - reinterpret_cast( - &*(info_bytes.begin() + static_cast(reader.offset() + csp_name_length)))); - std::string csp_name(csp_name_wide.begin(), csp_name_wide.end() - 1); // without trailing null - if (csp_name != "Microsoft Enhanced RSA and AES Cryptographic Provider (Prototype)" - && csp_name != "Microsoft Enhanced RSA and AES Cryptographic Provider") + const auto csp_name_length = header_length - (info_stream.tellg() - index_at_start); + auto csp_name = xlnt::detail::read_string(info_stream, csp_name_length); + if (csp_name != u"Microsoft Enhanced RSA and AES Cryptographic Provider (Prototype)" + && csp_name != u"Microsoft Enhanced RSA and AES Cryptographic Provider") { throw xlnt::exception("invalid cryptographic provider"); } - reader.offset(reader.offset() + csp_name_length); + info_stream.seekg(csp_name_length); - const auto salt_size = reader.read(); - result.salt = std::vector( - info_bytes.begin() + static_cast(reader.offset()), - info_bytes.begin() + static_cast(reader.offset() + salt_size)); - reader.offset(reader.offset() + salt_size); + const auto salt_size = read(info_stream); + result.salt = xlnt::detail::read_vector(info_stream, salt_size); static const auto verifier_size = std::size_t(16); - result.encrypted_verifier = std::vector( - info_bytes.begin() + static_cast(reader.offset()), - info_bytes.begin() + static_cast(reader.offset() + verifier_size)); - reader.offset(reader.offset() + verifier_size); + result.encrypted_verifier = xlnt::detail::read_vector(info_stream, verifier_size); - /*const auto verifier_hash_size = */reader.read(); + /*const auto verifier_hash_size = */read(info_stream); const auto encrypted_verifier_hash_size = std::size_t(32); - result.encrypted_verifier_hash = std::vector( - info_bytes.begin() + static_cast(reader.offset()), - info_bytes.begin() + static_cast(reader.offset() + encrypted_verifier_hash_size)); - reader.offset(reader.offset() + encrypted_verifier_hash_size); - - if (reader.offset() != info_bytes.size()) - { - throw xlnt::exception("extra data after encryption info"); - } + result.encrypted_verifier_hash = xlnt::detail::read_vector(info_stream, encrypted_verifier_hash_size); return result; } -encryption_info::agile_encryption_info read_agile_encryption_info(const std::vector &info_bytes) +encryption_info::agile_encryption_info read_agile_encryption_info(std::istream &info_stream) { using xlnt::detail::decode_base64; @@ -199,8 +177,10 @@ encryption_info::agile_encryption_info read_agile_encryption_info(const std::vec encryption_info::agile_encryption_info result; - auto header_size = std::size_t(8); - xml::parser parser(info_bytes.data() + header_size, info_bytes.size() - header_size, "EncryptionInfo"); + auto xml_string = std::string( + std::istreambuf_iterator(info_stream), + std::istreambuf_iterator()); + xml::parser parser(xml_string.data(), xml_string.size(), "EncryptionInfo"); parser.next_expect(xml::parser::event_type::start_element, xmlns, "encryption"); @@ -269,17 +249,15 @@ encryption_info::agile_encryption_info read_agile_encryption_info(const std::vec return result; } -encryption_info read_encryption_info(const std::vector &info_bytes, const std::u16string &password) +encryption_info read_encryption_info(std::istream &info_stream, const std::u16string &password) { encryption_info info; info.password = password; - auto reader = binary_reader(info_bytes); - - auto version_major = reader.read(); - auto version_minor = reader.read(); - auto encryption_flags = reader.read(); + auto version_major = read(info_stream); + auto version_minor = read(info_stream); + auto encryption_flags = read(info_stream); info.is_agile = version_major == 4 && version_minor == 4; @@ -290,7 +268,7 @@ encryption_info read_encryption_info(const std::vector &info_bytes throw xlnt::exception("bad header"); } - info.agile = read_agile_encryption_info(info_bytes); + info.agile = read_agile_encryption_info(info_stream); } else { @@ -315,7 +293,7 @@ encryption_info read_encryption_info(const std::vector &info_bytes throw xlnt::exception("not an OOXML document"); } - info.standard = read_standard_encryption_info(info_bytes); + info.standard = read_standard_encryption_info(info_stream); } return info; @@ -330,15 +308,16 @@ std::vector decrypt_xlsx( throw xlnt::exception("empty file"); } - xlnt::detail::compound_document document(bytes); + xlnt::detail::vector_istreambuf buffer(bytes); + std::istream stream(&buffer); + xlnt::detail::compound_document document(stream); - auto encryption_info = read_encryption_info( - document.read_stream("EncryptionInfo"), password); - auto encrypted_package = document.read_stream("EncryptedPackage"); + auto &encryption_info_stream = document.open_read_stream("EncryptionInfo"); + auto encryption_info = read_encryption_info(encryption_info_stream, password); return encryption_info.is_agile - ? decrypt_xlsx_agile(encryption_info, encrypted_package) - : decrypt_xlsx_standard(encryption_info, encrypted_package); + ? decrypt_xlsx_agile(encryption_info, document.open_read_stream("EncryptedPackage")) + : decrypt_xlsx_standard(encryption_info, document.open_read_stream("EncryptedPackage")); } } // namespace diff --git a/source/detail/cryptography/xlsx_crypto_producer.cpp b/source/detail/cryptography/xlsx_crypto_producer.cpp index bebadc4e..9cb25ff2 100644 --- a/source/detail/cryptography/xlsx_crypto_producer.cpp +++ b/source/detail/cryptography/xlsx_crypto_producer.cpp @@ -105,8 +105,9 @@ encryption_info generate_encryption_info(const std::u16string &/*password*/) return result; } -std::vector write_agile_encryption_info( - const encryption_info &info) +void write_agile_encryption_info( + const encryption_info &info, + std::ostream &info_stream) { static const auto &xmlns = xlnt::constants::ns("encryption"); static const auto &xmlns_p = xlnt::constants::ns("encryption-password"); @@ -166,10 +167,10 @@ std::vector write_agile_encryption_info( serializer.end_element(xmlns, "encryption"); - return encryption_info; + info_stream.write(reinterpret_cast(encryption_info.data()), encryption_info.size()); } -std::vector write_standard_encryption_info(const encryption_info &info) +void write_standard_encryption_info(const encryption_info &info, std::ostream &info_stream) { auto result = std::vector(); auto writer = xlnt::detail::binary_writer(result); @@ -205,15 +206,15 @@ std::vector write_standard_encryption_info(const encryption_info & writer.write(std::uint32_t(20)); writer.append(info.standard.encrypted_verifier_hash); - return result; + info_stream.write(reinterpret_cast(result.data()), result.size()); } -std::vector encrypt_xlsx_agile( +void encrypt_xlsx_agile( const encryption_info &info, - const std::vector &plaintext) + std::ostream &plaintext) { auto key = info.calculate_key(); - + /* auto padded = plaintext; padded.resize((plaintext.size() / 16 + (plaintext.size() % 16 == 0 ? 0 : 1)) * 16); auto ciphertext = xlnt::detail::aes_ecb_encrypt(padded, key); @@ -221,16 +222,15 @@ std::vector encrypt_xlsx_agile( ciphertext.insert(ciphertext.begin(), reinterpret_cast(&length), reinterpret_cast(&length + sizeof(std::uint64_t))); - - return ciphertext; + */ } -std::vector encrypt_xlsx_standard( +void encrypt_xlsx_standard( const encryption_info &info, - const std::vector &plaintext) + std::ostream &plaintext) { auto key = info.calculate_key(); - + /* auto padded = plaintext; padded.resize((plaintext.size() / 16 + (plaintext.size() % 16 == 0 ? 0 : 1)) * 16); auto ciphertext = xlnt::detail::aes_ecb_encrypt(padded, key); @@ -238,8 +238,7 @@ std::vector encrypt_xlsx_standard( ciphertext.insert(ciphertext.begin(), reinterpret_cast(&length), reinterpret_cast(&length + sizeof(std::uint64_t))); - - return ciphertext; + */ } std::vector encrypt_xlsx( @@ -250,14 +249,20 @@ std::vector encrypt_xlsx( encryption_info.password = u"secret"; auto ciphertext = std::vector(); - xlnt::detail::compound_document document(ciphertext); + xlnt::detail::vector_ostreambuf buffer(ciphertext); + std::ostream stream(&buffer); + xlnt::detail::compound_document document(stream); - document.write_stream("EncryptionInfo", encryption_info.is_agile - ? write_agile_encryption_info(encryption_info) - : write_standard_encryption_info(encryption_info)); - document.write_stream("EncryptedPackage", encryption_info.is_agile - ? encrypt_xlsx_agile(encryption_info, plaintext) - : encrypt_xlsx_standard(encryption_info, plaintext)); + if (encryption_info.is_agile) + { + write_agile_encryption_info(encryption_info, document.open_write_stream("/EncryptionInfo")); + encrypt_xlsx_agile(encryption_info, document.open_write_stream("/EncryptedPackage")); + } + else + { + write_standard_encryption_info(encryption_info, document.open_write_stream("/EncryptionInfo")); + encrypt_xlsx_standard(encryption_info, document.open_write_stream("/EncryptedPackage")); + } return ciphertext; } diff --git a/source/detail/serialization/zstream.cpp b/source/detail/serialization/zstream.cpp index 95a8c30e..6fee8ff1 100644 --- a/source/detail/serialization/zstream.cpp +++ b/source/detail/serialization/zstream.cpp @@ -197,7 +197,7 @@ public: strm.avail_in = 0; strm.next_in = Z_NULL; - setg(in.data(), in.data(), in.data()); + setg(in.data(), in.data(), in.data() + buffer_size); setp(0, 0); // skip the header diff --git a/tests/runner.cpp b/tests/runner.cpp index bc7a14c6..b01ea246 100644 --- a/tests/runner.cpp +++ b/tests/runner.cpp @@ -73,19 +73,6 @@ void print_summary() int main() { - std::ifstream file("C:/Users/Thomas/Development/xlnt/tests/data/6_encrypted_libre.xlsx", std::ios::binary); - const auto bytes2 = xlnt::detail::to_vector(file); - xlnt::detail::compound_document doc2(bytes2); - auto info = doc2.read_stream("/EncryptionInfo"); - - std::vector bytes; - xlnt::detail::compound_document doc(bytes); - doc.write_stream("aaa", std::vector(4095, 'a')); - doc.write_stream("bbb", std::vector(4095, 'b')); - doc.write_stream("ccc", std::vector(4095, 'c')); - std::ofstream file2("cd.xlsx", std::ios::binary); - xlnt::detail::to_stream(bytes, file2); - // cell run_tests(); run_tests();