diff --git a/src/driver/ept.cpp b/src/driver/ept.cpp index 01b2ee2..5f146f7 100644 --- a/src/driver/ept.cpp +++ b/src/driver/ept.cpp @@ -7,16 +7,6 @@ #include "memory.hpp" #include "vmx.hpp" -#define MTRR_PAGE_SIZE 4096 -#define MTRR_PAGE_MASK (~(MTRR_PAGE_SIZE-1)) - -#define ADDRMASK_EPT_PML1_OFFSET(_VAR_) ((_VAR_) & 0xFFFULL) - -#define ADDRMASK_EPT_PML1_INDEX(_VAR_) (((_VAR_) & 0x1FF000ULL) >> 12) -#define ADDRMASK_EPT_PML2_INDEX(_VAR_) (((_VAR_) & 0x3FE00000ULL) >> 21) -#define ADDRMASK_EPT_PML3_INDEX(_VAR_) (((_VAR_) & 0x7FC0000000ULL) >> 30) -#define ADDRMASK_EPT_PML4_INDEX(_VAR_) (((_VAR_) & 0xFF8000000000ULL) >> 39) - namespace vmx { namespace @@ -301,7 +291,7 @@ namespace vmx // -------------------------- - epdpte temp_epdpte; + epdpte temp_epdpte{}; temp_epdpte.flags = 0; temp_epdpte.read_access = 1; temp_epdpte.write_access = 1; diff --git a/src/driver/ept.hpp b/src/driver/ept.hpp index 69b2ac3..8c6e6c3 100644 --- a/src/driver/ept.hpp +++ b/src/driver/ept.hpp @@ -3,6 +3,18 @@ #define DECLSPEC_PAGE_ALIGN DECLSPEC_ALIGN(PAGE_SIZE) #include "list.hpp" + +#define MTRR_PAGE_SIZE 4096 +#define MTRR_PAGE_MASK (~(MTRR_PAGE_SIZE-1)) + +#define ADDRMASK_EPT_PML1_OFFSET(_VAR_) ((_VAR_) & 0xFFFULL) + +#define ADDRMASK_EPT_PML1_INDEX(_VAR_) (((_VAR_) & 0x1FF000ULL) >> 12) +#define ADDRMASK_EPT_PML2_INDEX(_VAR_) (((_VAR_) & 0x3FE00000ULL) >> 21) +#define ADDRMASK_EPT_PML3_INDEX(_VAR_) (((_VAR_) & 0x7FC0000000ULL) >> 30) +#define ADDRMASK_EPT_PML4_INDEX(_VAR_) (((_VAR_) & 0xFF8000000000ULL) >> 39) + + namespace vmx { using pml4 = ept_pml4; @@ -11,6 +23,11 @@ namespace vmx using pml2_ptr = epde; using pml1 = epte; + using pml4_entry = pml4e_64; + using pml3_entry = pdpte_64; + using pml2_entry = pde_64; + using pml1_entry = pte_64; + struct ept_split { DECLSPEC_PAGE_ALIGN pml1 pml1[EPT_PTE_ENTRY_COUNT]{}; diff --git a/src/driver/hypervisor.cpp b/src/driver/hypervisor.cpp index 1cfef86..73198ed 100644 --- a/src/driver/hypervisor.cpp +++ b/src/driver/hypervisor.cpp @@ -15,10 +15,10 @@ typedef struct _EPROCESS { DISPATCHER_HEADER Header; - LIST_ENTRY ProfileListHead; - ULONG_PTR DirectoryTableBase; - UCHAR Data[1]; -} EPROCESS, * PEPROCESS; + LIST_ENTRY ProfileListHead; + ULONG_PTR DirectoryTableBase; + UCHAR Data[1]; +} EPROCESS, *PEPROCESS; namespace { @@ -485,10 +485,24 @@ void inject_interuption(const interruption_type type, const exception_vector vec } } -void inject_invalid_opcode(vmx::guest_context& guest_context) +void inject_invalid_opcode() { inject_interuption(hardware_exception, invalid_opcode, false, 0); - guest_context.increment_rip = false; +} + +void inject_page_fault(const uint64_t page_fault_address) +{ + __writecr2(page_fault_address); + + page_fault_exception error_code{}; + error_code.flags = 0; + + inject_interuption(hardware_exception, page_fault, true, error_code.flags); +} + +void inject_page_fault(const void* page_fault_address) +{ + inject_page_fault(reinterpret_cast(page_fault_address)); } cr3 get_current_process_cr3() @@ -517,41 +531,118 @@ enum class syscall_state { is_sysret, is_syscall, + page_fault, none, }; +class scoped_cr3_switch +{ +public: + scoped_cr3_switch() + { + original_cr3_.flags = __readcr3(); + } + + scoped_cr3_switch(const cr3 new_cr3) + : scoped_cr3_switch() + { + this->set_cr3(new_cr3); + } + + scoped_cr3_switch(const scoped_cr3_switch&) = delete; + scoped_cr3_switch& operator=(const scoped_cr3_switch&) = delete; + + scoped_cr3_switch(scoped_cr3_switch&&) = delete; + scoped_cr3_switch& operator=(scoped_cr3_switch&&) = delete; + + ~scoped_cr3_switch() + { + __writecr3(original_cr3_.flags); + } + + void set_cr3(const cr3 new_cr3) + { + this->must_restore_ = true; + __writecr3(new_cr3.flags); + } + +private: + bool must_restore_{false}; + cr3 original_cr3_{}; +}; + +template +bool read_data_or_page_fault(uint8_t (&array)[Length], const uint8_t* base) +{ + for (size_t offset = 0; offset < Length;) + { + auto* current_base = base + offset; + auto* current_destination = array + offset; + auto read_length = Length - offset; + + const auto* page_start = static_cast(PAGE_ALIGN(current_base)); + const auto* next_page = page_start + PAGE_SIZE; + + if (current_base + read_length > next_page) + { + read_length = next_page - current_base; + } + + offset += read_length; + + const auto physical_base = memory::get_physical_address(const_cast(current_base)); + + if (!physical_base) + { + inject_page_fault(current_base); + return false; + } + + if (!memory::read_physical_memory(current_destination, physical_base, read_length)) + { + // Not sure if we can recover from that :( + return false; + } + } + + return true; +} + syscall_state get_syscall_state(const vmx::guest_context& guest_context) { - cr3 orignal_cr3{}; - orignal_cr3.flags = __readcr3(); - - const auto _ = utils::finally([&] - { - __writecr3(orignal_cr3.flags); - }); + scoped_cr3_switch cr3_switch{}; constexpr auto PCID_NONE = 0x000; constexpr auto PCID_MASK = 0x003; - const auto guest_cr3 = read_vmx(VMCS_GUEST_CR3); - if ((guest_cr3 & PCID_MASK) != PCID_NONE) + cr3 guest_cr3{}; + guest_cr3.flags = read_vmx(VMCS_GUEST_CR3); + + if ((guest_cr3.flags & PCID_MASK) != PCID_NONE) { - const auto process_cr3 = get_current_process_cr3(); - __writecr3(process_cr3.flags); + cr3_switch.set_cr3(get_current_process_cr3()); } - // TODO: Check for potential page fault const auto* rip = reinterpret_cast(guest_context.guest_rip); - constexpr uint8_t syscall_bytes[] = { 0x0F, 0x05 }; + constexpr uint8_t syscall_bytes[] = {0x0F, 0x05}; constexpr uint8_t sysret_bytes[] = {0x48, 0x0F, 0x07}; - if (is_mem_equal(rip, syscall_bytes)) + constexpr auto max_byte_length = max(sizeof(sysret_bytes), sizeof(syscall_bytes)); + + uint8_t data[max_byte_length]; + + if (!read_data_or_page_fault(data, rip)) + { + return syscall_state::page_fault; + } + + if (is_mem_equal(data, syscall_bytes)) { return syscall_state::is_syscall; } - if (is_mem_equal(rip, sysret_bytes)) + if (is_mem_equal(data, sysret_bytes)) { return syscall_state::is_sysret; } @@ -573,12 +664,17 @@ void vmx_handle_exception(vmx::guest_context& guest_context) if (interrupt.vector == invalid_opcode) { + guest_context.increment_rip = false; + const auto state = get_syscall_state(guest_context); + if (state == syscall_state::page_fault) + { + return; + } + if (state == syscall_state::is_syscall) { - guest_context.increment_rip = false; - rflags rflags{}; rflags.flags = read_vmx(VMCS_GUEST_RFLAGS); @@ -617,8 +713,6 @@ void vmx_handle_exception(vmx::guest_context& guest_context) } else if (state == syscall_state::is_sysret) { - guest_context.increment_rip = false; - __vmx_vmwrite(VMCS_GUEST_RIP, guest_context.vp_regs->Rcx); rflags rflags{}; @@ -659,7 +753,7 @@ void vmx_handle_exception(vmx::guest_context& guest_context) } else { - inject_invalid_opcode(guest_context); + inject_invalid_opcode(); } } else diff --git a/src/driver/memory.cpp b/src/driver/memory.cpp index 235d465..1f4f68c 100644 --- a/src/driver/memory.cpp +++ b/src/driver/memory.cpp @@ -68,6 +68,18 @@ namespace memory return memory; } + _IRQL_requires_max_(APC_LEVEL) + + bool read_physical_memory(void* destination, uint64_t physical_address, const size_t size) + { + size_t bytes_read{}; + MM_COPY_ADDRESS source{}; + source.PhysicalAddress.QuadPart = static_cast(physical_address); + + return MmCopyMemory(destination, source, size, MM_COPY_MEMORY_PHYSICAL, &bytes_read) == STATUS_SUCCESS && + bytes_read == size; + } + uint64_t get_physical_address(void* address) { return static_cast(MmGetPhysicalAddress(address).QuadPart); diff --git a/src/driver/memory.hpp b/src/driver/memory.hpp index f1b028e..10f87b5 100644 --- a/src/driver/memory.hpp +++ b/src/driver/memory.hpp @@ -10,6 +10,9 @@ namespace memory _IRQL_requires_max_(DISPATCH_LEVEL) void* allocate_aligned_memory(size_t size); + _IRQL_requires_max_(APC_LEVEL) + bool read_physical_memory(void* destination, uint64_t physical_address, size_t size); + uint64_t get_physical_address(void* address); void* get_virtual_address(uint64_t address); diff --git a/src/driver/nt_ext.hpp b/src/driver/nt_ext.hpp index f0c9511..b80d297 100644 --- a/src/driver/nt_ext.hpp +++ b/src/driver/nt_ext.hpp @@ -56,6 +56,29 @@ MmAllocateContiguousNodeMemory( // ---------------------------------------- +typedef struct _MM_COPY_ADDRESS { + union { + PVOID VirtualAddress; + PHYSICAL_ADDRESS PhysicalAddress; + }; +} MM_COPY_ADDRESS, * PMMCOPY_ADDRESS; + +#define MM_COPY_MEMORY_PHYSICAL 0x1 +#define MM_COPY_MEMORY_VIRTUAL 0x2 + +_IRQL_requires_max_(APC_LEVEL) +NTKERNELAPI +NTSTATUS +MmCopyMemory( + _In_ PVOID TargetAddress, + _In_ MM_COPY_ADDRESS SourceAddress, + _In_ SIZE_T NumberOfBytes, + _In_ ULONG Flags, + _Out_ PSIZE_T NumberOfBytesTransferred +); + +// ---------------------------------------- + NTSYSAPI VOID NTAPI