separate compound file into a reader and a writer class

This commit is contained in:
Thomas Fussell 2017-04-23 16:56:01 -04:00
parent be11002a93
commit 2fc692d694
6 changed files with 195 additions and 222 deletions

View File

@ -162,7 +162,7 @@ if(MSVC)
endif()
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/detail/serialization/miniz.cpp PROPERTIES COMPILE_FLAGS "-Wno-comma -Wno-undef")
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/detail/serialization/miniz.cpp PROPERTIES COMPILE_FLAGS "-Wno-undef")
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/detail/serialization/zstream.cpp PROPERTIES COMPILE_FLAGS "-Wno-undef -Wno-shorten-64-to-32")
endif()

View File

@ -133,12 +133,7 @@ public:
{
return byte_vector::from(data_);
}
/*
std::size_t size_in_bytes()
{
return count() * 4;
}
*/
std::size_t sector_size() const
{
return sector_size_;
@ -480,13 +475,6 @@ public:
return result;
}
/*
std::size_t size()
{
return entry_count() * sizeof(directory_entry);
}
*/
directory_entry create_root_entry() const
{
directory_entry root;
@ -499,11 +487,6 @@ public:
return root;
}
bool contains(const std::u16string &name) const
{
return find_entry(name).second;
}
private:
// helper function: recursively find siblings of index
void find_siblings(std::vector<directory_id> &result, directory_id index) const
@ -620,13 +603,26 @@ private:
namespace xlnt {
namespace detail {
class compound_document_impl
class compound_document_reader_impl
{
public:
compound_document_impl()
compound_document_reader_impl(const std::vector<std::uint8_t> &data)
{
auto reader = byte_vector(data);
header_.load(reader);
sector_table_.sector_size(header_.sector_size());
short_sector_table_.sector_size(header_.short_sector_size());
sectors_.append(data, 512, data.size() - 512);
sector_table_.load(load_sectors(load_msat(reader)));
short_sector_table_.load(load_sectors(sector_table_.follow(header_.short_table_start())));
auto directory_data = load_sectors(sector_table_.follow(header_.directory_start()));
directory_.load(directory_data);
auto first_short_sector = directory_.entry(u"/Root Entry", false).first;
short_container_stream_ = sector_table_.follow(first_short_sector);
}
byte_vector load_sectors(const std::vector<sector_id> &sectors) const
@ -643,32 +639,6 @@ public:
return result;
}
void write_sectors(const byte_vector &data, directory_entry &/*entry*/)
{
const auto sector_size = sector_table_.sector_size();
const auto num_sectors = data.size() / sector_size;
for (auto i = std::size_t(0); i < num_sectors; ++i)
{
auto position = sector_size * i;
auto current_sector_size = data.size() % sector_size;
sectors_.append(data.data(), position, current_sector_size);
}
}
void write_short_sectors(const byte_vector &data, directory_entry &/*entry*/)
{
const auto sector_size = sector_table_.sector_size();
const auto num_sectors = data.size() / sector_size;
for (auto i = std::size_t(0); i < num_sectors; ++i)
{
auto position = sector_size * i;
auto current_sector_size = data.size() % sector_size;
sectors_.append(data.data(), position, current_sector_size);
}
}
byte_vector load_short_sectors(const std::vector<sector_id> &sectors) const
{
auto result = byte_vector();
@ -715,67 +685,9 @@ public:
return master_sectors;
}
void load(byte_vector &data)
byte_vector read_stream(const std::u16string &name) const
{
header_.load(data);
const auto sector_size = header_.sector_size();
const auto short_sector_size = header_.short_sector_size();
sector_table_.sector_size(sector_size);
short_sector_table_.sector_size(short_sector_size);
sectors_.append(data.data(), 512, data.size() - 512);
sector_table_.load(load_sectors(load_msat(data)));
short_sector_table_.load(load_sectors(sector_table_.follow(header_.short_table_start())));
auto directory_data = load_sectors(sector_table_.follow(header_.directory_start()));
directory_.load(directory_data);
auto first_short_sector = directory_.entry(u"/Root Entry", false).first;
short_container_stream_ = sector_table_.follow(first_short_sector);
}
byte_vector save() const
{
auto result = byte_vector();
result.append(header_.save().data());
result.append(sector_table_.save().data());
result.append(short_sector_table_.save().data());
result.append(directory_.save().data());
result.append(sectors_.data());
return result;
}
bool has_stream(const std::u16string &filename) const
{
return directory_.contains(filename);
}
void add_stream(const std::u16string &name, const byte_vector &data)
{
auto entry = directory_.entry(name, !has_stream(name));
if (entry.size < header_.threshold())
{
write_short_sectors(data, entry);
}
else
{
write_sectors(data, entry);
}
}
byte_vector stream(const std::u16string &name) const
{
if (!has_stream(name))
{
throw xlnt::exception("document doesn't contain stream with the given name");
}
auto entry = directory_.entry(name);
const auto entry = directory_.entry(name);
byte_vector result;
if (entry.size < header_.threshold())
@ -802,51 +714,106 @@ private:
std::vector<sector_id> short_container_stream_;
};
compound_document::compound_document()
: d_(new compound_document_impl())
class compound_document_writer_impl
{
public:
compound_document_writer_impl(std::vector<std::uint8_t> &data)
{
sector_table_.sector_size(header_.sector_size());
short_sector_table_.sector_size(header_.short_sector_size());
}
void write_sectors(const byte_vector &data, directory_entry &/*entry*/)
{
const auto sector_size = sector_table_.sector_size();
const auto num_sectors = data.size() / sector_size;
for (auto i = std::size_t(0); i < num_sectors; ++i)
{
auto position = sector_size * i;
auto current_sector_size = data.size() % sector_size;
sectors_.append(data.data(), position, current_sector_size);
}
}
void write_short_sectors(const byte_vector &data, directory_entry &/*entry*/)
{
const auto sector_size = sector_table_.sector_size();
const auto num_sectors = data.size() / sector_size;
for (auto i = std::size_t(0); i < num_sectors; ++i)
{
auto position = sector_size * i;
auto current_sector_size = data.size() % sector_size;
sectors_.append(data.data(), position, current_sector_size);
}
}
byte_vector save() const
{
auto result = byte_vector();
result.append(header_.save().data());
result.append(sector_table_.save().data());
result.append(short_sector_table_.save().data());
result.append(directory_.save().data());
result.append(sectors_.data());
return result;
}
void write_stream(const std::u16string &name, const byte_vector &data)
{
auto &entry = directory_.entry(name, true);
if (entry.size < header_.threshold())
{
write_short_sectors(data, entry);
}
else
{
write_sectors(data, entry);
}
}
private:
directory_tree directory_;
header header_;
allocation_table sector_table_;
byte_vector sectors_;
allocation_table short_sector_table_;
byte_vector short_sectors_;
std::vector<sector_id> short_container_stream_;
};
compound_document_reader::compound_document_reader(const std::vector<std::uint8_t> &data)
: d_(new compound_document_reader_impl(data))
{
}
compound_document::~compound_document()
compound_document_reader::~compound_document_reader()
{
}
compound_document_impl &compound_document::impl()
std::vector<std::uint8_t> compound_document_reader::read_stream(const std::u16string &name) const
{
return *d_;
return d_->read_stream(name).data();
}
compound_document_impl &compound_document::impl() const
compound_document_writer::compound_document_writer(std::vector<std::uint8_t> &data)
: d_(new compound_document_writer_impl(data))
{
return *d_;
}
void compound_document::load(std::vector<std::uint8_t> &data)
compound_document_writer::~compound_document_writer()
{
byte_vector vec(data);
return impl().load(vec);
}
std::vector<std::uint8_t> compound_document::save() const
void compound_document_writer::write_stream(const std::u16string &name, const std::vector<std::uint8_t> &data)
{
return impl().save().data();
d_->write_stream(name, data);
}
bool compound_document::has_stream(const std::u16string &filename) const
{
return impl().has_stream(filename);
}
void compound_document::add_stream(const std::u16string &name, const std::vector<std::uint8_t> &data)
{
return impl().add_stream(name, data);
}
std::vector<std::uint8_t> compound_document::stream(const std::u16string &name) const
{
return impl().stream(name).data();
}
} // namespace detail
} // namespace xlnt

