Better memory management

This commit is contained in:
momo5502 2022-03-27 11:57:26 +02:00
parent 520bdf3aea
commit 1bbd9e9c73
5 changed files with 114 additions and 56 deletions

View File

@ -4,6 +4,7 @@
#include "exception.hpp" #include "exception.hpp"
#include "logging.hpp" #include "logging.hpp"
#include "finally.hpp" #include "finally.hpp"
#include "memory.hpp"
#include "thread.hpp" #include "thread.hpp"
#define IA32_FEATURE_CONTROL_MSR 0x3A #define IA32_FEATURE_CONTROL_MSR 0x3A
@ -34,34 +35,6 @@ namespace
{ {
return is_vmx_supported() && is_vmx_available(); return is_vmx_supported() && is_vmx_available();
} }
_IRQL_requires_max_(DISPATCH_LEVEL)
void free_aligned_memory(void* memory)
{
MmFreeContiguousMemory(memory);
}
_Must_inspect_result_
_IRQL_requires_max_(DISPATCH_LEVEL)
void* allocate_aligned_memory(const size_t size)
{
PHYSICAL_ADDRESS lowest{}, highest{};
lowest.QuadPart = 0;
highest.QuadPart = lowest.QuadPart - 1;
#if (NTDDI_VERSION >= NTDDI_VISTA)
return MmAllocateContiguousNodeMemory(size,
lowest,
highest,
lowest,
PAGE_READWRITE,
KeGetCurrentNodeNumber());
#else
return MmAllocateContiguousMemory(size, highest);
#endif
}
} }
hypervisor::hypervisor() hypervisor::hypervisor()
@ -125,23 +98,27 @@ void hypervisor::disable_core()
void hypervisor::allocate_vm_states() void hypervisor::allocate_vm_states()
{ {
if (this->vm_states_)
{
throw std::runtime_error("VM states are still in use");
}
const auto core_count = thread::get_processor_count(); const auto core_count = thread::get_processor_count();
const auto allocation_size = sizeof(vmx::vm_state) * core_count; const auto allocation_size = sizeof(vmx::vm_state) * core_count;
this->vm_states_ = static_cast<vmx::vm_state*>(allocate_aligned_memory(allocation_size)); this->vm_states_ = static_cast<vmx::vm_state*>(memory::allocate_aligned_memory(allocation_size));
if(!this->vm_states_) if (!this->vm_states_)
{ {
throw std::runtime_error("Failed to allocate vm states"); throw std::runtime_error("Failed to allocate VM states");
} }
RtlSecureZeroMemory(this->vm_states_, allocation_size);
} }
void hypervisor::free_vm_states() void hypervisor::free_vm_states()
{ {
if(this->vm_states_) memory::free_aligned_memory(this->vm_states_);
{ this->vm_states_ = nullptr;
free_aligned_memory(this->vm_states_);
this->vm_states_ = nullptr;
}
} }
vmx::vm_state* hypervisor::get_current_vm_state() const vmx::vm_state* hypervisor::get_current_vm_state() const

74
src/driver/memory.cpp Normal file
View File

@ -0,0 +1,74 @@
#include "std_include.hpp"
#include "memory.hpp"
namespace memory
{
namespace
{
void* allocate_aligned_memory_internal(const size_t size)
{
PHYSICAL_ADDRESS lowest{}, highest{};
lowest.QuadPart = 0;
highest.QuadPart = lowest.QuadPart - 1;
#if (NTDDI_VERSION >= NTDDI_VISTA)
return MmAllocateContiguousNodeMemory(size,
lowest,
highest,
lowest,
PAGE_READWRITE,
KeGetCurrentNodeNumber());
#else
return MmAllocateContiguousMemory(size, highest);
#endif
}
}
_IRQL_requires_max_(DISPATCH_LEVEL)
void free_aligned_memory(void* memory)
{
if (memory)
{
MmFreeContiguousMemory(memory);
}
}
_Must_inspect_result_
_IRQL_requires_max_(DISPATCH_LEVEL)
void* allocate_aligned_memory(const size_t size)
{
void* memory = allocate_aligned_memory_internal(size);
if (memory)
{
RtlSecureZeroMemory(memory, size);
}
return memory;
}
_Must_inspect_result_
_IRQL_requires_max_(DISPATCH_LEVEL)
void* allocate_non_paged_memory(const size_t size)
{
void* memory = ExAllocatePoolWithTag(NonPagedPool, size, 'MOMO');
if (memory)
{
RtlSecureZeroMemory(memory, size);
}
return memory;
}
_IRQL_requires_max_(DISPATCH_LEVEL)
void free_non_paged_memory(void* memory)
{
if (memory)
{
ExFreePool(memory);
}
}
}

