Use containers for ept allocations

This commit is contained in:
Maurice Heumann 2022-12-24 08:36:23 +01:00
parent 33b44f1dc1
commit 1d23c10734
5 changed files with 121 additions and 157 deletions

View File

@ -4,7 +4,7 @@
namespace utils namespace utils
{ {
template <typename T> template <typename T>
concept IsAllocator = requires(size_t size, void* ptr) concept is_allocator = requires(size_t size, void* ptr)
{ {
T().free(T().allocate(size)); T().free(T().allocate(size));
T().free(ptr); T().free(ptr);

View File

@ -95,18 +95,16 @@ namespace vmx
} }
} }
void reset_all_watch_point_pages(ept_code_watch_point* watch_point) void reset_all_watch_point_pages(utils::list<ept_code_watch_point>& watch_points)
{ {
while (watch_point) for(const auto& watch_point : watch_points)
{ {
if (watch_point->target_page) if (watch_point.target_page)
{ {
watch_point->target_page->write_access = 0; watch_point.target_page->write_access = 0;
watch_point->target_page->read_access = 0; watch_point.target_page->read_access = 0;
watch_point->target_page->execute_access = 1; watch_point.target_page->execute_access = 1;
} }
watch_point = watch_point->next_watch_point;
} }
} }
} }
@ -142,29 +140,7 @@ namespace vmx
ept::~ept() ept::~ept()
{ {
auto* split = this->ept_splits; this->disable_all_hooks();
while (split)
{
auto* current_split = split;
split = split->next_split;
memory::free_aligned_object(current_split);
}
auto* hook = this->ept_hooks;
while (hook)
{
auto* current_hook = hook;
hook = hook->next_hook;
memory::free_aligned_object(current_hook);
}
auto* watch_point = this->ept_code_watch_points;
while (watch_point)
{
auto* current_watch_point = watch_point;
watch_point = watch_point->next_watch_point;
memory::free_non_paged_object(current_watch_point);
}
} }
void ept::install_page_hook(void* destination, const void* source, const size_t length, void ept::install_page_hook(void* destination, const void* source, const size_t length,
@ -230,11 +206,9 @@ namespace vmx
void ept::disable_all_hooks() const void ept::disable_all_hooks() const
{ {
auto* hook = this->ept_hooks; for(auto& hook : this->ept_hooks)
while (hook)
{ {
hook->target_page->flags = hook->original_entry.flags; hook.target_page->flags = hook.original_entry.flags;
hook = hook->next_hook;
} }
} }
@ -368,23 +342,19 @@ namespace vmx
return; return;
} }
auto* watch_point = this->allocate_ept_code_watch_point(); auto& watch_point = this->allocate_ept_code_watch_point();
if (!watch_point)
{
throw std::runtime_error("Failed to allocate watch point");
}
this->split_large_page(physical_base_address); this->split_large_page(physical_base_address);
watch_point->physical_base_address = physical_base_address; watch_point.physical_base_address = physical_base_address;
watch_point->target_page = this->get_pml1_entry(physical_base_address); watch_point.target_page = this->get_pml1_entry(physical_base_address);
if (!watch_point->target_page) if (!watch_point.target_page)
{ {
throw std::runtime_error("Failed to get PML1 entry for target address"); throw std::runtime_error("Failed to get PML1 entry for target address");
} }
watch_point->target_page->write_access = 0; watch_point.target_page->write_access = 0;
watch_point->target_page->read_access = 0; watch_point.target_page->read_access = 0;
} }
ept_pointer ept::get_ept_pointer() const ept_pointer ept::get_ept_pointer() const
@ -448,91 +418,55 @@ namespace vmx
return &pml1[ADDRMASK_EPT_PML1_INDEX(physical_address)]; return &pml1[ADDRMASK_EPT_PML1_INDEX(physical_address)];
} }
pml1* ept::find_pml1_table(const uint64_t physical_address) const pml1* ept::find_pml1_table(const uint64_t physical_address)
{ {
auto* split = this->ept_splits; for(auto& split : this->ept_splits)
while (split)
{ {
if (memory::get_physical_address(&split->pml1[0]) == physical_address) if (memory::get_physical_address(&split.pml1[0]) == physical_address)
{ {
return split->pml1; return split.pml1;
} }
split = split->next_split;
} }
return nullptr; return nullptr;
} }
ept_split* ept::allocate_ept_split() ept_split& ept::allocate_ept_split()
{ {
auto* split = memory::allocate_aligned_object<ept_split>(); return this->ept_splits.emplace_back();
if (!split)
{
throw std::runtime_error("Failed to allocate ept split object");
}
split->next_split = this->ept_splits;
this->ept_splits = split;
return split;
} }
ept_hook* ept::allocate_ept_hook(const uint64_t physical_address) ept_hook& ept::allocate_ept_hook(const uint64_t physical_address)
{ {
auto* hook = memory::allocate_aligned_object<ept_hook>(physical_address); return this->ept_hooks.emplace_back(physical_address);
if (!hook)
{
throw std::runtime_error("Failed to allocate ept hook object");
}
hook->next_hook = this->ept_hooks;
this->ept_hooks = hook;
return hook;
} }
ept_hook* ept::find_ept_hook(const uint64_t physical_address) const ept_hook* ept::find_ept_hook(const uint64_t physical_address)
{ {
auto* hook = this->ept_hooks; for (auto& hook : this->ept_hooks)
while (hook)
{ {
if (hook->physical_base_address == physical_address) if (hook.physical_base_address == physical_address)
{ {
return hook; return &hook;
} }
hook = hook->next_hook;
} }
return nullptr; return nullptr;
} }
ept_code_watch_point* ept::allocate_ept_code_watch_point() ept_code_watch_point& ept::allocate_ept_code_watch_point()
{ {
auto* watch_point = memory::allocate_non_paged_object<ept_code_watch_point>(); return this->ept_code_watch_points.emplace_back();
if (!watch_point)
{
throw std::runtime_error("Failed to allocate ept watch point object");
}
watch_point->next_watch_point = this->ept_code_watch_points;
this->ept_code_watch_points = watch_point;
return watch_point;
} }
ept_code_watch_point* ept::find_ept_code_watch_point(const uint64_t physical_address) const ept_code_watch_point* ept::find_ept_code_watch_point(const uint64_t physical_address)
{ {
auto* watch_point = this->ept_code_watch_points; for(auto& watch_point : this->ept_code_watch_points)
while (watch_point)
{ {
if (watch_point->physical_base_address == physical_address) if (watch_point.physical_base_address == physical_address)
{ {
return watch_point; return &watch_point;
} }
watch_point = watch_point->next_watch_point;
} }
return nullptr; return nullptr;
@ -573,12 +507,7 @@ namespace vmx
return hook; return hook;
} }
hook = this->allocate_ept_hook(physical_base_address); hook = &this->allocate_ept_hook(physical_base_address);
if (!hook)
{
throw std::runtime_error("Failed to allocate hook");
}
this->split_large_page(physical_address); this->split_large_page(physical_address);
@ -624,7 +553,7 @@ namespace vmx
return; return;
} }
auto* split = this->allocate_ept_split(); auto& split = this->allocate_ept_split();
epte pml1_template{}; epte pml1_template{};
pml1_template.flags = 0; pml1_template.flags = 0;
@ -635,11 +564,11 @@ namespace vmx
pml1_template.ignore_pat = target_entry->ignore_pat; pml1_template.ignore_pat = target_entry->ignore_pat;
pml1_template.suppress_ve = target_entry->suppress_ve; pml1_template.suppress_ve = target_entry->suppress_ve;
__stosq(reinterpret_cast<uint64_t*>(&split->pml1[0]), pml1_template.flags, EPT_PTE_ENTRY_COUNT); __stosq(reinterpret_cast<uint64_t*>(&split.pml1[0]), pml1_template.flags, EPT_PTE_ENTRY_COUNT);
for (auto i = 0; i < EPT_PTE_ENTRY_COUNT; ++i) for (auto i = 0; i < EPT_PTE_ENTRY_COUNT; ++i)
{ {
split->pml1[i].page_frame_number = ((target_entry->page_frame_number * 2_mb) / PAGE_SIZE) + i; split.pml1[i].page_frame_number = ((target_entry->page_frame_number * 2_mb) / PAGE_SIZE) + i;
} }
pml2_ptr new_pointer{}; pml2_ptr new_pointer{};
@ -648,7 +577,7 @@ namespace vmx
new_pointer.write_access = 1; new_pointer.write_access = 1;
new_pointer.execute_access = 1; new_pointer.execute_access = 1;
new_pointer.page_frame_number = memory::get_physical_address(&split->pml1[0]) / PAGE_SIZE; new_pointer.page_frame_number = memory::get_physical_address(&split.pml1[0]) / PAGE_SIZE;
target_entry->flags = new_pointer.flags; target_entry->flags = new_pointer.flags;
} }