View File

@ -33,26 +33,31 @@
namespace xlnt {
namespace detail {
class compound_document_impl;
class compound_document_reader_impl;
class compound_document_writer_impl;
class compound_document
class compound_document_reader
{
public:
compound_document();
~compound_document();
compound_document_reader(const std::vector<std::uint8_t> &data);
~compound_document_reader();
void load(std::vector<std::uint8_t> &data);
std::vector<std::uint8_t> save() const;
bool has_stream(const std::u16string &filename) const;
void add_stream(const std::u16string &filename, const std::vector<std::uint8_t> &data);
std::vector<std::uint8_t> stream(const std::u16string &filename) const;
std::vector<std::uint8_t> read_stream(const std::u16string &filename) const;
private:
compound_document_impl &impl();
//TODO: can this return a const reference?
compound_document_impl &impl() const;
std::unique_ptr<compound_document_impl> d_;
std::unique_ptr<compound_document_reader_impl> d_;
};
class compound_document_writer
{
public:
compound_document_writer(std::vector<std::uint8_t> &data);
~compound_document_writer();
void write_stream(const std::u16string &filename, const std::vector<std::uint8_t> &data);
private:
std::unique_ptr<compound_document_writer_impl> d_;
};
} // namespace detail

View File

@ -39,14 +39,14 @@ struct encryption_info
struct standard_encryption_info
{
const std::size_t spin_count = 50000;
std::size_t spin_count = 50000;
std::size_t block_size;
std::size_t key_bits;
std::size_t key_bytes;
std::size_t hash_size;
cipher_algorithm cipher;
cipher_chaining chaining;
const hash_algorithm hash = hash_algorithm::sha1;
hash_algorithm hash = hash_algorithm::sha1;
std::vector<std::uint8_t> salt;
std::vector<std::uint8_t> encrypted_verifier;
std::vector<std::uint8_t> encrypted_verifier_hash;

View File

@ -103,11 +103,9 @@ std::vector<std::uint8_t> decrypt_xlsx_agile(
return decrypted_package;
}
encryption_info read_standard_encryption_info(const std::vector<std::uint8_t> &info_bytes)
encryption_info::standard_encryption_info read_standard_encryption_info(const std::vector<std::uint8_t> &info_bytes)
{
encryption_info result;
result.is_agile = false;
auto &standard_info = result.standard;
encryption_info::standard_encryption_info result;
auto reader = xlnt::detail::byte_reader(info_bytes);
@ -123,7 +121,7 @@ encryption_info read_standard_encryption_info(const std::vector<std::uint8_t> &i
if (alg_id == 0 || alg_id == 0x0000660E || alg_id == 0x0000660F || alg_id == 0x00006610)
{
standard_info.cipher = xlnt::detail::cipher_algorithm::aes;
result.cipher = xlnt::detail::cipher_algorithm::aes;
}
else
{
@ -136,8 +134,8 @@ encryption_info read_standard_encryption_info(const std::vector<std::uint8_t> &i
throw xlnt::exception("invalid hash algorithm");
}
standard_info.key_bits = reader.read<std::uint32_t>();
standard_info.key_bytes = standard_info.key_bits / 8;
result.key_bits = reader.read<std::uint32_t>();
result.key_bytes = result.key_bits / 8;
auto provider_type = reader.read<std::uint32_t>();
if (provider_type != 0 && provider_type != 0x00000018)
@ -165,20 +163,20 @@ encryption_info read_standard_encryption_info(const std::vector<std::uint8_t> &i
reader.offset(reader.offset() + csp_name_length);
const auto salt_size = reader.read<std::uint32_t>();
standard_info.salt = std::vector<std::uint8_t>(
result.salt = std::vector<std::uint8_t>(
info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset()),
info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset() + salt_size));
reader.offset(reader.offset() + salt_size);
static const auto verifier_size = std::size_t(16);
standard_info.encrypted_verifier = std::vector<std::uint8_t>(
result.encrypted_verifier = std::vector<std::uint8_t>(
info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset()),
info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset() + verifier_size));
reader.offset(reader.offset() + verifier_size);
/*const auto verifier_hash_size = */reader.read<std::uint32_t>();
const auto encrypted_verifier_hash_size = std::size_t(32);
standard_info.encrypted_verifier_hash = std::vector<std::uint8_t>(
result.encrypted_verifier_hash = std::vector<std::uint8_t>(
info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset()),
info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset() + encrypted_verifier_hash_size));
reader.offset(reader.offset() + encrypted_verifier_hash_size);
@ -191,7 +189,7 @@ encryption_info read_standard_encryption_info(const std::vector<std::uint8_t> &i
return result;
}
encryption_info read_agile_encryption_info(const std::vector<std::uint8_t> &info_bytes)
encryption_info::agile_encryption_info read_agile_encryption_info(const std::vector<std::uint8_t> &info_bytes)
{
using xlnt::detail::decode_base64;
@ -199,31 +197,32 @@ encryption_info read_agile_encryption_info(const std::vector<std::uint8_t> &info
static const auto &xmlns_p = xlnt::constants::ns("encryption-password");
// static const auto &xmlns_c = xlnt::constants::namespace_("encryption-certificate");
encryption_info result;
result.is_agile = true;
auto &agile_info = result.agile;
encryption_info::agile_encryption_info result;
auto header_size = std::size_t(8);
xml::parser parser(info_bytes.data() + header_size, info_bytes.size() - header_size, "EncryptionInfo");
parser.next_expect(xml::parser::event_type::start_element, xmlns, "encryption");
auto &key_data = result.key_data;
parser.next_expect(xml::parser::event_type::start_element, xmlns, "keyData");
agile_info.key_data.salt_size = parser.attribute<std::size_t>("saltSize");
agile_info.key_data.block_size = parser.attribute<std::size_t>("blockSize");
agile_info.key_data.key_bits = parser.attribute<std::size_t>("keyBits");
agile_info.key_data.hash_size = parser.attribute<std::size_t>("hashSize");
agile_info.key_data.cipher_algorithm = parser.attribute("cipherAlgorithm");
agile_info.key_data.cipher_chaining = parser.attribute("cipherChaining");
agile_info.key_data.hash_algorithm = parser.attribute("hashAlgorithm");
agile_info.key_data.salt_value = decode_base64(parser.attribute("saltValue"));
key_data.salt_size = parser.attribute<std::size_t>("saltSize");
key_data.block_size = parser.attribute<std::size_t>("blockSize");
key_data.key_bits = parser.attribute<std::size_t>("keyBits");
key_data.hash_size = parser.attribute<std::size_t>("hashSize");
key_data.cipher_algorithm = parser.attribute("cipherAlgorithm");
key_data.cipher_chaining = parser.attribute("cipherChaining");
key_data.hash_algorithm = parser.attribute("hashAlgorithm");
key_data.salt_value = decode_base64(parser.attribute("saltValue"));
parser.next_expect(xml::parser::event_type::end_element, xmlns, "keyData");
auto &data_integrity = result.data_integrity;
parser.next_expect(xml::parser::event_type::start_element, xmlns, "dataIntegrity");
agile_info.data_integrity.hmac_key = decode_base64(parser.attribute("encryptedHmacKey"));
agile_info.data_integrity.hmac_value = decode_base64(parser.attribute("encryptedHmacValue"));
data_integrity.hmac_key = decode_base64(parser.attribute("encryptedHmacKey"));
data_integrity.hmac_value = decode_base64(parser.attribute("encryptedHmacValue"));
parser.next_expect(xml::parser::event_type::end_element, xmlns, "dataIntegrity");
auto &key_encryptor = result.key_encryptor;
parser.next_expect(xml::parser::event_type::start_element, xmlns, "keyEncryptors");
parser.next_expect(xml::parser::event_type::start_element, xmlns, "keyEncryptor");
parser.attribute("uri");
@ -236,22 +235,18 @@ encryption_info read_agile_encryption_info(const std::vector<std::uint8_t> &info
if (parser.namespace_() == xmlns_p && parser.name() == "encryptedKey")
{
any_password_key = true;
agile_info.key_encryptor.spin_count = parser.attribute<std::size_t>("spinCount");
agile_info.key_encryptor.salt_size = parser.attribute<std::size_t>("saltSize");
agile_info.key_encryptor.block_size = parser.attribute<std::size_t>("blockSize");
agile_info.key_encryptor.key_bits = parser.attribute<std::size_t>("keyBits");
agile_info.key_encryptor.hash_size = parser.attribute<std::size_t>("hashSize");
agile_info.key_encryptor.cipher_algorithm = parser.attribute("cipherAlgorithm");
agile_info.key_encryptor.cipher_chaining = parser.attribute("cipherChaining");
agile_info.key_encryptor.hash = parser.attribute<xlnt::detail::hash_algorithm>("hashAlgorithm");
agile_info.key_encryptor.salt_value =
decode_base64(parser.attribute("saltValue"));
agile_info.key_encryptor.verifier_hash_input =
decode_base64(parser.attribute("encryptedVerifierHashInput"));
agile_info.key_encryptor.verifier_hash_value =
decode_base64(parser.attribute("encryptedVerifierHashValue"));
agile_info.key_encryptor.encrypted_key_value =
decode_base64(parser.attribute("encryptedKeyValue"));
key_encryptor.spin_count = parser.attribute<std::size_t>("spinCount");
key_encryptor.salt_size = parser.attribute<std::size_t>("saltSize");
key_encryptor.block_size = parser.attribute<std::size_t>("blockSize");
key_encryptor.key_bits = parser.attribute<std::size_t>("keyBits");
key_encryptor.hash_size = parser.attribute<std::size_t>("hashSize");
key_encryptor.cipher_algorithm = parser.attribute("cipherAlgorithm");
key_encryptor.cipher_chaining = parser.attribute("cipherChaining");
key_encryptor.hash = parser.attribute<xlnt::detail::hash_algorithm>("hashAlgorithm");
key_encryptor.salt_value = decode_base64(parser.attribute("saltValue"));
key_encryptor.verifier_hash_input = decode_base64(parser.attribute("encryptedVerifierHashInput"));
key_encryptor.verifier_hash_value = decode_base64(parser.attribute("encryptedVerifierHashValue"));
key_encryptor.encrypted_key_value = decode_base64(parser.attribute("encryptedKeyValue"));
}
else
{
@ -274,50 +269,56 @@ encryption_info read_agile_encryption_info(const std::vector<std::uint8_t> &info
return result;
}
encryption_info read_encryption_info(const std::vector<std::uint8_t> &info_bytes)
encryption_info read_encryption_info(const std::vector<std::uint8_t> &info_bytes, const std::u16string &password)
{
encryption_info info;
info.password = password;
auto reader = xlnt::detail::byte_reader(info_bytes);
auto version_major = reader.read<std::uint16_t>();
auto version_minor = reader.read<std::uint16_t>();
auto encryption_flags = reader.read<std::uint32_t>();
info.is_agile = version_major == 4 && version_minor == 4;
// version 4.4 is agile
if (version_major == 4 && version_minor == 4)
if (info.is_agile)
{
if (encryption_flags != 0x40)
{
throw xlnt::exception("bad header");
}
return read_agile_encryption_info(info_bytes);
info.agile = read_agile_encryption_info(info_bytes);
}
// not agile, only try to decrypt versions 3.2 and 4.2
if (version_minor != 2 || (version_major != 2 && version_major != 3 && version_major != 4))
else
{
throw xlnt::exception("unsupported encryption version");
}
if (version_minor != 2 || (version_major != 2 && version_major != 3 && version_major != 4))
{
throw xlnt::exception("unsupported encryption version");
}
if ((encryption_flags & 0b00000011) != 0) // Reserved1 and Reserved2, MUST be 0
{
throw xlnt::exception("bad header");
}
if ((encryption_flags & 0b00000011) != 0) // Reserved1 and Reserved2, MUST be 0
{
throw xlnt::exception("bad header");
}
if ((encryption_flags & 0b00000100) == 0 // fCryptoAPI
|| (encryption_flags & 0b00010000) != 0) // fExternal
{
throw xlnt::exception("extensible encryption is not supported");
}
if ((encryption_flags & 0b00000100) == 0 // fCryptoAPI
|| (encryption_flags & 0b00010000) != 0) // fExternal
{
throw xlnt::exception("extensible encryption is not supported");
}
if ((encryption_flags & 0b00100000) == 0) // fAES
{
throw xlnt::exception("not an OOXML document");
}
if ((encryption_flags & 0b00100000) == 0) // fAES
{
throw xlnt::exception("not an OOXML document");
}
return read_standard_encryption_info(info_bytes);
info.standard = read_standard_encryption_info(info_bytes);
}
return info;
}
std::vector<std::uint8_t> decrypt_xlsx(
@ -329,12 +330,11 @@ std::vector<std::uint8_t> decrypt_xlsx(
throw xlnt::exception("empty file");
}
xlnt::detail::compound_document document;
document.load(const_cast<std::vector<std::uint8_t> &>(bytes));
xlnt::detail::compound_document_reader document(bytes);
auto encryption_info = read_encryption_info(document.stream(u"EncryptionInfo"));
encryption_info.password = password;
auto encrypted_package = document.stream(u"EncryptedPackage");
auto encryption_info = read_encryption_info(
document.read_stream(u"EncryptionInfo"), password);
auto encrypted_package = document.read_stream(u"EncryptedPackage");
return encryption_info.is_agile
? decrypt_xlsx_agile(encryption_info, encrypted_package)

View File

@ -248,16 +248,17 @@ std::vector<std::uint8_t> encrypt_xlsx(
auto encryption_info = generate_encryption_info(password);
encryption_info.password = u"secret";
xlnt::detail::compound_document document;
auto ciphertext = std::vector<std::uint8_t>();
xlnt::detail::compound_document_writer document(ciphertext);
document.add_stream(u"EncryptionInfo", encryption_info.is_agile
document.write_stream(u"EncryptionInfo", encryption_info.is_agile
? write_agile_encryption_info(encryption_info)
: write_standard_encryption_info(encryption_info));
document.add_stream(u"EncryptedPackage", encryption_info.is_agile
document.write_stream(u"EncryptedPackage", encryption_info.is_agile
? encrypt_xlsx_agile(encryption_info, plaintext)
: encrypt_xlsx_standard(encryption_info, plaintext));
return document.save();
return ciphertext;
}
} // namespace