Add sleep callback

This commit is contained in:
momo5502 2022-03-26 15:56:36 +01:00
parent eef4a9a5a2
commit da7204ee90
8 changed files with 245 additions and 83 deletions

View File

@ -1,6 +1,7 @@
#include "std_include.hpp"
#include "logging.hpp"
#include "thread.hpp"
#include "sleep_callback.hpp"
#define DOS_DEV_NAME L"\\DosDevices\\HelloDev"
#define DEV_NAME L"\\Device\\HelloDev"
@ -157,85 +158,7 @@ NTSTATUS create_io_device(const PDRIVER_OBJECT DriverObject)
return Status;
}
_Function_class_(CALLBACK_FUNCTION)
VOID
PowerCallback(
_In_opt_ PVOID CallbackContext,
_In_opt_ PVOID Argument1,
_In_opt_ PVOID Argument2
)
{
UNREFERENCED_PARAMETER(CallbackContext);
//
// Ignore non-Sx changes
//
if (Argument1 != (PVOID)PO_CB_SYSTEM_STATE_LOCK)
{
return;
}
//
// Check if this is S0->Sx, or Sx->S0
//
if (ARGUMENT_PRESENT(Argument2))
{
//
// Reload the hypervisor
//
debug_log("Waking up!\n");
}
else
{
//
// Unload the hypervisor
//
debug_log("Going to sleep!\n");
}
}
PVOID g_PowerCallbackRegistration{nullptr};
NTSTATUS register_sleep_callback()
{
PCALLBACK_OBJECT callbackObject;
UNICODE_STRING callbackName =
RTL_CONSTANT_STRING(L"\\Callback\\PowerState");
OBJECT_ATTRIBUTES objectAttributes =
RTL_CONSTANT_OBJECT_ATTRIBUTES(&callbackName,
OBJ_CASE_INSENSITIVE |
OBJ_KERNEL_HANDLE);
auto status = ExCreateCallback(&callbackObject, &objectAttributes, FALSE, TRUE);
if (!NT_SUCCESS(status))
{
return status;
}
//
// Now register our routine with this callback
//
g_PowerCallbackRegistration = ExRegisterCallback(callbackObject,
PowerCallback,
NULL);
//
// Dereference it in both cases -- either it's registered, so that is now
// taking a reference, and we'll unregister later, or it failed to register
// so we failing now, and it's gone.
//
ObDereferenceObject(callbackObject);
//
// Fail if we couldn't register the power callback
//
if (g_PowerCallbackRegistration == NULL)
{
return STATUS_INSUFFICIENT_RESOURCES;
}
return STATUS_SUCCESS;
}
sleep_callback* sleep_cb{nullptr};
_Function_class_(DRIVER_UNLOAD)
@ -243,7 +166,7 @@ void unload(PDRIVER_OBJECT DriverObject)
{
debug_log("Leaving World\n");
IrpUnloadHandler(DriverObject);
ExUnregisterCallback(g_PowerCallbackRegistration);
delete sleep_cb;
}
void throw_test()
@ -287,7 +210,29 @@ extern "C" NTSTATUS DriverEntry(const PDRIVER_OBJECT DriverObject, PUNICODE_STRI
debug_log("Final i = %i\n", i);
throw_test();
register_sleep_callback();
try
{
sleep_cb = new sleep_callback([](const sleep_callback::type type)
{
if (type == sleep_callback::type::sleep)
{
debug_log("Going to sleep!");
}
if (type == sleep_callback::type::wakeup)
{
debug_log("Waking up!");
}
});
sleep_cb->dispatcher(sleep_callback::type::sleep);
sleep_cb->dispatcher(sleep_callback::type::wakeup);
}
catch (...)
{
debug_log("Failed to register sleep callback");
}
return create_io_device(DriverObject);

38
src/driver/exception.hpp Normal file
View File

@ -0,0 +1,38 @@
#pragma once
#include "type_traits.hpp"
namespace std
{
class exception
{
public:
exception& operator=(const exception& obj) noexcept = default;
exception& operator=(exception&& obj) noexcept = default;
virtual ~exception() = default;
virtual const char* what() const noexcept = 0;
};
class runtime_error : public exception
{
public:
runtime_error(const char* message)
: message_(message)
{
}
runtime_error(const runtime_error& obj) noexcept = default;
runtime_error& operator=(const runtime_error& obj) noexcept = default;
runtime_error(runtime_error&& obj) noexcept = default;
runtime_error& operator=(runtime_error&& obj) noexcept = default;
const char* what() const noexcept override
{
return message_;
}
private:
const char* message_{};
};
}

55
src/driver/finally.hpp Normal file
View File

@ -0,0 +1,55 @@
#pragma once
#include "type_traits.hpp"
namespace utils
{
/*
* Copied from here: https://github.com/microsoft/GSL/blob/e0880931ae5885eb988d1a8a57acf8bc2b8dacda/include/gsl/util#L57
*/
template <class F>
class final_action
{
public:
/*static_assert(!std::is_reference<F>::value && !std::is_const<F>::value &&
!std::is_volatile<F>::value,
"Final_action should store its callable by value");*/
explicit final_action(F f) noexcept : f_(std::move(f))
{
}
final_action(final_action&& other) noexcept
: f_(std::move(other.f_)), invoke_(other.invoke_)
{
other.invoke_ = false;
}
final_action(const final_action&) = delete;
final_action& operator=(const final_action&) = delete;
final_action& operator=(final_action&&) = delete;
~final_action() noexcept
{
if (invoke_) f_();
}
// Added by momo5502
void cancel()
{
invoke_ = false;
}
private:
F f_;
bool invoke_{true};
};
template <class F>
final_action<typename std::remove_cv<typename std::remove_reference<F>::type>::type>
finally(F&& f) noexcept
{
return final_action<typename std::remove_cv<typename std::remove_reference<F>::type>::type>(
std::forward<F>(f));
}
}

View File

@ -34,6 +34,8 @@ namespace std
std::unique_ptr<fn_interface> fn{};
public:
function() = default;
template <typename T>
function(T&& t)
: fn(new fn_implementation<T>(std::forward<T>(t)))
@ -49,7 +51,12 @@ namespace std
Result operator()(Args ... args) const
{
return (*fn)(std::forward<Args>(args)...);
return (*this->fn)(std::forward<Args>(args)...);
}
operator bool() const
{
return this->fn;
}
};
}

