diff --git a/source/detail/cryptography/compound_document.cpp b/source/detail/cryptography/compound_document.cpp index d1d4452c..d3f47814 100644 --- a/source/detail/cryptography/compound_document.cpp +++ b/source/detail/cryptography/compound_document.cpp @@ -204,25 +204,17 @@ private: int_type underflow() override { - if (position_ == entry_.size) + if (position_ >= entry_.size) { return traits_type::eof(); } - sector_writer_.reset(); + auto old_position = position_; + auto result = '\0'; + xsgetn(&result, 1); + position_ = old_position; - if (entry_.size < document_.header_.threshold) - { - document_.read_short_sector_chain(entry_.start, - sector_writer_, sector_id(position_ / document_.short_sector_size()), 1); - return current_sector_[position_ % document_.short_sector_size()]; - } - else - { - document_.read_sector_chain(entry_.start, - sector_writer_, sector_id(position_ / document_.sector_size()), 1); - return current_sector_[position_ % document_.sector_size()]; - } + return result; } int_type uflow() override @@ -378,6 +370,7 @@ private: position_ += written; entry_.size = std::max(entry_.size, static_cast(position_)); + document_.write_directory(); std::fill(current_sector_.begin(), current_sector_.end(), byte(0)); setp(reinterpret_cast(current_sector_.data()), @@ -400,12 +393,14 @@ private: auto next_sector = document_.allocate_short_sector(); document_.ssat_[chain_.back()] = next_sector; chain_.push_back(next_sector); + document_.write_ssat(); } else { auto next_sector = document_.allocate_sector(); document_.sat_[chain_.back()] = next_sector; chain_.push_back(next_sector); + document_.write_sat(); } auto value = static_cast(c); @@ -436,15 +431,18 @@ private: current_sector_.resize(document_.sector_size(), 0); std::fill(current_sector_.begin(), current_sector_.end(), byte(0)); - if (document_.header_.num_short_sectors == 0) + if (entry_.start < 0) { - document_.entries_[0].start = EndOfChain; + // TODO: deallocate short sectors here + if (document_.header_.num_short_sectors == 0) + { + document_.entries_[0].start = EndOfChain; + } } - // TODO: deallocate short sectors here - chain_ = new_chain; entry_.start = chain_.front(); + document_.write_directory(); } std::streampos seekoff(std::streamoff off, std::ios_base::seekdir way, std::ios_base::openmode) override @@ -523,6 +521,7 @@ compound_document::compound_document(std::ostream &out) stream_in_(nullptr), stream_out_(nullptr) { + header_.msat.fill(FreeSector); write_header(); insert_entry("/Root Entry", compound_document_entry::entry_type::RootStorage); } @@ -683,13 +682,21 @@ sector_id compound_document::allocate_sector() { auto next_msat_index = header_.num_msat_sectors; auto new_sat_sector_id = sector_id(sat_.size()); + msat_.push_back(new_sat_sector_id); + write_msat(); + header_.msat[msat_.size() - 1] = new_sat_sector_id; + ++header_.num_msat_sectors; + write_header(); + sat_.resize(sat_.size() + sectors_per_sector, FreeSector); sat_[new_sat_sector_id] = SATSector; + auto sat_reader = binary_reader(sat_); sat_reader.offset(next_msat_index * sectors_per_sector); write_sector(sat_reader, new_sat_sector_id); + next_free_iter = std::find(sat_.begin(), sat_.end(), FreeSector); } @@ -721,6 +728,7 @@ sector_chain compound_document::allocate_sectors(std::size_t count) } chain.push_back(current); + write_sat(); return chain; } @@ -755,6 +763,7 @@ sector_chain compound_document::allocate_short_sectors(std::size_t count) } chain.push_back(current); + write_ssat(); return chain; } @@ -768,8 +777,6 @@ sector_id compound_document::allocate_short_sector() { auto new_ssat_sector_id = allocate_sector(); - ++header_.num_short_sectors; - if (header_.ssat_start < 0) { header_.ssat_start = new_ssat_sector_id; @@ -778,6 +785,7 @@ sector_id compound_document::allocate_short_sector() { auto ssat_chain = follow_chain(header_.ssat_start, sat_); sat_[ssat_chain.back()] = new_ssat_sector_id; + write_sat(); } write_header(); @@ -791,7 +799,10 @@ sector_id compound_document::allocate_short_sector() next_free_iter = std::find(ssat_.begin(), ssat_.end(), FreeSector); } - + + ++header_.num_short_sectors; + write_header(); + auto next_free = sector_id(next_free_iter - ssat_.begin()); ssat_[next_free] = EndOfChain; @@ -844,6 +855,7 @@ directory_id compound_document::next_empty_entry() { auto directory_chain = follow_chain(header_.directory_start, sat_); sat_[directory_chain.back()] = allocate_sector(); + write_sat(); } const auto entries_per_sector = sector_size() @@ -875,6 +887,12 @@ directory_id compound_document::insert_entry( if (split.size() > 1) { parent_id = find_entry(join_path(split), compound_document_entry::entry_type::UserStorage); + + if (parent_id < 0) + { + throw xlnt::exception("bad path"); + } + parent_storage_[entry_id] = parent_id; } @@ -938,7 +956,7 @@ void compound_document::write_directory() { for (auto entry_id = std::size_t(0); entry_id < entries_.size(); ++entry_id) { - write_entry(directory_id(entry_id++)); + write_entry(directory_id(entry_id)); } } @@ -1269,7 +1287,6 @@ void compound_document::read_msat() void compound_document::read_sat() { sat_.clear(); - auto sat_writer = binary_writer(sat_); for (auto msat_sector : msat_) @@ -1281,15 +1298,11 @@ void compound_document::read_sat() void compound_document::read_ssat() { ssat_.clear(); + auto ssat_writer = binary_writer(ssat_); for (auto ssat_sector : follow_chain(header_.ssat_start, sat_)) { - auto sector = std::vector(); - auto sector_writer = binary_writer(sector); - - read_sector(ssat_sector, sector_writer); - - std::copy(sector.begin(), sector.end(), std::back_inserter(ssat_)); + read_sector(ssat_sector, ssat_writer); } } @@ -1361,7 +1374,7 @@ void compound_document::write_entry(directory_id id) const auto directory_chain = follow_chain(header_.directory_start, sat_); const auto entries_per_sector = sector_size() / sizeof(compound_document_entry); const auto directory_sector = directory_chain[id / entries_per_sector]; - const auto offset = sector_size() * directory_sector + const auto offset = sector_data_start() + sector_size() * directory_sector + ((id % entries_per_sector) * sizeof(compound_document_entry)); out_->seekp(offset, std::ios::beg); diff --git a/source/detail/cryptography/xlsx_crypto_consumer.cpp b/source/detail/cryptography/xlsx_crypto_consumer.cpp index 59b91183..10a319bf 100644 --- a/source/detail/cryptography/xlsx_crypto_consumer.cpp +++ b/source/detail/cryptography/xlsx_crypto_consumer.cpp @@ -52,15 +52,26 @@ std::vector decrypt_xlsx_standard( const auto key = info.calculate_key(); auto decrypted_size = read(encrypted_package_stream); - auto encrypted_package = std::vector( - std::istreambuf_iterator(encrypted_package_stream), - std::istreambuf_iterator()); - auto decrypted = xlnt::detail::aes_ecb_decrypt(encrypted_package, key); + std::vector encrypted_segment(4096, 0); + std::vector decrypted_package; - decrypted.resize(static_cast(decrypted_size)); + while (encrypted_package_stream) + { + encrypted_package_stream.read( + reinterpret_cast(encrypted_segment.data()), + encrypted_segment.size()); + auto decrypted_segment = xlnt::detail::aes_ecb_decrypt(encrypted_segment, key); - return decrypted; + decrypted_package.insert( + decrypted_package.end(), + decrypted_segment.begin(), + decrypted_segment.end()); + } + + decrypted_package.resize(static_cast(decrypted_size)); + + return decrypted_package; } std::vector decrypt_xlsx_agile( @@ -174,10 +185,7 @@ encryption_info::agile_encryption_info read_agile_encryption_info(std::istream & encryption_info::agile_encryption_info result; - auto xml_string = std::string( - std::istreambuf_iterator(info_stream), - std::istreambuf_iterator()); - xml::parser parser(xml_string.data(), xml_string.size(), "EncryptionInfo"); + xml::parser parser(info_stream, "EncryptionInfo"); parser.next_expect(xml::parser::event_type::start_element, xmlns, "encryption"); diff --git a/source/detail/cryptography/xlsx_crypto_producer.cpp b/source/detail/cryptography/xlsx_crypto_producer.cpp index 9cb25ff2..a4583eae 100644 --- a/source/detail/cryptography/xlsx_crypto_producer.cpp +++ b/source/detail/cryptography/xlsx_crypto_producer.cpp @@ -109,13 +109,18 @@ void write_agile_encryption_info( const encryption_info &info, std::ostream &info_stream) { + const auto version_major = std::uint16_t(4); + const auto version_minor = std::uint16_t(4); + const auto encryption_flags = std::uint32_t(0x40); + + info_stream.write(reinterpret_cast(&version_major), sizeof(std::uint16_t)); + info_stream.write(reinterpret_cast(&version_minor), sizeof(std::uint16_t)); + info_stream.write(reinterpret_cast(&encryption_flags), sizeof(std::uint32_t)); + static const auto &xmlns = xlnt::constants::ns("encryption"); static const auto &xmlns_p = xlnt::constants::ns("encryption-password"); - std::vector encryption_info; - xlnt::detail::vector_ostreambuf encryption_info_buffer(encryption_info); - std::ostream encryption_info_stream(&encryption_info_buffer); - xml::serializer serializer(encryption_info_stream, "EncryptionInfo"); + xml::serializer serializer(info_stream, "EncryptionInfo"); serializer.start_element(xmlns, "encryption"); @@ -166,8 +171,6 @@ void write_agile_encryption_info( serializer.end_element(xmlns, "keyEncryptors"); serializer.end_element(xmlns, "encryption"); - - info_stream.write(reinterpret_cast(encryption_info.data()), encryption_info.size()); } void write_standard_encryption_info(const encryption_info &info, std::ostream &info_stream) @@ -211,34 +214,55 @@ void write_standard_encryption_info(const encryption_info &info, std::ostream &i void encrypt_xlsx_agile( const encryption_info &info, - std::ostream &plaintext) + const std::vector &plaintext, + std::ostream &ciphertext_stream) { - auto key = info.calculate_key(); - /* - auto padded = plaintext; - padded.resize((plaintext.size() / 16 + (plaintext.size() % 16 == 0 ? 0 : 1)) * 16); - auto ciphertext = xlnt::detail::aes_ecb_encrypt(padded, key); const auto length = static_cast(plaintext.size()); - ciphertext.insert(ciphertext.begin(), - reinterpret_cast(&length), - reinterpret_cast(&length + sizeof(std::uint64_t))); - */ + ciphertext_stream.write(reinterpret_cast(&length), sizeof(std::uint64_t)); + + auto key = info.calculate_key(); + + auto salt_size = info.agile.key_data.salt_size; + auto salt_with_block_key = info.agile.key_data.salt_value; + salt_with_block_key.resize(salt_size + sizeof(std::uint32_t), 0); + auto &segment_index = *reinterpret_cast(salt_with_block_key.data() + salt_size); + + auto segment = std::vector(4096, 0); + + for (auto i = std::size_t(0); i < length; i += 4096) + { + auto iv = hash(info.agile.key_encryptor.hash, salt_with_block_key); + iv.resize(16); + + auto start = plaintext.begin() + i; + auto bytes = std::min(std::size_t(length - i), std::size_t(4096)); + std::copy(start, start + bytes, segment.begin()); + auto encrypted_segment = xlnt::detail::aes_cbc_encrypt(segment, key, iv); + ciphertext_stream.write(reinterpret_cast(encrypted_segment.data()), bytes); + + ++segment_index; + } } void encrypt_xlsx_standard( const encryption_info &info, - std::ostream &plaintext) + const std::vector &plaintext, + std::ostream &ciphertext_stream) { - auto key = info.calculate_key(); - /* - auto padded = plaintext; - padded.resize((plaintext.size() / 16 + (plaintext.size() % 16 == 0 ? 0 : 1)) * 16); - auto ciphertext = xlnt::detail::aes_ecb_encrypt(padded, key); const auto length = static_cast(plaintext.size()); - ciphertext.insert(ciphertext.begin(), - reinterpret_cast(&length), - reinterpret_cast(&length + sizeof(std::uint64_t))); - */ + ciphertext_stream.write(reinterpret_cast(&length), sizeof(std::uint64_t)); + + auto key = info.calculate_key(); + auto segment = std::vector(4096, 0); + + for (auto i = std::size_t(0); i < length; ++i) + { + auto start = plaintext.begin() + i; + auto bytes = std::min(std::size_t(length - i), std::size_t(4096)); + std::copy(start, start + bytes, segment.begin()); + auto encrypted_segment = xlnt::detail::aes_ecb_encrypt(segment, key); + ciphertext_stream.write(reinterpret_cast(encrypted_segment.data()), bytes); + } } std::vector encrypt_xlsx( @@ -249,19 +273,24 @@ std::vector encrypt_xlsx( encryption_info.password = u"secret"; auto ciphertext = std::vector(); + xlnt::detail::vector_ostreambuf buffer(ciphertext); std::ostream stream(&buffer); xlnt::detail::compound_document document(stream); if (encryption_info.is_agile) { - write_agile_encryption_info(encryption_info, document.open_write_stream("/EncryptionInfo")); - encrypt_xlsx_agile(encryption_info, document.open_write_stream("/EncryptedPackage")); + write_agile_encryption_info(encryption_info, + document.open_write_stream("/EncryptionInfo")); + encrypt_xlsx_agile(encryption_info, plaintext, + document.open_write_stream("/EncryptedPackage")); } else { - write_standard_encryption_info(encryption_info, document.open_write_stream("/EncryptionInfo")); - encrypt_xlsx_standard(encryption_info, document.open_write_stream("/EncryptedPackage")); + write_standard_encryption_info(encryption_info, + document.open_write_stream("/EncryptionInfo")); + encrypt_xlsx_standard(encryption_info, plaintext, + document.open_write_stream("/EncryptedPackage")); } return ciphertext; diff --git a/tests/workbook/serialization_test_suite.hpp b/tests/workbook/serialization_test_suite.hpp index f20bdaaf..b85c2065 100644 --- a/tests/workbook/serialization_test_suite.hpp +++ b/tests/workbook/serialization_test_suite.hpp @@ -413,6 +413,9 @@ public: std::vector destination_data; source_workbook.save(destination_data, password); source_workbook.save("encrypted.xlsx", password); + + xlnt::workbook temp; + temp.load("encrypted.xlsx", password); //TODO: finish implementing encryption and uncomment this //return source_data == destination_data;