View File

@ -20,20 +20,17 @@ namespace vmx
pml2 entry{}; pml2 entry{};
pml2_ptr pointer; pml2_ptr pointer;
}; };
ept_split* next_split{nullptr};
}; };
struct ept_code_watch_point struct ept_code_watch_point
{ {
uint64_t physical_base_address{}; uint64_t physical_base_address{};
pml1* target_page{}; pml1* target_page{};
ept_code_watch_point* next_watch_point{nullptr};
}; };
struct ept_hook struct ept_hook
{ {
ept_hook(const uint64_t physical_base); ept_hook(uint64_t physical_base);
~ept_hook(); ~ept_hook();
DECLSPEC_PAGE_ALIGN uint8_t fake_page[PAGE_SIZE]{}; DECLSPEC_PAGE_ALIGN uint8_t fake_page[PAGE_SIZE]{};
@ -46,8 +43,6 @@ namespace vmx
pml1 original_entry{}; pml1 original_entry{};
pml1 execute_entry{}; pml1 execute_entry{};
pml1 readwrite_entry{}; pml1 readwrite_entry{};
ept_hook* next_hook{nullptr};
}; };
struct ept_translation_hint struct ept_translation_hint
@ -76,7 +71,7 @@ namespace vmx
void install_code_watch_point(uint64_t physical_page); void install_code_watch_point(uint64_t physical_page);
void install_hook(const void* destination, const void* source, size_t length, void install_hook(const void* destination, const void* source, size_t length,
const utils::list<ept_translation_hint>& hints = {}); const utils::list<ept_translation_hint>& hints = {});
void disable_all_hooks() const; void disable_all_hooks() const;
void handle_violation(guest_context& guest_context); void handle_violation(guest_context& guest_context);
@ -96,20 +91,20 @@ namespace vmx
uint64_t access_records[1024]; uint64_t access_records[1024];
ept_split* ept_splits{nullptr}; utils::list<ept_split, utils::AlignedAllocator> ept_splits{};
ept_hook* ept_hooks{nullptr}; utils::list<ept_hook, utils::AlignedAllocator> ept_hooks{};
ept_code_watch_point* ept_code_watch_points{nullptr}; utils::list<ept_code_watch_point> ept_code_watch_points{};
pml2* get_pml2_entry(uint64_t physical_address); pml2* get_pml2_entry(uint64_t physical_address);
pml1* get_pml1_entry(uint64_t physical_address); pml1* get_pml1_entry(uint64_t physical_address);
pml1* find_pml1_table(uint64_t physical_address) const; pml1* find_pml1_table(uint64_t physical_address);
ept_split* allocate_ept_split(); ept_split& allocate_ept_split();
ept_hook* allocate_ept_hook(uint64_t physical_address); ept_hook& allocate_ept_hook(uint64_t physical_address);
ept_hook* find_ept_hook(uint64_t physical_address) const; ept_hook* find_ept_hook(uint64_t physical_address);
ept_code_watch_point* allocate_ept_code_watch_point(); ept_code_watch_point& allocate_ept_code_watch_point();
ept_code_watch_point* find_ept_code_watch_point(uint64_t physical_address) const; ept_code_watch_point* find_ept_code_watch_point(uint64_t physical_address);
ept_hook* get_or_create_ept_hook(void* destination, const ept_translation_hint* translation_hint = nullptr); ept_hook* get_or_create_ept_hook(void* destination, const ept_translation_hint* translation_hint = nullptr);

