From ee642fc6c12851d729bf6695abf3d6ad36c470ff Mon Sep 17 00:00:00 2001 From: Thomas Fussell Date: Sun, 30 Apr 2017 20:21:47 -0400 Subject: [PATCH] getting there --- source/detail/binary.hpp | 1 + .../detail/cryptography/compound_document.cpp | 213 ++++++++---------- .../cryptography/xlsx_crypto_consumer.cpp | 18 +- .../detail/serialization/vector_streambuf.cpp | 11 + tests/runner.cpp | 21 -- 5 files changed, 115 insertions(+), 149 deletions(-) diff --git a/source/detail/binary.hpp b/source/detail/binary.hpp index 26f06e49..a4180f20 100644 --- a/source/detail/binary.hpp +++ b/source/detail/binary.hpp @@ -25,6 +25,7 @@ #include #include +#include #include #include diff --git a/source/detail/cryptography/compound_document.cpp b/source/detail/cryptography/compound_document.cpp index ea733ec9..5032a436 100644 --- a/source/detail/cryptography/compound_document.cpp +++ b/source/detail/cryptography/compound_document.cpp @@ -105,13 +105,15 @@ public: compound_document_istreambuf(const compound_document_entry &entry, compound_document &document) : entry_(entry), document_(document), - position_(0), - sector_writer_(current_sector_) + sector_writer_(current_sector_), + position_(0) { } compound_document_istreambuf(const compound_document_istreambuf &) = delete; compound_document_istreambuf &operator=(const compound_document_istreambuf &) = delete; + + virtual ~compound_document_istreambuf(); private: int_type underflow() @@ -220,6 +222,10 @@ private: std::size_t position_; }; +compound_document_istreambuf::~compound_document_istreambuf() +{ +} + /// /// Allows a std::vector to be written through a std::ostream. /// @@ -231,164 +237,119 @@ public: compound_document_ostreambuf(compound_document_entry &entry, compound_document &document) : entry_(entry), document_(document), - position_(0), sector_reader_(current_sector_), - current_sector_(sector_size(), 0), - chain_(document_.follow_chain(entry_.start, table())) + current_sector_(document.header_.threshold), + position_(0) { + setp(reinterpret_cast(current_sector_.data()), + reinterpret_cast(current_sector_.data() + current_sector_.size())); } compound_document_ostreambuf(const compound_document_ostreambuf &) = delete; compound_document_ostreambuf &operator=(const compound_document_ostreambuf &) = delete; - virtual ~compound_document_ostreambuf() - { - if (position_ % 64 != 0) - { - write_sector(); - } - } + virtual ~compound_document_ostreambuf(); private: + int sync() + { + auto written = pptr() - pbase(); + + if (written == 0) + { + return 0; + } + + sector_reader_.reset(); + + if (short_stream()) + { + if (position_ + written >= document_.header_.threshold) + { + convert_to_long_stream(); + } + else + { + if (entry_.start < 0) + { + auto num_sectors = (position_ + written + document_.short_sector_size() - 1) / document_.short_sector_size(); + chain_ = document_.allocate_short_sectors(num_sectors); + entry_.start = chain_.front(); + } + + for (auto link : chain_) + { + document_.write_short_sector(sector_reader_, link); + sector_reader_.offset(sector_reader_.offset() + document_.short_sector_size()); + } + } + } + else + { + const auto sector_index = position_ / document_.sector_size(); + document_.write_sector(sector_reader_, chain_[sector_index]); + } + + position_ += written; + entry_.size = std::max(entry_.size, static_cast(position_)); + + std::fill(current_sector_.begin(), current_sector_.end(), byte(0)); + setp(reinterpret_cast(current_sector_.data()), + reinterpret_cast(current_sector_.data() + current_sector_.size())); + + return 0; + } + bool short_stream() { return entry_.size < document_.header_.threshold; } - sector_chain &table() + int_type overflow(int_type c = traits_type::eof()) { - return short_stream() - ? document_.ssat_ - : document_.sat_; - } + sync(); - std::size_t sector_size() - { - return short_stream() - ? document_.short_sector_size() - : document_.short_sector_size(); - } - - void write_sector() - { if (short_stream()) { auto next_sector = document_.allocate_short_sector(); document_.ssat_[chain_.back()] = next_sector; chain_.push_back(next_sector); - document_.write_short_sector(sector_reader_, next_sector); } else { auto next_sector = document_.allocate_sector(); document_.sat_[chain_.back()] = next_sector; chain_.push_back(next_sector); - document_.write_sector(sector_reader_, next_sector); } - } - - int_type overflow(int_type c = traits_type::eof()) - { + auto value = static_cast(c); if (c != traits_type::eof()) { - current_sector_[position_ % sector_size()] = value; + current_sector_[position_ % current_sector_.size()] = value; } + + pbump(1); - if (entry_.start < 0) - { - entry_.start = entry_.size == document_.header_.threshold - ? document_.allocate_sector() - : document_.allocate_short_sector(); - chain_.push_back(entry_.start); - } - - if (position_ % 64 == 0 && position_ > 0) - { - write_sector(); - std::fill(current_sector_.begin(), current_sector_.end(), byte(0)); - } - - if (c != traits_type::eof()) - { - ++position_; - - auto previous_size = entry_.size; - entry_.size = std::max(entry_.size, static_cast(position_)); - - if (entry_.size >= document_.header_.threshold && previous_size < document_.header_.threshold) - { - convert_to_long_stream(); - } - - return traits_type::to_int_type(static_cast(value)); - } - else - { - return traits_type::eof(); - } + return traits_type::to_int_type(static_cast(value)); } void convert_to_long_stream() { - const auto sectors_per_sector = document_.sector_size() / document_.short_sector_size(); + sector_reader_.reset(); - current_sector_.resize(sector_size(), 0); + auto num_sectors = current_sector_.size() / document_.sector_size(); + auto new_chain = document_.allocate_sectors(num_sectors); + + for (auto link : new_chain) + { + document_.write_sector(sector_reader_, link); + sector_reader_.offset(sector_reader_.offset() + document_.short_sector_size()); + } + + current_sector_.resize(document_.sector_size(), 0); std::fill(current_sector_.begin(), current_sector_.end(), byte(0)); - auto sector_writer = binary_writer(current_sector_); - auto index = std::size_t(0); - auto long_chain = sector_chain(); - entry_.start = document_.allocate_sector(); - long_chain.push_back(entry_.start); - - for (auto link : chain_) - { - document_.read_short_sector(link, sector_writer); - document_.header_.num_short_sectors--; - document_.ssat_[link] = FreeSector; - - if (index % sectors_per_sector == 0 && index > 0) - { - document_.write_sector(sector_reader_, long_chain.back()); - auto next_sector = document_.allocate_sector(); - document_.sat_[long_chain.back()] = next_sector; - long_chain.push_back(next_sector); - } - } - - if (index % sectors_per_sector != 0) - { - document_.write_sector(sector_reader_, long_chain.back()); - } - - index = 0; - auto previous = sector_id(0); - - for (auto link : document_.follow_chain(document_.entries_[0].start, document_.sat_)) - { - auto ssat_index_start = document_.ssat_.begin() + index * sectors_per_sector; - auto ssat_index_end = document_.ssat_.begin() + (index + 1) * sectors_per_sector; - - if (std::size_t(std::count(ssat_index_start, ssat_index_end, FreeSector)) == sectors_per_sector) - { - if (index > 0) - { - document_.sat_[previous] = document_.sat_[link]; - } - else - { - document_.entries_[0].start = document_.sat_[link]; - } - - document_.sat_[link] = FreeSector; - } - - previous = link; - index++; - } - if (document_.header_.num_short_sectors == 0) { document_.entries_[0].start = EndOfChain; @@ -396,7 +357,8 @@ private: // TODO: deallocate short sectors here - chain_ = long_chain; + chain_ = new_chain; + entry_.start = chain_.front(); } std::streampos seekoff(std::streamoff off, std::ios_base::seekdir way, std::ios_base::openmode) @@ -465,6 +427,10 @@ private: sector_chain chain_; }; +compound_document_ostreambuf::~compound_document_ostreambuf() +{ + sync(); +} compound_document::compound_document(std::ostream &out) : out_(&out), @@ -509,6 +475,11 @@ std::size_t compound_document::short_sector_size() std::istream &compound_document::open_read_stream(const std::string &name) { + if (!contains_entry(name, compound_document_entry::entry_type::UserStream)) + { + throw xlnt::exception("not found"); + } + const auto entry_id = find_entry(name, compound_document_entry::entry_type::UserStream); const auto &entry = entries_.at(entry_id); @@ -662,6 +633,8 @@ sector_chain compound_document::allocate_sectors(std::size_t count) sat_[current] = next; current = next; } + + chain.push_back(current); return chain; } diff --git a/source/detail/cryptography/xlsx_crypto_consumer.cpp b/source/detail/cryptography/xlsx_crypto_consumer.cpp index 71eee198..e2afcfc8 100644 --- a/source/detail/cryptography/xlsx_crypto_consumer.cpp +++ b/source/detail/cryptography/xlsx_crypto_consumer.cpp @@ -145,14 +145,14 @@ encryption_info::standard_encryption_info read_standard_encryption_info(std::ist throw xlnt::exception("invalid header"); } - const auto csp_name_length = header_length - (info_stream.tellg() - index_at_start); - auto csp_name = xlnt::detail::read_string(info_stream, csp_name_length); - if (csp_name != u"Microsoft Enhanced RSA and AES Cryptographic Provider (Prototype)" - && csp_name != u"Microsoft Enhanced RSA and AES Cryptographic Provider") + const auto csp_name_length = (header_length - (info_stream.tellg() - index_at_start) - 1) / 2; + auto csp_name = xlnt::detail::utf16_to_utf8(xlnt::detail::read_string(info_stream, csp_name_length)); + if (csp_name != "Microsoft Enhanced RSA and AES Cryptographic Provider (Prototype)" + && csp_name != "Microsoft Enhanced RSA and AES Cryptographic Provider") { throw xlnt::exception("invalid cryptographic provider"); } - info_stream.seekg(csp_name_length); + //info_stream.seekg((csp_name_length + 1) * 2); const auto salt_size = read(info_stream); result.salt = xlnt::detail::read_vector(info_stream, salt_size); @@ -312,12 +312,14 @@ std::vector decrypt_xlsx( std::istream stream(&buffer); xlnt::detail::compound_document document(stream); - auto &encryption_info_stream = document.open_read_stream("EncryptionInfo"); + auto &encryption_info_stream = document.open_read_stream("/EncryptionInfo"); auto encryption_info = read_encryption_info(encryption_info_stream, password); + auto &encrypted_package_stream = document.open_read_stream("/EncryptedPackage"); + return encryption_info.is_agile - ? decrypt_xlsx_agile(encryption_info, document.open_read_stream("EncryptedPackage")) - : decrypt_xlsx_standard(encryption_info, document.open_read_stream("EncryptedPackage")); + ? decrypt_xlsx_agile(encryption_info, encrypted_package_stream) + : decrypt_xlsx_standard(encryption_info, encrypted_package_stream); } } // namespace diff --git a/source/detail/serialization/vector_streambuf.cpp b/source/detail/serialization/vector_streambuf.cpp index c68398b0..2869b8c2 100644 --- a/source/detail/serialization/vector_streambuf.cpp +++ b/source/detail/serialization/vector_streambuf.cpp @@ -22,6 +22,7 @@ // @author: see AUTHORS file #include +#include namespace xlnt { namespace detail { @@ -214,6 +215,11 @@ std::streampos vector_ostreambuf::seekpos(std::streampos sp, std::ios_base::open XLNT_API std::vector to_vector(std::istream &in_stream) { + if (!in_stream) + { + throw xlnt::exception("bad stream"); + } + return std::vector( std::istreambuf_iterator(in_stream), std::istreambuf_iterator()); @@ -221,6 +227,11 @@ XLNT_API std::vector to_vector(std::istream &in_stream) XLNT_API void to_stream(const std::vector &bytes, std::ostream &out_stream) { + if (!out_stream) + { + throw xlnt::exception("bad stream"); + } + out_stream.write(reinterpret_cast(bytes.data()), bytes.size()); } diff --git a/tests/runner.cpp b/tests/runner.cpp index dc4601b6..b01ea246 100644 --- a/tests/runner.cpp +++ b/tests/runner.cpp @@ -73,27 +73,6 @@ void print_summary() int main() { - std::ifstream file("C:/Users/Thomas/Development/xlnt/tests/data/6_encrypted_libre.xlsx", std::ios::binary); - const auto bytes2 = xlnt::detail::to_vector(file); - xlnt::detail::vector_istreambuf buffer(bytes2); - std::istream buffer_stream(&buffer); - xlnt::detail::compound_document doc2(buffer_stream); - auto info = xlnt::detail::to_vector(doc2.open_read_stream("/EncryptionInfo")); - auto package = xlnt::detail::to_vector(doc2.open_read_stream("/EncryptedPackage")); - - std::vector bytes; - xlnt::detail::vector_ostreambuf byte_buffer(bytes); - std::ostream byte_buffer_stream(&byte_buffer); - xlnt::detail::compound_document doc(byte_buffer_stream); - auto &a_stream = doc.open_write_stream("/aaa"); - xlnt::detail::to_stream(std::vector(4095, 'a'), a_stream); - auto &b_stream = doc.open_write_stream("/bbb"); - xlnt::detail::to_stream(std::vector(4095, 'b'), b_stream); - auto &c_stream = doc.open_write_stream("/ccc"); - xlnt::detail::to_stream(std::vector(4095, 'c'), c_stream); - std::ofstream file2("cd.xlsx", std::ios::binary); - xlnt::detail::to_stream(bytes, file2); - // cell run_tests(); run_tests();