diff --git a/src/driver/allocator.hpp b/src/driver/allocator.hpp index ab976db..0ea6756 100644 --- a/src/driver/allocator.hpp +++ b/src/driver/allocator.hpp @@ -4,7 +4,7 @@ namespace utils { template - concept IsAllocator = requires(size_t size, void* ptr) + concept is_allocator = requires(size_t size, void* ptr) { T().free(T().allocate(size)); T().free(ptr); diff --git a/src/driver/ept.cpp b/src/driver/ept.cpp index 1ed057b..5954257 100644 --- a/src/driver/ept.cpp +++ b/src/driver/ept.cpp @@ -95,18 +95,16 @@ namespace vmx } } - void reset_all_watch_point_pages(ept_code_watch_point* watch_point) + void reset_all_watch_point_pages(utils::list& watch_points) { - while (watch_point) + for(const auto& watch_point : watch_points) { - if (watch_point->target_page) + if (watch_point.target_page) { - watch_point->target_page->write_access = 0; - watch_point->target_page->read_access = 0; - watch_point->target_page->execute_access = 1; + watch_point.target_page->write_access = 0; + watch_point.target_page->read_access = 0; + watch_point.target_page->execute_access = 1; } - - watch_point = watch_point->next_watch_point; } } } @@ -142,29 +140,7 @@ namespace vmx ept::~ept() { - auto* split = this->ept_splits; - while (split) - { - auto* current_split = split; - split = split->next_split; - memory::free_aligned_object(current_split); - } - - auto* hook = this->ept_hooks; - while (hook) - { - auto* current_hook = hook; - hook = hook->next_hook; - memory::free_aligned_object(current_hook); - } - - auto* watch_point = this->ept_code_watch_points; - while (watch_point) - { - auto* current_watch_point = watch_point; - watch_point = watch_point->next_watch_point; - memory::free_non_paged_object(current_watch_point); - } + this->disable_all_hooks(); } void ept::install_page_hook(void* destination, const void* source, const size_t length, @@ -230,11 +206,9 @@ namespace vmx void ept::disable_all_hooks() const { - auto* hook = this->ept_hooks; - while (hook) + for(auto& hook : this->ept_hooks) { - hook->target_page->flags = hook->original_entry.flags; - hook = hook->next_hook; + hook.target_page->flags = hook.original_entry.flags; } } @@ -368,23 +342,19 @@ namespace vmx return; } - auto* watch_point = this->allocate_ept_code_watch_point(); - if (!watch_point) - { - throw std::runtime_error("Failed to allocate watch point"); - } + auto& watch_point = this->allocate_ept_code_watch_point(); this->split_large_page(physical_base_address); - watch_point->physical_base_address = physical_base_address; - watch_point->target_page = this->get_pml1_entry(physical_base_address); - if (!watch_point->target_page) + watch_point.physical_base_address = physical_base_address; + watch_point.target_page = this->get_pml1_entry(physical_base_address); + if (!watch_point.target_page) { throw std::runtime_error("Failed to get PML1 entry for target address"); } - watch_point->target_page->write_access = 0; - watch_point->target_page->read_access = 0; + watch_point.target_page->write_access = 0; + watch_point.target_page->read_access = 0; } ept_pointer ept::get_ept_pointer() const @@ -448,91 +418,55 @@ namespace vmx return &pml1[ADDRMASK_EPT_PML1_INDEX(physical_address)]; } - pml1* ept::find_pml1_table(const uint64_t physical_address) const + pml1* ept::find_pml1_table(const uint64_t physical_address) { - auto* split = this->ept_splits; - while (split) + for(auto& split : this->ept_splits) { - if (memory::get_physical_address(&split->pml1[0]) == physical_address) + if (memory::get_physical_address(&split.pml1[0]) == physical_address) { - return split->pml1; + return split.pml1; } - - split = split->next_split; } return nullptr; } - ept_split* ept::allocate_ept_split() + ept_split& ept::allocate_ept_split() { - auto* split = memory::allocate_aligned_object(); - if (!split) - { - throw std::runtime_error("Failed to allocate ept split object"); - } - - split->next_split = this->ept_splits; - this->ept_splits = split; - - return split; + return this->ept_splits.emplace_back(); } - ept_hook* ept::allocate_ept_hook(const uint64_t physical_address) + ept_hook& ept::allocate_ept_hook(const uint64_t physical_address) { - auto* hook = memory::allocate_aligned_object(physical_address); - if (!hook) - { - throw std::runtime_error("Failed to allocate ept hook object"); - } - - hook->next_hook = this->ept_hooks; - this->ept_hooks = hook; - - return hook; + return this->ept_hooks.emplace_back(physical_address); } - ept_hook* ept::find_ept_hook(const uint64_t physical_address) const + ept_hook* ept::find_ept_hook(const uint64_t physical_address) { - auto* hook = this->ept_hooks; - while (hook) + for (auto& hook : this->ept_hooks) { - if (hook->physical_base_address == physical_address) + if (hook.physical_base_address == physical_address) { - return hook; + return &hook; } - - hook = hook->next_hook; } return nullptr; } - ept_code_watch_point* ept::allocate_ept_code_watch_point() + ept_code_watch_point& ept::allocate_ept_code_watch_point() { - auto* watch_point = memory::allocate_non_paged_object(); - if (!watch_point) - { - throw std::runtime_error("Failed to allocate ept watch point object"); - } - - watch_point->next_watch_point = this->ept_code_watch_points; - this->ept_code_watch_points = watch_point; - - return watch_point; + return this->ept_code_watch_points.emplace_back(); } - ept_code_watch_point* ept::find_ept_code_watch_point(const uint64_t physical_address) const + ept_code_watch_point* ept::find_ept_code_watch_point(const uint64_t physical_address) { - auto* watch_point = this->ept_code_watch_points; - while (watch_point) + for(auto& watch_point : this->ept_code_watch_points) { - if (watch_point->physical_base_address == physical_address) + if (watch_point.physical_base_address == physical_address) { - return watch_point; + return &watch_point; } - - watch_point = watch_point->next_watch_point; } return nullptr; @@ -573,12 +507,7 @@ namespace vmx return hook; } - hook = this->allocate_ept_hook(physical_base_address); - - if (!hook) - { - throw std::runtime_error("Failed to allocate hook"); - } + hook = &this->allocate_ept_hook(physical_base_address); this->split_large_page(physical_address); @@ -624,7 +553,7 @@ namespace vmx return; } - auto* split = this->allocate_ept_split(); + auto& split = this->allocate_ept_split(); epte pml1_template{}; pml1_template.flags = 0; @@ -635,11 +564,11 @@ namespace vmx pml1_template.ignore_pat = target_entry->ignore_pat; pml1_template.suppress_ve = target_entry->suppress_ve; - __stosq(reinterpret_cast(&split->pml1[0]), pml1_template.flags, EPT_PTE_ENTRY_COUNT); + __stosq(reinterpret_cast(&split.pml1[0]), pml1_template.flags, EPT_PTE_ENTRY_COUNT); for (auto i = 0; i < EPT_PTE_ENTRY_COUNT; ++i) { - split->pml1[i].page_frame_number = ((target_entry->page_frame_number * 2_mb) / PAGE_SIZE) + i; + split.pml1[i].page_frame_number = ((target_entry->page_frame_number * 2_mb) / PAGE_SIZE) + i; } pml2_ptr new_pointer{}; @@ -648,7 +577,7 @@ namespace vmx new_pointer.write_access = 1; new_pointer.execute_access = 1; - new_pointer.page_frame_number = memory::get_physical_address(&split->pml1[0]) / PAGE_SIZE; + new_pointer.page_frame_number = memory::get_physical_address(&split.pml1[0]) / PAGE_SIZE; target_entry->flags = new_pointer.flags; } diff --git a/src/driver/ept.hpp b/src/driver/ept.hpp index 8e36379..5feef1f 100644 --- a/src/driver/ept.hpp +++ b/src/driver/ept.hpp @@ -20,20 +20,17 @@ namespace vmx pml2 entry{}; pml2_ptr pointer; }; - - ept_split* next_split{nullptr}; }; struct ept_code_watch_point { uint64_t physical_base_address{}; pml1* target_page{}; - ept_code_watch_point* next_watch_point{nullptr}; }; struct ept_hook { - ept_hook(const uint64_t physical_base); + ept_hook(uint64_t physical_base); ~ept_hook(); DECLSPEC_PAGE_ALIGN uint8_t fake_page[PAGE_SIZE]{}; @@ -46,8 +43,6 @@ namespace vmx pml1 original_entry{}; pml1 execute_entry{}; pml1 readwrite_entry{}; - - ept_hook* next_hook{nullptr}; }; struct ept_translation_hint @@ -76,7 +71,7 @@ namespace vmx void install_code_watch_point(uint64_t physical_page); void install_hook(const void* destination, const void* source, size_t length, - const utils::list& hints = {}); + const utils::list& hints = {}); void disable_all_hooks() const; void handle_violation(guest_context& guest_context); @@ -96,20 +91,20 @@ namespace vmx uint64_t access_records[1024]; - ept_split* ept_splits{nullptr}; - ept_hook* ept_hooks{nullptr}; - ept_code_watch_point* ept_code_watch_points{nullptr}; + utils::list ept_splits{}; + utils::list ept_hooks{}; + utils::list ept_code_watch_points{}; pml2* get_pml2_entry(uint64_t physical_address); pml1* get_pml1_entry(uint64_t physical_address); - pml1* find_pml1_table(uint64_t physical_address) const; + pml1* find_pml1_table(uint64_t physical_address); - ept_split* allocate_ept_split(); - ept_hook* allocate_ept_hook(uint64_t physical_address); - ept_hook* find_ept_hook(uint64_t physical_address) const; + ept_split& allocate_ept_split(); + ept_hook& allocate_ept_hook(uint64_t physical_address); + ept_hook* find_ept_hook(uint64_t physical_address); - ept_code_watch_point* allocate_ept_code_watch_point(); - ept_code_watch_point* find_ept_code_watch_point(uint64_t physical_address) const; + ept_code_watch_point& allocate_ept_code_watch_point(); + ept_code_watch_point* find_ept_code_watch_point(uint64_t physical_address); ept_hook* get_or_create_ept_hook(void* destination, const ept_translation_hint* translation_hint = nullptr); diff --git a/src/driver/list.hpp b/src/driver/list.hpp index d353d4b..9f31b02 100644 --- a/src/driver/list.hpp +++ b/src/driver/list.hpp @@ -6,13 +6,13 @@ namespace utils { template - requires IsAllocator && IsAllocator + requires is_allocator && is_allocator class list { - struct ListEntry + struct list_entry { T* entry{nullptr}; - ListEntry* next{nullptr}; + list_entry* next{nullptr}; void* this_base{nullptr}; void* entry_base{nullptr}; @@ -26,7 +26,7 @@ namespace utils friend list; public: - iterator(ListEntry* entry = nullptr) + iterator(list_entry* entry = nullptr) : entry_(entry) { } @@ -71,9 +71,9 @@ namespace utils } private: - ListEntry* entry_{nullptr}; + list_entry* entry_{nullptr}; - ListEntry* get_entry() const + list_entry* get_entry() const { return entry_; } @@ -84,7 +84,7 @@ namespace utils friend list; public: - const_iterator(ListEntry* entry = nullptr) + const_iterator(list_entry* entry = nullptr) : entry_(entry) { } @@ -111,9 +111,9 @@ namespace utils } private: - ListEntry* entry_{nullptr}; + list_entry* entry_{nullptr}; - ListEntry* get_entry() const + list_entry* get_entry() const { return entry_; } @@ -274,22 +274,22 @@ namespace utils iterator erase(iterator iterator) { auto* list_entry = iterator.get_entry(); - auto** inseration_point = &this->entries_; - while (*inseration_point && list_entry) + auto** insertion_point = &this->entries_; + while (*insertion_point && list_entry) { - if (*inseration_point != list_entry) + if (*insertion_point != list_entry) { - inseration_point = &(*inseration_point)->next; + insertion_point = &(*insertion_point)->next; continue; } - *inseration_point = list_entry->next; + *insertion_point = list_entry->next; list_entry->entry->~T(); this->object_allocator_.free(list_entry->entry_base); this->list_allocator_.free(list_entry->this_base); - return {*inseration_point}; + return {*insertion_point}; } throw std::runtime_error("Bad iterator"); @@ -305,7 +305,7 @@ namespace utils ObjectAllocator object_allocator_{}; ListAllocator list_allocator_{}; - ListEntry* entries_{nullptr}; + list_entry* entries_{nullptr}; template static U* align_pointer(V* pointer) @@ -317,28 +317,63 @@ namespace utils return reinterpret_cast(ptr); } - T& add_uninitialized_entry() + void allocate_entry(void*& list_base, void* entry_base) { - auto** inseration_point = &this->entries_; - while (*inseration_point) + list_base = nullptr; + entry_base = nullptr; + + auto destructor = utils::finally([&] { - inseration_point = &(*inseration_point)->next; + if (list_base) + { + this->list_allocator_.free(list_base); + } + + if (entry_base) + { + this->object_allocator_.free(entry_base); + } + }); + + list_base = this->list_allocator_.allocate(sizeof(list_entry) + alignof(list_entry)); + if (!list_base) + { + throw std::runtime_error("Memory allocation failed"); } - auto* list_base = this->list_allocator_.allocate(sizeof(ListEntry) + alignof(ListEntry)); - auto* entry_base = this->object_allocator_.allocate(sizeof(T) + alignof(T)); + entry_base = this->object_allocator_.allocate(sizeof(T) + alignof(T)); + if (!entry_base) + { + throw std::runtime_error("Memory allocation failed"); + } - auto* entry = align_pointer(entry_base); - auto* list_entry = align_pointer(list_base); + destructor.cancel(); + } - list_entry->this_base = list_base; - list_entry->entry_base = entry_base; - list_entry->next = nullptr; - list_entry->entry = entry; - *inseration_point = list_entry; + T& add_uninitialized_entry() + { + void* list_base = {}; + void* entry_base = {}; + this->allocate_entry(list_base, entry_base); - return *entry; + auto** insertion_point = &this->entries_; + while (*insertion_point) + { + insertion_point = &(*insertion_point)->next; + } + + auto* obj = align_pointer(entry_base); + auto* entry = align_pointer(list_base); + + entry->this_base = list_base; + entry->entry_base = entry_base; + entry->next = nullptr; + entry->entry = obj; + + *insertion_point = entry; + + return *obj; } }; } diff --git a/src/driver/vector.hpp b/src/driver/vector.hpp index 023b8b8..0d6348c 100644 --- a/src/driver/vector.hpp +++ b/src/driver/vector.hpp @@ -6,7 +6,7 @@ namespace utils { template - requires IsAllocator + requires is_allocator class vector { public: @@ -258,6 +258,11 @@ namespace utils { constexpr auto alignment = alignof(T); auto* memory = this->allocator_.allocate(capacity * sizeof(T) + alignment); + if (!memory) + { + throw std::runtime_error("Failed to allocate memory"); + } + return memory; }