18
src/driver/memory.hpp Normal file
View File

@ -0,0 +1,18 @@
#pragma once
namespace memory
{
_IRQL_requires_max_(DISPATCH_LEVEL)
void free_aligned_memory(void* memory);
_Must_inspect_result_
_IRQL_requires_max_(DISPATCH_LEVEL)
void* allocate_aligned_memory(size_t size);
_Must_inspect_result_
_IRQL_requires_max_(DISPATCH_LEVEL)
void* allocate_non_paged_memory(size_t size);
_IRQL_requires_max_(DISPATCH_LEVEL)
void free_non_paged_memory(void* memory);
}

View File

@ -1,12 +1,13 @@
#include "std_include.hpp" #include "std_include.hpp"
#include "new.hpp" #include "new.hpp"
#include "exception.hpp" #include "exception.hpp"
#include "memory.hpp"
namespace namespace
{ {
void* perform_allocation(const size_t size, const POOL_TYPE pool, const unsigned long tag) void* perform_checked_non_paged_allocation(const size_t size)
{ {
auto* memory = ExAllocatePoolWithTag(pool, size, tag); auto* memory = memory::allocate_non_paged_memory(size);
if (!memory) if (!memory)
{ {
throw std::runtime_error("Memory allocation failed"); throw std::runtime_error("Memory allocation failed");
@ -16,24 +17,14 @@ namespace
} }
} }
void* operator new(const size_t size, const POOL_TYPE pool, const unsigned long tag)
{
return perform_allocation(size, pool, tag);
}
void* operator new[](const size_t size, const POOL_TYPE pool, const unsigned long tag)
{
return perform_allocation(size, pool, tag);
}
void* operator new(const size_t size) void* operator new(const size_t size)
{ {
return operator new(size, NonPagedPool); return perform_checked_non_paged_allocation(size);
} }
void* operator new[](const size_t size) void* operator new[](const size_t size)
{ {
return operator new[](size, NonPagedPool); return perform_checked_non_paged_allocation(size);
} }
// Placement new // Placement new
@ -44,22 +35,22 @@ void* operator new(size_t, void* where)
void operator delete(void* ptr, size_t) void operator delete(void* ptr, size_t)
{ {
ExFreePool(ptr); memory::free_non_paged_memory(ptr);
} }
void operator delete(void* ptr) void operator delete(void* ptr)
{ {
ExFreePool(ptr); memory::free_non_paged_memory(ptr);
} }
void operator delete[](void* ptr, size_t) void operator delete[](void* ptr, size_t)
{ {
ExFreePool(ptr); memory::free_non_paged_memory(ptr);
} }
void operator delete[](void* ptr) void operator delete[](void* ptr)
{ {
ExFreePool(ptr); memory::free_non_paged_memory(ptr);
} }
extern "C" void __std_terminate() extern "C" void __std_terminate()

View File

@ -1,7 +1,5 @@
#pragma once #pragma once
void* operator new(size_t size, POOL_TYPE pool, unsigned long tag = 'momo');
void* operator new[](size_t size, POOL_TYPE pool, unsigned long tag = 'momo');
void* operator new(size_t size); void* operator new(size_t size);
void* operator new[](size_t size); void* operator new[](size_t size);