diff --git a/src/ZoneLoading/Game/IW4/ZoneLoaderFactoryIW4.cpp b/src/ZoneLoading/Game/IW4/ZoneLoaderFactoryIW4.cpp index 5634d742..050cd35a 100644 --- a/src/ZoneLoading/Game/IW4/ZoneLoaderFactoryIW4.cpp +++ b/src/ZoneLoading/Game/IW4/ZoneLoaderFactoryIW4.cpp @@ -16,7 +16,9 @@ #include "Game/IW4/GameAssetPoolIW4.h" #include "Game/IW4/GameIW4.h" #include "Game/GameLanguage.h" +#include "Loading/Processor/ProcessorAuthedBlocks.h" #include "Loading/Processor/ProcessorCaptureData.h" +#include "Loading/Processor/ProcessorInflate.h" #include "Loading/Steps/StepLoadHash.h" #include "Loading/Steps/StepRemoveProcessor.h" #include "Loading/Steps/StepVerifyHash.h" @@ -25,13 +27,7 @@ const std::string ZoneLoaderFactoryIW4::MAGIC_SIGNED_INFINITY_WARD = "IWff0100"; const std::string ZoneLoaderFactoryIW4::MAGIC_UNSIGNED = "IWffu100"; const int ZoneLoaderFactoryIW4::VERSION = 276; -const int ZoneLoaderFactoryIW4::STREAM_COUNT = 4; -const int ZoneLoaderFactoryIW4::VANILLA_BUFFER_SIZE = 0x80000; -const int ZoneLoaderFactoryIW4::OFFSET_BLOCK_BIT_COUNT = 4; -const block_t ZoneLoaderFactoryIW4::INSERT_BLOCK = IW4::XFILE_BLOCK_VIRTUAL; - const std::string ZoneLoaderFactoryIW4::MAGIC_AUTH_HEADER = "IWffs100"; - const uint8_t ZoneLoaderFactoryIW4::RSA_PUBLIC_KEY_INFINITY_WARD[] { 0x30, 0x82, 0x01, 0x0A, 0x02, 0x82, 0x01, 0x01, @@ -70,6 +66,12 @@ const uint8_t ZoneLoaderFactoryIW4::RSA_PUBLIC_KEY_INFINITY_WARD[] 0x89, 0x02, 0x03, 0x01, 0x00, 0x01 }; +const size_t ZoneLoaderFactoryIW4::AUTHED_CHUNK_SIZE = 0x2000; +const size_t ZoneLoaderFactoryIW4::AUTHED_CHUNK_COUNT_PER_GROUP = 256; + +const int ZoneLoaderFactoryIW4::OFFSET_BLOCK_BIT_COUNT = 4; +const block_t ZoneLoaderFactoryIW4::INSERT_BLOCK = IW4::XFILE_BLOCK_VIRTUAL; + class ZoneLoaderFactoryIW4::Impl { static GameLanguage GetZoneLanguage(std::string& zoneName) @@ -146,15 +148,15 @@ class ZoneLoaderFactoryIW4::Impl } } - static IHashProvider* AddAuthHeaderSteps(const bool isSecure, const bool isOfficial, ZoneLoader* zoneLoader, std::string& fileName) + static void AddAuthHeaderSteps(const bool isSecure, const bool isOfficial, ZoneLoader* zoneLoader, + std::string& fileName) { // Unsigned zones do not have an auth header if (!isSecure) - return nullptr; + return; // If file is signed setup a RSA instance. IPublicKeyAlgorithm* rsa = SetupRSA(isOfficial); - auto sha256 = std::unique_ptr(Crypto::CreateSHA256()); zoneLoader->AddLoadingStep(new StepVerifyMagic(MAGIC_AUTH_HEADER.c_str())); zoneLoader->AddLoadingStep(new StepSkipBytes(4)); // Skip reserved @@ -172,13 +174,21 @@ class ZoneLoaderFactoryIW4::Impl zoneLoader->AddLoadingStep(new StepVerifyFileName(fileName, sizeof IW4::DB_AuthSubHeader::fastfileName)); zoneLoader->AddLoadingStep(new StepSkipBytes(4)); // Skip reserved - auto* masterBlockHashes = new StepLoadHash(sizeof IW4::DB_AuthHash::bytes, _countof(IW4::DB_AuthSubHeader::masterBlockHashes)); + auto* masterBlockHashes = new StepLoadHash(sizeof IW4::DB_AuthHash::bytes, + _countof(IW4::DB_AuthSubHeader::masterBlockHashes)); zoneLoader->AddLoadingStep(masterBlockHashes); zoneLoader->AddLoadingStep(new StepRemoveProcessor(subHeaderCapture)); - zoneLoader->AddLoadingStep(new StepVerifyHash(std::move(sha256), 0, subheaderHash, subHeaderCapture)); + zoneLoader->AddLoadingStep(new StepVerifyHash(std::unique_ptr(Crypto::CreateSHA256()), 0, + subheaderHash, subHeaderCapture)); - return masterBlockHashes; + // Skip the rest of the first chunk + zoneLoader->AddLoadingStep(new StepSkipBytes(AUTHED_CHUNK_SIZE - sizeof(IW4::DB_AuthHeader))); + + zoneLoader->AddLoadingStep(new StepAddProcessor(new ProcessorAuthedBlocks( + AUTHED_CHUNK_COUNT_PER_GROUP, AUTHED_CHUNK_SIZE, _countof(IW4::DB_AuthSubHeader::masterBlockHashes), + std::unique_ptr(Crypto::CreateSHA256()), + masterBlockHashes))); } public: @@ -207,7 +217,9 @@ public: zoneLoader->AddLoadingStep(new StepSkipBytes(8)); // Add steps for loading the auth header which also contain the signature of the zone if it is signed. - IHashProvider* masterBlockHashProvider = AddAuthHeaderSteps(isSecure, isOfficial, zoneLoader, fileName); + AddAuthHeaderSteps(isSecure, isOfficial, zoneLoader, fileName); + + zoneLoader->AddLoadingStep(new StepAddProcessor(new ProcessorInflate(AUTHED_CHUNK_SIZE))); // Start of the XFile struct zoneLoader->AddLoadingStep(new StepSkipBytes(8)); @@ -218,11 +230,6 @@ public: zoneLoader->AddLoadingStep( new StepLoadZoneContent(new ContentLoaderIW4(), zone, OFFSET_BLOCK_BIT_COUNT, INSERT_BLOCK)); - /*if (isSecure) - { - zoneLoader->AddLoadingStep(new StepVerifySignature(rsa, signatureProvider, signatureDataProvider)); - }*/ - // Return the fully setup zoneloader return zoneLoader; } diff --git a/src/ZoneLoading/Game/IW4/ZoneLoaderFactoryIW4.h b/src/ZoneLoading/Game/IW4/ZoneLoaderFactoryIW4.h index 18d3bdc1..7e279867 100644 --- a/src/ZoneLoading/Game/IW4/ZoneLoaderFactoryIW4.h +++ b/src/ZoneLoading/Game/IW4/ZoneLoaderFactoryIW4.h @@ -9,14 +9,15 @@ class ZoneLoaderFactoryIW4 final : public IZoneLoaderFactory static const std::string MAGIC_UNSIGNED; static const int VERSION; - static const int STREAM_COUNT; - static const int VANILLA_BUFFER_SIZE; - static const int OFFSET_BLOCK_BIT_COUNT; - static const block_t INSERT_BLOCK; - static const std::string MAGIC_AUTH_HEADER; static const uint8_t RSA_PUBLIC_KEY_INFINITY_WARD[]; + static const size_t AUTHED_CHUNK_SIZE; + static const unsigned AUTHED_CHUNK_COUNT_PER_GROUP; + + static const int OFFSET_BLOCK_BIT_COUNT; + static const block_t INSERT_BLOCK; + class Impl; public: diff --git a/src/ZoneLoading/Loading/Exception/TooManyAuthedGroupsException.cpp b/src/ZoneLoading/Loading/Exception/TooManyAuthedGroupsException.cpp new file mode 100644 index 00000000..168829f6 --- /dev/null +++ b/src/ZoneLoading/Loading/Exception/TooManyAuthedGroupsException.cpp @@ -0,0 +1,11 @@ +#include "TooManyAuthedGroupsException.h" + +std::string TooManyAuthedGroupsException::DetailedMessage() +{ + return "Loaded fastfile has too many authed groups."; +} + +char const* TooManyAuthedGroupsException::what() const +{ + return "Loaded fastfile has too many authed groups."; +} \ No newline at end of file diff --git a/src/ZoneLoading/Loading/Exception/TooManyAuthedGroupsException.h b/src/ZoneLoading/Loading/Exception/TooManyAuthedGroupsException.h new file mode 100644 index 00000000..9a2b2cf2 --- /dev/null +++ b/src/ZoneLoading/Loading/Exception/TooManyAuthedGroupsException.h @@ -0,0 +1,9 @@ +#pragma once +#include "LoadingException.h" + +class TooManyAuthedGroupsException final : public LoadingException +{ +public: + std::string DetailedMessage() override; + char const* what() const override; +}; diff --git a/src/ZoneLoading/Loading/Processor/ProcessorAuthedBlocks.cpp b/src/ZoneLoading/Loading/Processor/ProcessorAuthedBlocks.cpp index 506652b1..5b690f79 100644 --- a/src/ZoneLoading/Loading/Processor/ProcessorAuthedBlocks.cpp +++ b/src/ZoneLoading/Loading/Processor/ProcessorAuthedBlocks.cpp @@ -1,32 +1,148 @@ #include "ProcessorAuthedBlocks.h" -class ProcessorAuthedBlocks::Impl final : public StreamProcessor +#include +#include + + +#include "Game/IW4/IW4.h" +#include "Loading/Exception/InvalidHashException.h" +#include "Loading/Exception/TooManyAuthedGroupsException.h" +#include "Loading/Exception/UnexpectedEndOfFileException.h" + +class ProcessorAuthedBlocks::Impl { - const int m_authed_chunk_count; - const int m_max_master_block_count; - IHashProvider* m_hash_provider; + ProcessorAuthedBlocks* const m_base; + + const unsigned m_authed_chunk_count; + const size_t m_chunk_size; + const unsigned m_max_master_block_count; + + const std::unique_ptr m_hash_function; + IHashProvider* const m_master_block_hash_provider; + const std::unique_ptr m_chunk_hashes_buffer; + const std::unique_ptr m_current_chunk_hash_buffer; + + const std::unique_ptr m_chunk_buffer; + unsigned m_current_group; + unsigned m_current_chunk_in_group; + + size_t m_current_chunk_offset; + size_t m_current_chunk_size; public: - Impl(const int authedChunkCount, const int maxMasterBlockCount, IHashProvider* masterBlockHashProvider) - : m_authed_chunk_count(authedChunkCount), + Impl(ProcessorAuthedBlocks* base, const unsigned authedChunkCount, const size_t chunkSize, + const unsigned maxMasterBlockCount, + std::unique_ptr hashFunction, + IHashProvider* masterBlockHashProvider) + : m_base(base), + m_authed_chunk_count(authedChunkCount), + m_chunk_size(chunkSize), m_max_master_block_count(maxMasterBlockCount), - m_hash_provider(masterBlockHashProvider) + m_hash_function(std::move(hashFunction)), + m_master_block_hash_provider(masterBlockHashProvider), + m_chunk_hashes_buffer(std::make_unique(m_authed_chunk_count * m_hash_function->GetHashSize())), + m_current_chunk_hash_buffer(std::make_unique(m_hash_function->GetHashSize())), + m_chunk_buffer(std::make_unique(m_chunk_size)), + m_current_group(1), + m_current_chunk_in_group(0), + m_current_chunk_offset(0), + m_current_chunk_size(0) { + assert(m_authed_chunk_count * m_hash_function->GetHashSize() <= m_chunk_size); } - size_t Load(void* buffer, size_t length) override + bool NextChunk() { - return 0; + m_current_chunk_offset = 0; + + while (true) + { + m_current_chunk_size = m_base->m_base_stream->Load(m_chunk_buffer.get(), m_chunk_size); + + if (m_current_chunk_size == 0) + return false; + + m_hash_function->Init(); + m_hash_function->Process(m_chunk_buffer.get(), m_current_chunk_size); + m_hash_function->Finish(m_current_chunk_hash_buffer.get()); + + if (m_current_chunk_in_group == 0) + { + if (m_current_chunk_size < m_authed_chunk_count * m_hash_function->GetHashSize()) + throw UnexpectedEndOfFileException(); + + const uint8_t* masterBlockHash = nullptr; + size_t masterBlockHashSize = 0; + m_master_block_hash_provider->GetHash(m_current_group - 1, &masterBlockHash, &masterBlockHashSize); + + if (masterBlockHashSize != m_hash_function->GetHashSize() + || std::memcmp(m_current_chunk_hash_buffer.get(), masterBlockHash, + m_hash_function->GetHashSize()) != 0) + throw InvalidHashException(); + + memcpy_s(m_chunk_hashes_buffer.get(), m_authed_chunk_count * m_hash_function->GetHashSize(), + m_chunk_buffer.get(), m_authed_chunk_count * m_hash_function->GetHashSize()); + + m_current_chunk_in_group++; + } + else + { + if (std::memcmp(m_current_chunk_hash_buffer.get(), + &m_chunk_hashes_buffer[(m_current_chunk_in_group - 1) * m_hash_function->GetHashSize()], + m_hash_function->GetHashSize()) != 0) + throw InvalidHashException(); + + if (++m_current_chunk_in_group > m_authed_chunk_count) + { + m_current_chunk_in_group = 0; + m_current_group++; + + if (m_current_group > m_max_master_block_count) + throw TooManyAuthedGroupsException(); + } + + return true; + } + } } - int64_t Pos() override + size_t Load(void* buffer, const size_t length) { - return 0; + size_t loadedSize = 0; + + while (loadedSize < length) + { + if (m_current_chunk_offset >= m_current_chunk_size) + { + if (!NextChunk()) + return loadedSize; + } + + size_t sizeToWrite = length - loadedSize; + if (sizeToWrite > m_current_chunk_size - m_current_chunk_offset) + sizeToWrite = m_current_chunk_size - m_current_chunk_offset; + + memcpy_s(&static_cast(buffer)[loadedSize], length - loadedSize, + &m_chunk_buffer[m_current_chunk_offset], sizeToWrite); + loadedSize += sizeToWrite; + m_current_chunk_offset += sizeToWrite; + } + + return loadedSize; + } + + int64_t Pos() + { + return m_base->m_base_stream->Pos() - (m_current_chunk_size - m_current_chunk_offset); } }; -ProcessorAuthedBlocks::ProcessorAuthedBlocks(const int authedChunkCount, const int maxMasterBlockCount, IHashProvider* masterBlockHashProvider) - : m_impl(new Impl(authedChunkCount, maxMasterBlockCount, masterBlockHashProvider)) +ProcessorAuthedBlocks::ProcessorAuthedBlocks(const unsigned authedChunkCount, const size_t chunkSize, + const unsigned maxMasterBlockCount, + std::unique_ptr hashFunction, + IHashProvider* masterBlockHashProvider) + : m_impl(new Impl(this, authedChunkCount, chunkSize, maxMasterBlockCount, std::move(hashFunction), + masterBlockHashProvider)) { } diff --git a/src/ZoneLoading/Loading/Processor/ProcessorAuthedBlocks.h b/src/ZoneLoading/Loading/Processor/ProcessorAuthedBlocks.h index 03db7597..b7be3891 100644 --- a/src/ZoneLoading/Loading/Processor/ProcessorAuthedBlocks.h +++ b/src/ZoneLoading/Loading/Processor/ProcessorAuthedBlocks.h @@ -1,15 +1,23 @@ #pragma once +#include + +#include "Crypto.h" #include "Loading/StreamProcessor.h" #include "Loading/IHashProvider.h" -class ProcessorAuthedBlocks : public StreamProcessor +class ProcessorAuthedBlocks final : public StreamProcessor { class Impl; Impl* m_impl; public: - ProcessorAuthedBlocks(int authedChunkCount, int maxMasterBlockCount, IHashProvider* masterBlockHashProvider); + ProcessorAuthedBlocks(unsigned authedChunkCount, size_t chunkSize, unsigned maxMasterBlockCount, + std::unique_ptr hashFunction, IHashProvider* masterBlockHashProvider); ~ProcessorAuthedBlocks() override; + ProcessorAuthedBlocks(const ProcessorAuthedBlocks& other) = delete; + ProcessorAuthedBlocks(ProcessorAuthedBlocks&& other) noexcept = default; + ProcessorAuthedBlocks& operator=(const ProcessorAuthedBlocks& other) = delete; + ProcessorAuthedBlocks& operator=(ProcessorAuthedBlocks&& other) noexcept = default; size_t Load(void* buffer, size_t length) override; int64_t Pos() override; diff --git a/src/ZoneLoading/Loading/Processor/ProcessorInflate.cpp b/src/ZoneLoading/Loading/Processor/ProcessorInflate.cpp index 9a933489..cf69193c 100644 --- a/src/ZoneLoading/Loading/Processor/ProcessorInflate.cpp +++ b/src/ZoneLoading/Loading/Processor/ProcessorInflate.cpp @@ -1,17 +1,25 @@ #include "ProcessorInflate.h" -#include "zlib.h" -#include -#include "zutil.h" -#include -class ProcessorInflate::ProcessorInflateImpl +#include +#include +#include +#include +#include + +#include "Loading/Exception/InvalidCompressionException.h" + +class ProcessorInflate::Impl { z_stream m_stream{}; - uint8_t m_in_buffer[0x800]; ProcessorInflate* m_base; + std::unique_ptr m_buffer; + size_t m_buffer_size; + public: - ProcessorInflateImpl(ProcessorInflate* baseClass) + Impl(ProcessorInflate* baseClass, const size_t bufferSize) + : m_buffer(std::make_unique(bufferSize)), + m_buffer_size(bufferSize) { m_base = baseClass; @@ -21,35 +29,43 @@ public: m_stream.avail_in = 0; m_stream.next_in = Z_NULL; - const int ret = inflateInit2(&m_stream, -DEF_WBITS); + const int ret = inflateInit(&m_stream); - if(ret != Z_OK) + if (ret != Z_OK) { throw std::exception("Initializing inflate failed"); } } - ~ProcessorInflateImpl() + ~Impl() { inflateEnd(&m_stream); } - size_t Load(void* buffer, size_t length) + Impl(const Impl& other) = delete; + Impl(Impl&& other) noexcept = default; + Impl& operator=(const Impl& other) = delete; + Impl& operator=(Impl&& other) noexcept = default; + + size_t Load(void* buffer, const size_t length) { m_stream.next_out = static_cast(buffer); m_stream.avail_out = length; - while(m_stream.avail_out > 0) + while (m_stream.avail_out > 0) { - if(m_stream.avail_in == 0) + if (m_stream.avail_in == 0) { - m_stream.avail_in = m_base->m_base_stream->Load(m_in_buffer, sizeof(m_in_buffer)); + m_stream.avail_in = m_base->m_base_stream->Load(m_buffer.get(), m_buffer_size); - if(m_stream.avail_in == 0) // EOF + if (m_stream.avail_in == 0) // EOF return length - m_stream.avail_out; } - inflate(&m_stream, Z_FULL_FLUSH); + auto ret = inflate(&m_stream, Z_SYNC_FLUSH); + + if(ret < 0) + throw InvalidCompressionException(); } return m_stream.avail_out; @@ -57,8 +73,13 @@ public: }; ProcessorInflate::ProcessorInflate() + : ProcessorInflate(DEFAULT_BUFFER_SIZE) +{ +} + +ProcessorInflate::ProcessorInflate(const size_t bufferSize) + : m_impl(new Impl(this, bufferSize)) { - m_impl = new ProcessorInflateImpl(this); } ProcessorInflate::~ProcessorInflate() @@ -70,4 +91,9 @@ ProcessorInflate::~ProcessorInflate() size_t ProcessorInflate::Load(void* buffer, const size_t length) { return m_impl->Load(buffer, length); -} \ No newline at end of file +} + +int64_t ProcessorInflate::Pos() +{ + return m_base_stream->Pos(); +} diff --git a/src/ZoneLoading/Loading/Processor/ProcessorInflate.h b/src/ZoneLoading/Loading/Processor/ProcessorInflate.h index f81d41d3..f1e17ed5 100644 --- a/src/ZoneLoading/Loading/Processor/ProcessorInflate.h +++ b/src/ZoneLoading/Loading/Processor/ProcessorInflate.h @@ -3,12 +3,20 @@ class ProcessorInflate final : public StreamProcessor { - class ProcessorInflateImpl; - ProcessorInflateImpl* m_impl; + class Impl; + Impl* m_impl; + + static constexpr size_t DEFAULT_BUFFER_SIZE = 0x2000; public: ProcessorInflate(); + ProcessorInflate(size_t bufferSize); ~ProcessorInflate() override; + ProcessorInflate(const ProcessorInflate& other) = delete; + ProcessorInflate(ProcessorInflate&& other) noexcept = default; + ProcessorInflate& operator=(const ProcessorInflate& other) = delete; + ProcessorInflate& operator=(ProcessorInflate&& other) noexcept = default; size_t Load(void* buffer, size_t length) override; + int64_t Pos() override; };