diff --git a/src/driver/driver_main.cpp b/src/driver/driver_main.cpp index cad9cfc..1ffeda3 100644 --- a/src/driver/driver_main.cpp +++ b/src/driver/driver_main.cpp @@ -1,6 +1,7 @@ #include "std_include.hpp" #include "logging.hpp" #include "thread.hpp" +#include "sleep_callback.hpp" #define DOS_DEV_NAME L"\\DosDevices\\HelloDev" #define DEV_NAME L"\\Device\\HelloDev" @@ -157,85 +158,7 @@ NTSTATUS create_io_device(const PDRIVER_OBJECT DriverObject) return Status; } -_Function_class_(CALLBACK_FUNCTION) -VOID -PowerCallback( - _In_opt_ PVOID CallbackContext, - _In_opt_ PVOID Argument1, - _In_opt_ PVOID Argument2 -) -{ - UNREFERENCED_PARAMETER(CallbackContext); - - // - // Ignore non-Sx changes - // - if (Argument1 != (PVOID)PO_CB_SYSTEM_STATE_LOCK) - { - return; - } - - // - // Check if this is S0->Sx, or Sx->S0 - // - if (ARGUMENT_PRESENT(Argument2)) - { - // - // Reload the hypervisor - // - debug_log("Waking up!\n"); - } - else - { - // - // Unload the hypervisor - // - debug_log("Going to sleep!\n"); - } -} - -PVOID g_PowerCallbackRegistration{nullptr}; - -NTSTATUS register_sleep_callback() -{ - PCALLBACK_OBJECT callbackObject; - UNICODE_STRING callbackName = - RTL_CONSTANT_STRING(L"\\Callback\\PowerState"); - OBJECT_ATTRIBUTES objectAttributes = - RTL_CONSTANT_OBJECT_ATTRIBUTES(&callbackName, - OBJ_CASE_INSENSITIVE | - OBJ_KERNEL_HANDLE); - - auto status = ExCreateCallback(&callbackObject, &objectAttributes, FALSE, TRUE); - if (!NT_SUCCESS(status)) - { - return status; - } - - // - // Now register our routine with this callback - // - g_PowerCallbackRegistration = ExRegisterCallback(callbackObject, - PowerCallback, - NULL); - - // - // Dereference it in both cases -- either it's registered, so that is now - // taking a reference, and we'll unregister later, or it failed to register - // so we failing now, and it's gone. - // - ObDereferenceObject(callbackObject); - - // - // Fail if we couldn't register the power callback - // - if (g_PowerCallbackRegistration == NULL) - { - return STATUS_INSUFFICIENT_RESOURCES; - } - - return STATUS_SUCCESS; -} +sleep_callback* sleep_cb{nullptr}; _Function_class_(DRIVER_UNLOAD) @@ -243,7 +166,7 @@ void unload(PDRIVER_OBJECT DriverObject) { debug_log("Leaving World\n"); IrpUnloadHandler(DriverObject); - ExUnregisterCallback(g_PowerCallbackRegistration); + delete sleep_cb; } void throw_test() @@ -287,7 +210,29 @@ extern "C" NTSTATUS DriverEntry(const PDRIVER_OBJECT DriverObject, PUNICODE_STRI debug_log("Final i = %i\n", i); throw_test(); - register_sleep_callback(); + + try + { + sleep_cb = new sleep_callback([](const sleep_callback::type type) + { + if (type == sleep_callback::type::sleep) + { + debug_log("Going to sleep!"); + } + + if (type == sleep_callback::type::wakeup) + { + debug_log("Waking up!"); + } + }); + + sleep_cb->dispatcher(sleep_callback::type::sleep); + sleep_cb->dispatcher(sleep_callback::type::wakeup); + } + catch (...) + { + debug_log("Failed to register sleep callback"); + } return create_io_device(DriverObject); diff --git a/src/driver/exception.hpp b/src/driver/exception.hpp new file mode 100644 index 0000000..585bb53 --- /dev/null +++ b/src/driver/exception.hpp @@ -0,0 +1,38 @@ +#pragma once +#include "type_traits.hpp" + +namespace std +{ + class exception + { + public: + exception& operator=(const exception& obj) noexcept = default; + exception& operator=(exception&& obj) noexcept = default; + + virtual ~exception() = default; + virtual const char* what() const noexcept = 0; + }; + + class runtime_error : public exception + { + public: + runtime_error(const char* message) + : message_(message) + { + + } + + runtime_error(const runtime_error& obj) noexcept = default; + runtime_error& operator=(const runtime_error& obj) noexcept = default; + + runtime_error(runtime_error&& obj) noexcept = default; + runtime_error& operator=(runtime_error&& obj) noexcept = default; + + const char* what() const noexcept override + { + return message_; + } + private: + const char* message_{}; + }; +} diff --git a/src/driver/finally.hpp b/src/driver/finally.hpp new file mode 100644 index 0000000..9c3cfda --- /dev/null +++ b/src/driver/finally.hpp @@ -0,0 +1,55 @@ +#pragma once +#include "type_traits.hpp" + +namespace utils +{ + /* + * Copied from here: https://github.com/microsoft/GSL/blob/e0880931ae5885eb988d1a8a57acf8bc2b8dacda/include/gsl/util#L57 + */ + + template + class final_action + { + public: + /*static_assert(!std::is_reference::value && !std::is_const::value && + !std::is_volatile::value, + "Final_action should store its callable by value");*/ + + explicit final_action(F f) noexcept : f_(std::move(f)) + { + } + + final_action(final_action&& other) noexcept + : f_(std::move(other.f_)), invoke_(other.invoke_) + { + other.invoke_ = false; + } + + final_action(const final_action&) = delete; + final_action& operator=(const final_action&) = delete; + final_action& operator=(final_action&&) = delete; + + ~final_action() noexcept + { + if (invoke_) f_(); + } + + // Added by momo5502 + void cancel() + { + invoke_ = false; + } + + private: + F f_; + bool invoke_{true}; + }; + + template + final_action::type>::type> + finally(F&& f) noexcept + { + return final_action::type>::type>( + std::forward(f)); + } +} \ No newline at end of file diff --git a/src/driver/functional.hpp b/src/driver/functional.hpp index 70aed0f..6bb8587 100644 --- a/src/driver/functional.hpp +++ b/src/driver/functional.hpp @@ -34,6 +34,8 @@ namespace std std::unique_ptr fn{}; public: + function() = default; + template function(T&& t) : fn(new fn_implementation(std::forward(t))) @@ -49,7 +51,12 @@ namespace std Result operator()(Args ... args) const { - return (*fn)(std::forward(args)...); + return (*this->fn)(std::forward(args)...); + } + + operator bool() const + { + return this->fn; } }; } diff --git a/src/driver/sleep_callback.cpp b/src/driver/sleep_callback.cpp index f1fea29..a3fea9b 100644 --- a/src/driver/sleep_callback.cpp +++ b/src/driver/sleep_callback.cpp @@ -1,2 +1,71 @@ #include "std_include.hpp" #include "sleep_callback.hpp" +#include "exception.hpp" +#include "finally.hpp" + +sleep_callback::sleep_callback(callback_function&& callback) + : callback_(std::move(callback)) +{ + PCALLBACK_OBJECT callback_object{}; + UNICODE_STRING callback_name = RTL_CONSTANT_STRING(L"\\Callback\\PowerState"); + OBJECT_ATTRIBUTES object_attributes = RTL_CONSTANT_OBJECT_ATTRIBUTES( + &callback_name, OBJ_CASE_INSENSITIVE | OBJ_KERNEL_HANDLE); + + const auto _ = utils::finally([&callback_object]() + { + ObDereferenceObject(callback_object); + }); + + const auto status = ExCreateCallback(&callback_object, &object_attributes, FALSE, TRUE); + if (!NT_SUCCESS(status)) + { + throw std::runtime_error("Unable to create callback object"); + } + + this->handle_ = ExRegisterCallback(callback_object, sleep_callback::static_callback, this); + if (!this->handle_) + { + throw std::runtime_error("Unable to register callback"); + } +} + +sleep_callback::~sleep_callback() +{ + if (this->handle_) + { + ExUnregisterCallback(this->handle_); + } +} + + +void sleep_callback::dispatcher(const type type) const +{ + try + { + if (this->callback_) + { + this->callback_(type); + } + } + catch (...) + { + } +} + +_Function_class_(CALLBACK_FUNCTION) + +void sleep_callback::static_callback(void* context, void* argument1, void* argument2) +{ + if (!context || argument1 != reinterpret_cast(PO_CB_SYSTEM_STATE_LOCK)) + { + return; + } + + auto type = type::sleep; + if(ARGUMENT_PRESENT(argument2)) + { + type = type::wakeup; + } + + static_cast(context)->dispatcher(type); +} diff --git a/src/driver/sleep_callback.hpp b/src/driver/sleep_callback.hpp index 3cff8f7..e0afd3b 100644 --- a/src/driver/sleep_callback.hpp +++ b/src/driver/sleep_callback.hpp @@ -1,2 +1,33 @@ #pragma once -#include "functional.hpp" \ No newline at end of file +#include "functional.hpp" + +class sleep_callback +{ +public: + enum class type + { + sleep, + wakeup, + }; + + using callback_function = std::function; + + sleep_callback(callback_function&& callback); + ~sleep_callback(); + + sleep_callback(sleep_callback&& obj) noexcept = delete; + sleep_callback& operator=(sleep_callback&& obj) noexcept = delete; + + sleep_callback(const sleep_callback& obj) = delete; + sleep_callback& operator=(const sleep_callback& obj) = delete; + void dispatcher(type type) const; + +private: + void* handle_{nullptr}; + callback_function callback_{}; + + + + _Function_class_(CALLBACK_FUNCTION) + static void static_callback(void* context, void* argument1, void* argument2); +}; diff --git a/src/driver/type_traits.hpp b/src/driver/type_traits.hpp index 625972c..5b3062e 100644 --- a/src/driver/type_traits.hpp +++ b/src/driver/type_traits.hpp @@ -51,4 +51,15 @@ namespace std // forward an rvalue as an rvalue return (static_cast<_Ty&&>(_Arg)); } + + template< class T > struct remove_cv { typedef T type; }; + template< class T > struct remove_cv { typedef T type; }; + template< class T > struct remove_cv { typedef T type; }; + template< class T > struct remove_cv { typedef T type; }; + + template< class T > struct remove_const { typedef T type; }; + template< class T > struct remove_const { typedef T type; }; + + template< class T > struct remove_volatile { typedef T type; }; + template< class T > struct remove_volatile { typedef T type; }; } diff --git a/src/driver/unique_ptr.hpp b/src/driver/unique_ptr.hpp index 8585fa5..9a6b04a 100644 --- a/src/driver/unique_ptr.hpp +++ b/src/driver/unique_ptr.hpp @@ -6,6 +6,7 @@ namespace std template class unique_ptr { + public: unique_ptr() = default; unique_ptr(T* pointer) @@ -63,6 +64,11 @@ namespace std return *this->pointer_; } + operator bool() const + { + return this->pointer_; + } + private: T* pointer_{nullptr}; };