View File

@ -6,13 +6,13 @@
namespace utils namespace utils
{ {
template <typename T, typename ObjectAllocator = NonPagedAllocator, typename ListAllocator = NonPagedAllocator> template <typename T, typename ObjectAllocator = NonPagedAllocator, typename ListAllocator = NonPagedAllocator>
requires IsAllocator<ObjectAllocator> && IsAllocator<ListAllocator> requires is_allocator<ObjectAllocator> && is_allocator<ListAllocator>
class list class list
{ {
struct ListEntry struct list_entry
{ {
T* entry{nullptr}; T* entry{nullptr};
ListEntry* next{nullptr}; list_entry* next{nullptr};
void* this_base{nullptr}; void* this_base{nullptr};
void* entry_base{nullptr}; void* entry_base{nullptr};
@ -26,7 +26,7 @@ namespace utils
friend list; friend list;
public: public:
iterator(ListEntry* entry = nullptr) iterator(list_entry* entry = nullptr)
: entry_(entry) : entry_(entry)
{ {
} }
@ -71,9 +71,9 @@ namespace utils
} }
private: private:
ListEntry* entry_{nullptr}; list_entry* entry_{nullptr};
ListEntry* get_entry() const list_entry* get_entry() const
{ {
return entry_; return entry_;
} }
@ -84,7 +84,7 @@ namespace utils
friend list; friend list;
public: public:
const_iterator(ListEntry* entry = nullptr) const_iterator(list_entry* entry = nullptr)
: entry_(entry) : entry_(entry)
{ {
} }
@ -111,9 +111,9 @@ namespace utils
} }
private: private:
ListEntry* entry_{nullptr}; list_entry* entry_{nullptr};
ListEntry* get_entry() const list_entry* get_entry() const
{ {
return entry_; return entry_;
} }
@ -274,22 +274,22 @@ namespace utils
iterator erase(iterator iterator) iterator erase(iterator iterator)
{ {
auto* list_entry = iterator.get_entry(); auto* list_entry = iterator.get_entry();
auto** inseration_point = &this->entries_; auto** insertion_point = &this->entries_;
while (*inseration_point && list_entry) while (*insertion_point && list_entry)
{ {
if (*inseration_point != list_entry) if (*insertion_point != list_entry)
{ {
inseration_point = &(*inseration_point)->next; insertion_point = &(*insertion_point)->next;
continue; continue;
} }
*inseration_point = list_entry->next; *insertion_point = list_entry->next;
list_entry->entry->~T(); list_entry->entry->~T();
this->object_allocator_.free(list_entry->entry_base); this->object_allocator_.free(list_entry->entry_base);
this->list_allocator_.free(list_entry->this_base); this->list_allocator_.free(list_entry->this_base);
return {*inseration_point}; return {*insertion_point};
} }
throw std::runtime_error("Bad iterator"); throw std::runtime_error("Bad iterator");
@ -305,7 +305,7 @@ namespace utils
ObjectAllocator object_allocator_{}; ObjectAllocator object_allocator_{};
ListAllocator list_allocator_{}; ListAllocator list_allocator_{};
ListEntry* entries_{nullptr}; list_entry* entries_{nullptr};
template <typename U, typename V> template <typename U, typename V>
static U* align_pointer(V* pointer) static U* align_pointer(V* pointer)
@ -317,28 +317,63 @@ namespace utils
return reinterpret_cast<U*>(ptr); return reinterpret_cast<U*>(ptr);
} }
T& add_uninitialized_entry() void allocate_entry(void*& list_base, void* entry_base)
{ {
auto** inseration_point = &this->entries_; list_base = nullptr;
while (*inseration_point) entry_base = nullptr;
auto destructor = utils::finally([&]
{ {
inseration_point = &(*inseration_point)->next; if (list_base)
{
this->list_allocator_.free(list_base);
}
if (entry_base)
{
this->object_allocator_.free(entry_base);
}
});
list_base = this->list_allocator_.allocate(sizeof(list_entry) + alignof(list_entry));
if (!list_base)
{
throw std::runtime_error("Memory allocation failed");
} }
auto* list_base = this->list_allocator_.allocate(sizeof(ListEntry) + alignof(ListEntry)); entry_base = this->object_allocator_.allocate(sizeof(T) + alignof(T));
auto* entry_base = this->object_allocator_.allocate(sizeof(T) + alignof(T)); if (!entry_base)
{
throw std::runtime_error("Memory allocation failed");
}
auto* entry = align_pointer<T>(entry_base); destructor.cancel();
auto* list_entry = align_pointer<ListEntry>(list_base); }
list_entry->this_base = list_base;
list_entry->entry_base = entry_base;
list_entry->next = nullptr;
list_entry->entry = entry;
*inseration_point = list_entry; T& add_uninitialized_entry()
{
void* list_base = {};
void* entry_base = {};
this->allocate_entry(list_base, entry_base);
return *entry; auto** insertion_point = &this->entries_;
while (*insertion_point)
{
insertion_point = &(*insertion_point)->next;
}
auto* obj = align_pointer<T>(entry_base);
auto* entry = align_pointer<list_entry>(list_base);
entry->this_base = list_base;
entry->entry_base = entry_base;
entry->next = nullptr;
entry->entry = obj;
*insertion_point = entry;
return *obj;
} }
}; };
} }

View File

@ -6,7 +6,7 @@
namespace utils namespace utils
{ {
template <typename T, typename Allocator = NonPagedAllocator> template <typename T, typename Allocator = NonPagedAllocator>
requires IsAllocator<Allocator> requires is_allocator<Allocator>
class vector class vector
{ {
public: public:
@ -258,6 +258,11 @@ namespace utils
{ {
constexpr auto alignment = alignof(T); constexpr auto alignment = alignof(T);
auto* memory = this->allocator_.allocate(capacity * sizeof(T) + alignment); auto* memory = this->allocator_.allocate(capacity * sizeof(T) + alignment);
if (!memory)
{
throw std::runtime_error("Failed to allocate memory");
}
return memory; return memory;
} }