diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in index 6dfec147d..439779e41 100644 --- a/msipackage/package.wix.in +++ b/msipackage/package.wix.in @@ -353,6 +353,14 @@ + + + + + + + + diff --git a/src/windows/inc/WslPluginApi.h b/src/windows/inc/WslPluginApi.h index 86372d326..c969c080b 100644 --- a/src/windows/inc/WslPluginApi.h +++ b/src/windows/inc/WslPluginApi.h @@ -26,6 +26,9 @@ extern "C" { #define WSLPLUGINAPI_ENTRYPOINTV1 WSLPluginAPIV1_EntryPoint #define WSL_E_PLUGIN_REQUIRES_UPDATE MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, 0x032A) +// Maximum size for mount points returned by WSLCPluginAPI_MountFolder. This includes the null terminator. +#define WSLC_MOUNTPOINT_LENGTH 256 + #define WSL_PLUGIN_REQUIRE_VERSION(_Major, _Minor, _Revision, Api) \ if (Api->Version.Major < (_Major) || (Api->Version.Major == (_Major) && Api->Version.Minor < (_Minor)) || \ (Api->Version.Major == (_Major) && Api->Version.Minor == (_Minor) && Api->Version.Revision < (_Revision))) \ @@ -85,6 +88,30 @@ struct WslOfflineDistributionInformation LPCWSTR Version; // Distribution version. Introduced in 2.4.4 }; +// Identifies a WSLC session inside the WSLC plugin API. Distinct from WSLSessionId. +typedef DWORD WSLCSessionId; + +// Information about a WSLC session passed to plugin notifications. +struct WSLCSessionInformation +{ + WSLCSessionId SessionId; + LPCWSTR DisplayName; + DWORD ApplicationPid; + HANDLE UserToken; + PSID UserSid; +}; + +// Opaque handle to a WSLC process created via WSLCPluginAPI_CreateProcess. +// Must be released with WSLCPluginAPI_ReleaseProcess. +typedef void* WSLCProcessHandle; + +typedef enum _WSLCProcessFd +{ + WSLCProcessFdStdin = 0, + WSLCProcessFdStdout = 1, + WSLCProcessFdStderr = 2 +} WSLCProcessFd; + // Create plan9 mount between Windows & Linux typedef HRESULT (*WSLPluginAPI_MountFolder)(WSLSessionId Session, LPCWSTR WindowsPath, LPCWSTR LinuxPath, BOOL ReadOnly, LPCWSTR Name); @@ -92,6 +119,63 @@ typedef HRESULT (*WSLPluginAPI_MountFolder)(WSLSessionId Session, LPCWSTR Window // On success, 'Socket' is connected to stdin & stdout (stderr goes to dmesg) // 'Arguments' is expected to be NULL terminated typedef HRESULT (*WSLPluginAPI_ExecuteBinary)(WSLSessionId Session, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket); +// +// WSLC plugin hooks. +// + +// Called when a WSLC session is created. Returning an error prevents the session creation. +typedef HRESULT (*WSLPluginAPI_OnSessionCreated)(const struct WSLCSessionInformation* Session); + +// Called when a WSLC session is about to stop. Errors are ignored. +typedef HRESULT (*WSLPluginAPI_OnSessionStopping)(const struct WSLCSessionInformation* Session); + +// Called when a container starts. Returning an error prevents the container creation. +// 'InspectContainer' is a JSON document that follows the wslc_schema::InspectContainer format. +typedef HRESULT (*WSLPluginAPI_ContainerStarted)(const struct WSLCSessionInformation* Session, LPCSTR InspectContainer); + +// Called when a container is about to stop. 'ContainerId' is the container identifier. Errors are ignored. +typedef HRESULT (*WSLPluginAPI_ContainerStopping)(const struct WSLCSessionInformation* Session, LPCSTR ContainerId); + +// Called when an image is created (either pulled, or imported). Errors are ignored. +// 'InspectImage' is a JSON document that follows the wslc_schema::InspectImage format. +// N.B. This callback is currently only invoked when images are pulled or imported. Images created via load or build are not reported. +typedef HRESULT (*WSLPluginAPI_ImageCreated)(const struct WSLCSessionInformation* Session, LPCSTR InspectImage); + +// Called when an image is deleted. 'ImageId' is the deleted image identifier. Errors are ignored. +typedef HRESULT (*WSLPluginAPI_ImageDeleted)(const struct WSLCSessionInformation* Session, LPCSTR ImageId); + +// +// WSLC plugin API calls. +// + +// Mount a Windows folder into the WSLC session VM. The mount path is returned via 'Mountpoint'. +// 'Mountpoint' must point to a buffer of at least WSLC_MOUNTPOINT_LENGTH chars, including the null terminator. +typedef HRESULT (*WSLCPluginAPI_MountFolder)(WSLCSessionId Session, LPCWSTR WindowsPath, BOOL ReadOnly, LPCWSTR Name, LPSTR Mountpoint); + +// Unmount a folder previously mounted via WSLCPluginAPI_MountFolder. +typedef HRESULT (*WSLCPluginAPI_UnmountFolder)(WSLCSessionId Session, LPCSTR Mountpoint); + +// Create a process in the WSLC session's root namespace. +// 'Arguments' and 'Env' are NULL-terminated arrays. 'Env' may be NULL. +// 'Errno' is optional and receives the errno value if the process creation fails. +// On success, 'Process' receives an opaque handle that must be released with WSLCPluginAPI_ReleaseProcess. +typedef HRESULT (*WSLCPluginAPI_CreateProcess)( + WSLCSessionId Session, LPCSTR Executable, LPCSTR* Arguments, LPCSTR* Env, WSLCProcessHandle* Process, int* Errno); + +// Get a stdio handle from a WSLC process. The caller takes ownership and must close it with CloseHandle(). +typedef HRESULT (*WSLCPluginAPI_ProcessGetFd)(WSLCProcessHandle Process, WSLCProcessFd Fd, HANDLE* Handle); + +// Get the exit event for a WSLC process. Signaled when the process exits. +// The caller takes ownership and must close it with CloseHandle(). +typedef HRESULT (*WSLCPluginAPI_ProcessGetExitEvent)(WSLCProcessHandle Process, HANDLE* ExitEvent); + +// Get the exit code of a WSLC process. The process must have exited. +typedef HRESULT (*WSLCPluginAPI_ProcessGetExitCode)(WSLCProcessHandle Process, int* ExitCode); + +// Release a WSLC process handle. All outstanding handles obtained via +// WSLCPluginAPI_ProcessGetFd/GetExitEvent must be closed before calling this. +typedef void (*WSLCPluginAPI_ReleaseProcess)(WSLCProcessHandle Process); + // Execute a program in a user distribution // On success, 'Socket' is connected to stdin & stdout (stderr goes to dmesg) // 'Arguments' is expected to be NULL terminated typedef HRESULT (*WSLPluginAPI_ExecuteBinaryInDistribution)(WSLSessionId Session, const GUID* Distribution, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket); @@ -132,6 +216,14 @@ struct WSLPluginHooksV1 WSLPluginAPI_OnDistributionStopping OnDistributionStopping; WSLPluginAPI_OnDistributionRegistered OnDistributionRegistered; // Introduced in 2.1.2 WSLPluginAPI_OnDistributionRegistered OnDistributionUnregistered; // Introduced in 2.1.2 + + // WSLC hooks. Plugins compiled against older headers leave these zero-initialized. + WSLPluginAPI_OnSessionCreated OnSessionCreated; + WSLPluginAPI_OnSessionStopping OnSessionStopping; + WSLPluginAPI_ContainerStarted ContainerStarted; + WSLPluginAPI_ContainerStopping ContainerStopping; + WSLPluginAPI_ImageCreated ImageCreated; + WSLPluginAPI_ImageDeleted ImageDeleted; }; struct WSLPluginAPIV1 @@ -141,10 +233,19 @@ struct WSLPluginAPIV1 WSLPluginAPI_ExecuteBinary ExecuteBinary; WSLPluginAPI_PluginError PluginError; WSLPluginAPI_ExecuteBinaryInDistribution ExecuteBinaryInDistribution; // Introduced in 2.1.2 + + // WSLC API calls. + WSLCPluginAPI_MountFolder WSLCMountFolder; // Introduced in 2.9.0 + WSLCPluginAPI_UnmountFolder WSLCUnmountFolder; // Introduced in 2.9.0 + WSLCPluginAPI_CreateProcess WSLCCreateProcess; // Introduced in 2.9.0 + WSLCPluginAPI_ProcessGetFd WSLCProcessGetFd; // Introduced in 2.9.0 + WSLCPluginAPI_ProcessGetExitEvent WSLCProcessGetExitEvent; // Introduced in 2.9.0 + WSLCPluginAPI_ProcessGetExitCode WSLCProcessGetExitCode; // Introduced in 2.9.0 + WSLCPluginAPI_ReleaseProcess WSLCReleaseProcess; // Introduced in 2.9.0 }; typedef HRESULT (*WSLPluginAPI_EntryPointV1)(const struct WSLPluginAPIV1* Api, struct WSLPluginHooksV1* Hooks); #ifdef __cplusplus } -#endif \ No newline at end of file +#endif diff --git a/src/windows/service/exe/CMakeLists.txt b/src/windows/service/exe/CMakeLists.txt index 7482eacf9..c7b4ebb0d 100644 --- a/src/windows/service/exe/CMakeLists.txt +++ b/src/windows/service/exe/CMakeLists.txt @@ -23,6 +23,7 @@ set(SOURCES HcsVirtualMachine.cpp WSLCSessionManager.cpp WSLCSessionManagerFactory.cpp + WSLCPluginNotifier.cpp main.rc ${CMAKE_CURRENT_BINARY_DIR}/../mc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wsleventschema.rc application.manifest) @@ -53,7 +54,8 @@ set(HEADERS WslCoreVm.h HcsVirtualMachine.h WSLCSessionManager.h - WSLCSessionManagerFactory.h) + WSLCSessionManagerFactory.h + WSLCPluginNotifier.h) add_executable(wslservice ${SOURCES} ${HEADERS}) add_dependencies(wslservice wslserviceidl wslservicemc) diff --git a/src/windows/service/exe/LxssUserSessionFactory.cpp b/src/windows/service/exe/LxssUserSessionFactory.cpp index ae40cec10..33e1fe7dc 100644 --- a/src/windows/service/exe/LxssUserSessionFactory.cpp +++ b/src/windows/service/exe/LxssUserSessionFactory.cpp @@ -30,7 +30,7 @@ srwlock g_sessionLock; std::optional>> g_sessions = std::make_optional>>(); -std::optional g_pluginManager; +extern wsl::windows::service::PluginManager g_pluginManager; extern unique_event g_networkingReady; extern bool g_lxcoreInitialized; @@ -53,9 +53,6 @@ void ClearSessionsAndBlockNewInstancesLockHeld(std::optional>>(); } - - if (!g_pluginManager.has_value()) - { - g_pluginManager.emplace(); - g_pluginManager->LoadPlugins(); - } } else { @@ -236,7 +227,7 @@ std::weak_ptr CreateInstanceForCurrentUser() if (!userSession) { - userSession.reset(new LxssUserSessionImpl(tokenInfo->User.Sid, sessionId, *g_pluginManager)); + userSession.reset(new LxssUserSessionImpl(tokenInfo->User.Sid, sessionId, g_pluginManager)); g_sessions->emplace_back(userSession); } } diff --git a/src/windows/service/exe/PluginManager.cpp b/src/windows/service/exe/PluginManager.cpp index e4d23f226..38c1bfbdf 100644 --- a/src/windows/service/exe/PluginManager.cpp +++ b/src/windows/service/exe/PluginManager.cpp @@ -17,6 +17,7 @@ Module Name: #include "PluginManager.h" #include "WslPluginApi.h" #include "LxssUserSessionFactory.h" +#include "WSLCSessionManager.h" using wsl::windows::common::Context; using wsl::windows::common::ExecutionContext; @@ -99,7 +100,264 @@ try CATCH_RETURN(); } -static constexpr WSLPluginAPIV1 ApiV1 = {Version, &MountFolder, &ExecuteBinary, &PluginError, &ExecuteBinaryInDistribution}; +namespace { + +// Opaque wrapper around IWSLCProcess, handed out as WSLCProcessHandle to plugins. +struct WslcProcessWrapper +{ + wil::com_ptr Process; +}; + +wil::com_ptr ResolveWslcSession(WSLCSessionId Session) +{ + auto* mgr = wsl::windows::service::wslc::WSLCSessionManagerImpl::Instance(); + THROW_HR_IF(RPC_E_DISCONNECTED, mgr == nullptr); + + return mgr->FindSession(static_cast(Session)); +} + +} // namespace + +extern "C" { + +HRESULT WSLCMountFolder(WSLCSessionId Session, LPCWSTR WindowsPath, BOOL ReadOnly, LPCWSTR Name, LPSTR Mountpoint) +try +{ + RETURN_HR_IF(E_POINTER, WindowsPath == nullptr || Name == nullptr || Mountpoint == nullptr); + auto nameLength = wcslen(Name); + + RETURN_HR_IF_MSG( + E_INVALIDARG, + nameLength == 0 || + !std::ranges::all_of(Name, Name + nameLength, [&](auto c) { return c == '-' || c == '_' || iswalnum(c); }), + "Invalid mount name: %ls", + Name); + + auto session = ResolveWslcSession(Session); + + // Mount the folder under /mnt/wsl-plugin/. Convert Name to UTF-8 for the Linux path. + const auto linuxPath = std::format("/mnt/wsl-plugin/{}", Name); + + THROW_HR_IF_MSG(E_INVALIDARG, linuxPath.length() >= WSLC_MOUNTPOINT_LENGTH, "Mountpoint too long: %hs", linuxPath.c_str()); + + auto result = session->MountWindowsFolder(WindowsPath, linuxPath.c_str(), ReadOnly); + + WSL_LOG( + "WslcPluginMountFolderCall", + TraceLoggingValue(Session, "SessionId"), + TraceLoggingValue(WindowsPath, "WindowsPath"), + TraceLoggingValue(linuxPath.c_str(), "LinuxPath"), + TraceLoggingValue(ReadOnly, "ReadOnly"), + TraceLoggingValue(Name, "Name"), + TraceLoggingValue(result, "Result")); + + if (SUCCEEDED(result)) + { + THROW_HR_IF(E_UNEXPECTED, strcpy_s(Mountpoint, WSLC_MOUNTPOINT_LENGTH, linuxPath.c_str()) != 0); + } + + return result; +} +CATCH_RETURN(); + +HRESULT WSLCUnmountFolder(WSLCSessionId Session, LPCSTR Mountpoint) +try +{ + RETURN_HR_IF(E_POINTER, Mountpoint == nullptr); + + auto session = ResolveWslcSession(Session); + + auto result = session->UnmountWindowsFolder(Mountpoint); + + WSL_LOG( + "WslcPluginUnmountFolderCall", + TraceLoggingValue(Session, "SessionId"), + TraceLoggingValue(Mountpoint, "Mountpoint"), + TraceLoggingValue(result, "Result")); + + return result; +} +CATCH_RETURN(); + +HRESULT WSLCCreateProcess(WSLCSessionId Session, LPCSTR Executable, LPCSTR* Arguments, LPCSTR* Env, WSLCProcessHandle* Process, int* Errno) +try +{ + RETURN_HR_IF(E_POINTER, Executable == nullptr || Process == nullptr); + + *Process = nullptr; + if (Errno != nullptr) + { + *Errno = 0; + } + + auto session = ResolveWslcSession(Session); + + // Count NULL-terminated arrays. + auto countArray = [](LPCSTR* arr) -> ULONG { + if (arr == nullptr) + { + return 0; + } + ULONG count = 0; + while (arr[count] != nullptr) + { + ++count; + } + return count; + }; + + WSLCProcessOptions options{}; + options.CommandLine.Values = Arguments; + options.CommandLine.Count = countArray(Arguments); + options.Environment.Values = Env; + options.Environment.Count = countArray(Env); + options.Flags = WSLCProcessFlagsStdin; + + wil::com_ptr process; + int errnoValue = 0; + auto result = session->CreateRootNamespaceProcess(Executable, &options, &process, &errnoValue); + + if (Errno != nullptr) + { + *Errno = errnoValue; + } + + if (FAILED(result)) + { + WSL_LOG( + "WslcPluginCreateProcessCall", + TraceLoggingValue(Session, "SessionId"), + TraceLoggingValue(Executable, "Executable"), + TraceLoggingValue(result, "Result"), + TraceLoggingValue(errnoValue, "Errno")); + return result; + } + + auto wrapper = std::make_unique(); + wrapper->Process = std::move(process); + *Process = wrapper.release(); + + WSL_LOG( + "WslcPluginCreateProcessCall", + TraceLoggingValue(Session, "SessionId"), + TraceLoggingValue(Executable, "Executable"), + TraceLoggingValue(*Process, "Process"), + TraceLoggingValue(S_OK, "Result")); + + return S_OK; +} +CATCH_RETURN(); + +HRESULT WSLCProcessGetFd(WSLCProcessHandle Process, WSLCProcessFd Fd, HANDLE* Handle) +try +{ + RETURN_HR_IF(E_POINTER, Process == nullptr || Handle == nullptr); + + *Handle = nullptr; + + auto* wrapper = static_cast(Process); + + WSLCFD wslcFd{}; + switch (Fd) + { + case WSLCProcessFdStdin: + wslcFd = WSLCFDStdin; + break; + case WSLCProcessFdStdout: + wslcFd = WSLCFDStdout; + break; + case WSLCProcessFdStderr: + wslcFd = WSLCFDStderr; + break; + default: + WSL_LOG( + "WslcPluginProcessGetFd", TraceLoggingValue(static_cast(Fd), "Fd"), TraceLoggingValue(E_INVALIDARG, "Result")); + return E_INVALIDARG; + } + + WSLCHandle handle{}; + auto result = wrapper->Process->GetStdHandle(wslcFd, &handle); + + WSL_LOG( + "WslcPluginProcessGetFd", + TraceLoggingValue(static_cast(Fd), "Fd"), + TraceLoggingValue(handle.Handle.Socket, "Handle"), + TraceLoggingValue(result, "Result")); + + RETURN_IF_FAILED(result); + WI_ASSERT(handle.Type == WSLCHandleTypeSocket); + + *Handle = handle.Handle.Socket; + return S_OK; +} +CATCH_RETURN(); + +HRESULT WSLCProcessGetExitEvent(WSLCProcessHandle Process, HANDLE* ExitEvent) +try +{ + RETURN_HR_IF(E_POINTER, Process == nullptr || ExitEvent == nullptr); + + *ExitEvent = nullptr; + + auto* wrapper = static_cast(Process); + auto result = wrapper->Process->GetExitEvent(ExitEvent); + + WSL_LOG("WslcPluginProcessGetExitEvent", TraceLoggingValue(*ExitEvent, "ExitEvent"), TraceLoggingValue(result, "Result")); + + return result; +} +CATCH_RETURN(); + +HRESULT WSLCProcessGetExitCode(WSLCProcessHandle Process, int* ExitCode) +try +{ + RETURN_HR_IF(E_POINTER, Process == nullptr || ExitCode == nullptr); + + *ExitCode = -1; + auto* wrapper = static_cast(Process); + + WSLCProcessState state{}; + auto result = wrapper->Process->GetState(&state, ExitCode); + + if (SUCCEEDED(result) && state != WslcProcessStateExited && state != WslcProcessStateSignalled) + { + result = HRESULT_FROM_WIN32(ERROR_INVALID_STATE); + } + + WSL_LOG( + "WslcPluginProcessGetExitCode", + TraceLoggingValue(*ExitCode, "ExitCode"), + TraceLoggingValue(static_cast(state), "State"), + TraceLoggingValue(result, "Result")); + + return result; +} +CATCH_RETURN(); + +void WSLCReleaseProcess(WSLCProcessHandle Process) +{ + if (Process != nullptr) + { + WSL_LOG("WslcPluginReleaseProcess", TraceLoggingValue(Process, "Process")); + delete static_cast(Process); + } +} + +} // extern "C" + +static constexpr WSLPluginAPIV1 ApiV1 = { + Version, + &MountFolder, + &ExecuteBinary, + &PluginError, + &ExecuteBinaryInDistribution, + &WSLCMountFolder, + &WSLCUnmountFolder, + &WSLCCreateProcess, + &WSLCProcessGetFd, + &WSLCProcessGetExitEvent, + &WSLCProcessGetExitCode, + &WSLCReleaseProcess}; void PluginManager::LoadPlugins() { @@ -181,7 +439,7 @@ void PluginManager::OnVmStarted(const WSLSessionInformation* Session, const WSLV WSL_LOG( "PluginOnVmStartedCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid")); - ThrowIfPluginError(e.hooks.OnVMStarted(Session, Settings), Session->SessionId, e.name.c_str()); + ThrowIfPluginError(e.hooks.OnVMStarted(Session, Settings), e.name.c_str()); } } } @@ -217,7 +475,7 @@ void PluginManager::OnDistributionStarted(const WSLSessionInformation* Session, TraceLoggingValue(Session->UserSid, "Sid"), TraceLoggingValue(Distribution->Id, "DistributionId")); - ThrowIfPluginError(e.hooks.OnDistributionStarted(Session, Distribution), Session->SessionId, e.name.c_str()); + ThrowIfPluginError(e.hooks.OnDistributionStarted(Session, Distribution), e.name.c_str()); } } } @@ -282,7 +540,7 @@ void PluginManager::OnDistributionUnregistered(const WSLSessionInformation* Sess } } -void PluginManager::ThrowIfPluginError(HRESULT Result, WSLSessionId Session, LPCWSTR Plugin) +void PluginManager::ThrowIfPluginError(HRESULT Result, LPCWSTR Plugin) { const auto message = std::move(g_pluginErrorMessage); g_pluginErrorMessage.reset(); // std::move() doesn't clear the previous std::optional @@ -322,3 +580,130 @@ void PluginManager::ThrowIfFatalPluginError() const THROW_HR_WITH_USER_ERROR(m_pluginError->error, wsl::shared::Localization::MessageFatalPluginError(m_pluginError->plugin)); } } + +void PluginManager::OnWslcSessionCreated(const WSLCSessionInformation* Session) +{ + ExecutionContext context(Context::Plugin); + + for (const auto& e : m_plugins) + { + if (e.hooks.OnSessionCreated != nullptr) + { + auto result = e.hooks.OnSessionCreated(Session); + WSL_LOG( + "PluginOnWslcSessionCreatedCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->SessionId, "SessionId"), + TraceLoggingValue(Session->DisplayName, "DisplayName"), + TraceLoggingValue(result, "Result")); + + ThrowIfPluginError(result, e.name.c_str()); + } + } +} + +void PluginManager::OnWslcSessionStopping(const WSLCSessionInformation* Session) const +{ + ExecutionContext context(Context::Plugin); + + for (const auto& e : m_plugins) + { + if (e.hooks.OnSessionStopping != nullptr) + { + const auto result = e.hooks.OnSessionStopping(Session); + WSL_LOG( + "PluginOnWslcSessionStoppingCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->SessionId, "SessionId"), + TraceLoggingValue(result, "Result")); + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + } + } +} + +HRESULT PluginManager::OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson) const +try +{ + ExecutionContext context(Context::Plugin); + + for (const auto& e : m_plugins) + { + if (e.hooks.ContainerStarted != nullptr) + { + // Failure here aborts the container creation. Surface the first error. + const auto result = e.hooks.ContainerStarted(Session, InspectJson); + WSL_LOG( + "PluginOnWslcContainerStartedCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->SessionId, "SessionId"), + TraceLoggingValue(result, "Result")); + + ThrowIfPluginError(result, e.name.c_str()); + } + } + return S_OK; +} +CATCH_RETURN() + +void PluginManager::OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId) const +{ + ExecutionContext context(Context::Plugin); + + for (const auto& e : m_plugins) + { + if (e.hooks.ContainerStopping != nullptr) + { + + const auto result = e.hooks.ContainerStopping(Session, ContainerId); + WSL_LOG( + "PluginOnWslcContainerStoppingCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->SessionId, "SessionId"), + TraceLoggingValue(ContainerId, "ContainerId"), + TraceLoggingValue(result, "Result")); + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + } + } +} + +void PluginManager::OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson) const +{ + ExecutionContext context(Context::Plugin); + + for (const auto& e : m_plugins) + { + if (e.hooks.ImageCreated != nullptr) + { + const auto result = e.hooks.ImageCreated(Session, InspectJson); + WSL_LOG( + "PluginOnWslcImageCreatedCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->SessionId, "SessionId"), + TraceLoggingValue(result, "Result")); + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + } + } +} + +void PluginManager::OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId) const +{ + ExecutionContext context(Context::Plugin); + + for (const auto& e : m_plugins) + { + if (e.hooks.ImageDeleted != nullptr) + { + const auto result = e.hooks.ImageDeleted(Session, ImageId); + WSL_LOG( + "PluginOnWslcImageDeletedCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->SessionId, "SessionId"), + TraceLoggingValue(ImageId, "ImageId"), + TraceLoggingValue(result, "Result")); + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + } + } +} diff --git a/src/windows/service/exe/PluginManager.h b/src/windows/service/exe/PluginManager.h index 746a11fa4..a99a3e332 100644 --- a/src/windows/service/exe/PluginManager.h +++ b/src/windows/service/exe/PluginManager.h @@ -43,11 +43,21 @@ class PluginManager void OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* distro) const; void OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro) const; void OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro) const; + + // WSLC notifications. Returning failure from OnSessionCreated/OnContainerStarted causes the + // corresponding operation to be aborted. Other notifications log errors and continue. + void OnWslcSessionCreated(const WSLCSessionInformation* Session); + void OnWslcSessionStopping(const WSLCSessionInformation* Session) const; + HRESULT OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson) const; + void OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId) const; + void OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson) const; + void OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId) const; + void ThrowIfFatalPluginError() const; private: void LoadPlugin(LPCWSTR Name, LPCWSTR Path); - static void ThrowIfPluginError(HRESULT Result, WSLSessionId session, LPCWSTR Plugin); + static void ThrowIfPluginError(HRESULT Result, LPCWSTR Plugin); struct LoadedPlugin { @@ -60,4 +70,4 @@ class PluginManager std::optional m_pluginError; }; -} // namespace wsl::windows::service \ No newline at end of file +} // namespace wsl::windows::service diff --git a/src/windows/service/exe/ServiceMain.cpp b/src/windows/service/exe/ServiceMain.cpp index 059f7d599..9cad83bf0 100644 --- a/src/windows/service/exe/ServiceMain.cpp +++ b/src/windows/service/exe/ServiceMain.cpp @@ -29,6 +29,8 @@ using namespace wsl::windows::policies; bool g_lxcoreInitialized{false}; wil::unique_event g_networkingReady{wil::EventOptions::ManualReset}; +wsl::windows::service::PluginManager g_pluginManager; + // Declare the LxssUserSession COM class. CoCreatableClassWrlCreatorMapInclude(LxssUserSession); @@ -178,6 +180,9 @@ try WSADATA Data; THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &Data)); + // Load plugins. + g_pluginManager.LoadPlugins(); + // Check if WSL is disabled via policy and set up a registry watcher to watch for changes. // // N.B. The registry watcher must be created before checking the policy to avoid missing notifications. diff --git a/src/windows/service/exe/WSLCPluginNotifier.cpp b/src/windows/service/exe/WSLCPluginNotifier.cpp new file mode 100644 index 000000000..7daab35bb --- /dev/null +++ b/src/windows/service/exe/WSLCPluginNotifier.cpp @@ -0,0 +1,66 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. + +#include "precomp.h" +#include "WSLCPluginNotifier.h" + +using wsl::windows::common::COMServiceExecutionContext; +using wsl::windows::service::wslc::WSLCPluginNotifier; + +WSLCPluginNotifier::WSLCPluginNotifier( + wsl::windows::service::PluginManager& Plugins, + ULONG SessionId, + DWORD CreatorPid, + std::wstring DisplayName, + wil::shared_handle UserToken, + std::vector&& UserSid) : + m_plugins(Plugins), m_displayName(std::move(DisplayName)), m_userToken(std::move(UserToken)), m_userSid(std::move(UserSid)) +{ + m_sessionInfo.SessionId = static_cast(SessionId); + m_sessionInfo.DisplayName = m_displayName.c_str(); + m_sessionInfo.ApplicationPid = CreatorPid; + m_sessionInfo.UserToken = m_userToken.get(); + m_sessionInfo.UserSid = m_userSid.empty() ? nullptr : reinterpret_cast(m_userSid.data()); +} + +HRESULT WSLCPluginNotifier::OnContainerStarted(LPCSTR InspectJson) +try +{ + COMServiceExecutionContext context; + + RETURN_HR_IF(E_POINTER, InspectJson == nullptr); + return m_plugins.OnWslcContainerStarted(&m_sessionInfo, InspectJson); +} +CATCH_RETURN(); + +HRESULT WSLCPluginNotifier::OnContainerStopping(LPCSTR ContainerId) +try +{ + COMServiceExecutionContext context; + + RETURN_HR_IF(E_POINTER, ContainerId == nullptr); + m_plugins.OnWslcContainerStopping(&m_sessionInfo, ContainerId); + return S_OK; +} +CATCH_RETURN(); + +HRESULT WSLCPluginNotifier::OnImageCreated(LPCSTR InspectJson) +try +{ + COMServiceExecutionContext context; + + RETURN_HR_IF(E_POINTER, InspectJson == nullptr); + m_plugins.OnWslcImageCreated(&m_sessionInfo, InspectJson); + return S_OK; +} +CATCH_RETURN(); + +HRESULT WSLCPluginNotifier::OnImageDeleted(LPCSTR ImageId) +try +{ + COMServiceExecutionContext context; + + RETURN_HR_IF(E_POINTER, ImageId == nullptr); + m_plugins.OnWslcImageDeleted(&m_sessionInfo, ImageId); + return S_OK; +} +CATCH_RETURN(); diff --git a/src/windows/service/exe/WSLCPluginNotifier.h b/src/windows/service/exe/WSLCPluginNotifier.h new file mode 100644 index 000000000..b8ba4e11e --- /dev/null +++ b/src/windows/service/exe/WSLCPluginNotifier.h @@ -0,0 +1,46 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. + +#pragma once + +#include "wslc.h" +#include "PluginManager.h" +#include +#include + +namespace wsl::windows::service::wslc { + +// +// WSLCPluginNotifier - SYSTEM service implementation of IWSLCPluginNotifier. +// Lives in the SYSTEM service and is passed (via COM marshalling) as a top-level +// parameter to the per-user WSLC session process via IWSLCSessionFactory::CreateSession. +// The per-user process invokes the On* methods, which dispatch to PluginManager. +// +class DECLSPEC_UUID("E29B0F1A-4E18-4F09-83A2-2D6B1B9F8C4D") WSLCPluginNotifier + : public Microsoft::WRL::RuntimeClass, IWSLCPluginNotifier, IFastRundown> +{ +public: + NON_COPYABLE(WSLCPluginNotifier); + NON_MOVABLE(WSLCPluginNotifier); + + WSLCPluginNotifier( + wsl::windows::service::PluginManager& Plugins, + ULONG SessionId, + DWORD CreatorPid, + std::wstring DisplayName, + wil::shared_handle UserToken, + std::vector&& UserSid); + + IFACEMETHOD(OnContainerStarted)(_In_ LPCSTR InspectJson) override; + IFACEMETHOD(OnContainerStopping)(_In_ LPCSTR ContainerId) override; + IFACEMETHOD(OnImageCreated)(_In_ LPCSTR InspectJson) override; + IFACEMETHOD(OnImageDeleted)(_In_ LPCSTR ImageId) override; + +private: + wsl::windows::service::PluginManager& m_plugins; + std::wstring m_displayName; + wil::shared_handle m_userToken; + std::vector m_userSid; + WSLCSessionInformation m_sessionInfo{}; +}; + +} // namespace wsl::windows::service::wslc diff --git a/src/windows/service/exe/WSLCSessionManager.cpp b/src/windows/service/exe/WSLCSessionManager.cpp index 318626cbe..fc3543882 100644 --- a/src/windows/service/exe/WSLCSessionManager.cpp +++ b/src/windows/service/exe/WSLCSessionManager.cpp @@ -31,17 +31,26 @@ Module Name: #include "HcsVirtualMachine.h" #include "WSLCUserSettings.h" #include "WSLCSessionDefaults.h" +#include "WSLCPluginNotifier.h" +#include "PluginManager.h" +#include "ExecutionContext.h" #include "wslutil.h" #include "filesystem.hpp" +extern wsl::windows::service::PluginManager g_pluginManager; + +using wsl::windows::common::COMServiceExecutionContext; using wsl::windows::service::wslc::CallingProcessTokenInfo; using wsl::windows::service::wslc::HcsVirtualMachine; +using wsl::windows::service::wslc::WSLCPluginNotifier; using wsl::windows::service::wslc::WSLCSessionManagerImpl; namespace wslutil = wsl::windows::common::wslutil; namespace settings = wsl::windows::wslc::settings; namespace { +std::atomic g_managerInstance{nullptr}; + // Session settings built server-side from the caller's settings.yaml. struct SessionSettings { @@ -114,8 +123,15 @@ struct SessionSettings } // namespace +WSLCSessionManagerImpl::WSLCSessionManagerImpl() +{ + g_managerInstance.store(this); +} + WSLCSessionManagerImpl::~WSLCSessionManagerImpl() { + g_managerInstance.store(nullptr); + // Terminate all sessions on shutdown. // Call Terminate() directly rather than going through ForEachSession(), // which would needlessly resolve weak references and call GetState(). @@ -123,10 +139,30 @@ WSLCSessionManagerImpl::~WSLCSessionManagerImpl() std::lock_guard lock(m_wslcSessionsLock); for (auto& entry : m_sessions) { + NotifySessionStoppingLockHeld(entry); LOG_IF_FAILED(entry.Ref->Terminate()); } } +void WSLCSessionManagerImpl::NotifySessionStoppingLockHeld(SessionEntry& entry) noexcept +try +{ + if (entry.StoppingNotified) + { + return; + } + + entry.StoppingNotified = true; + WSLCSessionInformation info{}; + info.SessionId = static_cast(entry.SessionId); + info.DisplayName = entry.DisplayName.c_str(); + info.ApplicationPid = entry.CreatorPid; + info.UserToken = entry.UserToken.get(); + info.UserSid = entry.UserSid.data(); + g_pluginManager.OnWslcSessionStopping(&info); +} +CATCH_LOG() + void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings, WSLCSessionFlags Flags, IWSLCSession** WslcSession) { auto tokenInfo = GetCallingProcessTokenInfo(); @@ -144,7 +180,7 @@ void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings, { THROW_HR_IF(WSLC_E_INVALID_SESSION_NAME, Settings->DisplayName == nullptr || wcslen(Settings->DisplayName) == 0); THROW_HR_IF(E_INVALIDARG, Settings->StoragePath != nullptr && wcslen(Settings->StoragePath) == 0); - THROW_HR_IF(WSLC_E_INVALID_SESSION_NAME, wcslen(Settings->DisplayName) >= std::size(WSLCSessionInformation{}.DisplayName)); + THROW_HR_IF(WSLC_E_INVALID_SESSION_NAME, wcslen(Settings->DisplayName) >= std::size(WSLCSessionListEntry{}.DisplayName)); THROW_HR_IF_MSG( E_INVALIDARG, WI_IsAnyFlagSet(Settings->StorageFlags, ~WSLCSessionStorageFlagsValid), @@ -208,6 +244,23 @@ void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings, const auto userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation); + // Capture a duplicated user token + raw SID so PluginManager can build + // WSLCSessionInformation later (e.g. on shutdown) without re-impersonating. + // The token is shared between the SessionEntry and the WSLCPluginNotifier. + wil::unique_handle dupToken; + THROW_IF_WIN32_BOOL_FALSE(DuplicateTokenEx( + userToken.get(), TOKEN_QUERY | TOKEN_DUPLICATE, nullptr, SecurityImpersonation, TokenImpersonation, &dupToken)); + wil::shared_handle sharedToken{dupToken.release()}; + + const DWORD sidLen = GetLengthSid(tokenInfo.TokenInfo->User.Sid); + std::vector storedSid(sidLen); + THROW_IF_WIN32_BOOL_FALSE(CopySid(sidLen, storedSid.data(), tokenInfo.TokenInfo->User.Sid)); + + // Build the plugin notifier service-side. Lifetime tracked via the SessionEntry. + Microsoft::WRL::ComPtr notifier; + notifier = wil::MakeOrThrow( + g_pluginManager, sessionId, creatorPid, std::wstring(resolvedDisplayName), wil::shared_handle(sharedToken), std::vector(storedSid)); + // Create the VM in the SYSTEM service (privileged). auto vm = Microsoft::WRL::Make(Settings); @@ -215,14 +268,14 @@ void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings, auto factory = wslutil::CreateComServerAsUser(__uuidof(WSLCSessionFactory), userToken.get()); AddSessionProcessToJobObject(factory.get()); - // Create the session via the factory. - const auto sessionSettings = CreateSessionSettings(sessionId, creatorPid, Settings, resolvedDisplayName.c_str()); + auto sessionSettings = CreateSessionSettings(sessionId, creatorPid, Settings, resolvedDisplayName.c_str()); wil::com_ptr session; wil::com_ptr serviceRef; - THROW_IF_FAILED(factory->CreateSession(&sessionSettings, vm.Get(), &session, &serviceRef)); + THROW_IF_FAILED(factory->CreateSession(&sessionSettings, vm.Get(), notifier.Get(), &session, &serviceRef)); // Track the session via its service ref, along with metadata and security info. - m_sessions.push_back({std::move(serviceRef), sessionId, creatorPid, resolvedDisplayName, std::move(tokenInfo)}); + m_sessions.push_back(SessionEntry{ + std::move(serviceRef), sessionId, creatorPid, resolvedDisplayName, std::move(tokenInfo), notifier, false, sharedToken, std::move(storedSid)}); // For persistent sessions, also hold a strong reference to keep them alive. const bool persistent = WI_IsFlagSet(Flags, WSLCSessionFlagsPersistent); @@ -231,6 +284,33 @@ void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings, m_persistentSessions.emplace_back(sessionId, session); } + // Notify plugins that the session was created. A failure here aborts session creation. + try + { + auto& entry = m_sessions.back(); + WSLCSessionInformation info{}; + info.SessionId = static_cast(entry.SessionId); + info.DisplayName = entry.DisplayName.c_str(); + info.ApplicationPid = entry.CreatorPid; + info.UserToken = entry.UserToken.get(); + info.UserSid = entry.UserSid.data(); + g_pluginManager.OnWslcSessionCreated(&info); + } + catch (...) + { + const auto error = wil::ResultFromCaughtException(); + + // Plugin rejected the session: tear it down before propagating. + m_sessions.back().StoppingNotified = true; // Don't fire stopping for a session that never started successfully. + LOG_IF_FAILED(m_sessions.back().Ref->Terminate()); + m_sessions.pop_back(); + + auto remove = std::ranges::remove_if(m_persistentSessions, [&](const auto& e) { return e.first == sessionId; }); + m_persistentSessions.erase(remove.begin(), remove.end()); + + THROW_HR(error); + } + *WslcSession = session.detach(); }); @@ -297,9 +377,9 @@ void WSLCSessionManagerImpl::OpenSessionByName(LPCWSTR DisplayName, IWSLCSession THROW_IF_FAILED_MSG(result.value_or(HRESULT_FROM_WIN32(ERROR_NOT_FOUND)), "Session '%ls' not found", DisplayName); } -void WSLCSessionManagerImpl::ListSessions(_Out_ WSLCSessionInformation** Sessions, _Out_ ULONG* SessionsCount) +void WSLCSessionManagerImpl::ListSessions(_Out_ WSLCSessionListEntry** Sessions, _Out_ ULONG* SessionsCount) { - std::vector sessionInfo; + std::vector sessionInfo; ForEachSession([&](auto& entry, const auto&) noexcept { try @@ -307,15 +387,15 @@ void WSLCSessionManagerImpl::ListSessions(_Out_ WSLCSessionInformation** Session wil::unique_hlocal_string sidString; THROW_IF_WIN32_BOOL_FALSE(ConvertSidToStringSidW(entry.Owner.TokenInfo->User.Sid, &sidString)); - auto& it = sessionInfo.emplace_back(WSLCSessionInformation{.SessionId = entry.SessionId, .CreatorPid = entry.CreatorPid}); + auto& it = sessionInfo.emplace_back(WSLCSessionListEntry{.SessionId = entry.SessionId, .CreatorPid = entry.CreatorPid}); wcscpy_s(it.Sid, _countof(it.Sid), sidString.get()); wcscpy_s(it.DisplayName, _countof(it.DisplayName), entry.DisplayName.c_str()); } CATCH_LOG() }); - auto output = wil::make_unique_cotaskmem(sessionInfo.size()); - memcpy(output.get(), sessionInfo.data(), sessionInfo.size() * sizeof(WSLCSessionInformation)); + auto output = wil::make_unique_cotaskmem(sessionInfo.size()); + memcpy(output.get(), sessionInfo.data(), sessionInfo.size() * sizeof(WSLCSessionListEntry)); *Sessions = output.release(); *SessionsCount = static_cast(sessionInfo.size()); @@ -478,26 +558,65 @@ try CATCH_RETURN(); HRESULT WSLCSessionManager::CreateSession(const WSLCSessionSettings* WslcSessionSettings, WSLCSessionFlags Flags, IWSLCSession** WslcSession) +try { + COMServiceExecutionContext context; + return CallImpl(&WSLCSessionManagerImpl::CreateSession, WslcSessionSettings, Flags, WslcSession); } +CATCH_RETURN(); HRESULT WSLCSessionManager::EnterSession(_In_ LPCWSTR DisplayName, _In_ LPCWSTR StoragePath, IWSLCSession** WslcSession) { + COMServiceExecutionContext context; + return CallImpl(&WSLCSessionManagerImpl::EnterSession, DisplayName, StoragePath, WslcSession); } -HRESULT WSLCSessionManager::ListSessions(_Out_ WSLCSessionInformation** Sessions, _Out_ ULONG* SessionsCount) +HRESULT WSLCSessionManager::ListSessions(_Out_ WSLCSessionListEntry** Sessions, _Out_ ULONG* SessionsCount) { + COMServiceExecutionContext context; + return CallImpl(&WSLCSessionManagerImpl::ListSessions, Sessions, SessionsCount); } HRESULT WSLCSessionManager::OpenSession(_In_ ULONG Id, _Out_ IWSLCSession** Session) { + COMServiceExecutionContext context; + return CallImpl(&WSLCSessionManagerImpl::OpenSession, Id, Session); } HRESULT WSLCSessionManager::OpenSessionByName(_In_ LPCWSTR DisplayName, _Out_ IWSLCSession** Session) { + COMServiceExecutionContext context; + return CallImpl(&WSLCSessionManagerImpl::OpenSessionByName, DisplayName, Session); } + +namespace wsl::windows::service::wslc { + +WSLCSessionManagerImpl* WSLCSessionManagerImpl::Instance() noexcept +{ + return g_managerInstance.load(); +} + +wil::com_ptr WSLCSessionManagerImpl::FindSession(ULONG Id) +{ + wil::com_ptr result; + + ForEachSession([&](SessionEntry& entry, const wil::com_ptr& session) noexcept -> std::optional { + if (entry.SessionId != Id) + { + return std::nullopt; + } + + result = session; + return S_OK; + }); + + THROW_HR_IF_MSG(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), !result, "WSLC session %lu not found", Id); + return result; +} + +} // namespace wsl::windows::service::wslc diff --git a/src/windows/service/exe/WSLCSessionManager.h b/src/windows/service/exe/WSLCSessionManager.h index 3ef386cdd..ce5d31eab 100644 --- a/src/windows/service/exe/WSLCSessionManager.h +++ b/src/windows/service/exe/WSLCSessionManager.h @@ -60,6 +60,14 @@ struct SessionEntry DWORD CreatorPid = 0; std::wstring DisplayName; CallingProcessTokenInfo Owner; + + Microsoft::WRL::ComPtr PluginNotifier; + + // Whether OnSessionStopping has been fired already; ensures it is fired exactly once. + bool StoppingNotified = false; + + wil::shared_handle UserToken; + std::vector UserSid; }; class WSLCSessionManagerImpl @@ -68,15 +76,20 @@ class WSLCSessionManagerImpl NON_COPYABLE(WSLCSessionManagerImpl); NON_MOVABLE(WSLCSessionManagerImpl); - WSLCSessionManagerImpl() = default; + WSLCSessionManagerImpl(); ~WSLCSessionManagerImpl(); void CreateSession(const WSLCSessionSettings* WslcSessionSettings, WSLCSessionFlags Flags, IWSLCSession** WslcSession); void EnterSession(_In_ LPCWSTR DisplayName, _In_ LPCWSTR StoragePath, IWSLCSession** WslcSession); - void ListSessions(_Out_ WSLCSessionInformation** Sessions, _Out_ ULONG* SessionsCount); + void ListSessions(_Out_ WSLCSessionListEntry** Sessions, _Out_ ULONG* SessionsCount); void OpenSession(_In_ ULONG Id, _Out_ IWSLCSession** Session); void OpenSessionByName(_In_ LPCWSTR DisplayName, _Out_ IWSLCSession** Session); + // Resolves a session by ID for plugin->API calls. Throws ERROR_NOT_FOUND if no session matches. + wil::com_ptr FindSession(ULONG Id); + + static WSLCSessionManagerImpl* Instance() noexcept; + private: // Resolves the default session name for a caller: appends the username // from the token SID so different users don't collide. @@ -109,7 +122,9 @@ class WSLCSessionManagerImpl wil::com_ptr lockedSession; if (FAILED_LOG(entry.Ref->OpenSession(&lockedSession))) { - // Session is gone, drop the persistent reference if any. + // Session is gone: notify plugins (if not already), then drop persistent reference if any. + NotifySessionStoppingLockHeld(entry); + auto remove = std::ranges::remove_if(m_persistentSessions, [&](const auto& e) { return e.first == entry.SessionId; }); m_persistentSessions.erase(remove.begin(), remove.end()); @@ -151,6 +166,8 @@ class WSLCSessionManagerImpl static CallingProcessTokenInfo GetCallingProcessTokenInfo(); static HRESULT CheckTokenAccess(const SessionEntry& Entry, const CallingProcessTokenInfo& TokenInfo); + void NotifySessionStoppingLockHeld(SessionEntry& entry) noexcept; + std::atomic m_nextSessionId{1}; std::recursive_mutex m_wslcSessionsLock; @@ -183,7 +200,7 @@ class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") WSLCSessionManager IFACEMETHOD(IsClientVersionSupported)(_In_ const WSLCVersion* ClientVersion, _Out_ BOOL* IsSupported) override; IFACEMETHOD(CreateSession)(const WSLCSessionSettings* WslcSessionSettings, WSLCSessionFlags Flags, IWSLCSession** WslcSession) override; IFACEMETHOD(EnterSession)(_In_ LPCWSTR DisplayName, _In_ LPCWSTR StoragePath, IWSLCSession** WslcSession) override; - IFACEMETHOD(ListSessions)(_Out_ WSLCSessionInformation** Sessions, _Out_ ULONG* SessionsCount) override; + IFACEMETHOD(ListSessions)(_Out_ WSLCSessionListEntry** Sessions, _Out_ ULONG* SessionsCount) override; IFACEMETHOD(OpenSession)(_In_ ULONG Id, _Out_ IWSLCSession** Session) override; IFACEMETHOD(OpenSessionByName)(_In_ LPCWSTR DisplayName, _Out_ IWSLCSession** Session) override; }; diff --git a/src/windows/service/exe/WSLCSessionManagerFactory.cpp b/src/windows/service/exe/WSLCSessionManagerFactory.cpp index 6929bffba..09f2ffb20 100644 --- a/src/windows/service/exe/WSLCSessionManagerFactory.cpp +++ b/src/windows/service/exe/WSLCSessionManagerFactory.cpp @@ -82,4 +82,4 @@ void wsl::windows::service::wslc::ClearWslcSessionsAndBlockNewInstances() } g_sessionManagerImpl.reset(); -} \ No newline at end of file +} diff --git a/src/windows/service/inc/wslc.idl b/src/windows/service/inc/wslc.idl index 2ad3216f8..f2ebc1a81 100644 --- a/src/windows/service/inc/wslc.idl +++ b/src/windows/service/inc/wslc.idl @@ -117,6 +117,27 @@ interface IProgressCallback : IUnknown HRESULT OnProgress(LPCSTR Status, LPCSTR Id, ULONGLONG Current, ULONGLONG Total); }; +[ + uuid(F3E6D5B2-1D40-4E8B-9C39-7A45D1C0F8A2), + pointer_default(unique), + object +] +interface IWSLCPluginNotifier : IUnknown +{ + // 'InspectJson' follows the wslc_schema::InspectContainer format. + // Returning failure prevents the container creation. + HRESULT OnContainerStarted([in] LPCSTR InspectJson); + + // Called when a container is about to stop. 'ContainerId' is the container identifier. Errors are logged but ignored. + HRESULT OnContainerStopping([in] LPCSTR ContainerId); + + // 'InspectJson' follows the wslc_schema::InspectImage format. Errors are logged but ignored. + HRESULT OnImageCreated([in] LPCSTR InspectJson); + + // Called when an image is deleted. 'ImageId' is the image identifier. Errors are logged but ignored. + HRESULT OnImageDeleted([in] LPCSTR ImageId); +}; + typedef struct _WSLCImageInformation { char Image[WSLC_MAX_IMAGE_NAME_LENGTH + 1]; @@ -773,7 +794,8 @@ interface IWSLCSession : IUnknown // Initializes the session with a pre-created VM. HRESULT Initialize( [in] const WSLCSessionInitSettings* Settings, - [in] IWSLCVirtualMachine* Vm); + [in] IWSLCVirtualMachine* Vm, + [in] IWSLCPluginNotifier* PluginNotifier); // Volume management. HRESULT CreateVolume([in] const WSLCVolumeOptions* Options, [out] WSLCVolumeInformation* VolumeInfo); @@ -827,6 +849,7 @@ interface IWSLCSessionFactory : IUnknown HRESULT CreateSession( [in] const WSLCSessionInitSettings* Settings, [in] IWSLCVirtualMachine* Vm, + [in] IWSLCPluginNotifier* PluginNotifier, [out] IWSLCSession** Session, [out] IWSLCSessionReference** ServiceRef); @@ -834,13 +857,13 @@ interface IWSLCSessionFactory : IUnknown HRESULT GetProcessHandle([out, system_handle(sh_process)] HANDLE* ProcessHandle); } -typedef struct _WSLCSessionInformation +typedef struct _WSLCSessionListEntry { ULONG SessionId; DWORD CreatorPid; wchar_t DisplayName[256]; wchar_t Sid[256 + 1]; // MAX_SID_SIZE = 256 -} WSLCSessionInformation; +} WSLCSessionListEntry; typedef enum _WSLCSessionFlags { @@ -865,7 +888,7 @@ interface IWSLCSessionManager : IUnknown // Session management. HRESULT CreateSession([in, unique] const WSLCSessionSettings* Settings, WSLCSessionFlags Flags, [out] IWSLCSession** Session); HRESULT EnterSession([in, ref] LPCWSTR DisplayName, [in, ref] LPCWSTR StoragePath, [out] IWSLCSession** Session); - HRESULT ListSessions([out, size_is(, *SessionsCount)] WSLCSessionInformation** Sessions, [out] ULONG* SessionsCount); + HRESULT ListSessions([out, size_is(, *SessionsCount)] WSLCSessionListEntry** Sessions, [out] ULONG* SessionsCount); HRESULT OpenSession([in] ULONG Id, [out] IWSLCSession** Session); HRESULT OpenSessionByName([in, unique] LPCWSTR DisplayName, [out] IWSLCSession** Session); } diff --git a/src/windows/wslc/services/SessionService.cpp b/src/windows/wslc/services/SessionService.cpp index a38b5c5b5..a390294bb 100644 --- a/src/windows/wslc/services/SessionService.cpp +++ b/src/windows/wslc/services/SessionService.cpp @@ -149,7 +149,7 @@ std::vector SessionService::List() THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLCSessionManager), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&sessionManager))); wsl::windows::common::security::ConfigureForCOMImpersonation(sessionManager.get()); - wil::unique_cotaskmem_array_ptr sessions; + wil::unique_cotaskmem_array_ptr sessions; THROW_IF_FAILED(sessionManager->ListSessions(&sessions, sessions.size_address())); for (size_t i = 0; i < sessions.size(); ++i) { diff --git a/src/windows/wslcsession/WSLCContainer.cpp b/src/windows/wslcsession/WSLCContainer.cpp index aaa602b00..2d99a6f88 100644 --- a/src/windows/wslcsession/WSLCContainer.cpp +++ b/src/windows/wslcsession/WSLCContainer.cpp @@ -531,6 +531,7 @@ WSLCPortMapping ContainerPortMapping::Serialize() const WSLCContainerImpl::WSLCContainerImpl( WSLCSession& wslcSession, WSLCVirtualMachine& virtualMachine, + IWSLCPluginNotifier* pluginNotifier, std::string&& Id, std::string&& Name, std::string&& Image, @@ -547,6 +548,7 @@ WSLCContainerImpl::WSLCContainerImpl( WSLCProcessFlags InitProcessFlags, WSLCContainerFlags ContainerFlags) : m_wslcSession(wslcSession), + m_pluginNotifier(pluginNotifier), m_virtualMachine(virtualMachine), m_name(std::move(Name)), m_image(std::move(Image)), @@ -803,6 +805,30 @@ void WSLCContainerImpl::Start(WSLCContainerStartFlags Flags, LPCSTR DetachKeys) } CATCH_AND_THROW_DOCKER_USER_ERROR("Failed to start container '%hs'", m_id.c_str()); + auto inspectJson = InspectLockHeld(); + const auto pluginResult = m_pluginNotifier->OnContainerStarted(inspectJson.c_str()); + if (FAILED(pluginResult)) + { + // Forward the COM error message, if available. + auto comError = wsl::windows::common::wslutil::GetCOMErrorInfo(); + + LOG_HR_MSG(pluginResult, "Plugin rejected start of container '%hs' (0x%x)", m_id.c_str(), pluginResult); + try + { + m_dockerClient.StopContainer(m_id.c_str(), {}, {}); + } + CATCH_LOG(); + + if (comError.has_value() && comError->Message) + { + THROW_HR_WITH_USER_ERROR(pluginResult, comError->Message.get()); + } + else + { + THROW_HR(pluginResult); + } + } + portCleanup.release(); volumeCleanup.release(); @@ -946,6 +972,16 @@ __requires_exclusive_lock_held(m_lock) unique_com_disconnect WSLCContainerImpl:: { unique_com_disconnect comWrapper; + // Notify plugin manager that the container is stopping. Errors are ignored. + if (m_state == WslcContainerStateRunning) + { + try + { + LOG_IF_FAILED(m_pluginNotifier->OnContainerStopping(m_id.c_str())); + } + CATCH_LOG(); + } + ReleaseProcesses(); ReleaseRuntimeResources(); @@ -1298,6 +1334,7 @@ std::unique_ptr WSLCContainerImpl::Create( const std::string& containerName, WSLCSession& wslcSession, WSLCVirtualMachine& virtualMachine, + IWSLCPluginNotifier* pluginNotifier, const std::unordered_map& sessionNetworks, std::function&& OnDeleted, DockerEventTracker& EventTracker, @@ -1651,6 +1688,7 @@ std::unique_ptr WSLCContainerImpl::Create( auto container = std::make_unique( wslcSession, virtualMachine, + pluginNotifier, std::move(result.Id), CleanContainerName(inspectData.Name), std::string(containerOptions.Image), @@ -1675,6 +1713,7 @@ std::unique_ptr WSLCContainerImpl::Open( const common::docker_schema::ContainerInfo& dockerContainer, WSLCSession& wslcSession, WSLCVirtualMachine& virtualMachine, + IWSLCPluginNotifier* pluginNotifier, WSLCVolumes& volumes, std::function&& OnDeleted, DockerEventTracker& EventTracker, @@ -1739,6 +1778,7 @@ std::unique_ptr WSLCContainerImpl::Open( auto container = std::make_unique( wslcSession, virtualMachine, + pluginNotifier, std::string(dockerContainer.Id), std::move(name), std::string(dockerContainer.Image), @@ -1783,19 +1823,23 @@ void WSLCContainerImpl::Inspect(LPSTR* Output) const try { - // Get Docker inspect data - auto dockerInspect = m_dockerClient.InspectContainer(m_id); - - // Convert to WSLC schema - auto wslcInspect = BuildInspectContainer(dockerInspect); - - // Serialize WSLC schema to JSON - std::string wslcJson = wsl::shared::ToJson(wslcInspect); - *Output = wil::make_unique_ansistring(wslcJson.c_str()).release(); + *Output = wil::make_unique_ansistring(InspectLockHeld().c_str()).release(); } CATCH_AND_THROW_DOCKER_USER_ERROR("Failed to inspect container '%hs'", m_id.c_str()); } +std::string WSLCContainerImpl::InspectLockHeld() const +{ + // Get Docker inspect data + auto dockerInspect = m_dockerClient.InspectContainer(m_id); + + // Convert to WSLC schema + auto wslcInspect = BuildInspectContainer(dockerInspect); + + // Serialize WSLC schema to JSON + return wsl::shared::ToJson(wslcInspect); +} + void WSLCContainerImpl::Logs(WSLCLogsFlags Flags, WSLCHandle* Stdout, WSLCHandle* Stderr, ULONGLONG Since, ULONGLONG Until, ULONGLONG Tail) const { auto lock = m_lock.lock_shared(); diff --git a/src/windows/wslcsession/WSLCContainer.h b/src/windows/wslcsession/WSLCContainer.h index 4ebfcd758..f1e433bdd 100644 --- a/src/windows/wslcsession/WSLCContainer.h +++ b/src/windows/wslcsession/WSLCContainer.h @@ -72,6 +72,7 @@ class WSLCContainerImpl WSLCContainerImpl( WSLCSession& wslcSession, WSLCVirtualMachine& virtualMachine, + IWSLCPluginNotifier* pluginNotifier, std::string&& Id, std::string&& Name, std::string&& Image, @@ -130,6 +131,7 @@ class WSLCContainerImpl const std::string& Name, WSLCSession& wslcSession, WSLCVirtualMachine& virtualMachine, + IWSLCPluginNotifier* pluginNotifier, const std::unordered_map& SessionNetworks, std::function&& OnDeleted, DockerEventTracker& EventTracker, @@ -140,6 +142,7 @@ class WSLCContainerImpl const common::docker_schema::ContainerInfo& DockerContainer, WSLCSession& wslcSession, WSLCVirtualMachine& virtualMachine, + IWSLCPluginNotifier* pluginNotifier, WSLCVolumes& Volumes, std::function&& OnDeleted, DockerEventTracker& EventTracker, @@ -169,6 +172,8 @@ class WSLCContainerImpl void MapPorts(); void UnmapPorts(); + __requires_shared_lock_held(m_lock) std::string InspectLockHeld() const; + mutable wil::srwlock m_lock; std::string m_name; std::string m_image; @@ -197,6 +202,7 @@ class WSLCContainerImpl std::uint64_t m_createdAt{}; WSLCContainerState m_state = WslcContainerStateInvalid; WSLCSession& m_wslcSession; + IWSLCPluginNotifier* m_pluginNotifier; WSLCVirtualMachine& m_virtualMachine; std::vector m_mappedPorts; std::vector m_mountedVolumes; diff --git a/src/windows/wslcsession/WSLCSession.cpp b/src/windows/wslcsession/WSLCSession.cpp index bc40f2925..97b27da67 100644 --- a/src/windows/wslcsession/WSLCSession.cpp +++ b/src/windows/wslcsession/WSLCSession.cpp @@ -257,7 +257,7 @@ try } CATCH_RETURN(); -HRESULT WSLCSession::Initialize(_In_ const WSLCSessionInitSettings* Settings, _In_ IWSLCVirtualMachine* Vm) +HRESULT WSLCSession::Initialize(_In_ const WSLCSessionInitSettings* Settings, _In_ IWSLCVirtualMachine* Vm, _In_ IWSLCPluginNotifier* PluginNotifier) try { RETURN_HR_IF(E_POINTER, Settings == nullptr || Vm == nullptr); @@ -267,6 +267,7 @@ try m_id = Settings->SessionId; m_displayName = Settings->DisplayName ? Settings->DisplayName : L""; m_featureFlags = Settings->FeatureFlags; + m_pluginNotifier = PluginNotifier; // Get user token for the current process const auto tokenInfo = wil::get_token_information(GetCurrentProcessToken()); @@ -645,6 +646,20 @@ void WSLCSession::StreamImageOperation(DockerHTTPClient::HTTPRequestContext& req } } +void WSLCSession::OnImageCreated(const std::string& ImageNameOrId) noexcept +try +{ + LOG_IF_FAILED(m_pluginNotifier->OnImageCreated(InspectImageLockHeld(ImageNameOrId).c_str())); +} +CATCH_LOG() + +void WSLCSession::OnImageDeleted(const std::string& ImageId) noexcept +try +{ + LOG_IF_FAILED(m_pluginNotifier->OnImageDeleted(ImageId.c_str())); +} +CATCH_LOG() + HRESULT WSLCSession::PullImage(LPCSTR Image, LPCSTR RegistryAuthenticationInformation, IProgressCallback* ProgressCallback) try { @@ -673,6 +688,8 @@ try auto requestContext = m_dockerClient->PullImage(repo, tagOrDigest, registryAuth); StreamImageOperation(*requestContext, Image, "Pull", ProgressCallback); + OnImageCreated(Image); + return S_OK; } CATCH_RETURN(); @@ -1014,6 +1031,7 @@ try auto requestContext = m_dockerClient->LoadImage(ContentSize); ImportImageImpl(*requestContext, ImageHandle); + return S_OK; } CATCH_RETURN(); @@ -1039,6 +1057,8 @@ try auto requestContext = m_dockerClient->ImportImage(repo, tagOrDigest.value(), ContentSize); ImportImageImpl(*requestContext, ImageHandle); + + OnImageCreated(ImageName); return S_OK; } CATCH_RETURN(); @@ -1095,7 +1115,6 @@ void WSLCSession::ImportImageImpl(DockerHTTPClient::HTTPRequestContext& Request, } else if (parsed.stream.has_value()) { - // TODO: report progress to caller. WSL_LOG("ImageImportProgress", TraceLoggingValue(parsed.stream->c_str(), "Content")); } else @@ -1420,6 +1439,15 @@ try *Count = static_cast(deletedImages.size()); *DeletedImages = output.release(); + // Notify plugin manager of all deleted image IDs. + for (const auto& image : deletedImages) + { + if (!image.Deleted.empty()) + { + OnImageDeleted(image.Deleted); + } + } + return S_OK; } CATCH_RETURN(); @@ -1497,10 +1525,18 @@ try auto lock = m_lock.lock_shared(); RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_dockerClient.has_value()); + *Output = wil::make_unique_ansistring(InspectImageLockHeld(ImageNameOrId).c_str()).release(); + + return S_OK; +} +CATCH_RETURN(); + +std::string WSLCSession::InspectImageLockHeld(const std::string& NameOrId) +{ docker_schema::InspectImage dockerInspect; try { - dockerInspect = m_dockerClient->InspectImage(ImageNameOrId); + dockerInspect = m_dockerClient->InspectImage(NameOrId); } catch (const DockerHTTPException& e) { @@ -1519,12 +1555,8 @@ try auto wslcInspect = ConvertInspectImage(dockerInspect); // Serialize to JSON - std::string wslcJson = wsl::shared::ToJson(wslcInspect); - *Output = wil::make_unique_ansistring(wslcJson.c_str()).release(); - - return S_OK; + return wsl::shared::ToJson(wslcInspect); } -CATCH_RETURN(); HRESULT WSLCSession::Authenticate(_In_ LPCSTR ServerAddress, _In_ LPCSTR Username, _In_ LPCSTR Password, _Out_ LPSTR* IdentityToken) try @@ -1724,6 +1756,7 @@ try containerName, *this, m_virtualMachine.value(), + m_pluginNotifier.get(), m_networks, std::bind(&WSLCSession::OnContainerDeleted, this, std::placeholders::_1), m_eventTracker.value(), @@ -2766,6 +2799,7 @@ void WSLCSession::RecoverExistingContainers() dockerContainer, *this, m_virtualMachine.value(), + m_pluginNotifier.get(), *m_volumes, std::bind(&WSLCSession::OnContainerDeleted, this, std::placeholders::_1), m_eventTracker.value(), diff --git a/src/windows/wslcsession/WSLCSession.h b/src/windows/wslcsession/WSLCSession.h index 572e721c8..c5db04808 100644 --- a/src/windows/wslcsession/WSLCSession.h +++ b/src/windows/wslcsession/WSLCSession.h @@ -85,7 +85,7 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLCSession // IWSLCSession - initialization methods IFACEMETHOD(GetProcessHandle)(_Out_ HANDLE* ProcessHandle) override; - IFACEMETHOD(Initialize)(_In_ const WSLCSessionInitSettings* Settings, _In_ IWSLCVirtualMachine* Vm) override; + IFACEMETHOD(Initialize)(_In_ const WSLCSessionInitSettings* Settings, _In_ IWSLCVirtualMachine* Vm, _In_ IWSLCPluginNotifier* PluginNotifier) override; IFACEMETHOD(GetId)(_Out_ ULONG* Id) override; IFACEMETHOD(GetState)(_Out_ WSLCSessionState* State) override; @@ -177,7 +177,16 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLCSession __requires_lock_held(m_userCOMCallbacksLock) void CancelUserCOMCallbacks(); void ConfigureStorage(const WSLCSessionInitSettings& Settings, PSID UserSid); void Ext4Format(const std::string& Device); + _Requires_shared_lock_held_(m_lock) + std::string InspectImageLockHeld(const std::string& Id); void OnContainerDeleted(const WSLCContainerImpl* Container); + + _Requires_shared_lock_held_(m_lock) + void OnImageCreated(const std::string& ImageNameOrId) noexcept; + + _Requires_shared_lock_held_(m_lock) + void OnImageDeleted(const std::string& ImageId) noexcept; + void OnProcessLog(const gsl::span& Data, PCSTR Source); void OnContainerdExited(); void OnDockerdExited(); @@ -222,6 +231,8 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLCSession std::atomic m_terminating{false}; std::atomic m_terminated{false}; + wil::com_ptr m_pluginNotifier; + // User-provided handles that the session is currently doing IO on. std::mutex m_userHandlesLock; __guarded_by(m_userHandlesLock) std::vector m_userHandles; diff --git a/src/windows/wslcsession/WSLCSessionFactory.cpp b/src/windows/wslcsession/WSLCSessionFactory.cpp index 9b2e287b7..e08f72090 100644 --- a/src/windows/wslcsession/WSLCSessionFactory.cpp +++ b/src/windows/wslcsession/WSLCSessionFactory.cpp @@ -30,7 +30,11 @@ void wslc::WSLCSessionFactory::SetDestructionCallback(std::function&& ca } HRESULT wslc::WSLCSessionFactory::CreateSession( - _In_ const WSLCSessionInitSettings* Settings, _In_ IWSLCVirtualMachine* Vm, _Out_ IWSLCSession** Session, _Out_ IWSLCSessionReference** ServiceRef) + _In_ const WSLCSessionInitSettings* Settings, + _In_ IWSLCVirtualMachine* Vm, + _In_ IWSLCPluginNotifier* PluginNotifier, + _Out_ IWSLCSession** Session, + _Out_ IWSLCSessionReference** ServiceRef) try { *Session = nullptr; @@ -44,7 +48,7 @@ try session->SetDestructionCallback(std::move(m_destructionCallback)); // Initialize the session with the VM. - RETURN_IF_FAILED(session->Initialize(Settings, Vm)); + RETURN_IF_FAILED(session->Initialize(Settings, Vm, PluginNotifier)); // Create the service session ref. It extracts metadata and a weak reference from the session. auto serviceRef = Microsoft::WRL::Make(session.Get()); diff --git a/src/windows/wslcsession/WSLCSessionFactory.h b/src/windows/wslcsession/WSLCSessionFactory.h index 17382c967..266517d28 100644 --- a/src/windows/wslcsession/WSLCSessionFactory.h +++ b/src/windows/wslcsession/WSLCSessionFactory.h @@ -44,8 +44,11 @@ class DECLSPEC_UUID("9FCD2067-9FC6-4EFA-9EB0-698169EBF7D3") WSLCSessionFactory // IWSLCSessionFactory IFACEMETHOD(CreateSession) - (_In_ const WSLCSessionInitSettings* Settings, _In_ IWSLCVirtualMachine* Vm, _Out_ IWSLCSession** Session, _Out_ IWSLCSessionReference** ServiceRef) - override; + (_In_ const WSLCSessionInitSettings* Settings, + _In_ IWSLCVirtualMachine* Vm, + _In_ IWSLCPluginNotifier* PluginNotifier, + _Out_ IWSLCSession** Session, + _Out_ IWSLCSessionReference** ServiceRef) override; IFACEMETHOD(GetProcessHandle)(_Out_ HANDLE* ProcessHandle) override; diff --git a/test/windows/Common.cpp b/test/windows/Common.cpp index 5fb028f84..d18622dd6 100644 --- a/test/windows/Common.cpp +++ b/test/windows/Common.cpp @@ -2904,6 +2904,19 @@ std::filesystem::path GetTestImagePath(std::string_view imageName) return result; } +void LoadTestImage(IWSLCSession& session, std::string_view imageName) +{ + std::filesystem::path imagePath = GetTestImagePath(imageName); + wil::unique_hfile imageFile{ + CreateFileW(imagePath.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr)}; + THROW_LAST_ERROR_IF(!imageFile); + + LARGE_INTEGER fileSize{}; + THROW_LAST_ERROR_IF(!GetFileSizeEx(imageFile.get(), &fileSize)); + + THROW_IF_FAILED(session.LoadImage(wsl::windows::common::wslutil::ToCOMInputHandle(imageFile.get()), nullptr, fileSize.QuadPart)); +} + void ExpectHttpResponse(LPCWSTR Url, std::optional expectedCode, bool retry) { const winrt::Windows::Web::Http::Filters::HttpBaseProtocolFilter filter; @@ -2991,3 +3004,52 @@ void WriteSocket(SOCKET Socket, const void* data, size_t size) data = static_cast(data) + result; } } + +void ValidateCOMErrorMessage(const std::optional& Expected, const std::source_location& Source) +{ + auto comError = wsl::windows::common::wslutil::GetCOMErrorInfo(); + + if (comError.has_value()) + { + if (!Expected.has_value()) + { + LogError("Unexpected COM error: '%ls'. Source: %hs", comError->Message.get(), std::format("{}", Source).c_str()); + VERIFY_FAIL(); + } + + VERIFY_ARE_EQUAL(Expected.value(), comError->Message.get()); + } + else + { + if (Expected.has_value()) + { + LogError("Expected COM error: '%ls' but none was set. Source: %hs", Expected->c_str(), std::format("{}", Source).c_str()); + VERIFY_FAIL(); + } + } +} + +void ValidateCOMErrorMessageContains(const std::wstring& ExpectedSubstring) +{ + auto comError = wsl::windows::common::wslutil::GetCOMErrorInfo(); + + if (comError.has_value()) + { + if (!comError->Message) + { + LogError("Expected COM error containing: '%ls', but COM error message was null", ExpectedSubstring.c_str()); + VERIFY_FAIL(); + } + + if (wcsstr(comError->Message.get(), ExpectedSubstring.c_str()) == nullptr) + { + LogError("Expected COM error containing: '%ls', but got: '%ls'", ExpectedSubstring.c_str(), comError->Message.get()); + VERIFY_FAIL(); + } + } + else + { + LogError("Expected COM error containing: '%ls' but none was set", ExpectedSubstring.c_str()); + VERIFY_FAIL(); + } +} diff --git a/test/windows/Common.h b/test/windows/Common.h index 43cf3a3b7..ad2706ebf 100644 --- a/test/windows/Common.h +++ b/test/windows/Common.h @@ -617,6 +617,8 @@ void VerifyPatternMatch(const std::string& Content, const std::string& Pattern); std::filesystem::path GetTestImagePath(std::string_view imageName); +void LoadTestImage(IWSLCSession& session, std::string_view imageName); + void ExpectHttpResponse(LPCWSTR Url, std::optional expectedCode, bool retry = false); template @@ -678,3 +680,7 @@ void VerifyAreEqualUnordered(const std::vector& expected, const std::vector& Expected, const std::source_location& Source = std::source_location::current()); + +void ValidateCOMErrorMessageContains(const std::wstring& ExpectedSubstring); diff --git a/test/windows/PluginTests.cpp b/test/windows/PluginTests.cpp index 8b42f941e..3f2be8665 100644 --- a/test/windows/PluginTests.cpp +++ b/test/windows/PluginTests.cpp @@ -16,10 +16,16 @@ Module Name: #include "Common.h" #include "registry.hpp" #include "PluginTests.h" +#include "wslc.h" +#include "WSLCContainerLauncher.h" +#include "WSLCProcessLauncher.h" +#include "wslc/e2e/WSLCE2EHelpers.h" using namespace wsl::windows::common::registry; +using WSLCE2ETests::StartLocalRegistry; extern std::wstring g_testDistroPath; +extern std::wstring g_testDataPath; class PluginTests { @@ -592,6 +598,186 @@ class PluginTests StartWsl(0); ValidateLogFile(ExpectedOutput); } + static wil::com_ptr OpenWslcSessionManager() + { + wil::com_ptr sessionManager; + VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLCSessionManager), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&sessionManager))); + wsl::windows::common::security::ConfigureForCOMImpersonation(sessionManager.get()); + return sessionManager; + } + + static wil::com_ptr CreateWslcSession(LPCWSTR Name, WSLCNetworkingMode NetworkingMode = WSLCNetworkingModeNone) + { + WSLCSessionSettings settings{}; + settings.DisplayName = Name; + settings.CpuCount = 4; + settings.MemoryMb = 4096; + settings.BootTimeoutMs = 30 * 1000; + settings.NetworkingMode = NetworkingMode; + + auto manager = OpenWslcSessionManager(); + wil::com_ptr session; + VERIFY_SUCCEEDED(manager->CreateSession(&settings, WSLCSessionFlagsNone, &session)); + wsl::windows::common::security::ConfigureForCOMImpersonation(session.get()); + + WSLCSessionState state{}; + VERIFY_SUCCEEDED(session->GetState(&state)); + VERIFY_ARE_EQUAL(state, WSLCSessionStateRunning); + + return session; + } + + WSL2_TEST_METHOD(WslcSuccess) + { + ConfigurePlugin(PluginTestType::WslcSuccess); + + { + auto session = CreateWslcSession(L"plugin-wslc-test"); + + LoadTestImage(*session, "debian:latest"); + + // Create a container that will have a stuck process so it's still in a running state when the callback is made. + wsl::windows::common::WSLCContainerLauncher launcher( + "debian:latest", "wslc-plugin-container", {"/bin/sh", "-c", "sleep 120"}); + + auto container = launcher.Launch(*session, WSLCContainerStartFlagsAttach); + VERIFY_SUCCEEDED(container.Get().Stop(WSLCSignalSIGKILL, 0)); + + // Delete the image so we get an ImageDeleted notification before the session goes away. + WSLCDeleteImageOptions options{.Image = "debian:latest", .Flags = WSLCDeleteImageFlagsForce}; + wil::unique_cotaskmem_array_ptr deletedImages; + VERIFY_SUCCEEDED(session->DeleteImage(&options, deletedImages.addressof(), deletedImages.size_address())); + } + + const auto ExpectedOutput = std::format( + LR"(Plugin loaded. TestMode=18 + WSLC Session created, name=plugin-wslc-test, id=*, pid=*, token=set, sid=set + Command: 'echo -n stdout-ok && echo -n stderr-ok >&2', status=0, stdout: stdout-ok, stderr: stderr-ok + Command: 'cat', status=0, stdout: stdin-ok, stderr: + Command: 'exit 12', status=12, stdout: , stderr: + Command: 'echo -n $ENV', status=0, stdout: env-ok, stderr: + WSLCCreateProcess(does-not-exist): {:x}, errno=2 + WSLCProcessGetFd(999): {} + WSLCProcessGetExitCode(): {} + WSLC RW folder mounted at: /mnt/wsl-plugin/plugin-rw-test + Command: 'cat /mnt/wsl-plugin/plugin-rw-test/plugin-test.txt', status=0, stdout: Windows-content, stderr: + WSLC RO folder mounted at: /mnt/wsl-plugin/plugin-ro-test + Command: 'echo fail > /mnt/wsl-plugin/plugin-ro-test/should-not-exist.txt', status=1, stdout: , stderr: * + WSLCMountFolder(nonexistent): {} + WSLCMountFolder(../escape): {} + WSLCMountFolder(): {} + Test completed + WSLC Container started, session=*, id=*, name=wslc-plugin-container, image=debian:latest, state=* + WSLC Container stopping, session=*, id=* + WSLC Image deleted, session=*, id=* + WSLC Session stopping, name=plugin-wslc-test, id=*)", + static_cast(E_FAIL), + E_INVALIDARG, + HRESULT_FROM_WIN32(ERROR_INVALID_STATE), + HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND), + E_INVALIDARG, + E_INVALIDARG); + + ValidateLogFile(ExpectedOutput.c_str()); + } + + WSL2_TEST_METHOD(WslcPullImageNotification) + { + ConfigurePlugin(PluginTestType::WslcImagePull); + + { + auto session = CreateWslcSession(L"plugin-wslc-pull-test", WSLCNetworkingModeVirtioProxy); + + // Load the registry and debian images. + LoadTestImage(*session, "debian:latest"); + + // Start a local registry container. + auto [registryContainer, registryAddress] = StartLocalRegistry(*session); + + // Tag debian:latest for the local registry and push it. + auto registryImage = std::format("{}/debian:latest", registryAddress); + auto registryRepo = std::format("{}/debian", registryAddress); + WSLCTagImageOptions tagOptions{}; + tagOptions.Image = "debian:latest"; + tagOptions.Repo = registryRepo.c_str(); + tagOptions.Tag = "latest"; + VERIFY_SUCCEEDED(session->TagImage(&tagOptions)); + + auto emptyAuth = wsl::windows::common::wslutil::BuildRegistryAuthHeader("", ""); + VERIFY_SUCCEEDED(session->PushImage(registryImage.c_str(), emptyAuth.c_str(), nullptr)); + + // Delete the local tagged copy so PullImage actually downloads it. + WSLCDeleteImageOptions deleteOpts{.Image = registryImage.c_str(), .Flags = WSLCDeleteImageFlagsNone}; + wil::unique_cotaskmem_array_ptr deletedImages; + VERIFY_SUCCEEDED(session->DeleteImage(&deleteOpts, deletedImages.addressof(), deletedImages.size_address())); + + // Pull the image back — this should trigger the ImageCreated plugin callback. + VERIFY_SUCCEEDED(session->PullImage(registryImage.c_str(), nullptr, nullptr)); + } + + constexpr auto ExpectedOutput = + LR"(Plugin loaded. TestMode=21 + WSLC Session created, name=plugin-wslc-pull-test, id=*, pid=*, token=set, sid=set + WSLC Container started, session=*, id=*, name=*, image=wslc-registry:latest, state=running + WSLC Image created, session=*, id=sha256:*, name=127.0.0.1:5000/debian:latest + WSLC Session stopping, name=plugin-wslc-pull-test, id=*)"; + + ValidateLogFile(ExpectedOutput); + } + + WSL2_TEST_METHOD(WslcSessionRejected) + { + ConfigurePlugin(PluginTestType::WslcSessionRejected); + + WSLCSessionSettings settings{}; + settings.DisplayName = L"plugin-wslc-rejected"; + settings.CpuCount = 4; + settings.MemoryMb = 2048; + settings.BootTimeoutMs = 30 * 1000; + settings.MaximumStorageSizeMb = 1024 * 20; + settings.NetworkingMode = WSLCNetworkingModeNone; + + auto manager = OpenWslcSessionManager(); + wil::com_ptr session; + const auto hr = manager->CreateSession(&settings, WSLCSessionFlagsNone, &session); + ValidateCOMErrorMessageContains(L"A fatal error was returned by plugin 'TestPlugin'"); + VERIFY_ARE_EQUAL(hr, HRESULT_FROM_WIN32(ERROR_ACCESS_DENIED)); + + constexpr auto ExpectedOutput = + LR"(Plugin loaded. TestMode=19 + WSLC Session created, name=plugin-wslc-rejected, id=*, pid=*, token=set, sid=set + OnWslcSessionCreated: ERROR_ACCESS_DENIED)"; + + ValidateLogFile(ExpectedOutput); + } + + WSL2_TEST_METHOD(WslcContainerRejected) + { + ConfigurePlugin(PluginTestType::WslcContainerRejected); + + { + auto session = CreateWslcSession(L"plugin-wslc-container-rejected"); + + LoadTestImage(*session, "debian:latest"); + + wsl::windows::common::WSLCContainerLauncher launcher( + "debian:latest", "wslc-plugin-rejected-container", {"/bin/sh", "-c", "echo nope"}); + + auto [hr, container] = launcher.LaunchNoThrow(*session, WSLCContainerStartFlagsAttach); + ValidateCOMErrorMessageContains(L"A fatal error was returned by plugin 'TestPlugin'"); + VERIFY_ARE_EQUAL(hr, HRESULT_FROM_WIN32(ERROR_ACCESS_DENIED)); + } + + constexpr auto ExpectedOutput = + LR"(Plugin loaded. TestMode=20 + WSLC Session created, name=plugin-wslc-container-rejected, id=*, pid=*, token=set, sid=set + WSLC Container started, session=*, id=*, name=*, image=debian:latest, state=* + OnWslcContainerStarted: ERROR_ACCESS_DENIED + WSLC Session stopping, name=plugin-wslc-container-rejected, id=*)"; + + ValidateLogFile(ExpectedOutput); + } + // This test must run last so it doesn't break test cases that depends on plugin signature. WSL2_TEST_METHOD(InvalidPluginSignature) { diff --git a/test/windows/PluginTests.h b/test/windows/PluginTests.h index d29d84e39..c9a779031 100644 --- a/test/windows/PluginTests.h +++ b/test/windows/PluginTests.h @@ -37,7 +37,11 @@ enum class PluginTestType InitPidIsDifferent, FailToRegisterUnregisterDistro, RunDistroCommand, - GetUsername + GetUsername, + WslcSuccess, + WslcSessionRejected, + WslcContainerRejected, + WslcImagePull }; constexpr auto c_testType = L"TestType"; @@ -46,4 +50,4 @@ constexpr auto c_logFile = L"LogFile"; inline wil::unique_hkey OpenTestRegistryKey(REGSAM AccessMask) { return wsl::windows::common::registry::CreateKey(HKEY_LOCAL_MACHINE, c_configKey, AccessMask, nullptr, REG_OPTION_VOLATILE); -} \ No newline at end of file +} diff --git a/test/windows/WSLCTests.cpp b/test/windows/WSLCTests.cpp index cf7511e57..029a20a80 100644 --- a/test/windows/WSLCTests.cpp +++ b/test/windows/WSLCTests.cpp @@ -20,6 +20,7 @@ Module Name: #include "WslCoreFilesystem.h" #include "hcs.hpp" #include "ContainerNameGenerator.h" +#include "wslc/e2e/WSLCE2EHelpers.h" #include using namespace std::literals::chrono_literals; @@ -31,6 +32,7 @@ using wsl::windows::common::WSLCProcessLauncher; using wsl::windows::common::io::OverlappedIOHandle; using wsl::windows::common::io::WriteHandle; using namespace wsl::windows::common::wslutil; +using WSLCE2ETests::StartLocalRegistry; extern std::wstring g_testDataPath; extern bool g_fastTestRun; @@ -45,20 +47,6 @@ class WSLCTests wil::com_ptr m_defaultSession; static inline auto c_testSessionName = L"wslc-test"; - void LoadTestImage(std::string_view imageName, IWSLCSession* session = nullptr) - { - std::filesystem::path imagePath = GetTestImagePath(imageName); - wil::unique_hfile imageFile{ - CreateFileW(imagePath.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr)}; - THROW_LAST_ERROR_IF(!imageFile); - - LARGE_INTEGER fileSize{}; - THROW_LAST_ERROR_IF(!GetFileSizeEx(imageFile.get(), &fileSize)); - - THROW_IF_FAILED( - (session ? session : m_defaultSession.get())->LoadImage(ToCOMInputHandle(imageFile.get()), nullptr, fileSize.QuadPart)); - } - TEST_CLASS_SETUP(TestClassSetup) { THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &m_wsadata)); @@ -78,27 +66,27 @@ class WSLCTests if (!hasImage("debian:latest")) { - LoadTestImage("debian:latest"); + LoadTestImage(*m_defaultSession, "debian:latest"); } if (!hasImage("python:3.12-alpine")) { - LoadTestImage("python:3.12-alpine"); + LoadTestImage(*m_defaultSession, "python:3.12-alpine"); } if (!hasImage("hello-world:latest")) { - LoadTestImage("hello-world:latest"); + LoadTestImage(*m_defaultSession, "hello-world:latest"); } if (!hasImage("alpine:latest")) { - LoadTestImage("alpine:latest"); + LoadTestImage(*m_defaultSession, "alpine:latest"); } if (!hasImage("wslc-registry:latest")) { - LoadTestImage("wslc-registry:latest"); + LoadTestImage(*m_defaultSession, "wslc-registry:latest"); } PruneResult result; @@ -207,28 +195,6 @@ class WSLCTests return result; } - std::pair StartLocalRegistry(const std::string& username = {}, const std::string& password = {}, USHORT port = 5000) - { - std::vector env = {std::format("REGISTRY_HTTP_ADDR=0.0.0.0:{}", port)}; - if (!username.empty()) - { - env.push_back(std::format("USERNAME={}", username)); - env.push_back(std::format("PASSWORD={}", password)); - } - - WSLCContainerLauncher launcher("wslc-registry:latest", {}, {}, env); - launcher.SetEntrypoint({"/entrypoint.sh"}); - launcher.AddPort(port, port, AF_INET); - - auto container = launcher.Launch(*m_defaultSession, WSLCContainerStartFlagsNone); - - auto registryAddress = std::format("127.0.0.1:{}", port); - auto registryUrl = std::format(L"http://{}", registryAddress); - ExpectHttpResponse(registryUrl.c_str(), 200, true); - - return {std::move(container), std::move(registryAddress)}; - } - std::string PushImageToRegistry(const std::string& imageName, const std::string& registryAddress, const std::string& registryAuth) { auto [repo, tag] = ParseImage(imageName); @@ -394,7 +360,7 @@ class WSLCTests // Act: list sessions { - wil::unique_cotaskmem_array_ptr sessions; + wil::unique_cotaskmem_array_ptr sessions; VERIFY_SUCCEEDED(sessionManager->ListSessions(&sessions, sessions.size_address())); // Assert @@ -409,7 +375,7 @@ class WSLCTests { auto session2 = CreateSession(GetDefaultSessionSettings(L"wslc-test-list-2")); - wil::unique_cotaskmem_array_ptr sessions; + wil::unique_cotaskmem_array_ptr sessions; VERIFY_SUCCEEDED(sessionManager->ListSessions(&sessions, sessions.size_address())); VERIFY_ARE_EQUAL(sessions.size(), 2); @@ -455,7 +421,7 @@ class WSLCTests // Reject DisplayName at exact boundary (no room for null terminator). { - std::wstring boundaryName(std::size(WSLCSessionInformation{}.DisplayName), L'x'); + std::wstring boundaryName(std::size(WSLCSessionListEntry{}.DisplayName), L'x'); auto settings = GetDefaultSessionSettings(boundaryName.c_str()); wil::com_ptr session; VERIFY_ARE_EQUAL(sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session), WSLC_E_INVALID_SESSION_NAME); @@ -463,7 +429,7 @@ class WSLCTests // Reject too long DisplayName. { - std::wstring longName(std::size(WSLCSessionInformation{}.DisplayName) + 1, L'x'); + std::wstring longName(std::size(WSLCSessionListEntry{}.DisplayName) + 1, L'x'); auto settings = GetDefaultSessionSettings(longName.c_str()); wil::com_ptr session; VERIFY_ARE_EQUAL(sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session), WSLC_E_INVALID_SESSION_NAME); @@ -587,7 +553,7 @@ class WSLCTests { { // Start a local registry without auth and push hello-world:latest to it. - auto [registryContainer, registryAddress] = StartLocalRegistry(); + auto [registryContainer, registryAddress] = StartLocalRegistry(*m_defaultSession); auto image = PushImageToRegistry("hello-world:latest", registryAddress, BuildRegistryAuthHeader("", "")); ExpectImagePresent(*m_defaultSession, image.c_str(), false); @@ -630,7 +596,7 @@ class WSLCTests WSLC_TEST_METHOD(PullImageAdvanced) { // Start a local registry without auth to avoid Docker Hub rate limits. - auto [registryContainer, registryAddress] = StartLocalRegistry(); + auto [registryContainer, registryAddress] = StartLocalRegistry(*m_defaultSession); auto auth = BuildRegistryAuthHeader("", ""); auto validatePull = [&](const std::string& sourceImage) { @@ -740,7 +706,7 @@ class WSLCTests constexpr auto c_username = "wslctest"; constexpr auto c_password = "password"; - auto [registryContainer, registryAddress] = StartLocalRegistry(c_username, c_password); + auto [registryContainer, registryAddress] = StartLocalRegistry(*m_defaultSession, c_username, c_password); wil::unique_cotaskmem_ansistring token; VERIFY_ARE_EQUAL(m_defaultSession->Authenticate(registryAddress.c_str(), c_username, "wrong-password", &token), E_FAIL); @@ -1017,7 +983,7 @@ class WSLCTests LogInfo("Test: Dangling filter"); { // Setup a dangling image - LoadTestImage("alpine:latest"); + LoadTestImage(*m_defaultSession, "alpine:latest"); WSLCTagImageOptions tagOptions{}; tagOptions.Image = "debian:latest"; tagOptions.Repo = "alpine"; @@ -1301,7 +1267,7 @@ class WSLCTests WSLC_TEST_METHOD(DeleteImage) { // Prepare alpine image to delete. - LoadTestImage("alpine:latest"); + LoadTestImage(*m_defaultSession, "alpine:latest"); // Verify that the image is in the list of images. ExpectImagePresent(*m_defaultSession, "alpine:latest"); @@ -1338,55 +1304,6 @@ class WSLCTests } } - void ValidateCOMErrorMessage(const std::optional& Expected, const std::source_location& Source = std::source_location::current()) - { - auto comError = wsl::windows::common::wslutil::GetCOMErrorInfo(); - - if (comError.has_value()) - { - if (!Expected.has_value()) - { - LogError("Unexpected COM error: '%ls'. Source: %hs", comError->Message.get(), std::format("{}", Source).c_str()); - VERIFY_FAIL(); - } - - VERIFY_ARE_EQUAL(Expected.value(), comError->Message.get()); - } - else - { - if (Expected.has_value()) - { - LogError("Expected COM error: '%ls' but none was set. Source: %hs", Expected->c_str(), std::format("{}", Source).c_str()); - VERIFY_FAIL(); - } - } - } - - void ValidateCOMErrorMessageContains(const std::wstring& ExpectedSubstring) - { - auto comError = wsl::windows::common::wslutil::GetCOMErrorInfo(); - - if (comError.has_value()) - { - if (!comError->Message) - { - LogError("Expected COM error containing: '%ls', but COM error message was null", ExpectedSubstring.c_str()); - VERIFY_FAIL(); - } - - if (wcsstr(comError->Message.get(), ExpectedSubstring.c_str()) == nullptr) - { - LogError("Expected COM error containing: '%ls', but got: '%ls'", ExpectedSubstring.c_str(), comError->Message.get()); - VERIFY_FAIL(); - } - } - else - { - LogError("Expected COM error containing: '%ls' but none was set", ExpectedSubstring.c_str()); - VERIFY_FAIL(); - } - } - class CapturingProgressCallback : public Microsoft::WRL::RuntimeClass, IProgressCallback> { @@ -7701,7 +7618,7 @@ class WSLCTests auto manager = OpenSessionManager(); auto expectSessions = [&](const std::vector& expectedSessions) { - wil::unique_cotaskmem_array_ptr sessions; + wil::unique_cotaskmem_array_ptr sessions; VERIFY_SUCCEEDED(manager->ListSessions(&sessions, sessions.size_address())); std::set displayNames; @@ -9104,12 +9021,12 @@ class WSLCTests // Helper to create a dangling image using only test-local tags: // Load alpine and hello-world under unique tags, then overwrite one with the other. auto createDanglingImage = [this]() { - LoadTestImage("alpine:latest"); + LoadTestImage(*m_defaultSession, "alpine:latest"); WSLCTagImageOptions tagA{.Image = "alpine:latest", .Repo = "prune-test-a", .Tag = "v1"}; VERIFY_SUCCEEDED(m_defaultSession->TagImage(&tagA)); DeleteImage("alpine:latest", WSLCDeleteImageFlagsNone); - LoadTestImage("hello-world:latest"); + LoadTestImage(*m_defaultSession, "hello-world:latest"); WSLCTagImageOptions tagB{.Image = "hello-world:latest", .Repo = "prune-test-b", .Tag = "v1"}; VERIFY_SUCCEEDED(m_defaultSession->TagImage(&tagB)); DeleteImage("hello-world:latest", WSLCDeleteImageFlagsNone); @@ -9182,7 +9099,7 @@ class WSLCTests // Validate null Options uses defaults (dangling-only prune). { - LoadTestImage("alpine:latest"); + LoadTestImage(*m_defaultSession, "alpine:latest"); WSLCTagImageOptions renameOptions{.Image = "alpine:latest", .Repo = "prune-test-a", .Tag = "v1"}; VERIFY_SUCCEEDED(m_defaultSession->TagImage(&renameOptions)); DeleteImage("alpine:latest", WSLCDeleteImageFlagsNone); @@ -9319,7 +9236,7 @@ class WSLCTests auto revert = wil::impersonate_token(nonElevatedToken.get()); nonElevatedSession = CreateSession(GetDefaultSessionSettings(L"non-elevated-session"), WSLCSessionFlagsNone); - LoadTestImage("debian:latest", nonElevatedSession.get()); + LoadTestImage(*nonElevatedSession, "debian:latest"); WSLCContainerLauncher launcher("debian:latest", "test-non-elevated-handles-1", {"echo", "OK"}); auto container = launcher.Launch(*nonElevatedSession); diff --git a/test/windows/testplugin/Plugin.cpp b/test/windows/testplugin/Plugin.cpp index b7cdc3815..af5fc5f9c 100644 --- a/test/windows/testplugin/Plugin.cpp +++ b/test/windows/testplugin/Plugin.cpp @@ -14,10 +14,14 @@ Module Name: #include "precomp.h" #include "WslPluginApi.h" +#include "wslc_schema.h" #include "PluginTests.h" using namespace wsl::windows::common::registry; +using namespace wsl::windows::common::relay; +using namespace wsl::shared::string; +using namespace std::chrono_literals; std::ofstream g_logfile; std::optional g_distroGuid; @@ -338,6 +342,215 @@ HRESULT OnDistributionUnregistered(const WSLSessionInformation* Session, const W return S_OK; } +HRESULT OnWslcSessionCreated(const WSLCSessionInformation* Session) +try +{ + g_logfile << "WSLC Session created, name=" << wsl::shared::string::WideToMultiByte(Session->DisplayName) << ", id=" << Session->SessionId + << ", pid=" << Session->ApplicationPid << ", token=" << (Session->UserToken != nullptr ? "set" : "null") + << ", sid=" << (Session->UserSid != nullptr ? "set" : "null") << std::endl; + + if (g_testType == PluginTestType::WslcSessionRejected) + { + g_logfile << "OnWslcSessionCreated: ERROR_ACCESS_DENIED" << std::endl; + return HRESULT_FROM_WIN32(ERROR_ACCESS_DENIED); + } + + if (g_testType == PluginTestType::WslcSuccess) + { + // Helper: run a command in the root namespace and return (status, stdout, stderr). + auto runCommand = [&](const char* cmd, + const std::optional& input = {}, + std::vector env = {}) -> std::tuple { + std::vector arguments = {"/bin/sh", "-c", cmd, nullptr}; + WSLCProcessHandle process = nullptr; + THROW_IF_FAILED(g_api->WSLCCreateProcess( + Session->SessionId, arguments[0], arguments.data(), env.empty() ? nullptr : env.data(), &process, nullptr)); + auto releaseProcess = wil::scope_exit([&]() { g_api->WSLCReleaseProcess(process); }); + + wil::unique_handle stdinHandle; + wil::unique_handle stdoutHandle; + wil::unique_handle stderrHandle; + wil::unique_handle exitEvent; + THROW_IF_FAILED(g_api->WSLCProcessGetFd(process, WSLCProcessFdStdin, &stdinHandle)); + THROW_IF_FAILED(g_api->WSLCProcessGetFd(process, WSLCProcessFdStdout, &stdoutHandle)); + THROW_IF_FAILED(g_api->WSLCProcessGetFd(process, WSLCProcessFdStderr, &stderrHandle)); + THROW_IF_FAILED(g_api->WSLCProcessGetExitEvent(process, &exitEvent)); + + std::string out; + std::string err; + + MultiHandleWait io; + io.AddHandle(std::make_unique( + std::move(stdoutHandle), [&out](const auto& span) { out.append(span.begin(), span.end()); })); + + io.AddHandle(std::make_unique( + std::move(stderrHandle), [&err](const auto& span) { err.append(span.begin(), span.end()); })); + + io.AddHandle(std::make_unique(std::move(exitEvent))); + + if (input.has_value()) + { + io.AddHandle(std::make_unique(std::move(stdinHandle), std::vector(input->begin(), input->end()))); + } + else + { + stdinHandle.reset(); + } + + io.Run(60000ms); + + int status = 0; + THROW_IF_FAILED(g_api->WSLCProcessGetExitCode(process, &status)); + g_logfile << "Command: '" << cmd << "', status=" << status << ", stdout: " << out << ", stderr: " << err << std::endl; + + return {status, out, err}; + }; + + // Test process creation (output & exit code validated by the test code). + { + runCommand("echo -n stdout-ok && echo -n stderr-ok >&2"); + runCommand("cat", "stdin-ok"); + runCommand("exit 12"); + runCommand("echo -n $ENV", {}, {"ENV=env-ok", nullptr}); + } + + // Validate that trying to execute a non-existent file fails with the expected error code. + { + WSLCProcessHandle process = nullptr; + int errnoValue = 0; + std::vector args = {"does-not-exist", nullptr}; + + auto hr = g_api->WSLCCreateProcess(Session->SessionId, args[0], args.data(), nullptr, &process, &errnoValue); + g_logfile << "WSLCCreateProcess(does-not-exist): " << std::hex << hr << ", errno=" << std::dec << errnoValue << std::endl; + } + + // Validate various error paths + { + std::vector args = {"/bin/sh", "-c", "sleep 9999", nullptr}; + WSLCProcessHandle process = nullptr; + THROW_IF_FAILED(g_api->WSLCCreateProcess(Session->SessionId, args[0], args.data(), nullptr, &process, nullptr)); + auto releaseProcess = wil::scope_exit([&]() { g_api->WSLCReleaseProcess(process); }); + + // Validate that getting an fd that doesn't exist fails with the expected error code. + HANDLE dummy = nullptr; + g_logfile << "WSLCProcessGetFd(999): " << g_api->WSLCProcessGetFd(process, static_cast(999), &dummy) << std::endl; + int exitCode = -1; + + g_logfile << "WSLCProcessGetExitCode(): " << g_api->WSLCProcessGetExitCode(process, &exitCode) << std::endl; + } + + const auto testFolder = L"C:\\"; + constexpr auto testFileName = L"plugin-test.txt"; + + // Validate rw mounts. + { + auto rwCleanup = wil::scope_exit_log( + WI_DIAGNOSTICS_INFO, [&]() { std::filesystem::remove(std::wstring(testFolder) + testFileName); }); + + { + std::ofstream file(std::wstring(testFolder) + testFileName); + file << "Windows-content"; + } + + // Mount read-write and verify the file can be read from Linux. + char rwMountpoint[WSLC_MOUNTPOINT_LENGTH] = {}; + THROW_IF_FAILED(g_api->WSLCMountFolder(Session->SessionId, testFolder, false, L"plugin-rw-test", rwMountpoint)); + + g_logfile << "WSLC RW folder mounted at: " << rwMountpoint << std::endl; + + auto readCmd = std::format("cat {}/{}", rwMountpoint, testFileName); + runCommand(readCmd.c_str()); + + THROW_IF_FAILED(g_api->WSLCUnmountFolder(Session->SessionId, rwMountpoint)); + } + + // Validate ro mounts. + { + char roMountpoint[WSLC_MOUNTPOINT_LENGTH] = {}; + THROW_IF_FAILED(g_api->WSLCMountFolder(Session->SessionId, L"C:\\", TRUE, L"plugin-ro-test", roMountpoint)); + + g_logfile << "WSLC RO folder mounted at: " << roMountpoint << std::endl; + + // Attempt to write from Linux — should fail on a read-only mount. + auto writeCmd = std::format("echo fail > {}/should-not-exist.txt", roMountpoint); + runCommand(writeCmd.c_str()); + + THROW_IF_FAILED(g_api->WSLCUnmountFolder(Session->SessionId, roMountpoint)); + } + + // Validate that trying to mount a folder that doesn't exist fails with the expected error code. + { + char mountpoint[WSLC_MOUNTPOINT_LENGTH] = {}; + g_logfile << "WSLCMountFolder(nonexistent): " + << g_api->WSLCMountFolder(Session->SessionId, L"C:\\nonexistent", TRUE, L"plugin-ro-test", mountpoint) << std::endl; + } + + // Validate that trying to escape the /mnt folder fails. + { + char mountpoint[WSLC_MOUNTPOINT_LENGTH] = {}; + g_logfile << "WSLCMountFolder(../escape): " << g_api->WSLCMountFolder(Session->SessionId, L"C:\\", TRUE, L"../escape", mountpoint) + << std::endl; + } + + // Validate that empty names are rejected. + { + char mountpoint[WSLC_MOUNTPOINT_LENGTH] = {}; + g_logfile << "WSLCMountFolder(): " << g_api->WSLCMountFolder(Session->SessionId, L"C:\\", TRUE, L"", mountpoint) << std::endl; + } + + g_logfile << "Test completed" << std::endl; + } + + return S_OK; +} +CATCH_RETURN(); + +HRESULT OnWslcSessionStopping(const WSLCSessionInformation* Session) +{ + g_logfile << "WSLC Session stopping, name=" << wsl::shared::string::WideToMultiByte(Session->DisplayName) + << ", id=" << Session->SessionId << std::endl; + + return S_OK; +} + +HRESULT OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson) +try +{ + auto container = wsl::shared::FromJson(InspectJson); + + g_logfile << "WSLC Container started, session=" << Session->SessionId << ", id=" << container.Id + << ", name=" << container.Name << ", image=" << container.Image << ", state=" << container.State.Status << std::endl; + + if (g_testType == PluginTestType::WslcContainerRejected) + { + g_logfile << "OnWslcContainerStarted: ERROR_ACCESS_DENIED" << std::endl; + return HRESULT_FROM_WIN32(ERROR_ACCESS_DENIED); + } + + return S_OK; +} +CATCH_RETURN(); + +HRESULT OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId) +{ + g_logfile << "WSLC Container stopping, session=" << Session->SessionId << ", id=" << ContainerId << std::endl; + return S_OK; +} + +HRESULT OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson) +{ + auto image = wsl::shared::FromJson(InspectJson); + auto name = (image.RepoTags.has_value() && !image.RepoTags->empty()) ? image.RepoTags->front() : ""; + g_logfile << "WSLC Image created, session=" << Session->SessionId << ", id=" << image.Id << ", name=" << name << std::endl; + return S_OK; +} + +HRESULT OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId) +{ + g_logfile << "WSLC Image deleted, session=" << Session->SessionId << ", id=" << ImageId << std::endl; + return S_OK; +} + EXTERN_C __declspec(dllexport) HRESULT WSLPLUGINAPI_ENTRYPOINTV1(const WSLPluginAPIV1* Api, WSLPluginHooksV1* Hooks) { try @@ -349,7 +562,7 @@ EXTERN_C __declspec(dllexport) HRESULT WSLPLUGINAPI_ENTRYPOINTV1(const WSLPlugin THROW_HR_IF(E_UNEXPECTED, !g_logfile); g_testType = static_cast(ReadDword(key.get(), nullptr, c_testType, static_cast(PluginTestType::Invalid))); - THROW_HR_IF(E_INVALIDARG, static_cast(g_testType) <= 0 || static_cast(g_testType) > static_cast(PluginTestType::GetUsername)); + THROW_HR_IF(E_INVALIDARG, static_cast(g_testType) <= 0 || static_cast(g_testType) > static_cast(PluginTestType::WslcImagePull)); g_logfile << "Plugin loaded. TestMode=" << static_cast(g_testType) << std::endl; g_api = Api; @@ -359,6 +572,12 @@ EXTERN_C __declspec(dllexport) HRESULT WSLPLUGINAPI_ENTRYPOINTV1(const WSLPlugin Hooks->OnDistributionStopping = &OnDistroStopping; Hooks->OnDistributionRegistered = &OnDistributionRegistered; Hooks->OnDistributionUnregistered = &OnDistributionUnregistered; + Hooks->OnSessionCreated = &OnWslcSessionCreated; + Hooks->OnSessionStopping = &OnWslcSessionStopping; + Hooks->ContainerStarted = &OnWslcContainerStarted; + Hooks->ContainerStopping = &OnWslcContainerStopping; + Hooks->ImageCreated = &OnWslcImageCreated; + Hooks->ImageDeleted = &OnWslcImageDeleted; if (g_testType == PluginTestType::FailToLoad) { @@ -383,4 +602,4 @@ EXTERN_C __declspec(dllexport) HRESULT WSLPLUGINAPI_ENTRYPOINTV1(const WSLPlugin return error; } return S_OK; -} \ No newline at end of file +} diff --git a/test/windows/wslc/e2e/WSLCE2EHelpers.cpp b/test/windows/wslc/e2e/WSLCE2EHelpers.cpp index dfb844cac..efa53fb4d 100644 --- a/test/windows/wslc/e2e/WSLCE2EHelpers.cpp +++ b/test/windows/wslc/e2e/WSLCE2EHelpers.cpp @@ -440,7 +440,17 @@ wil::com_ptr OpenDefaultElevatedSession() std::pair StartLocalRegistry(IWSLCSession& session, const std::string& username, const std::string& password, USHORT port) { - EnsureImageIsLoaded({L"wslc-registry", L"latest", GetTestImagePath("wslc-registry:latest")}); + // Check if the registry image is already loaded on this session. + wil::unique_cotaskmem_array_ptr images; + THROW_IF_FAILED(session.ListImages(nullptr, &images, images.size_address())); + + bool found = std::ranges::any_of( + std::span{images.get(), images.size()}, [](const auto& e) { return std::strcmp(e.Image, "wslc-registry:latest") == 0; }); + + if (!found) + { + LoadTestImage(session, "wslc-registry:latest"); + } std::vector env = {std::format("REGISTRY_HTTP_ADDR=0.0.0.0:{}", port)};