View File

@ -1,2 +1,71 @@
#include "std_include.hpp"
#include "sleep_callback.hpp"
#include "exception.hpp"
#include "finally.hpp"
sleep_callback::sleep_callback(callback_function&& callback)
: callback_(std::move(callback))
{
PCALLBACK_OBJECT callback_object{};
UNICODE_STRING callback_name = RTL_CONSTANT_STRING(L"\\Callback\\PowerState");
OBJECT_ATTRIBUTES object_attributes = RTL_CONSTANT_OBJECT_ATTRIBUTES(
&callback_name, OBJ_CASE_INSENSITIVE | OBJ_KERNEL_HANDLE);
const auto _ = utils::finally([&callback_object]()
{
ObDereferenceObject(callback_object);
});
const auto status = ExCreateCallback(&callback_object, &object_attributes, FALSE, TRUE);
if (!NT_SUCCESS(status))
{
throw std::runtime_error("Unable to create callback object");
}
this->handle_ = ExRegisterCallback(callback_object, sleep_callback::static_callback, this);
if (!this->handle_)
{
throw std::runtime_error("Unable to register callback");
}
}
sleep_callback::~sleep_callback()
{
if (this->handle_)
{
ExUnregisterCallback(this->handle_);
}
}
void sleep_callback::dispatcher(const type type) const
{
try
{
if (this->callback_)
{
this->callback_(type);
}
}
catch (...)
{
}
}
_Function_class_(CALLBACK_FUNCTION)
void sleep_callback::static_callback(void* context, void* argument1, void* argument2)
{
if (!context || argument1 != reinterpret_cast<PVOID>(PO_CB_SYSTEM_STATE_LOCK))
{
return;
}
auto type = type::sleep;
if(ARGUMENT_PRESENT(argument2))
{
type = type::wakeup;
}
static_cast<sleep_callback*>(context)->dispatcher(type);
}

View File

@ -1,2 +1,33 @@
#pragma once
#include "functional.hpp"
class sleep_callback
{
public:
enum class type
{
sleep,
wakeup,
};
using callback_function = std::function<void(type)>;
sleep_callback(callback_function&& callback);
~sleep_callback();
sleep_callback(sleep_callback&& obj) noexcept = delete;
sleep_callback& operator=(sleep_callback&& obj) noexcept = delete;
sleep_callback(const sleep_callback& obj) = delete;
sleep_callback& operator=(const sleep_callback& obj) = delete;
void dispatcher(type type) const;
private:
void* handle_{nullptr};
callback_function callback_{};
_Function_class_(CALLBACK_FUNCTION)
static void static_callback(void* context, void* argument1, void* argument2);
};

View File

@ -51,4 +51,15 @@ namespace std
// forward an rvalue as an rvalue
return (static_cast<_Ty&&>(_Arg));
}
template< class T > struct remove_cv { typedef T type; };
template< class T > struct remove_cv<const T> { typedef T type; };
template< class T > struct remove_cv<volatile T> { typedef T type; };
template< class T > struct remove_cv<const volatile T> { typedef T type; };
template< class T > struct remove_const { typedef T type; };
template< class T > struct remove_const<const T> { typedef T type; };
template< class T > struct remove_volatile { typedef T type; };
template< class T > struct remove_volatile<volatile T> { typedef T type; };
}

View File

@ -6,6 +6,7 @@ namespace std
template <typename T>
class unique_ptr
{
public:
unique_ptr() = default;
unique_ptr(T* pointer)
@ -63,6 +64,11 @@ namespace std
return *this->pointer_;
}
operator bool() const
{
return this->pointer_;
}
private:
T* pointer_{nullptr};
};