diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ebf6d51..1eae14d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,5 +1,6 @@ -wdk_add_driver(driver +wdk_add_driver(driver main.cpp + thread.cpp ) cmake_path(NATIVE_PATH PROJECT_SOURCE_DIR NORMALIZE WINDOWS_PROJECT_DIR) diff --git a/src/logging.h b/src/logging.h deleted file mode 100644 index 00765f2..0000000 --- a/src/logging.h +++ /dev/null @@ -1,7 +0,0 @@ -#pragma once - -#ifdef NDEBUG -#define DbgLog(...) -#else -#define DbgLog(...) DbgPrintEx(DPFLTR_IHVDRIVER_ID, DPFLTR_ERROR_LEVEL, __VA_ARGS__) -#endif \ No newline at end of file diff --git a/src/logging.hpp b/src/logging.hpp new file mode 100644 index 0000000..fb0e487 --- /dev/null +++ b/src/logging.hpp @@ -0,0 +1,7 @@ +#pragma once + +#ifdef NDEBUG +#define DbgLog(...) +#else +#define debug_log(...) DbgPrintEx(DPFLTR_IHVDRIVER_ID, DPFLTR_ERROR_LEVEL, __VA_ARGS__) +#endif \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index bad5cd6..34564c3 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,43 +1,36 @@ #include -#include "logging.h" -#include "nt_ext.h" - -_Function_class_(KDEFERRED_ROUTINE) - -void NTAPI test_function(struct _KDPC* /*Dpc*/, - PVOID param, - const PVOID arg1, - const PVOID arg2) -{ - const auto core_id = KeGetCurrentProcessorNumberEx(nullptr); - DbgLog("%s from CPU %u\n", static_cast(param), core_id); - - KeSignalCallDpcSynchronize(arg2); - KeSignalCallDpcDone(arg1); -} +#include "logging.hpp" +#include "nt_ext.hpp" +#include "thread.hpp" _Function_class_(DRIVER_UNLOAD) + void unload(PDRIVER_OBJECT /*DriverObject*/) { - DbgLog("Leaving World\n"); - KeGenericCallDpc(test_function, "Bye"); - DbgLog("Bye World\n"); + debug_log("Leaving World\n"); } -extern "C" { - -NTSTATUS DriverEntry(const PDRIVER_OBJECT DriverObject, PUNICODE_STRING /*RegistryPath*/) +extern "C" NTSTATUS DriverEntry(const PDRIVER_OBJECT DriverObject, PUNICODE_STRING /*RegistryPath*/) { DriverObject->DriverUnload = unload; - DbgLog("Hello World\n"); + debug_log("Hello World\n"); - KeGenericCallDpc(test_function, "Hello"); + volatile long i = 0; - DbgLog("Nice World\n"); + thread::dispatch_on_all_cores([&i]() + { + const auto index = thread::get_processor_index(); + while (i != index) + { + } + + debug_log("Hello from CPU %u/%u\n", thread::get_processor_index() + 1, thread::get_processor_count()); + ++i; + }); + + debug_log("Final i = %i\n", i); return STATUS_SUCCESS; } - -} diff --git a/src/nt_ext.h b/src/nt_ext.hpp similarity index 100% rename from src/nt_ext.h rename to src/nt_ext.hpp diff --git a/src/thread.cpp b/src/thread.cpp new file mode 100644 index 0000000..142b6c5 --- /dev/null +++ b/src/thread.cpp @@ -0,0 +1,57 @@ +#include "thread.hpp" +#include +#include "nt_ext.hpp" + +namespace thread +{ + namespace + { + struct dispatch_data + { + void (*callback)(void*){}; + void* data{}; + }; + + + _Function_class_(KDEFERRED_ROUTINE) + + void NTAPI callback_dispatcher(struct _KDPC* /*Dpc*/, + const PVOID param, + const PVOID arg1, + const PVOID arg2) + { + auto* const data = static_cast(param); + data->callback(data->data); + + KeSignalCallDpcSynchronize(arg2); + KeSignalCallDpcDone(arg1); + } + } + + uint32_t get_processor_count() + { + return static_cast(KeQueryActiveProcessorCountEx(0)); + } + + uint32_t get_processor_index() + { + return static_cast(KeGetCurrentProcessorNumberEx(nullptr)); + } + + bool sleep(const uint32_t milliseconds) + { + LARGE_INTEGER interval; + interval.QuadPart = -(10000ll * milliseconds); + + return STATUS_SUCCESS == KeDelayExecutionThread(KernelMode, FALSE, &interval); + } + + void dispatch_on_all_cores(void (*callback)(void*), void* data) + { + dispatch_data callback_data{}; + callback_data.callback = callback; + callback_data.data = data; + + KeGenericCallDpc(callback_dispatcher, &callback_data); + } +} diff --git a/src/thread.hpp b/src/thread.hpp new file mode 100644 index 0000000..b4acb74 --- /dev/null +++ b/src/thread.hpp @@ -0,0 +1,22 @@ +#pragma once + +using uint32_t = int; + +namespace thread +{ + uint32_t get_processor_count(); + uint32_t get_processor_index(); + + bool sleep(uint32_t milliseconds); + + void dispatch_on_all_cores(void(*callback)(void*), void* data); + + template + void dispatch_on_all_cores(F&& callback) + { + dispatch_on_all_cores([](void* data) + { + (*reinterpret_cast(data))(); + }, &callback); + } +}