Use containers for ept allocations

This commit is contained in:
Maurice Heumann 2022-12-24 08:36:23 +01:00
parent 33b44f1dc1
commit 1d23c10734
5 changed files with 121 additions and 157 deletions

View File

@ -4,7 +4,7 @@
namespace utils
{
template <typename T>
concept IsAllocator = requires(size_t size, void* ptr)
concept is_allocator = requires(size_t size, void* ptr)
{
T().free(T().allocate(size));
T().free(ptr);

View File

@ -95,18 +95,16 @@ namespace vmx
}
}
void reset_all_watch_point_pages(ept_code_watch_point* watch_point)
void reset_all_watch_point_pages(utils::list<ept_code_watch_point>& watch_points)
{
while (watch_point)
for(const auto& watch_point : watch_points)
{
if (watch_point->target_page)
if (watch_point.target_page)
{
watch_point->target_page->write_access = 0;
watch_point->target_page->read_access = 0;
watch_point->target_page->execute_access = 1;
watch_point.target_page->write_access = 0;
watch_point.target_page->read_access = 0;
watch_point.target_page->execute_access = 1;
}
watch_point = watch_point->next_watch_point;
}
}
}
@ -142,29 +140,7 @@ namespace vmx
ept::~ept()
{
auto* split = this->ept_splits;
while (split)
{
auto* current_split = split;
split = split->next_split;
memory::free_aligned_object(current_split);
}
auto* hook = this->ept_hooks;
while (hook)
{
auto* current_hook = hook;
hook = hook->next_hook;
memory::free_aligned_object(current_hook);
}
auto* watch_point = this->ept_code_watch_points;
while (watch_point)
{
auto* current_watch_point = watch_point;
watch_point = watch_point->next_watch_point;
memory::free_non_paged_object(current_watch_point);
}
this->disable_all_hooks();
}
void ept::install_page_hook(void* destination, const void* source, const size_t length,
@ -230,11 +206,9 @@ namespace vmx
void ept::disable_all_hooks() const
{
auto* hook = this->ept_hooks;
while (hook)
for(auto& hook : this->ept_hooks)
{
hook->target_page->flags = hook->original_entry.flags;
hook = hook->next_hook;
hook.target_page->flags = hook.original_entry.flags;
}
}
@ -368,23 +342,19 @@ namespace vmx
return;
}
auto* watch_point = this->allocate_ept_code_watch_point();
if (!watch_point)
{
throw std::runtime_error("Failed to allocate watch point");
}
auto& watch_point = this->allocate_ept_code_watch_point();
this->split_large_page(physical_base_address);
watch_point->physical_base_address = physical_base_address;
watch_point->target_page = this->get_pml1_entry(physical_base_address);
if (!watch_point->target_page)
watch_point.physical_base_address = physical_base_address;
watch_point.target_page = this->get_pml1_entry(physical_base_address);
if (!watch_point.target_page)
{
throw std::runtime_error("Failed to get PML1 entry for target address");
}
watch_point->target_page->write_access = 0;
watch_point->target_page->read_access = 0;
watch_point.target_page->write_access = 0;
watch_point.target_page->read_access = 0;
}
ept_pointer ept::get_ept_pointer() const
@ -448,91 +418,55 @@ namespace vmx
return &pml1[ADDRMASK_EPT_PML1_INDEX(physical_address)];
}
pml1* ept::find_pml1_table(const uint64_t physical_address) const
pml1* ept::find_pml1_table(const uint64_t physical_address)
{
auto* split = this->ept_splits;
while (split)
for(auto& split : this->ept_splits)
{
if (memory::get_physical_address(&split->pml1[0]) == physical_address)
if (memory::get_physical_address(&split.pml1[0]) == physical_address)
{
return split->pml1;
return split.pml1;
}
split = split->next_split;
}
return nullptr;
}
ept_split* ept::allocate_ept_split()
ept_split& ept::allocate_ept_split()
{
auto* split = memory::allocate_aligned_object<ept_split>();
if (!split)
{
throw std::runtime_error("Failed to allocate ept split object");
}
split->next_split = this->ept_splits;
this->ept_splits = split;
return split;
return this->ept_splits.emplace_back();
}
ept_hook* ept::allocate_ept_hook(const uint64_t physical_address)
ept_hook& ept::allocate_ept_hook(const uint64_t physical_address)
{
auto* hook = memory::allocate_aligned_object<ept_hook>(physical_address);
if (!hook)
{
throw std::runtime_error("Failed to allocate ept hook object");
}
hook->next_hook = this->ept_hooks;
this->ept_hooks = hook;
return hook;
return this->ept_hooks.emplace_back(physical_address);
}
ept_hook* ept::find_ept_hook(const uint64_t physical_address) const
ept_hook* ept::find_ept_hook(const uint64_t physical_address)
{
auto* hook = this->ept_hooks;
while (hook)
for (auto& hook : this->ept_hooks)
{
if (hook->physical_base_address == physical_address)
if (hook.physical_base_address == physical_address)
{
return hook;
return &hook;
}
hook = hook->next_hook;
}
return nullptr;
}
ept_code_watch_point* ept::allocate_ept_code_watch_point()
ept_code_watch_point& ept::allocate_ept_code_watch_point()
{
auto* watch_point = memory::allocate_non_paged_object<ept_code_watch_point>();
if (!watch_point)
{
throw std::runtime_error("Failed to allocate ept watch point object");
}
watch_point->next_watch_point = this->ept_code_watch_points;
this->ept_code_watch_points = watch_point;
return watch_point;
return this->ept_code_watch_points.emplace_back();
}
ept_code_watch_point* ept::find_ept_code_watch_point(const uint64_t physical_address) const
ept_code_watch_point* ept::find_ept_code_watch_point(const uint64_t physical_address)
{
auto* watch_point = this->ept_code_watch_points;
while (watch_point)
for(auto& watch_point : this->ept_code_watch_points)
{
if (watch_point->physical_base_address == physical_address)
if (watch_point.physical_base_address == physical_address)
{
return watch_point;
return &watch_point;
}
watch_point = watch_point->next_watch_point;
}
return nullptr;
@ -573,12 +507,7 @@ namespace vmx
return hook;
}
hook = this->allocate_ept_hook(physical_base_address);
if (!hook)
{
throw std::runtime_error("Failed to allocate hook");
}
hook = &this->allocate_ept_hook(physical_base_address);
this->split_large_page(physical_address);
@ -624,7 +553,7 @@ namespace vmx
return;
}
auto* split = this->allocate_ept_split();
auto& split = this->allocate_ept_split();
epte pml1_template{};
pml1_template.flags = 0;
@ -635,11 +564,11 @@ namespace vmx
pml1_template.ignore_pat = target_entry->ignore_pat;
pml1_template.suppress_ve = target_entry->suppress_ve;
__stosq(reinterpret_cast<uint64_t*>(&split->pml1[0]), pml1_template.flags, EPT_PTE_ENTRY_COUNT);
__stosq(reinterpret_cast<uint64_t*>(&split.pml1[0]), pml1_template.flags, EPT_PTE_ENTRY_COUNT);
for (auto i = 0; i < EPT_PTE_ENTRY_COUNT; ++i)
{
split->pml1[i].page_frame_number = ((target_entry->page_frame_number * 2_mb) / PAGE_SIZE) + i;
split.pml1[i].page_frame_number = ((target_entry->page_frame_number * 2_mb) / PAGE_SIZE) + i;
}
pml2_ptr new_pointer{};
@ -648,7 +577,7 @@ namespace vmx
new_pointer.write_access = 1;
new_pointer.execute_access = 1;
new_pointer.page_frame_number = memory::get_physical_address(&split->pml1[0]) / PAGE_SIZE;
new_pointer.page_frame_number = memory::get_physical_address(&split.pml1[0]) / PAGE_SIZE;
target_entry->flags = new_pointer.flags;
}

