From e3da821ee919db3f8222d4c93f9b34412b1f874a Mon Sep 17 00:00:00 2001 From: momo5502 Date: Sat, 26 Mar 2022 19:50:54 +0100 Subject: [PATCH] More cleanup --- src/driver/driver_main.cpp | 69 +++++++--- src/driver/irp.cpp | 231 +++++++++++++++++----------------- src/driver/irp.hpp | 21 +++- src/driver/sleep_callback.hpp | 1 + 4 files changed, 185 insertions(+), 137 deletions(-) diff --git a/src/driver/driver_main.cpp b/src/driver/driver_main.cpp index 2f3a8b0..d928811 100644 --- a/src/driver/driver_main.cpp +++ b/src/driver/driver_main.cpp @@ -3,50 +3,79 @@ #include "sleep_callback.hpp" #include "irp.hpp" #include "exception.hpp" -#include "finally.hpp" -sleep_callback* sleep_cb{nullptr}; +#define DOS_DEV_NAME L"\\DosDevices\\HelloDev" +#define DEV_NAME L"\\Device\\HelloDev" -void sleep_notification(const sleep_callback::type type) +class global_driver { - if (type == sleep_callback::type::sleep) +public: + global_driver(const PDRIVER_OBJECT driver_object) + : sleep_callback_([this](const sleep_callback::type type) + { + this->sleep_notification(type); + }) + , irp_(driver_object, DEV_NAME, DOS_DEV_NAME) { - debug_log("Going to sleep!"); + debug_log("Driver started\n"); } - if (type == sleep_callback::type::wakeup) + ~global_driver() { - debug_log("Waking up!"); + debug_log("Unloading driver\n"); } -} + + global_driver(global_driver&&) noexcept = delete; + global_driver& operator=(global_driver&&) noexcept = delete; + + global_driver(const global_driver&) = delete; + global_driver& operator=(const global_driver&) = delete; + + void pre_destroy(const PDRIVER_OBJECT /*driver_object*/) + { + } + +private: + sleep_callback sleep_callback_{}; + irp irp_{}; + + void sleep_notification(const sleep_callback::type type) + { + if (type == sleep_callback::type::sleep) + { + debug_log("Going to sleep!"); + } + + if (type == sleep_callback::type::wakeup) + { + debug_log("Waking up!"); + } + } +}; + +global_driver* global_driver_instance{nullptr}; extern "C" void __cdecl __std_terminate() { KeBugCheckEx(DRIVER_VIOLATION, 14, 0, 0, 0); } -void destroy_sleep_callback() -{ - delete sleep_cb; -} - _Function_class_(DRIVER_UNLOAD) void unload(const PDRIVER_OBJECT driver_object) { - irp::uninitialize(driver_object); - destroy_sleep_callback(); + if (global_driver_instance) + { + global_driver_instance->pre_destroy(driver_object); + delete global_driver_instance; + } } extern "C" NTSTATUS DriverEntry(const PDRIVER_OBJECT driver_object, PUNICODE_STRING /*registry_path*/) { driver_object->DriverUnload = unload; - auto sleep_destructor = utils::finally(&destroy_sleep_callback); - try { - sleep_cb = new sleep_callback(sleep_notification); - irp::initialize(driver_object); - sleep_destructor.cancel(); + global_driver_instance = new global_driver(driver_object); } catch (std::exception& e) { diff --git a/src/driver/irp.cpp b/src/driver/irp.cpp index 7b241f4..82393bf 100644 --- a/src/driver/irp.cpp +++ b/src/driver/irp.cpp @@ -3,140 +3,145 @@ #include "logging.hpp" #include "exception.hpp" -#define DOS_DEV_NAME L"\\DosDevices\\HelloDev" -#define DEV_NAME L"\\Device\\HelloDev" - #define HELLO_DRV_IOCTL CTL_CODE(FILE_DEVICE_UNKNOWN, 0x800, METHOD_NEITHER, FILE_ANY_ACCESS) -namespace irp +namespace { - namespace + UNICODE_STRING get_unicode_string(const wchar_t* string) { - UNICODE_STRING get_unicode_string(const wchar_t* string) - { - UNICODE_STRING unicode_string{}; - RtlInitUnicodeString(&unicode_string, string); - return unicode_string; - } - - UNICODE_STRING get_device_name() - { - return get_unicode_string(DEV_NAME); - } - - UNICODE_STRING get_dos_device_name() - { - return get_unicode_string(DOS_DEV_NAME); - } - - _Function_class_(DRIVER_DISPATCH) NTSTATUS not_supported_handler( - PDEVICE_OBJECT /*device_object*/, const PIRP irp) - { - PAGED_CODE() - - irp->IoStatus.Information = 0; - irp->IoStatus.Status = STATUS_NOT_SUPPORTED; - - IoCompleteRequest(irp, IO_NO_INCREMENT); - - return STATUS_NOT_SUPPORTED; - } - - _Function_class_(DRIVER_DISPATCH) NTSTATUS success_handler(PDEVICE_OBJECT /*device_object*/, const PIRP irp) - { - PAGED_CODE() - - irp->IoStatus.Information = 0; - irp->IoStatus.Status = STATUS_SUCCESS; - - IoCompleteRequest(irp, IO_NO_INCREMENT); - - return STATUS_SUCCESS; - } - - _Function_class_(DRIVER_DISPATCH) NTSTATUS io_ctl_handler( - PDEVICE_OBJECT /*device_object*/, const PIRP irp) - { - PAGED_CODE() - - irp->IoStatus.Information = 0; - irp->IoStatus.Status = STATUS_NOT_SUPPORTED; - - const auto irp_sp = IoGetCurrentIrpStackLocation(irp); - - if (irp_sp) - { - const auto ioctr_code = irp_sp->Parameters.DeviceIoControl.IoControlCode; - - switch (ioctr_code) - { - case HELLO_DRV_IOCTL: - debug_log("[< HelloDriver >] Hello from the Driver!\n"); - break; - default: - debug_log("[-] Invalid IOCTL Code: 0x%X\n", ioctr_code); - irp->IoStatus.Status = STATUS_INVALID_DEVICE_REQUEST; - break; - } - } - - IoCompleteRequest(irp, IO_NO_INCREMENT); - - return irp->IoStatus.Status; - } + UNICODE_STRING unicode_string{}; + RtlInitUnicodeString(&unicode_string, string); + return unicode_string; } - _Function_class_(DRIVER_DISPATCH) void uninitialize(const PDRIVER_OBJECT driver_object) + _Function_class_(DRIVER_DISPATCH) NTSTATUS not_supported_handler(PDEVICE_OBJECT /*device_object*/, const PIRP irp) { PAGED_CODE() - auto dos_device_name = get_dos_device_name(); + irp->IoStatus.Information = 0; + irp->IoStatus.Status = STATUS_NOT_SUPPORTED; - IoDeleteSymbolicLink(&dos_device_name); - IoDeleteDevice(driver_object->DeviceObject); + IoCompleteRequest(irp, IO_NO_INCREMENT); + + return STATUS_NOT_SUPPORTED; } - void initialize(const PDRIVER_OBJECT driver_object) + _Function_class_(DRIVER_DISPATCH) NTSTATUS success_handler(PDEVICE_OBJECT /*device_object*/, const PIRP irp) { PAGED_CODE() - auto device_name = get_device_name(); - auto dos_device_name = get_dos_device_name(); + irp->IoStatus.Information = 0; + irp->IoStatus.Status = STATUS_SUCCESS; - PDEVICE_OBJECT device_object{}; - auto destructor = utils::finally([&device_object]() + IoCompleteRequest(irp, IO_NO_INCREMENT); + + return STATUS_SUCCESS; + } + + _Function_class_(DRIVER_DISPATCH) NTSTATUS io_ctl_handler( + PDEVICE_OBJECT /*device_object*/, const PIRP irp) + { + PAGED_CODE() + + irp->IoStatus.Information = 0; + irp->IoStatus.Status = STATUS_NOT_SUPPORTED; + + const auto irp_sp = IoGetCurrentIrpStackLocation(irp); + + if (irp_sp) { - if (device_object) + const auto ioctr_code = irp_sp->Parameters.DeviceIoControl.IoControlCode; + + switch (ioctr_code) { - IoDeleteDevice(device_object); + case HELLO_DRV_IOCTL: + debug_log("[< HelloDriver >] Hello from the Driver!\n"); + break; + default: + debug_log("[-] Invalid IOCTL Code: 0x%X\n", ioctr_code); + irp->IoStatus.Status = STATUS_INVALID_DEVICE_REQUEST; + break; } - }); - - auto status = IoCreateDevice(driver_object, 0, &device_name, FILE_DEVICE_UNKNOWN, FILE_DEVICE_SECURE_OPEN, - FALSE, &device_object); - if (!NT_SUCCESS(status)) - { - throw std::runtime_error("Unable to create device"); } - for (auto i = 0u; i <= IRP_MJ_MAXIMUM_FUNCTION; i++) - { - driver_object->MajorFunction[i] = not_supported_handler; - } + IoCompleteRequest(irp, IO_NO_INCREMENT); - driver_object->MajorFunction[IRP_MJ_CREATE] = success_handler; - driver_object->MajorFunction[IRP_MJ_CLOSE] = success_handler; - driver_object->MajorFunction[IRP_MJ_DEVICE_CONTROL] = io_ctl_handler; - - device_object->Flags |= DO_DIRECT_IO; - device_object->Flags &= ~DO_DEVICE_INITIALIZING; - - status = IoCreateSymbolicLink(&dos_device_name, &device_name); - if (!NT_SUCCESS(status)) - { - throw std::runtime_error("Unable to create symbolic link"); - } - - destructor.cancel(); + return irp->IoStatus.Status; } } + +irp::irp(const PDRIVER_OBJECT driver_object, const wchar_t* device_name, const wchar_t* dos_device_name) +{ + PAGED_CODE() + + this->device_name_ = get_unicode_string(device_name); + this->dos_device_name_ = get_unicode_string(dos_device_name); + + auto destructor = utils::finally([this]() + { + if (this->device_object_) + { + IoDeleteDevice(this->device_object_); + } + }); + + auto status = IoCreateDevice(driver_object, 0, &this->device_name_, FILE_DEVICE_UNKNOWN, FILE_DEVICE_SECURE_OPEN, + FALSE, &this->device_object_); + if (!NT_SUCCESS(status)) + { + throw std::runtime_error("Unable to create device"); + } + + for (auto i = 0u; i <= IRP_MJ_MAXIMUM_FUNCTION; i++) + { + driver_object->MajorFunction[i] = not_supported_handler; + } + + driver_object->MajorFunction[IRP_MJ_CREATE] = success_handler; + driver_object->MajorFunction[IRP_MJ_CLOSE] = success_handler; + driver_object->MajorFunction[IRP_MJ_DEVICE_CONTROL] = io_ctl_handler; + + this->device_object_->Flags |= DO_DIRECT_IO; + this->device_object_->Flags &= ~DO_DEVICE_INITIALIZING; + + status = IoCreateSymbolicLink(&this->dos_device_name_, &this->device_name_); + if (!NT_SUCCESS(status)) + { + throw std::runtime_error("Unable to create symbolic link"); + } + + destructor.cancel(); +} + +irp::~irp() +{ + PAGED_CODE() + + if (this->device_object_) + { + IoDeleteSymbolicLink(&this->dos_device_name_); + IoDeleteDevice(this->device_object_); + } +} + +irp::irp(irp&& obj) noexcept + : irp() +{ + this->operator=(std::move(obj)); +} + +irp& irp::operator=(irp&& obj) noexcept +{ + if (this != &obj) + { + this->~irp(); + + this->device_name_ = obj.device_name_; + this->dos_device_name_ = obj.dos_device_name_; + this->device_object_ = obj.device_object_; + + obj.device_object_ = nullptr; + } + + return *this; +} diff --git a/src/driver/irp.hpp b/src/driver/irp.hpp index 2598dbc..fdec58c 100644 --- a/src/driver/irp.hpp +++ b/src/driver/irp.hpp @@ -1,7 +1,20 @@ #pragma once -namespace irp +class irp { - void initialize(PDRIVER_OBJECT driver_object); - void uninitialize(PDRIVER_OBJECT driver_object); -} +public: + irp() = default; + irp(PDRIVER_OBJECT driver_object, const wchar_t* device_name, const wchar_t* dos_device_name); + ~irp(); + + irp(irp&& obj) noexcept; + irp& operator=(irp&& obj) noexcept; + + irp(const irp&) = delete; + irp& operator=(const irp&) = delete; + +private: + UNICODE_STRING device_name_{}; + UNICODE_STRING dos_device_name_{}; + PDEVICE_OBJECT device_object_{}; +}; diff --git a/src/driver/sleep_callback.hpp b/src/driver/sleep_callback.hpp index d08ce79..d931b78 100644 --- a/src/driver/sleep_callback.hpp +++ b/src/driver/sleep_callback.hpp @@ -12,6 +12,7 @@ public: using callback_function = std::function; + sleep_callback() = default; sleep_callback(callback_function&& callback); ~sleep_callback();