intermediate commit

This commit is contained in:
Thomas Fussell 2017-04-22 14:25:27 -04:00
parent 3a57242b68
commit 89858e32b3
6 changed files with 976 additions and 889 deletions

View File

@ -30,73 +30,230 @@ namespace xlnt {
namespace detail { namespace detail {
using byte = std::uint8_t; using byte = std::uint8_t;
using byte_vector = std::vector<byte>;
class byte_reader
{
public:
byte_reader() = delete;
byte_reader(const std::vector<byte> &bytes)
: bytes_(&bytes)
{
}
byte_reader &operator=(const byte_reader &other)
{
offset_ = other.offset_;
bytes_ = other.bytes_;
return *this;
}
~byte_reader()
{
}
const std::vector<std::uint8_t> &data() const
{
return *bytes_;
}
void offset(std::size_t offset)
{
offset_ = offset;
}
std::size_t offset() const
{
return offset_;
}
void reset()
{
offset_ = 0;
}
template<typename T> template<typename T>
T read_int(const byte_vector &raw_data, std::size_t &index) T read()
{ {
auto result = *reinterpret_cast<const T *>(&raw_data[index]); T result;
index += sizeof(T); std::memcpy(&result, bytes_->data() + offset_, sizeof(T));
return result; return result;
} }
template<typename T> template<typename T>
void write_int(T value, byte_vector &raw_data, std::size_t &index) std::vector<T> as_vector_of() const
{ {
*reinterpret_cast<T *>(&raw_data[index]) = value; auto result = std::vector<T>(size() / sizeof(T), 0);
index += sizeof(T); std::memcpy(result.data(), bytes_->data(), size());
return result;
} }
static inline void writeU16(std::uint8_t *ptr, std::uint16_t data) std::size_t size() const
{ {
ptr[0] = static_cast<std::uint8_t>(data & 0xff); return bytes_->size();
ptr[1] = static_cast<std::uint8_t>((data >> 8) & 0xff);
} }
static inline void writeU32(std::uint8_t *ptr, std::uint32_t data) private:
std::size_t offset_ = 0;
const std::vector<std::uint8_t> *bytes_;
};
class byte_vector
{ {
ptr[0] = static_cast<std::uint8_t>(data & 0xff); public:
ptr[1] = static_cast<std::uint8_t>((data >> 8) & 0xff);
ptr[2] = static_cast<std::uint8_t>((data >> 16) & 0xff);
ptr[3] = static_cast<std::uint8_t>((data >> 24) & 0xff);
}
template<typename T> template<typename T>
byte *vector_byte(std::vector<T> &v, std::size_t offset) static byte_vector from(const std::vector<T> &ints)
{ {
return reinterpret_cast<byte *>(v.data() + offset); byte_vector result;
result.resize(ints.size() / sizeof(T));
std::memcpy(result.bytes_.data(), ints.data(), result.bytes_.size());
return result;
} }
template<typename T> template<typename T>
byte *first_byte(std::vector<T> &v) static byte_vector from(const std::basic_string<T> &string)
{
byte_vector result;
result.resize(string.size() / sizeof(T));
std::memcpy(result.bytes_.data(), string.data(), result.bytes_.size());
return result;
}
byte_vector()
: reader_(bytes_)
{
}
byte_vector(std::vector<byte> &bytes)
: bytes_(bytes),
reader_(bytes_)
{ {
return vector_byte(v, 0);
} }
template<typename T> template<typename T>
byte *last_byte(std::vector<T> &v) byte_vector(const std::vector<T> &ints)
: byte_vector()
{ {
return vector_byte(v, v.size()); bytes_ = from(ints).data();
} }
template <typename InIt> byte_vector(const byte_vector &other)
byte_vector to_bytes(InIt begin, InIt end) : byte_vector()
{ {
byte_vector bytes; *this = other;
for (auto i = begin; i != end; ++i)
{
auto c = *i;
bytes.insert(
bytes.end(),
reinterpret_cast<char *>(&c),
reinterpret_cast<char *>(&c) + sizeof(c));
} }
return bytes; ~byte_vector()
{
} }
byte_vector &operator=(const byte_vector &other)
{
bytes_ = other.bytes_;
reader_ = byte_reader(bytes_);
return *this;
}
const std::vector<byte> &data() const
{
return bytes_;
}
std::vector<byte> data()
{
return bytes_;
}
void data(std::vector<byte> &bytes)
{
bytes_ = bytes;
}
void offset(std::size_t offset)
{
reader_.offset(offset);
}
std::size_t offset() const
{
return reader_.offset();
}
void reset()
{
reader_.reset();
bytes_.clear();
}
template<typename T>
T read()
{
return reader_.read<T>();
}
template<typename T>
std::vector<T> as_vector_of() const
{
return reader_.as_vector_of<T>();
}
template<typename T>
void write(T value)
{
const auto num_bytes = sizeof(T);
if (offset() + num_bytes > size())
{
extend(offset() + num_bytes - size());
}
std::memcpy(bytes_.data() + offset(), &value, num_bytes);
reader_.offset(reader_.offset() + num_bytes);
}
std::size_t size() const
{
return bytes_.size();
}
void resize(std::size_t new_size, byte fill = 0)
{
bytes_.resize(new_size, fill);
}
void extend(std::size_t amount, byte fill = 0)
{
bytes_.resize(size() + amount, fill);
}
std::vector<byte>::iterator iterator()
{
return bytes_.begin() + offset();
}
void append(const std::vector<std::uint8_t> &data, std::size_t offset, std::size_t count)
{
extend(count);
std::memcpy(bytes_.data(), data.data() + offset, count);
}
void append(const std::vector<std::uint8_t> &data)
{
append(data, 0, data.size());
}
private:
std::vector<byte> bytes_;
byte_reader reader_;
};
} // namespace detail } // namespace detail
} // namespace xlnt } // namespace xlnt

File diff suppressed because it is too large Load Diff

View File

@ -33,7 +33,7 @@
namespace xlnt { namespace xlnt {
namespace detail { namespace detail {
struct compound_document_impl; class compound_document_impl;
class compound_document class compound_document
{ {
@ -41,14 +41,17 @@ public:
compound_document(); compound_document();
~compound_document(); ~compound_document();
void load(const std::vector<std::uint8_t> &data); void load(std::vector<std::uint8_t> &data);
std::vector<std::uint8_t> save() const; std::vector<std::uint8_t> save() const;
bool has_stream(const std::string &filename) const; bool has_stream(const std::u16string &filename) const;
void add_stream(const std::string &filename, const std::vector<std::uint8_t> &data); void add_stream(const std::u16string &filename, const std::vector<std::uint8_t> &data);
std::vector<std::uint8_t> stream(const std::string &filename) const; std::vector<std::uint8_t> stream(const std::u16string &filename) const;
private: private:
compound_document_impl &impl();
//TODO: can this return a const reference?
compound_document_impl &impl() const;
std::unique_ptr<compound_document_impl> d_; std::unique_ptr<compound_document_impl> d_;
}; };

View File

@ -37,7 +37,7 @@ std::vector<std::uint8_t> calculate_standard_key(
{ {
// H_0 = H(salt + password) // H_0 = H(salt + password)
auto salt_plus_password = info.salt; auto salt_plus_password = info.salt;
auto password_bytes = xlnt::detail::to_bytes(password.begin(), password.end()); auto password_bytes = xlnt::detail::byte_vector::from(password).data();
std::copy(password_bytes.begin(), std::copy(password_bytes.begin(),
password_bytes.end(), password_bytes.end(),
std::back_inserter(salt_plus_password)); std::back_inserter(salt_plus_password));
@ -107,7 +107,7 @@ std::vector<std::uint8_t> calculate_agile_key(
{ {
// H_0 = H(salt + password) // H_0 = H(salt + password)
auto salt_plus_password = info.key_encryptor.salt_value; auto salt_plus_password = info.key_encryptor.salt_value;
auto password_bytes = xlnt::detail::to_bytes(password.begin(), password.end()); auto password_bytes = xlnt::detail::byte_vector::from(password).data();
std::copy(password_bytes.begin(), std::copy(password_bytes.begin(),
password_bytes.end(), password_bytes.end(),
std::back_inserter(salt_plus_password)); std::back_inserter(salt_plus_password));

View File

@ -46,13 +46,13 @@ using xlnt::detail::encryption_info;
std::vector<std::uint8_t> decrypt_xlsx_standard( std::vector<std::uint8_t> decrypt_xlsx_standard(
encryption_info info, encryption_info info,
const byte_vector &encrypted_package) const std::vector<std::uint8_t> &encrypted_package)
{ {
const auto key = info.calculate_key(); const auto key = info.calculate_key();
auto offset = std::size_t(0); auto reader = xlnt::detail::byte_reader(encrypted_package);
auto decrypted_size = xlnt::detail::read_int<std::uint64_t>(encrypted_package, offset); auto decrypted_size = reader.read<std::uint64_t>();
auto decrypted = xlnt::detail::aes_ecb_decrypt(encrypted_package, key, offset); auto decrypted = xlnt::detail::aes_ecb_decrypt(encrypted_package, key, reader.offset());
decrypted.resize(static_cast<std::size_t>(decrypted_size)); decrypted.resize(static_cast<std::size_t>(decrypted_size));
return decrypted; return decrypted;
@ -108,14 +108,17 @@ encryption_info read_standard_encryption_info(const std::vector<std::uint8_t> &i
result.is_agile = false; result.is_agile = false;
auto &standard_info = result.standard; auto &standard_info = result.standard;
using xlnt::detail::read_int; auto reader = xlnt::detail::byte_reader(info_bytes);
auto offset = std::size_t(8); // skip version info
auto header_length = read_int<std::uint32_t>(info_bytes, offset); // skip version info
auto index_at_start = offset; reader.read<std::uint32_t>();
/*auto skip_flags = */ read_int<std::uint32_t>(info_bytes, offset); reader.read<std::uint32_t>();
/*auto size_extra = */ read_int<std::uint32_t>(info_bytes, offset);
auto alg_id = read_int<std::uint32_t>(info_bytes, offset); auto header_length = reader.read<std::uint32_t>();
auto index_at_start = reader.offset();
/*auto skip_flags = */ reader.read<std::uint32_t>();
/*auto size_extra = */ reader.read<std::uint32_t>();
auto alg_id = reader.read<std::uint32_t>();
if (alg_id == 0 || alg_id == 0x0000660E || alg_id == 0x0000660F || alg_id == 0x00006610) if (alg_id == 0 || alg_id == 0x0000660E || alg_id == 0x0000660F || alg_id == 0x00006610)
{ {
@ -126,60 +129,59 @@ encryption_info read_standard_encryption_info(const std::vector<std::uint8_t> &i
throw xlnt::exception("invalid cipher algorithm"); throw xlnt::exception("invalid cipher algorithm");
} }
auto alg_id_hash = read_int<std::uint32_t>(info_bytes, offset); auto alg_id_hash = reader.read<std::uint32_t>();
if (alg_id_hash != 0x00008004 && alg_id_hash == 0) if (alg_id_hash != 0x00008004 && alg_id_hash == 0)
{ {
throw xlnt::exception("invalid hash algorithm"); throw xlnt::exception("invalid hash algorithm");
} }
standard_info.key_bits = read_int<std::uint32_t>(info_bytes, offset); standard_info.key_bits = reader.read<std::uint32_t>();
standard_info.key_bytes = standard_info.key_bits / 8; standard_info.key_bytes = standard_info.key_bits / 8;
auto provider_type = read_int<std::uint32_t>(info_bytes, offset); auto provider_type = reader.read<std::uint32_t>();
if (provider_type != 0 && provider_type != 0x00000018) if (provider_type != 0 && provider_type != 0x00000018)
{ {
throw xlnt::exception("invalid provider type"); throw xlnt::exception("invalid provider type");
} }
read_int<std::uint32_t>(info_bytes, offset); // reserved 1 reader.read<std::uint32_t>(); // reserved 1
if (read_int<std::uint32_t>(info_bytes, offset) != 0) // reserved 2 if (reader.read<std::uint32_t>() != 0) // reserved 2
{ {
throw xlnt::exception("invalid header"); throw xlnt::exception("invalid header");
} }
const auto csp_name_length = header_length - (offset - index_at_start); const auto csp_name_length = header_length - (reader.offset() - index_at_start);
std::vector<std::uint16_t> csp_name_wide( std::vector<std::uint16_t> csp_name_wide(
reinterpret_cast<const std::uint16_t *>(&*(info_bytes.begin() + static_cast<std::ptrdiff_t>(offset))), reinterpret_cast<const std::uint16_t *>(&*(info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset()))),
reinterpret_cast<const std::uint16_t *>( reinterpret_cast<const std::uint16_t *>(
&*(info_bytes.begin() + static_cast<std::ptrdiff_t>(offset + csp_name_length)))); &*(info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset() + csp_name_length))));
std::string csp_name(csp_name_wide.begin(), csp_name_wide.end() - 1); // without trailing null std::string csp_name(csp_name_wide.begin(), csp_name_wide.end() - 1); // without trailing null
if (csp_name != "Microsoft Enhanced RSA and AES Cryptographic Provider (Prototype)" if (csp_name != "Microsoft Enhanced RSA and AES Cryptographic Provider (Prototype)"
&& csp_name != "Microsoft Enhanced RSA and AES Cryptographic Provider") && csp_name != "Microsoft Enhanced RSA and AES Cryptographic Provider")
{ {
throw xlnt::exception("invalid cryptographic provider"); throw xlnt::exception("invalid cryptographic provider");
} }
offset += csp_name_length;
const auto salt_size = read_int<std::uint32_t>(info_bytes, offset); const auto salt_size = reader.read<std::uint32_t>();
standard_info.salt = std::vector<std::uint8_t>( standard_info.salt = std::vector<std::uint8_t>(
info_bytes.begin() + static_cast<std::ptrdiff_t>(offset), info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset()),
info_bytes.begin() + static_cast<std::ptrdiff_t>(offset + salt_size)); info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset() + salt_size));
offset += salt_size; reader.offset(reader.offset() + salt_size);
static const auto verifier_size = std::size_t(16); static const auto verifier_size = std::size_t(16);
standard_info.encrypted_verifier = std::vector<std::uint8_t>( standard_info.encrypted_verifier = std::vector<std::uint8_t>(
info_bytes.begin() + static_cast<std::ptrdiff_t>(offset), info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset()),
info_bytes.begin() + static_cast<std::ptrdiff_t>(offset + verifier_size)); info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset() + verifier_size));
offset += verifier_size; reader.offset(reader.offset() + verifier_size);
const auto verifier_hash_size = read_int<std::uint32_t>(info_bytes, offset); const auto verifier_hash_size = reader.read<std::uint32_t>();
const auto encrypted_verifier_hash_size = std::size_t(32); const auto encrypted_verifier_hash_size = std::size_t(32);
standard_info.encrypted_verifier_hash = std::vector<std::uint8_t>( standard_info.encrypted_verifier_hash = std::vector<std::uint8_t>(
info_bytes.begin() + static_cast<std::ptrdiff_t>(offset), info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset()),
info_bytes.begin() + static_cast<std::ptrdiff_t>(offset + encrypted_verifier_hash_size)); info_bytes.begin() + static_cast<std::ptrdiff_t>(reader.offset() + encrypted_verifier_hash_size));
offset += encrypted_verifier_hash_size; reader.offset(reader.offset() + encrypted_verifier_hash_size);
if (offset != info_bytes.size()) if (reader.offset() != info_bytes.size())
{ {
throw xlnt::exception("extra data after encryption info"); throw xlnt::exception("extra data after encryption info");
} }
@ -274,12 +276,11 @@ encryption_info read_encryption_info(const std::vector<std::uint8_t> &info_bytes
{ {
encryption_info info; encryption_info info;
using xlnt::detail::read_int; auto reader = xlnt::detail::byte_reader(info_bytes);
std::size_t offset = 0;
auto version_major = read_int<std::uint16_t>(info_bytes, offset); auto version_major = reader.read<std::uint16_t>();
auto version_minor = read_int<std::uint16_t>(info_bytes, offset); auto version_minor = reader.read<std::uint16_t>();
auto encryption_flags = read_int<std::uint32_t>(info_bytes, offset); auto encryption_flags = reader.read<std::uint32_t>();
// version 4.4 is agile // version 4.4 is agile
if (version_major == 4 && version_minor == 4) if (version_major == 4 && version_minor == 4)
@ -327,11 +328,11 @@ std::vector<std::uint8_t> decrypt_xlsx(
} }
xlnt::detail::compound_document document; xlnt::detail::compound_document document;
document.load(bytes); document.load(const_cast<std::vector<std::uint8_t> &>(bytes));
auto encryption_info = read_encryption_info(document.stream("EncryptionInfo")); auto encryption_info = read_encryption_info(document.stream(u"EncryptionInfo"));
encryption_info.password = password; encryption_info.password = password;
auto encrypted_package = document.stream("EncryptedPackage"); auto encrypted_package = document.stream(u"EncryptedPackage");
return encryption_info.is_agile return encryption_info.is_agile
? decrypt_xlsx_agile(encryption_info, encrypted_package) ? decrypt_xlsx_agile(encryption_info, encrypted_package)

View File

@ -171,10 +171,10 @@ std::vector<std::uint8_t> encrypt_xlsx(
xlnt::detail::compound_document document; xlnt::detail::compound_document document;
document.add_stream("EncryptionInfo", encryption_info.is_agile document.add_stream(u"EncryptionInfo", encryption_info.is_agile
? write_agile_encryption_info(encryption_info) ? write_agile_encryption_info(encryption_info)
: write_standard_encryption_info(encryption_info)); : write_standard_encryption_info(encryption_info));
document.add_stream("EncryptedPackage", encryption_info.is_agile document.add_stream(u"EncryptedPackage", encryption_info.is_agile
? encrypt_xlsx_agile(encryption_info, plaintext) ? encrypt_xlsx_agile(encryption_info, plaintext)
: encrypt_xlsx_standard(encryption_info, plaintext)); : encrypt_xlsx_standard(encryption_info, plaintext));