View File

@ -20,20 +20,17 @@ namespace vmx
pml2 entry{};
pml2_ptr pointer;
};
ept_split* next_split{nullptr};
};
struct ept_code_watch_point
{
uint64_t physical_base_address{};
pml1* target_page{};
ept_code_watch_point* next_watch_point{nullptr};
};
struct ept_hook
{
ept_hook(const uint64_t physical_base);
ept_hook(uint64_t physical_base);
~ept_hook();
DECLSPEC_PAGE_ALIGN uint8_t fake_page[PAGE_SIZE]{};
@ -46,8 +43,6 @@ namespace vmx
pml1 original_entry{};
pml1 execute_entry{};
pml1 readwrite_entry{};
ept_hook* next_hook{nullptr};
};
struct ept_translation_hint
@ -76,7 +71,7 @@ namespace vmx
void install_code_watch_point(uint64_t physical_page);
void install_hook(const void* destination, const void* source, size_t length,
const utils::list<ept_translation_hint>& hints = {});
const utils::list<ept_translation_hint>& hints = {});
void disable_all_hooks() const;
void handle_violation(guest_context& guest_context);
@ -96,20 +91,20 @@ namespace vmx
uint64_t access_records[1024];
ept_split* ept_splits{nullptr};
ept_hook* ept_hooks{nullptr};
ept_code_watch_point* ept_code_watch_points{nullptr};
utils::list<ept_split, utils::AlignedAllocator> ept_splits{};
utils::list<ept_hook, utils::AlignedAllocator> ept_hooks{};
utils::list<ept_code_watch_point> ept_code_watch_points{};
pml2* get_pml2_entry(uint64_t physical_address);
pml1* get_pml1_entry(uint64_t physical_address);
pml1* find_pml1_table(uint64_t physical_address) const;
pml1* find_pml1_table(uint64_t physical_address);
ept_split* allocate_ept_split();
ept_hook* allocate_ept_hook(uint64_t physical_address);
ept_hook* find_ept_hook(uint64_t physical_address) const;
ept_split& allocate_ept_split();
ept_hook& allocate_ept_hook(uint64_t physical_address);
ept_hook* find_ept_hook(uint64_t physical_address);
ept_code_watch_point* allocate_ept_code_watch_point();
ept_code_watch_point* find_ept_code_watch_point(uint64_t physical_address) const;
ept_code_watch_point& allocate_ept_code_watch_point();
ept_code_watch_point* find_ept_code_watch_point(uint64_t physical_address);
ept_hook* get_or_create_ept_hook(void* destination, const ept_translation_hint* translation_hint = nullptr);

