Fix syscall hooking

This commit is contained in:
momo5502 2024-05-10 20:18:38 +02:00
parent 0896133821
commit 53c24b8325

View File

@ -12,18 +12,12 @@
#define DPL_USER 3 #define DPL_USER 3
#define DPL_SYSTEM 0 #define DPL_SYSTEM 0
typedef struct _KPROCESS typedef struct _EPROCESS
{ {
DISPATCHER_HEADER Header; DISPATCHER_HEADER Header;
LIST_ENTRY ProfileListHead; LIST_ENTRY ProfileListHead;
ULONG DirectoryTableBase; ULONG_PTR DirectoryTableBase;
// ... UCHAR Data[1];
} KPROCESS, *PKPROCESS;
typedef struct _EPROCESS
{
KPROCESS Pcb;
// ...
} EPROCESS, * PEPROCESS; } EPROCESS, * PEPROCESS;
namespace namespace
@ -500,7 +494,7 @@ void inject_invalid_opcode(vmx::guest_context& guest_context)
cr3 get_current_process_cr3() cr3 get_current_process_cr3()
{ {
cr3 guest_cr3{}; cr3 guest_cr3{};
guest_cr3.flags = PsGetCurrentProcess()->Pcb.DirectoryTableBase; guest_cr3.flags = PsGetCurrentProcess()->DirectoryTableBase;
return guest_cr3; return guest_cr3;
} }
@ -519,6 +513,52 @@ bool is_mem_equal(const uint8_t* ptr, const uint8_t (&array)[Length])
return true; return true;
} }
enum class syscall_state
{
is_sysret,
is_syscall,
none,
};
syscall_state get_syscall_state(const vmx::guest_context& guest_context)
{
cr3 orignal_cr3{};
orignal_cr3.flags = __readcr3();
const auto _ = utils::finally([&]
{
__writecr3(orignal_cr3.flags);
});
constexpr auto PCID_NONE = 0x000;
constexpr auto PCID_MASK = 0x003;
const auto guest_cr3 = read_vmx(VMCS_GUEST_CR3);
if ((guest_cr3 & PCID_MASK) != PCID_NONE)
{
const auto process_cr3 = get_current_process_cr3();
__writecr3(process_cr3.flags);
}
// TODO: Check for potential page fault
const auto* rip = reinterpret_cast<uint8_t*>(guest_context.guest_rip);
constexpr uint8_t syscall_bytes[] = { 0x0F, 0x05 };
constexpr uint8_t sysret_bytes[] = {0x48, 0x0F, 0x07};
if (is_mem_equal(rip, syscall_bytes))
{
return syscall_state::is_syscall;
}
if (is_mem_equal(rip, sysret_bytes))
{
return syscall_state::is_sysret;
}
return syscall_state::none;
}
void vmx_handle_exception(vmx::guest_context& guest_context) void vmx_handle_exception(vmx::guest_context& guest_context)
{ {
vmexit_interrupt_information interrupt{}; vmexit_interrupt_information interrupt{};
@ -533,21 +573,9 @@ void vmx_handle_exception(vmx::guest_context& guest_context)
if (interrupt.vector == invalid_opcode) if (interrupt.vector == invalid_opcode)
{ {
auto* rip = reinterpret_cast<uint8_t*>(guest_context.guest_rip); const auto state = get_syscall_state(guest_context);
cr3 orignal_cr3{}; if (state == syscall_state::is_syscall)
orignal_cr3.flags = __readcr3();
const auto guest_cr3 = get_current_process_cr3();
__writecr3(guest_cr3.flags);
// TODO: Check for potential page fault
constexpr uint8_t sysret_bytes[] = {0x48, 0x05, 0x07};
constexpr uint8_t syscall_bytes[] = {0x0F, 0x05};
if (is_mem_equal(rip, syscall_bytes))
{ {
guest_context.increment_rip = false; guest_context.increment_rip = false;
@ -587,7 +615,7 @@ void vmx_handle_exception(vmx::guest_context& guest_context)
__vmx_vmwrite(VMCS_GUEST_SS_ACCESS_RIGHTS, gdt_entry.access_rights.flags); __vmx_vmwrite(VMCS_GUEST_SS_ACCESS_RIGHTS, gdt_entry.access_rights.flags);
__vmx_vmwrite(VMCS_GUEST_SS_BASE, gdt_entry.base); __vmx_vmwrite(VMCS_GUEST_SS_BASE, gdt_entry.base);
} }
else if (is_mem_equal(rip, sysret_bytes)) else if (state == syscall_state::is_sysret)
{ {
guest_context.increment_rip = false; guest_context.increment_rip = false;
@ -952,7 +980,8 @@ void setup_vmcs_for_cpu(vmx::state& vm_state)
__vmx_vmwrite(VMCS_GUEST_DEBUGCTL, state->debug_control); __vmx_vmwrite(VMCS_GUEST_DEBUGCTL, state->debug_control);
__vmx_vmwrite(VMCS_GUEST_DR7, state->kernel_dr7); __vmx_vmwrite(VMCS_GUEST_DR7, state->kernel_dr7);
const auto stack_pointer = reinterpret_cast<uintptr_t>(vm_state.stack_buffer) + KERNEL_STACK_SIZE - sizeof(CONTEXT); const auto stack_pointer = reinterpret_cast<uintptr_t>(vm_state.stack_buffer) + KERNEL_STACK_SIZE - sizeof(
CONTEXT);
__vmx_vmwrite(VMCS_GUEST_RSP, stack_pointer); __vmx_vmwrite(VMCS_GUEST_RSP, stack_pointer);
__vmx_vmwrite(VMCS_GUEST_RIP, reinterpret_cast<uintptr_t>(vm_launch)); __vmx_vmwrite(VMCS_GUEST_RIP, reinterpret_cast<uintptr_t>(vm_launch));