View File

@ -6,13 +6,13 @@
namespace utils
{
template <typename T, typename ObjectAllocator = NonPagedAllocator, typename ListAllocator = NonPagedAllocator>
requires IsAllocator<ObjectAllocator> && IsAllocator<ListAllocator>
requires is_allocator<ObjectAllocator> && is_allocator<ListAllocator>
class list
{
struct ListEntry
struct list_entry
{
T* entry{nullptr};
ListEntry* next{nullptr};
list_entry* next{nullptr};
void* this_base{nullptr};
void* entry_base{nullptr};
@ -26,7 +26,7 @@ namespace utils
friend list;
public:
iterator(ListEntry* entry = nullptr)
iterator(list_entry* entry = nullptr)
: entry_(entry)
{
}
@ -71,9 +71,9 @@ namespace utils
}
private:
ListEntry* entry_{nullptr};
list_entry* entry_{nullptr};
ListEntry* get_entry() const
list_entry* get_entry() const
{
return entry_;
}
@ -84,7 +84,7 @@ namespace utils
friend list;
public:
const_iterator(ListEntry* entry = nullptr)
const_iterator(list_entry* entry = nullptr)
: entry_(entry)
{
}
@ -111,9 +111,9 @@ namespace utils
}
private:
ListEntry* entry_{nullptr};
list_entry* entry_{nullptr};
ListEntry* get_entry() const
list_entry* get_entry() const
{
return entry_;
}
@ -274,22 +274,22 @@ namespace utils
iterator erase(iterator iterator)
{
auto* list_entry = iterator.get_entry();
auto** inseration_point = &this->entries_;
while (*inseration_point && list_entry)
auto** insertion_point = &this->entries_;
while (*insertion_point && list_entry)
{
if (*inseration_point != list_entry)
if (*insertion_point != list_entry)
{
inseration_point = &(*inseration_point)->next;
insertion_point = &(*insertion_point)->next;
continue;
}
*inseration_point = list_entry->next;
*insertion_point = list_entry->next;
list_entry->entry->~T();
this->object_allocator_.free(list_entry->entry_base);
this->list_allocator_.free(list_entry->this_base);
return {*inseration_point};
return {*insertion_point};
}
throw std::runtime_error("Bad iterator");
@ -305,7 +305,7 @@ namespace utils
ObjectAllocator object_allocator_{};
ListAllocator list_allocator_{};
ListEntry* entries_{nullptr};
list_entry* entries_{nullptr};
template <typename U, typename V>
static U* align_pointer(V* pointer)
@ -317,28 +317,63 @@ namespace utils
return reinterpret_cast<U*>(ptr);
}
T& add_uninitialized_entry()
void allocate_entry(void*& list_base, void* entry_base)
{
auto** inseration_point = &this->entries_;
while (*inseration_point)
list_base = nullptr;
entry_base = nullptr;
auto destructor = utils::finally([&]
{
inseration_point = &(*inseration_point)->next;
if (list_base)
{
this->list_allocator_.free(list_base);
}
if (entry_base)
{
this->object_allocator_.free(entry_base);
}
});
list_base = this->list_allocator_.allocate(sizeof(list_entry) + alignof(list_entry));
if (!list_base)
{
throw std::runtime_error("Memory allocation failed");
}
auto* list_base = this->list_allocator_.allocate(sizeof(ListEntry) + alignof(ListEntry));
auto* entry_base = this->object_allocator_.allocate(sizeof(T) + alignof(T));
entry_base = this->object_allocator_.allocate(sizeof(T) + alignof(T));
if (!entry_base)
{
throw std::runtime_error("Memory allocation failed");
}
auto* entry = align_pointer<T>(entry_base);
auto* list_entry = align_pointer<ListEntry>(list_base);
destructor.cancel();
}
list_entry->this_base = list_base;
list_entry->entry_base = entry_base;
list_entry->next = nullptr;
list_entry->entry = entry;
*inseration_point = list_entry;
T& add_uninitialized_entry()
{
void* list_base = {};
void* entry_base = {};
this->allocate_entry(list_base, entry_base);
return *entry;
auto** insertion_point = &this->entries_;
while (*insertion_point)
{
insertion_point = &(*insertion_point)->next;
}
auto* obj = align_pointer<T>(entry_base);
auto* entry = align_pointer<list_entry>(list_base);
entry->this_base = list_base;
entry->entry_base = entry_base;
entry->next = nullptr;
entry->entry = obj;
*insertion_point = entry;
return *obj;
}
};
}

View File

@ -6,7 +6,7 @@
namespace utils
{
template <typename T, typename Allocator = NonPagedAllocator>
requires IsAllocator<Allocator>
requires is_allocator<Allocator>
class vector
{
public:
@ -258,6 +258,11 @@ namespace utils
{
constexpr auto alignment = alignof(T);
auto* memory = this->allocator_.allocate(capacity * sizeof(T) + alignment);
if (!memory)
{
throw std::runtime_error("Failed to allocate memory");
}
return memory;
}