diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in
index 6dfec147d..439779e41 100644
--- a/msipackage/package.wix.in
+++ b/msipackage/package.wix.in
@@ -353,6 +353,14 @@
+
+
+
+
+
+
+
+
diff --git a/src/windows/inc/WslPluginApi.h b/src/windows/inc/WslPluginApi.h
index 86372d326..c969c080b 100644
--- a/src/windows/inc/WslPluginApi.h
+++ b/src/windows/inc/WslPluginApi.h
@@ -26,6 +26,9 @@ extern "C" {
#define WSLPLUGINAPI_ENTRYPOINTV1 WSLPluginAPIV1_EntryPoint
#define WSL_E_PLUGIN_REQUIRES_UPDATE MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, 0x032A)
+// Maximum size for mount points returned by WSLCPluginAPI_MountFolder. This includes the null terminator.
+#define WSLC_MOUNTPOINT_LENGTH 256
+
#define WSL_PLUGIN_REQUIRE_VERSION(_Major, _Minor, _Revision, Api) \
if (Api->Version.Major < (_Major) || (Api->Version.Major == (_Major) && Api->Version.Minor < (_Minor)) || \
(Api->Version.Major == (_Major) && Api->Version.Minor == (_Minor) && Api->Version.Revision < (_Revision))) \
@@ -85,6 +88,30 @@ struct WslOfflineDistributionInformation
LPCWSTR Version; // Distribution version. Introduced in 2.4.4
};
+// Identifies a WSLC session inside the WSLC plugin API. Distinct from WSLSessionId.
+typedef DWORD WSLCSessionId;
+
+// Information about a WSLC session passed to plugin notifications.
+struct WSLCSessionInformation
+{
+ WSLCSessionId SessionId;
+ LPCWSTR DisplayName;
+ DWORD ApplicationPid;
+ HANDLE UserToken;
+ PSID UserSid;
+};
+
+// Opaque handle to a WSLC process created via WSLCPluginAPI_CreateProcess.
+// Must be released with WSLCPluginAPI_ReleaseProcess.
+typedef void* WSLCProcessHandle;
+
+typedef enum _WSLCProcessFd
+{
+ WSLCProcessFdStdin = 0,
+ WSLCProcessFdStdout = 1,
+ WSLCProcessFdStderr = 2
+} WSLCProcessFd;
+
// Create plan9 mount between Windows & Linux
typedef HRESULT (*WSLPluginAPI_MountFolder)(WSLSessionId Session, LPCWSTR WindowsPath, LPCWSTR LinuxPath, BOOL ReadOnly, LPCWSTR Name);
@@ -92,6 +119,63 @@ typedef HRESULT (*WSLPluginAPI_MountFolder)(WSLSessionId Session, LPCWSTR Window
// On success, 'Socket' is connected to stdin & stdout (stderr goes to dmesg) // 'Arguments' is expected to be NULL terminated
typedef HRESULT (*WSLPluginAPI_ExecuteBinary)(WSLSessionId Session, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket);
+//
+// WSLC plugin hooks.
+//
+
+// Called when a WSLC session is created. Returning an error prevents the session creation.
+typedef HRESULT (*WSLPluginAPI_OnSessionCreated)(const struct WSLCSessionInformation* Session);
+
+// Called when a WSLC session is about to stop. Errors are ignored.
+typedef HRESULT (*WSLPluginAPI_OnSessionStopping)(const struct WSLCSessionInformation* Session);
+
+// Called when a container starts. Returning an error prevents the container creation.
+// 'InspectContainer' is a JSON document that follows the wslc_schema::InspectContainer format.
+typedef HRESULT (*WSLPluginAPI_ContainerStarted)(const struct WSLCSessionInformation* Session, LPCSTR InspectContainer);
+
+// Called when a container is about to stop. 'ContainerId' is the container identifier. Errors are ignored.
+typedef HRESULT (*WSLPluginAPI_ContainerStopping)(const struct WSLCSessionInformation* Session, LPCSTR ContainerId);
+
+// Called when an image is created (either pulled, or imported). Errors are ignored.
+// 'InspectImage' is a JSON document that follows the wslc_schema::InspectImage format.
+// N.B. This callback is currently only invoked when images are pulled or imported. Images created via load or build are not reported.
+typedef HRESULT (*WSLPluginAPI_ImageCreated)(const struct WSLCSessionInformation* Session, LPCSTR InspectImage);
+
+// Called when an image is deleted. 'ImageId' is the deleted image identifier. Errors are ignored.
+typedef HRESULT (*WSLPluginAPI_ImageDeleted)(const struct WSLCSessionInformation* Session, LPCSTR ImageId);
+
+//
+// WSLC plugin API calls.
+//
+
+// Mount a Windows folder into the WSLC session VM. The mount path is returned via 'Mountpoint'.
+// 'Mountpoint' must point to a buffer of at least WSLC_MOUNTPOINT_LENGTH chars, including the null terminator.
+typedef HRESULT (*WSLCPluginAPI_MountFolder)(WSLCSessionId Session, LPCWSTR WindowsPath, BOOL ReadOnly, LPCWSTR Name, LPSTR Mountpoint);
+
+// Unmount a folder previously mounted via WSLCPluginAPI_MountFolder.
+typedef HRESULT (*WSLCPluginAPI_UnmountFolder)(WSLCSessionId Session, LPCSTR Mountpoint);
+
+// Create a process in the WSLC session's root namespace.
+// 'Arguments' and 'Env' are NULL-terminated arrays. 'Env' may be NULL.
+// 'Errno' is optional and receives the errno value if the process creation fails.
+// On success, 'Process' receives an opaque handle that must be released with WSLCPluginAPI_ReleaseProcess.
+typedef HRESULT (*WSLCPluginAPI_CreateProcess)(
+ WSLCSessionId Session, LPCSTR Executable, LPCSTR* Arguments, LPCSTR* Env, WSLCProcessHandle* Process, int* Errno);
+
+// Get a stdio handle from a WSLC process. The caller takes ownership and must close it with CloseHandle().
+typedef HRESULT (*WSLCPluginAPI_ProcessGetFd)(WSLCProcessHandle Process, WSLCProcessFd Fd, HANDLE* Handle);
+
+// Get the exit event for a WSLC process. Signaled when the process exits.
+// The caller takes ownership and must close it with CloseHandle().
+typedef HRESULT (*WSLCPluginAPI_ProcessGetExitEvent)(WSLCProcessHandle Process, HANDLE* ExitEvent);
+
+// Get the exit code of a WSLC process. The process must have exited.
+typedef HRESULT (*WSLCPluginAPI_ProcessGetExitCode)(WSLCProcessHandle Process, int* ExitCode);
+
+// Release a WSLC process handle. All outstanding handles obtained via
+// WSLCPluginAPI_ProcessGetFd/GetExitEvent must be closed before calling this.
+typedef void (*WSLCPluginAPI_ReleaseProcess)(WSLCProcessHandle Process);
+
// Execute a program in a user distribution
// On success, 'Socket' is connected to stdin & stdout (stderr goes to dmesg) // 'Arguments' is expected to be NULL terminated
typedef HRESULT (*WSLPluginAPI_ExecuteBinaryInDistribution)(WSLSessionId Session, const GUID* Distribution, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket);
@@ -132,6 +216,14 @@ struct WSLPluginHooksV1
WSLPluginAPI_OnDistributionStopping OnDistributionStopping;
WSLPluginAPI_OnDistributionRegistered OnDistributionRegistered; // Introduced in 2.1.2
WSLPluginAPI_OnDistributionRegistered OnDistributionUnregistered; // Introduced in 2.1.2
+
+ // WSLC hooks. Plugins compiled against older headers leave these zero-initialized.
+ WSLPluginAPI_OnSessionCreated OnSessionCreated;
+ WSLPluginAPI_OnSessionStopping OnSessionStopping;
+ WSLPluginAPI_ContainerStarted ContainerStarted;
+ WSLPluginAPI_ContainerStopping ContainerStopping;
+ WSLPluginAPI_ImageCreated ImageCreated;
+ WSLPluginAPI_ImageDeleted ImageDeleted;
};
struct WSLPluginAPIV1
@@ -141,10 +233,19 @@ struct WSLPluginAPIV1
WSLPluginAPI_ExecuteBinary ExecuteBinary;
WSLPluginAPI_PluginError PluginError;
WSLPluginAPI_ExecuteBinaryInDistribution ExecuteBinaryInDistribution; // Introduced in 2.1.2
+
+ // WSLC API calls.
+ WSLCPluginAPI_MountFolder WSLCMountFolder; // Introduced in 2.9.0
+ WSLCPluginAPI_UnmountFolder WSLCUnmountFolder; // Introduced in 2.9.0
+ WSLCPluginAPI_CreateProcess WSLCCreateProcess; // Introduced in 2.9.0
+ WSLCPluginAPI_ProcessGetFd WSLCProcessGetFd; // Introduced in 2.9.0
+ WSLCPluginAPI_ProcessGetExitEvent WSLCProcessGetExitEvent; // Introduced in 2.9.0
+ WSLCPluginAPI_ProcessGetExitCode WSLCProcessGetExitCode; // Introduced in 2.9.0
+ WSLCPluginAPI_ReleaseProcess WSLCReleaseProcess; // Introduced in 2.9.0
};
typedef HRESULT (*WSLPluginAPI_EntryPointV1)(const struct WSLPluginAPIV1* Api, struct WSLPluginHooksV1* Hooks);
#ifdef __cplusplus
}
-#endif
\ No newline at end of file
+#endif
diff --git a/src/windows/service/exe/CMakeLists.txt b/src/windows/service/exe/CMakeLists.txt
index 7482eacf9..c7b4ebb0d 100644
--- a/src/windows/service/exe/CMakeLists.txt
+++ b/src/windows/service/exe/CMakeLists.txt
@@ -23,6 +23,7 @@ set(SOURCES
HcsVirtualMachine.cpp
WSLCSessionManager.cpp
WSLCSessionManagerFactory.cpp
+ WSLCPluginNotifier.cpp
main.rc
${CMAKE_CURRENT_BINARY_DIR}/../mc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wsleventschema.rc
application.manifest)
@@ -53,7 +54,8 @@ set(HEADERS
WslCoreVm.h
HcsVirtualMachine.h
WSLCSessionManager.h
- WSLCSessionManagerFactory.h)
+ WSLCSessionManagerFactory.h
+ WSLCPluginNotifier.h)
add_executable(wslservice ${SOURCES} ${HEADERS})
add_dependencies(wslservice wslserviceidl wslservicemc)
diff --git a/src/windows/service/exe/LxssUserSessionFactory.cpp b/src/windows/service/exe/LxssUserSessionFactory.cpp
index ae40cec10..33e1fe7dc 100644
--- a/src/windows/service/exe/LxssUserSessionFactory.cpp
+++ b/src/windows/service/exe/LxssUserSessionFactory.cpp
@@ -30,7 +30,7 @@ srwlock g_sessionLock;
std::optional>> g_sessions =
std::make_optional>>();
-std::optional g_pluginManager;
+extern wsl::windows::service::PluginManager g_pluginManager;
extern unique_event g_networkingReady;
extern bool g_lxcoreInitialized;
@@ -53,9 +53,6 @@ void ClearSessionsAndBlockNewInstancesLockHeld(std::optional>>();
}
-
- if (!g_pluginManager.has_value())
- {
- g_pluginManager.emplace();
- g_pluginManager->LoadPlugins();
- }
}
else
{
@@ -236,7 +227,7 @@ std::weak_ptr CreateInstanceForCurrentUser()
if (!userSession)
{
- userSession.reset(new LxssUserSessionImpl(tokenInfo->User.Sid, sessionId, *g_pluginManager));
+ userSession.reset(new LxssUserSessionImpl(tokenInfo->User.Sid, sessionId, g_pluginManager));
g_sessions->emplace_back(userSession);
}
}
diff --git a/src/windows/service/exe/PluginManager.cpp b/src/windows/service/exe/PluginManager.cpp
index e4d23f226..38c1bfbdf 100644
--- a/src/windows/service/exe/PluginManager.cpp
+++ b/src/windows/service/exe/PluginManager.cpp
@@ -17,6 +17,7 @@ Module Name:
#include "PluginManager.h"
#include "WslPluginApi.h"
#include "LxssUserSessionFactory.h"
+#include "WSLCSessionManager.h"
using wsl::windows::common::Context;
using wsl::windows::common::ExecutionContext;
@@ -99,7 +100,264 @@ try
CATCH_RETURN();
}
-static constexpr WSLPluginAPIV1 ApiV1 = {Version, &MountFolder, &ExecuteBinary, &PluginError, &ExecuteBinaryInDistribution};
+namespace {
+
+// Opaque wrapper around IWSLCProcess, handed out as WSLCProcessHandle to plugins.
+struct WslcProcessWrapper
+{
+ wil::com_ptr Process;
+};
+
+wil::com_ptr ResolveWslcSession(WSLCSessionId Session)
+{
+ auto* mgr = wsl::windows::service::wslc::WSLCSessionManagerImpl::Instance();
+ THROW_HR_IF(RPC_E_DISCONNECTED, mgr == nullptr);
+
+ return mgr->FindSession(static_cast(Session));
+}
+
+} // namespace
+
+extern "C" {
+
+HRESULT WSLCMountFolder(WSLCSessionId Session, LPCWSTR WindowsPath, BOOL ReadOnly, LPCWSTR Name, LPSTR Mountpoint)
+try
+{
+ RETURN_HR_IF(E_POINTER, WindowsPath == nullptr || Name == nullptr || Mountpoint == nullptr);
+ auto nameLength = wcslen(Name);
+
+ RETURN_HR_IF_MSG(
+ E_INVALIDARG,
+ nameLength == 0 ||
+ !std::ranges::all_of(Name, Name + nameLength, [&](auto c) { return c == '-' || c == '_' || iswalnum(c); }),
+ "Invalid mount name: %ls",
+ Name);
+
+ auto session = ResolveWslcSession(Session);
+
+ // Mount the folder under /mnt/wsl-plugin/. Convert Name to UTF-8 for the Linux path.
+ const auto linuxPath = std::format("/mnt/wsl-plugin/{}", Name);
+
+ THROW_HR_IF_MSG(E_INVALIDARG, linuxPath.length() >= WSLC_MOUNTPOINT_LENGTH, "Mountpoint too long: %hs", linuxPath.c_str());
+
+ auto result = session->MountWindowsFolder(WindowsPath, linuxPath.c_str(), ReadOnly);
+
+ WSL_LOG(
+ "WslcPluginMountFolderCall",
+ TraceLoggingValue(Session, "SessionId"),
+ TraceLoggingValue(WindowsPath, "WindowsPath"),
+ TraceLoggingValue(linuxPath.c_str(), "LinuxPath"),
+ TraceLoggingValue(ReadOnly, "ReadOnly"),
+ TraceLoggingValue(Name, "Name"),
+ TraceLoggingValue(result, "Result"));
+
+ if (SUCCEEDED(result))
+ {
+ THROW_HR_IF(E_UNEXPECTED, strcpy_s(Mountpoint, WSLC_MOUNTPOINT_LENGTH, linuxPath.c_str()) != 0);
+ }
+
+ return result;
+}
+CATCH_RETURN();
+
+HRESULT WSLCUnmountFolder(WSLCSessionId Session, LPCSTR Mountpoint)
+try
+{
+ RETURN_HR_IF(E_POINTER, Mountpoint == nullptr);
+
+ auto session = ResolveWslcSession(Session);
+
+ auto result = session->UnmountWindowsFolder(Mountpoint);
+
+ WSL_LOG(
+ "WslcPluginUnmountFolderCall",
+ TraceLoggingValue(Session, "SessionId"),
+ TraceLoggingValue(Mountpoint, "Mountpoint"),
+ TraceLoggingValue(result, "Result"));
+
+ return result;
+}
+CATCH_RETURN();
+
+HRESULT WSLCCreateProcess(WSLCSessionId Session, LPCSTR Executable, LPCSTR* Arguments, LPCSTR* Env, WSLCProcessHandle* Process, int* Errno)
+try
+{
+ RETURN_HR_IF(E_POINTER, Executable == nullptr || Process == nullptr);
+
+ *Process = nullptr;
+ if (Errno != nullptr)
+ {
+ *Errno = 0;
+ }
+
+ auto session = ResolveWslcSession(Session);
+
+ // Count NULL-terminated arrays.
+ auto countArray = [](LPCSTR* arr) -> ULONG {
+ if (arr == nullptr)
+ {
+ return 0;
+ }
+ ULONG count = 0;
+ while (arr[count] != nullptr)
+ {
+ ++count;
+ }
+ return count;
+ };
+
+ WSLCProcessOptions options{};
+ options.CommandLine.Values = Arguments;
+ options.CommandLine.Count = countArray(Arguments);
+ options.Environment.Values = Env;
+ options.Environment.Count = countArray(Env);
+ options.Flags = WSLCProcessFlagsStdin;
+
+ wil::com_ptr process;
+ int errnoValue = 0;
+ auto result = session->CreateRootNamespaceProcess(Executable, &options, &process, &errnoValue);
+
+ if (Errno != nullptr)
+ {
+ *Errno = errnoValue;
+ }
+
+ if (FAILED(result))
+ {
+ WSL_LOG(
+ "WslcPluginCreateProcessCall",
+ TraceLoggingValue(Session, "SessionId"),
+ TraceLoggingValue(Executable, "Executable"),
+ TraceLoggingValue(result, "Result"),
+ TraceLoggingValue(errnoValue, "Errno"));
+ return result;
+ }
+
+ auto wrapper = std::make_unique();
+ wrapper->Process = std::move(process);
+ *Process = wrapper.release();
+
+ WSL_LOG(
+ "WslcPluginCreateProcessCall",
+ TraceLoggingValue(Session, "SessionId"),
+ TraceLoggingValue(Executable, "Executable"),
+ TraceLoggingValue(*Process, "Process"),
+ TraceLoggingValue(S_OK, "Result"));
+
+ return S_OK;
+}
+CATCH_RETURN();
+
+HRESULT WSLCProcessGetFd(WSLCProcessHandle Process, WSLCProcessFd Fd, HANDLE* Handle)
+try
+{
+ RETURN_HR_IF(E_POINTER, Process == nullptr || Handle == nullptr);
+
+ *Handle = nullptr;
+
+ auto* wrapper = static_cast(Process);
+
+ WSLCFD wslcFd{};
+ switch (Fd)
+ {
+ case WSLCProcessFdStdin:
+ wslcFd = WSLCFDStdin;
+ break;
+ case WSLCProcessFdStdout:
+ wslcFd = WSLCFDStdout;
+ break;
+ case WSLCProcessFdStderr:
+ wslcFd = WSLCFDStderr;
+ break;
+ default:
+ WSL_LOG(
+ "WslcPluginProcessGetFd", TraceLoggingValue(static_cast(Fd), "Fd"), TraceLoggingValue(E_INVALIDARG, "Result"));
+ return E_INVALIDARG;
+ }
+
+ WSLCHandle handle{};
+ auto result = wrapper->Process->GetStdHandle(wslcFd, &handle);
+
+ WSL_LOG(
+ "WslcPluginProcessGetFd",
+ TraceLoggingValue(static_cast(Fd), "Fd"),
+ TraceLoggingValue(handle.Handle.Socket, "Handle"),
+ TraceLoggingValue(result, "Result"));
+
+ RETURN_IF_FAILED(result);
+ WI_ASSERT(handle.Type == WSLCHandleTypeSocket);
+
+ *Handle = handle.Handle.Socket;
+ return S_OK;
+}
+CATCH_RETURN();
+
+HRESULT WSLCProcessGetExitEvent(WSLCProcessHandle Process, HANDLE* ExitEvent)
+try
+{
+ RETURN_HR_IF(E_POINTER, Process == nullptr || ExitEvent == nullptr);
+
+ *ExitEvent = nullptr;
+
+ auto* wrapper = static_cast(Process);
+ auto result = wrapper->Process->GetExitEvent(ExitEvent);
+
+ WSL_LOG("WslcPluginProcessGetExitEvent", TraceLoggingValue(*ExitEvent, "ExitEvent"), TraceLoggingValue(result, "Result"));
+
+ return result;
+}
+CATCH_RETURN();
+
+HRESULT WSLCProcessGetExitCode(WSLCProcessHandle Process, int* ExitCode)
+try
+{
+ RETURN_HR_IF(E_POINTER, Process == nullptr || ExitCode == nullptr);
+
+ *ExitCode = -1;
+ auto* wrapper = static_cast(Process);
+
+ WSLCProcessState state{};
+ auto result = wrapper->Process->GetState(&state, ExitCode);
+
+ if (SUCCEEDED(result) && state != WslcProcessStateExited && state != WslcProcessStateSignalled)
+ {
+ result = HRESULT_FROM_WIN32(ERROR_INVALID_STATE);
+ }
+
+ WSL_LOG(
+ "WslcPluginProcessGetExitCode",
+ TraceLoggingValue(*ExitCode, "ExitCode"),
+ TraceLoggingValue(static_cast(state), "State"),
+ TraceLoggingValue(result, "Result"));
+
+ return result;
+}
+CATCH_RETURN();
+
+void WSLCReleaseProcess(WSLCProcessHandle Process)
+{
+ if (Process != nullptr)
+ {
+ WSL_LOG("WslcPluginReleaseProcess", TraceLoggingValue(Process, "Process"));
+ delete static_cast(Process);
+ }
+}
+
+} // extern "C"
+
+static constexpr WSLPluginAPIV1 ApiV1 = {
+ Version,
+ &MountFolder,
+ &ExecuteBinary,
+ &PluginError,
+ &ExecuteBinaryInDistribution,
+ &WSLCMountFolder,
+ &WSLCUnmountFolder,
+ &WSLCCreateProcess,
+ &WSLCProcessGetFd,
+ &WSLCProcessGetExitEvent,
+ &WSLCProcessGetExitCode,
+ &WSLCReleaseProcess};
void PluginManager::LoadPlugins()
{
@@ -181,7 +439,7 @@ void PluginManager::OnVmStarted(const WSLSessionInformation* Session, const WSLV
WSL_LOG(
"PluginOnVmStartedCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid"));
- ThrowIfPluginError(e.hooks.OnVMStarted(Session, Settings), Session->SessionId, e.name.c_str());
+ ThrowIfPluginError(e.hooks.OnVMStarted(Session, Settings), e.name.c_str());
}
}
}
@@ -217,7 +475,7 @@ void PluginManager::OnDistributionStarted(const WSLSessionInformation* Session,
TraceLoggingValue(Session->UserSid, "Sid"),
TraceLoggingValue(Distribution->Id, "DistributionId"));
- ThrowIfPluginError(e.hooks.OnDistributionStarted(Session, Distribution), Session->SessionId, e.name.c_str());
+ ThrowIfPluginError(e.hooks.OnDistributionStarted(Session, Distribution), e.name.c_str());
}
}
}
@@ -282,7 +540,7 @@ void PluginManager::OnDistributionUnregistered(const WSLSessionInformation* Sess
}
}
-void PluginManager::ThrowIfPluginError(HRESULT Result, WSLSessionId Session, LPCWSTR Plugin)
+void PluginManager::ThrowIfPluginError(HRESULT Result, LPCWSTR Plugin)
{
const auto message = std::move(g_pluginErrorMessage);
g_pluginErrorMessage.reset(); // std::move() doesn't clear the previous std::optional
@@ -322,3 +580,130 @@ void PluginManager::ThrowIfFatalPluginError() const
THROW_HR_WITH_USER_ERROR(m_pluginError->error, wsl::shared::Localization::MessageFatalPluginError(m_pluginError->plugin));
}
}
+
+void PluginManager::OnWslcSessionCreated(const WSLCSessionInformation* Session)
+{
+ ExecutionContext context(Context::Plugin);
+
+ for (const auto& e : m_plugins)
+ {
+ if (e.hooks.OnSessionCreated != nullptr)
+ {
+ auto result = e.hooks.OnSessionCreated(Session);
+ WSL_LOG(
+ "PluginOnWslcSessionCreatedCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->SessionId, "SessionId"),
+ TraceLoggingValue(Session->DisplayName, "DisplayName"),
+ TraceLoggingValue(result, "Result"));
+
+ ThrowIfPluginError(result, e.name.c_str());
+ }
+ }
+}
+
+void PluginManager::OnWslcSessionStopping(const WSLCSessionInformation* Session) const
+{
+ ExecutionContext context(Context::Plugin);
+
+ for (const auto& e : m_plugins)
+ {
+ if (e.hooks.OnSessionStopping != nullptr)
+ {
+ const auto result = e.hooks.OnSessionStopping(Session);
+ WSL_LOG(
+ "PluginOnWslcSessionStoppingCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->SessionId, "SessionId"),
+ TraceLoggingValue(result, "Result"));
+
+ LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
+ }
+ }
+}
+
+HRESULT PluginManager::OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson) const
+try
+{
+ ExecutionContext context(Context::Plugin);
+
+ for (const auto& e : m_plugins)
+ {
+ if (e.hooks.ContainerStarted != nullptr)
+ {
+ // Failure here aborts the container creation. Surface the first error.
+ const auto result = e.hooks.ContainerStarted(Session, InspectJson);
+ WSL_LOG(
+ "PluginOnWslcContainerStartedCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->SessionId, "SessionId"),
+ TraceLoggingValue(result, "Result"));
+
+ ThrowIfPluginError(result, e.name.c_str());
+ }
+ }
+ return S_OK;
+}
+CATCH_RETURN()
+
+void PluginManager::OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId) const
+{
+ ExecutionContext context(Context::Plugin);
+
+ for (const auto& e : m_plugins)
+ {
+ if (e.hooks.ContainerStopping != nullptr)
+ {
+
+ const auto result = e.hooks.ContainerStopping(Session, ContainerId);
+ WSL_LOG(
+ "PluginOnWslcContainerStoppingCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->SessionId, "SessionId"),
+ TraceLoggingValue(ContainerId, "ContainerId"),
+ TraceLoggingValue(result, "Result"));
+
+ LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
+ }
+ }
+}
+
+void PluginManager::OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson) const
+{
+ ExecutionContext context(Context::Plugin);
+
+ for (const auto& e : m_plugins)
+ {
+ if (e.hooks.ImageCreated != nullptr)
+ {
+ const auto result = e.hooks.ImageCreated(Session, InspectJson);
+ WSL_LOG(
+ "PluginOnWslcImageCreatedCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->SessionId, "SessionId"),
+ TraceLoggingValue(result, "Result"));
+
+ LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
+ }
+ }
+}
+
+void PluginManager::OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId) const
+{
+ ExecutionContext context(Context::Plugin);
+
+ for (const auto& e : m_plugins)
+ {
+ if (e.hooks.ImageDeleted != nullptr)
+ {
+ const auto result = e.hooks.ImageDeleted(Session, ImageId);
+ WSL_LOG(
+ "PluginOnWslcImageDeletedCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->SessionId, "SessionId"),
+ TraceLoggingValue(ImageId, "ImageId"),
+ TraceLoggingValue(result, "Result"));
+ LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
+ }
+ }
+}
diff --git a/src/windows/service/exe/PluginManager.h b/src/windows/service/exe/PluginManager.h
index 746a11fa4..a99a3e332 100644
--- a/src/windows/service/exe/PluginManager.h
+++ b/src/windows/service/exe/PluginManager.h
@@ -43,11 +43,21 @@ class PluginManager
void OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* distro) const;
void OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro) const;
void OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro) const;
+
+ // WSLC notifications. Returning failure from OnSessionCreated/OnContainerStarted causes the
+ // corresponding operation to be aborted. Other notifications log errors and continue.
+ void OnWslcSessionCreated(const WSLCSessionInformation* Session);
+ void OnWslcSessionStopping(const WSLCSessionInformation* Session) const;
+ HRESULT OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson) const;
+ void OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId) const;
+ void OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson) const;
+ void OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId) const;
+
void ThrowIfFatalPluginError() const;
private:
void LoadPlugin(LPCWSTR Name, LPCWSTR Path);
- static void ThrowIfPluginError(HRESULT Result, WSLSessionId session, LPCWSTR Plugin);
+ static void ThrowIfPluginError(HRESULT Result, LPCWSTR Plugin);
struct LoadedPlugin
{
@@ -60,4 +70,4 @@ class PluginManager
std::optional m_pluginError;
};
-} // namespace wsl::windows::service
\ No newline at end of file
+} // namespace wsl::windows::service
diff --git a/src/windows/service/exe/ServiceMain.cpp b/src/windows/service/exe/ServiceMain.cpp
index 059f7d599..9cad83bf0 100644
--- a/src/windows/service/exe/ServiceMain.cpp
+++ b/src/windows/service/exe/ServiceMain.cpp
@@ -29,6 +29,8 @@ using namespace wsl::windows::policies;
bool g_lxcoreInitialized{false};
wil::unique_event g_networkingReady{wil::EventOptions::ManualReset};
+wsl::windows::service::PluginManager g_pluginManager;
+
// Declare the LxssUserSession COM class.
CoCreatableClassWrlCreatorMapInclude(LxssUserSession);
@@ -178,6 +180,9 @@ try
WSADATA Data;
THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &Data));
+ // Load plugins.
+ g_pluginManager.LoadPlugins();
+
// Check if WSL is disabled via policy and set up a registry watcher to watch for changes.
//
// N.B. The registry watcher must be created before checking the policy to avoid missing notifications.
diff --git a/src/windows/service/exe/WSLCPluginNotifier.cpp b/src/windows/service/exe/WSLCPluginNotifier.cpp
new file mode 100644
index 000000000..7daab35bb
--- /dev/null
+++ b/src/windows/service/exe/WSLCPluginNotifier.cpp
@@ -0,0 +1,66 @@
+// Copyright (C) Microsoft Corporation. All rights reserved.
+
+#include "precomp.h"
+#include "WSLCPluginNotifier.h"
+
+using wsl::windows::common::COMServiceExecutionContext;
+using wsl::windows::service::wslc::WSLCPluginNotifier;
+
+WSLCPluginNotifier::WSLCPluginNotifier(
+ wsl::windows::service::PluginManager& Plugins,
+ ULONG SessionId,
+ DWORD CreatorPid,
+ std::wstring DisplayName,
+ wil::shared_handle UserToken,
+ std::vector&& UserSid) :
+ m_plugins(Plugins), m_displayName(std::move(DisplayName)), m_userToken(std::move(UserToken)), m_userSid(std::move(UserSid))
+{
+ m_sessionInfo.SessionId = static_cast(SessionId);
+ m_sessionInfo.DisplayName = m_displayName.c_str();
+ m_sessionInfo.ApplicationPid = CreatorPid;
+ m_sessionInfo.UserToken = m_userToken.get();
+ m_sessionInfo.UserSid = m_userSid.empty() ? nullptr : reinterpret_cast(m_userSid.data());
+}
+
+HRESULT WSLCPluginNotifier::OnContainerStarted(LPCSTR InspectJson)
+try
+{
+ COMServiceExecutionContext context;
+
+ RETURN_HR_IF(E_POINTER, InspectJson == nullptr);
+ return m_plugins.OnWslcContainerStarted(&m_sessionInfo, InspectJson);
+}
+CATCH_RETURN();
+
+HRESULT WSLCPluginNotifier::OnContainerStopping(LPCSTR ContainerId)
+try
+{
+ COMServiceExecutionContext context;
+
+ RETURN_HR_IF(E_POINTER, ContainerId == nullptr);
+ m_plugins.OnWslcContainerStopping(&m_sessionInfo, ContainerId);
+ return S_OK;
+}
+CATCH_RETURN();
+
+HRESULT WSLCPluginNotifier::OnImageCreated(LPCSTR InspectJson)
+try
+{
+ COMServiceExecutionContext context;
+
+ RETURN_HR_IF(E_POINTER, InspectJson == nullptr);
+ m_plugins.OnWslcImageCreated(&m_sessionInfo, InspectJson);
+ return S_OK;
+}
+CATCH_RETURN();
+
+HRESULT WSLCPluginNotifier::OnImageDeleted(LPCSTR ImageId)
+try
+{
+ COMServiceExecutionContext context;
+
+ RETURN_HR_IF(E_POINTER, ImageId == nullptr);
+ m_plugins.OnWslcImageDeleted(&m_sessionInfo, ImageId);
+ return S_OK;
+}
+CATCH_RETURN();
diff --git a/src/windows/service/exe/WSLCPluginNotifier.h b/src/windows/service/exe/WSLCPluginNotifier.h
new file mode 100644
index 000000000..b8ba4e11e
--- /dev/null
+++ b/src/windows/service/exe/WSLCPluginNotifier.h
@@ -0,0 +1,46 @@
+// Copyright (C) Microsoft Corporation. All rights reserved.
+
+#pragma once
+
+#include "wslc.h"
+#include "PluginManager.h"
+#include
+#include
+
+namespace wsl::windows::service::wslc {
+
+//
+// WSLCPluginNotifier - SYSTEM service implementation of IWSLCPluginNotifier.
+// Lives in the SYSTEM service and is passed (via COM marshalling) as a top-level
+// parameter to the per-user WSLC session process via IWSLCSessionFactory::CreateSession.
+// The per-user process invokes the On* methods, which dispatch to PluginManager.
+//
+class DECLSPEC_UUID("E29B0F1A-4E18-4F09-83A2-2D6B1B9F8C4D") WSLCPluginNotifier
+ : public Microsoft::WRL::RuntimeClass, IWSLCPluginNotifier, IFastRundown>
+{
+public:
+ NON_COPYABLE(WSLCPluginNotifier);
+ NON_MOVABLE(WSLCPluginNotifier);
+
+ WSLCPluginNotifier(
+ wsl::windows::service::PluginManager& Plugins,
+ ULONG SessionId,
+ DWORD CreatorPid,
+ std::wstring DisplayName,
+ wil::shared_handle UserToken,
+ std::vector&& UserSid);
+
+ IFACEMETHOD(OnContainerStarted)(_In_ LPCSTR InspectJson) override;
+ IFACEMETHOD(OnContainerStopping)(_In_ LPCSTR ContainerId) override;
+ IFACEMETHOD(OnImageCreated)(_In_ LPCSTR InspectJson) override;
+ IFACEMETHOD(OnImageDeleted)(_In_ LPCSTR ImageId) override;
+
+private:
+ wsl::windows::service::PluginManager& m_plugins;
+ std::wstring m_displayName;
+ wil::shared_handle m_userToken;
+ std::vector m_userSid;
+ WSLCSessionInformation m_sessionInfo{};
+};
+
+} // namespace wsl::windows::service::wslc
diff --git a/src/windows/service/exe/WSLCSessionManager.cpp b/src/windows/service/exe/WSLCSessionManager.cpp
index 318626cbe..fc3543882 100644
--- a/src/windows/service/exe/WSLCSessionManager.cpp
+++ b/src/windows/service/exe/WSLCSessionManager.cpp
@@ -31,17 +31,26 @@ Module Name:
#include "HcsVirtualMachine.h"
#include "WSLCUserSettings.h"
#include "WSLCSessionDefaults.h"
+#include "WSLCPluginNotifier.h"
+#include "PluginManager.h"
+#include "ExecutionContext.h"
#include "wslutil.h"
#include "filesystem.hpp"
+extern wsl::windows::service::PluginManager g_pluginManager;
+
+using wsl::windows::common::COMServiceExecutionContext;
using wsl::windows::service::wslc::CallingProcessTokenInfo;
using wsl::windows::service::wslc::HcsVirtualMachine;
+using wsl::windows::service::wslc::WSLCPluginNotifier;
using wsl::windows::service::wslc::WSLCSessionManagerImpl;
namespace wslutil = wsl::windows::common::wslutil;
namespace settings = wsl::windows::wslc::settings;
namespace {
+std::atomic g_managerInstance{nullptr};
+
// Session settings built server-side from the caller's settings.yaml.
struct SessionSettings
{
@@ -114,8 +123,15 @@ struct SessionSettings
} // namespace
+WSLCSessionManagerImpl::WSLCSessionManagerImpl()
+{
+ g_managerInstance.store(this);
+}
+
WSLCSessionManagerImpl::~WSLCSessionManagerImpl()
{
+ g_managerInstance.store(nullptr);
+
// Terminate all sessions on shutdown.
// Call Terminate() directly rather than going through ForEachSession(),
// which would needlessly resolve weak references and call GetState().
@@ -123,10 +139,30 @@ WSLCSessionManagerImpl::~WSLCSessionManagerImpl()
std::lock_guard lock(m_wslcSessionsLock);
for (auto& entry : m_sessions)
{
+ NotifySessionStoppingLockHeld(entry);
LOG_IF_FAILED(entry.Ref->Terminate());
}
}
+void WSLCSessionManagerImpl::NotifySessionStoppingLockHeld(SessionEntry& entry) noexcept
+try
+{
+ if (entry.StoppingNotified)
+ {
+ return;
+ }
+
+ entry.StoppingNotified = true;
+ WSLCSessionInformation info{};
+ info.SessionId = static_cast(entry.SessionId);
+ info.DisplayName = entry.DisplayName.c_str();
+ info.ApplicationPid = entry.CreatorPid;
+ info.UserToken = entry.UserToken.get();
+ info.UserSid = entry.UserSid.data();
+ g_pluginManager.OnWslcSessionStopping(&info);
+}
+CATCH_LOG()
+
void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings, WSLCSessionFlags Flags, IWSLCSession** WslcSession)
{
auto tokenInfo = GetCallingProcessTokenInfo();
@@ -144,7 +180,7 @@ void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings,
{
THROW_HR_IF(WSLC_E_INVALID_SESSION_NAME, Settings->DisplayName == nullptr || wcslen(Settings->DisplayName) == 0);
THROW_HR_IF(E_INVALIDARG, Settings->StoragePath != nullptr && wcslen(Settings->StoragePath) == 0);
- THROW_HR_IF(WSLC_E_INVALID_SESSION_NAME, wcslen(Settings->DisplayName) >= std::size(WSLCSessionInformation{}.DisplayName));
+ THROW_HR_IF(WSLC_E_INVALID_SESSION_NAME, wcslen(Settings->DisplayName) >= std::size(WSLCSessionListEntry{}.DisplayName));
THROW_HR_IF_MSG(
E_INVALIDARG,
WI_IsAnyFlagSet(Settings->StorageFlags, ~WSLCSessionStorageFlagsValid),
@@ -208,6 +244,23 @@ void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings,
const auto userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation);
+ // Capture a duplicated user token + raw SID so PluginManager can build
+ // WSLCSessionInformation later (e.g. on shutdown) without re-impersonating.
+ // The token is shared between the SessionEntry and the WSLCPluginNotifier.
+ wil::unique_handle dupToken;
+ THROW_IF_WIN32_BOOL_FALSE(DuplicateTokenEx(
+ userToken.get(), TOKEN_QUERY | TOKEN_DUPLICATE, nullptr, SecurityImpersonation, TokenImpersonation, &dupToken));
+ wil::shared_handle sharedToken{dupToken.release()};
+
+ const DWORD sidLen = GetLengthSid(tokenInfo.TokenInfo->User.Sid);
+ std::vector storedSid(sidLen);
+ THROW_IF_WIN32_BOOL_FALSE(CopySid(sidLen, storedSid.data(), tokenInfo.TokenInfo->User.Sid));
+
+ // Build the plugin notifier service-side. Lifetime tracked via the SessionEntry.
+ Microsoft::WRL::ComPtr notifier;
+ notifier = wil::MakeOrThrow(
+ g_pluginManager, sessionId, creatorPid, std::wstring(resolvedDisplayName), wil::shared_handle(sharedToken), std::vector(storedSid));
+
// Create the VM in the SYSTEM service (privileged).
auto vm = Microsoft::WRL::Make(Settings);
@@ -215,14 +268,14 @@ void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings,
auto factory = wslutil::CreateComServerAsUser(__uuidof(WSLCSessionFactory), userToken.get());
AddSessionProcessToJobObject(factory.get());
- // Create the session via the factory.
- const auto sessionSettings = CreateSessionSettings(sessionId, creatorPid, Settings, resolvedDisplayName.c_str());
+ auto sessionSettings = CreateSessionSettings(sessionId, creatorPid, Settings, resolvedDisplayName.c_str());
wil::com_ptr session;
wil::com_ptr serviceRef;
- THROW_IF_FAILED(factory->CreateSession(&sessionSettings, vm.Get(), &session, &serviceRef));
+ THROW_IF_FAILED(factory->CreateSession(&sessionSettings, vm.Get(), notifier.Get(), &session, &serviceRef));
// Track the session via its service ref, along with metadata and security info.
- m_sessions.push_back({std::move(serviceRef), sessionId, creatorPid, resolvedDisplayName, std::move(tokenInfo)});
+ m_sessions.push_back(SessionEntry{
+ std::move(serviceRef), sessionId, creatorPid, resolvedDisplayName, std::move(tokenInfo), notifier, false, sharedToken, std::move(storedSid)});
// For persistent sessions, also hold a strong reference to keep them alive.
const bool persistent = WI_IsFlagSet(Flags, WSLCSessionFlagsPersistent);
@@ -231,6 +284,33 @@ void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings,
m_persistentSessions.emplace_back(sessionId, session);
}
+ // Notify plugins that the session was created. A failure here aborts session creation.
+ try
+ {
+ auto& entry = m_sessions.back();
+ WSLCSessionInformation info{};
+ info.SessionId = static_cast(entry.SessionId);
+ info.DisplayName = entry.DisplayName.c_str();
+ info.ApplicationPid = entry.CreatorPid;
+ info.UserToken = entry.UserToken.get();
+ info.UserSid = entry.UserSid.data();
+ g_pluginManager.OnWslcSessionCreated(&info);
+ }
+ catch (...)
+ {
+ const auto error = wil::ResultFromCaughtException();
+
+ // Plugin rejected the session: tear it down before propagating.
+ m_sessions.back().StoppingNotified = true; // Don't fire stopping for a session that never started successfully.
+ LOG_IF_FAILED(m_sessions.back().Ref->Terminate());
+ m_sessions.pop_back();
+
+ auto remove = std::ranges::remove_if(m_persistentSessions, [&](const auto& e) { return e.first == sessionId; });
+ m_persistentSessions.erase(remove.begin(), remove.end());
+
+ THROW_HR(error);
+ }
+
*WslcSession = session.detach();
});
@@ -297,9 +377,9 @@ void WSLCSessionManagerImpl::OpenSessionByName(LPCWSTR DisplayName, IWSLCSession
THROW_IF_FAILED_MSG(result.value_or(HRESULT_FROM_WIN32(ERROR_NOT_FOUND)), "Session '%ls' not found", DisplayName);
}
-void WSLCSessionManagerImpl::ListSessions(_Out_ WSLCSessionInformation** Sessions, _Out_ ULONG* SessionsCount)
+void WSLCSessionManagerImpl::ListSessions(_Out_ WSLCSessionListEntry** Sessions, _Out_ ULONG* SessionsCount)
{
- std::vector sessionInfo;
+ std::vector sessionInfo;
ForEachSession([&](auto& entry, const auto&) noexcept {
try
@@ -307,15 +387,15 @@ void WSLCSessionManagerImpl::ListSessions(_Out_ WSLCSessionInformation** Session
wil::unique_hlocal_string sidString;
THROW_IF_WIN32_BOOL_FALSE(ConvertSidToStringSidW(entry.Owner.TokenInfo->User.Sid, &sidString));
- auto& it = sessionInfo.emplace_back(WSLCSessionInformation{.SessionId = entry.SessionId, .CreatorPid = entry.CreatorPid});
+ auto& it = sessionInfo.emplace_back(WSLCSessionListEntry{.SessionId = entry.SessionId, .CreatorPid = entry.CreatorPid});
wcscpy_s(it.Sid, _countof(it.Sid), sidString.get());
wcscpy_s(it.DisplayName, _countof(it.DisplayName), entry.DisplayName.c_str());
}
CATCH_LOG()
});
- auto output = wil::make_unique_cotaskmem(sessionInfo.size());
- memcpy(output.get(), sessionInfo.data(), sessionInfo.size() * sizeof(WSLCSessionInformation));
+ auto output = wil::make_unique_cotaskmem(sessionInfo.size());
+ memcpy(output.get(), sessionInfo.data(), sessionInfo.size() * sizeof(WSLCSessionListEntry));
*Sessions = output.release();
*SessionsCount = static_cast(sessionInfo.size());
@@ -478,26 +558,65 @@ try
CATCH_RETURN();
HRESULT WSLCSessionManager::CreateSession(const WSLCSessionSettings* WslcSessionSettings, WSLCSessionFlags Flags, IWSLCSession** WslcSession)
+try
{
+ COMServiceExecutionContext context;
+
return CallImpl(&WSLCSessionManagerImpl::CreateSession, WslcSessionSettings, Flags, WslcSession);
}
+CATCH_RETURN();
HRESULT WSLCSessionManager::EnterSession(_In_ LPCWSTR DisplayName, _In_ LPCWSTR StoragePath, IWSLCSession** WslcSession)
{
+ COMServiceExecutionContext context;
+
return CallImpl(&WSLCSessionManagerImpl::EnterSession, DisplayName, StoragePath, WslcSession);
}
-HRESULT WSLCSessionManager::ListSessions(_Out_ WSLCSessionInformation** Sessions, _Out_ ULONG* SessionsCount)
+HRESULT WSLCSessionManager::ListSessions(_Out_ WSLCSessionListEntry** Sessions, _Out_ ULONG* SessionsCount)
{
+ COMServiceExecutionContext context;
+
return CallImpl(&WSLCSessionManagerImpl::ListSessions, Sessions, SessionsCount);
}
HRESULT WSLCSessionManager::OpenSession(_In_ ULONG Id, _Out_ IWSLCSession** Session)
{
+ COMServiceExecutionContext context;
+
return CallImpl(&WSLCSessionManagerImpl::OpenSession, Id, Session);
}
HRESULT WSLCSessionManager::OpenSessionByName(_In_ LPCWSTR DisplayName, _Out_ IWSLCSession** Session)
{
+ COMServiceExecutionContext context;
+
return CallImpl(&WSLCSessionManagerImpl::OpenSessionByName, DisplayName, Session);
}
+
+namespace wsl::windows::service::wslc {
+
+WSLCSessionManagerImpl* WSLCSessionManagerImpl::Instance() noexcept
+{
+ return g_managerInstance.load();
+}
+
+wil::com_ptr WSLCSessionManagerImpl::FindSession(ULONG Id)
+{
+ wil::com_ptr result;
+
+ ForEachSession([&](SessionEntry& entry, const wil::com_ptr& session) noexcept -> std::optional {
+ if (entry.SessionId != Id)
+ {
+ return std::nullopt;
+ }
+
+ result = session;
+ return S_OK;
+ });
+
+ THROW_HR_IF_MSG(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), !result, "WSLC session %lu not found", Id);
+ return result;
+}
+
+} // namespace wsl::windows::service::wslc
diff --git a/src/windows/service/exe/WSLCSessionManager.h b/src/windows/service/exe/WSLCSessionManager.h
index 3ef386cdd..ce5d31eab 100644
--- a/src/windows/service/exe/WSLCSessionManager.h
+++ b/src/windows/service/exe/WSLCSessionManager.h
@@ -60,6 +60,14 @@ struct SessionEntry
DWORD CreatorPid = 0;
std::wstring DisplayName;
CallingProcessTokenInfo Owner;
+
+ Microsoft::WRL::ComPtr PluginNotifier;
+
+ // Whether OnSessionStopping has been fired already; ensures it is fired exactly once.
+ bool StoppingNotified = false;
+
+ wil::shared_handle UserToken;
+ std::vector UserSid;
};
class WSLCSessionManagerImpl
@@ -68,15 +76,20 @@ class WSLCSessionManagerImpl
NON_COPYABLE(WSLCSessionManagerImpl);
NON_MOVABLE(WSLCSessionManagerImpl);
- WSLCSessionManagerImpl() = default;
+ WSLCSessionManagerImpl();
~WSLCSessionManagerImpl();
void CreateSession(const WSLCSessionSettings* WslcSessionSettings, WSLCSessionFlags Flags, IWSLCSession** WslcSession);
void EnterSession(_In_ LPCWSTR DisplayName, _In_ LPCWSTR StoragePath, IWSLCSession** WslcSession);
- void ListSessions(_Out_ WSLCSessionInformation** Sessions, _Out_ ULONG* SessionsCount);
+ void ListSessions(_Out_ WSLCSessionListEntry** Sessions, _Out_ ULONG* SessionsCount);
void OpenSession(_In_ ULONG Id, _Out_ IWSLCSession** Session);
void OpenSessionByName(_In_ LPCWSTR DisplayName, _Out_ IWSLCSession** Session);
+ // Resolves a session by ID for plugin->API calls. Throws ERROR_NOT_FOUND if no session matches.
+ wil::com_ptr FindSession(ULONG Id);
+
+ static WSLCSessionManagerImpl* Instance() noexcept;
+
private:
// Resolves the default session name for a caller: appends the username
// from the token SID so different users don't collide.
@@ -109,7 +122,9 @@ class WSLCSessionManagerImpl
wil::com_ptr lockedSession;
if (FAILED_LOG(entry.Ref->OpenSession(&lockedSession)))
{
- // Session is gone, drop the persistent reference if any.
+ // Session is gone: notify plugins (if not already), then drop persistent reference if any.
+ NotifySessionStoppingLockHeld(entry);
+
auto remove =
std::ranges::remove_if(m_persistentSessions, [&](const auto& e) { return e.first == entry.SessionId; });
m_persistentSessions.erase(remove.begin(), remove.end());
@@ -151,6 +166,8 @@ class WSLCSessionManagerImpl
static CallingProcessTokenInfo GetCallingProcessTokenInfo();
static HRESULT CheckTokenAccess(const SessionEntry& Entry, const CallingProcessTokenInfo& TokenInfo);
+ void NotifySessionStoppingLockHeld(SessionEntry& entry) noexcept;
+
std::atomic m_nextSessionId{1};
std::recursive_mutex m_wslcSessionsLock;
@@ -183,7 +200,7 @@ class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") WSLCSessionManager
IFACEMETHOD(IsClientVersionSupported)(_In_ const WSLCVersion* ClientVersion, _Out_ BOOL* IsSupported) override;
IFACEMETHOD(CreateSession)(const WSLCSessionSettings* WslcSessionSettings, WSLCSessionFlags Flags, IWSLCSession** WslcSession) override;
IFACEMETHOD(EnterSession)(_In_ LPCWSTR DisplayName, _In_ LPCWSTR StoragePath, IWSLCSession** WslcSession) override;
- IFACEMETHOD(ListSessions)(_Out_ WSLCSessionInformation** Sessions, _Out_ ULONG* SessionsCount) override;
+ IFACEMETHOD(ListSessions)(_Out_ WSLCSessionListEntry** Sessions, _Out_ ULONG* SessionsCount) override;
IFACEMETHOD(OpenSession)(_In_ ULONG Id, _Out_ IWSLCSession** Session) override;
IFACEMETHOD(OpenSessionByName)(_In_ LPCWSTR DisplayName, _Out_ IWSLCSession** Session) override;
};
diff --git a/src/windows/service/exe/WSLCSessionManagerFactory.cpp b/src/windows/service/exe/WSLCSessionManagerFactory.cpp
index 6929bffba..09f2ffb20 100644
--- a/src/windows/service/exe/WSLCSessionManagerFactory.cpp
+++ b/src/windows/service/exe/WSLCSessionManagerFactory.cpp
@@ -82,4 +82,4 @@ void wsl::windows::service::wslc::ClearWslcSessionsAndBlockNewInstances()
}
g_sessionManagerImpl.reset();
-}
\ No newline at end of file
+}
diff --git a/src/windows/service/inc/wslc.idl b/src/windows/service/inc/wslc.idl
index 2ad3216f8..f2ebc1a81 100644
--- a/src/windows/service/inc/wslc.idl
+++ b/src/windows/service/inc/wslc.idl
@@ -117,6 +117,27 @@ interface IProgressCallback : IUnknown
HRESULT OnProgress(LPCSTR Status, LPCSTR Id, ULONGLONG Current, ULONGLONG Total);
};
+[
+ uuid(F3E6D5B2-1D40-4E8B-9C39-7A45D1C0F8A2),
+ pointer_default(unique),
+ object
+]
+interface IWSLCPluginNotifier : IUnknown
+{
+ // 'InspectJson' follows the wslc_schema::InspectContainer format.
+ // Returning failure prevents the container creation.
+ HRESULT OnContainerStarted([in] LPCSTR InspectJson);
+
+ // Called when a container is about to stop. 'ContainerId' is the container identifier. Errors are logged but ignored.
+ HRESULT OnContainerStopping([in] LPCSTR ContainerId);
+
+ // 'InspectJson' follows the wslc_schema::InspectImage format. Errors are logged but ignored.
+ HRESULT OnImageCreated([in] LPCSTR InspectJson);
+
+ // Called when an image is deleted. 'ImageId' is the image identifier. Errors are logged but ignored.
+ HRESULT OnImageDeleted([in] LPCSTR ImageId);
+};
+
typedef struct _WSLCImageInformation
{
char Image[WSLC_MAX_IMAGE_NAME_LENGTH + 1];
@@ -773,7 +794,8 @@ interface IWSLCSession : IUnknown
// Initializes the session with a pre-created VM.
HRESULT Initialize(
[in] const WSLCSessionInitSettings* Settings,
- [in] IWSLCVirtualMachine* Vm);
+ [in] IWSLCVirtualMachine* Vm,
+ [in] IWSLCPluginNotifier* PluginNotifier);
// Volume management.
HRESULT CreateVolume([in] const WSLCVolumeOptions* Options, [out] WSLCVolumeInformation* VolumeInfo);
@@ -827,6 +849,7 @@ interface IWSLCSessionFactory : IUnknown
HRESULT CreateSession(
[in] const WSLCSessionInitSettings* Settings,
[in] IWSLCVirtualMachine* Vm,
+ [in] IWSLCPluginNotifier* PluginNotifier,
[out] IWSLCSession** Session,
[out] IWSLCSessionReference** ServiceRef);
@@ -834,13 +857,13 @@ interface IWSLCSessionFactory : IUnknown
HRESULT GetProcessHandle([out, system_handle(sh_process)] HANDLE* ProcessHandle);
}
-typedef struct _WSLCSessionInformation
+typedef struct _WSLCSessionListEntry
{
ULONG SessionId;
DWORD CreatorPid;
wchar_t DisplayName[256];
wchar_t Sid[256 + 1]; // MAX_SID_SIZE = 256
-} WSLCSessionInformation;
+} WSLCSessionListEntry;
typedef enum _WSLCSessionFlags
{
@@ -865,7 +888,7 @@ interface IWSLCSessionManager : IUnknown
// Session management.
HRESULT CreateSession([in, unique] const WSLCSessionSettings* Settings, WSLCSessionFlags Flags, [out] IWSLCSession** Session);
HRESULT EnterSession([in, ref] LPCWSTR DisplayName, [in, ref] LPCWSTR StoragePath, [out] IWSLCSession** Session);
- HRESULT ListSessions([out, size_is(, *SessionsCount)] WSLCSessionInformation** Sessions, [out] ULONG* SessionsCount);
+ HRESULT ListSessions([out, size_is(, *SessionsCount)] WSLCSessionListEntry** Sessions, [out] ULONG* SessionsCount);
HRESULT OpenSession([in] ULONG Id, [out] IWSLCSession** Session);
HRESULT OpenSessionByName([in, unique] LPCWSTR DisplayName, [out] IWSLCSession** Session);
}
diff --git a/src/windows/wslc/services/SessionService.cpp b/src/windows/wslc/services/SessionService.cpp
index a38b5c5b5..a390294bb 100644
--- a/src/windows/wslc/services/SessionService.cpp
+++ b/src/windows/wslc/services/SessionService.cpp
@@ -149,7 +149,7 @@ std::vector SessionService::List()
THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLCSessionManager), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&sessionManager)));
wsl::windows::common::security::ConfigureForCOMImpersonation(sessionManager.get());
- wil::unique_cotaskmem_array_ptr sessions;
+ wil::unique_cotaskmem_array_ptr sessions;
THROW_IF_FAILED(sessionManager->ListSessions(&sessions, sessions.size_address()));
for (size_t i = 0; i < sessions.size(); ++i)
{
diff --git a/src/windows/wslcsession/WSLCContainer.cpp b/src/windows/wslcsession/WSLCContainer.cpp
index aaa602b00..2d99a6f88 100644
--- a/src/windows/wslcsession/WSLCContainer.cpp
+++ b/src/windows/wslcsession/WSLCContainer.cpp
@@ -531,6 +531,7 @@ WSLCPortMapping ContainerPortMapping::Serialize() const
WSLCContainerImpl::WSLCContainerImpl(
WSLCSession& wslcSession,
WSLCVirtualMachine& virtualMachine,
+ IWSLCPluginNotifier* pluginNotifier,
std::string&& Id,
std::string&& Name,
std::string&& Image,
@@ -547,6 +548,7 @@ WSLCContainerImpl::WSLCContainerImpl(
WSLCProcessFlags InitProcessFlags,
WSLCContainerFlags ContainerFlags) :
m_wslcSession(wslcSession),
+ m_pluginNotifier(pluginNotifier),
m_virtualMachine(virtualMachine),
m_name(std::move(Name)),
m_image(std::move(Image)),
@@ -803,6 +805,30 @@ void WSLCContainerImpl::Start(WSLCContainerStartFlags Flags, LPCSTR DetachKeys)
}
CATCH_AND_THROW_DOCKER_USER_ERROR("Failed to start container '%hs'", m_id.c_str());
+ auto inspectJson = InspectLockHeld();
+ const auto pluginResult = m_pluginNotifier->OnContainerStarted(inspectJson.c_str());
+ if (FAILED(pluginResult))
+ {
+ // Forward the COM error message, if available.
+ auto comError = wsl::windows::common::wslutil::GetCOMErrorInfo();
+
+ LOG_HR_MSG(pluginResult, "Plugin rejected start of container '%hs' (0x%x)", m_id.c_str(), pluginResult);
+ try
+ {
+ m_dockerClient.StopContainer(m_id.c_str(), {}, {});
+ }
+ CATCH_LOG();
+
+ if (comError.has_value() && comError->Message)
+ {
+ THROW_HR_WITH_USER_ERROR(pluginResult, comError->Message.get());
+ }
+ else
+ {
+ THROW_HR(pluginResult);
+ }
+ }
+
portCleanup.release();
volumeCleanup.release();
@@ -946,6 +972,16 @@ __requires_exclusive_lock_held(m_lock) unique_com_disconnect WSLCContainerImpl::
{
unique_com_disconnect comWrapper;
+ // Notify plugin manager that the container is stopping. Errors are ignored.
+ if (m_state == WslcContainerStateRunning)
+ {
+ try
+ {
+ LOG_IF_FAILED(m_pluginNotifier->OnContainerStopping(m_id.c_str()));
+ }
+ CATCH_LOG();
+ }
+
ReleaseProcesses();
ReleaseRuntimeResources();
@@ -1298,6 +1334,7 @@ std::unique_ptr WSLCContainerImpl::Create(
const std::string& containerName,
WSLCSession& wslcSession,
WSLCVirtualMachine& virtualMachine,
+ IWSLCPluginNotifier* pluginNotifier,
const std::unordered_map& sessionNetworks,
std::function&& OnDeleted,
DockerEventTracker& EventTracker,
@@ -1651,6 +1688,7 @@ std::unique_ptr WSLCContainerImpl::Create(
auto container = std::make_unique(
wslcSession,
virtualMachine,
+ pluginNotifier,
std::move(result.Id),
CleanContainerName(inspectData.Name),
std::string(containerOptions.Image),
@@ -1675,6 +1713,7 @@ std::unique_ptr WSLCContainerImpl::Open(
const common::docker_schema::ContainerInfo& dockerContainer,
WSLCSession& wslcSession,
WSLCVirtualMachine& virtualMachine,
+ IWSLCPluginNotifier* pluginNotifier,
WSLCVolumes& volumes,
std::function&& OnDeleted,
DockerEventTracker& EventTracker,
@@ -1739,6 +1778,7 @@ std::unique_ptr WSLCContainerImpl::Open(
auto container = std::make_unique(
wslcSession,
virtualMachine,
+ pluginNotifier,
std::string(dockerContainer.Id),
std::move(name),
std::string(dockerContainer.Image),
@@ -1783,19 +1823,23 @@ void WSLCContainerImpl::Inspect(LPSTR* Output) const
try
{
- // Get Docker inspect data
- auto dockerInspect = m_dockerClient.InspectContainer(m_id);
-
- // Convert to WSLC schema
- auto wslcInspect = BuildInspectContainer(dockerInspect);
-
- // Serialize WSLC schema to JSON
- std::string wslcJson = wsl::shared::ToJson(wslcInspect);
- *Output = wil::make_unique_ansistring(wslcJson.c_str()).release();
+ *Output = wil::make_unique_ansistring(InspectLockHeld().c_str()).release();
}
CATCH_AND_THROW_DOCKER_USER_ERROR("Failed to inspect container '%hs'", m_id.c_str());
}
+std::string WSLCContainerImpl::InspectLockHeld() const
+{
+ // Get Docker inspect data
+ auto dockerInspect = m_dockerClient.InspectContainer(m_id);
+
+ // Convert to WSLC schema
+ auto wslcInspect = BuildInspectContainer(dockerInspect);
+
+ // Serialize WSLC schema to JSON
+ return wsl::shared::ToJson(wslcInspect);
+}
+
void WSLCContainerImpl::Logs(WSLCLogsFlags Flags, WSLCHandle* Stdout, WSLCHandle* Stderr, ULONGLONG Since, ULONGLONG Until, ULONGLONG Tail) const
{
auto lock = m_lock.lock_shared();
diff --git a/src/windows/wslcsession/WSLCContainer.h b/src/windows/wslcsession/WSLCContainer.h
index 4ebfcd758..f1e433bdd 100644
--- a/src/windows/wslcsession/WSLCContainer.h
+++ b/src/windows/wslcsession/WSLCContainer.h
@@ -72,6 +72,7 @@ class WSLCContainerImpl
WSLCContainerImpl(
WSLCSession& wslcSession,
WSLCVirtualMachine& virtualMachine,
+ IWSLCPluginNotifier* pluginNotifier,
std::string&& Id,
std::string&& Name,
std::string&& Image,
@@ -130,6 +131,7 @@ class WSLCContainerImpl
const std::string& Name,
WSLCSession& wslcSession,
WSLCVirtualMachine& virtualMachine,
+ IWSLCPluginNotifier* pluginNotifier,
const std::unordered_map& SessionNetworks,
std::function&& OnDeleted,
DockerEventTracker& EventTracker,
@@ -140,6 +142,7 @@ class WSLCContainerImpl
const common::docker_schema::ContainerInfo& DockerContainer,
WSLCSession& wslcSession,
WSLCVirtualMachine& virtualMachine,
+ IWSLCPluginNotifier* pluginNotifier,
WSLCVolumes& Volumes,
std::function&& OnDeleted,
DockerEventTracker& EventTracker,
@@ -169,6 +172,8 @@ class WSLCContainerImpl
void MapPorts();
void UnmapPorts();
+ __requires_shared_lock_held(m_lock) std::string InspectLockHeld() const;
+
mutable wil::srwlock m_lock;
std::string m_name;
std::string m_image;
@@ -197,6 +202,7 @@ class WSLCContainerImpl
std::uint64_t m_createdAt{};
WSLCContainerState m_state = WslcContainerStateInvalid;
WSLCSession& m_wslcSession;
+ IWSLCPluginNotifier* m_pluginNotifier;
WSLCVirtualMachine& m_virtualMachine;
std::vector m_mappedPorts;
std::vector m_mountedVolumes;
diff --git a/src/windows/wslcsession/WSLCSession.cpp b/src/windows/wslcsession/WSLCSession.cpp
index bc40f2925..97b27da67 100644
--- a/src/windows/wslcsession/WSLCSession.cpp
+++ b/src/windows/wslcsession/WSLCSession.cpp
@@ -257,7 +257,7 @@ try
}
CATCH_RETURN();
-HRESULT WSLCSession::Initialize(_In_ const WSLCSessionInitSettings* Settings, _In_ IWSLCVirtualMachine* Vm)
+HRESULT WSLCSession::Initialize(_In_ const WSLCSessionInitSettings* Settings, _In_ IWSLCVirtualMachine* Vm, _In_ IWSLCPluginNotifier* PluginNotifier)
try
{
RETURN_HR_IF(E_POINTER, Settings == nullptr || Vm == nullptr);
@@ -267,6 +267,7 @@ try
m_id = Settings->SessionId;
m_displayName = Settings->DisplayName ? Settings->DisplayName : L"";
m_featureFlags = Settings->FeatureFlags;
+ m_pluginNotifier = PluginNotifier;
// Get user token for the current process
const auto tokenInfo = wil::get_token_information(GetCurrentProcessToken());
@@ -645,6 +646,20 @@ void WSLCSession::StreamImageOperation(DockerHTTPClient::HTTPRequestContext& req
}
}
+void WSLCSession::OnImageCreated(const std::string& ImageNameOrId) noexcept
+try
+{
+ LOG_IF_FAILED(m_pluginNotifier->OnImageCreated(InspectImageLockHeld(ImageNameOrId).c_str()));
+}
+CATCH_LOG()
+
+void WSLCSession::OnImageDeleted(const std::string& ImageId) noexcept
+try
+{
+ LOG_IF_FAILED(m_pluginNotifier->OnImageDeleted(ImageId.c_str()));
+}
+CATCH_LOG()
+
HRESULT WSLCSession::PullImage(LPCSTR Image, LPCSTR RegistryAuthenticationInformation, IProgressCallback* ProgressCallback)
try
{
@@ -673,6 +688,8 @@ try
auto requestContext = m_dockerClient->PullImage(repo, tagOrDigest, registryAuth);
StreamImageOperation(*requestContext, Image, "Pull", ProgressCallback);
+ OnImageCreated(Image);
+
return S_OK;
}
CATCH_RETURN();
@@ -1014,6 +1031,7 @@ try
auto requestContext = m_dockerClient->LoadImage(ContentSize);
ImportImageImpl(*requestContext, ImageHandle);
+
return S_OK;
}
CATCH_RETURN();
@@ -1039,6 +1057,8 @@ try
auto requestContext = m_dockerClient->ImportImage(repo, tagOrDigest.value(), ContentSize);
ImportImageImpl(*requestContext, ImageHandle);
+
+ OnImageCreated(ImageName);
return S_OK;
}
CATCH_RETURN();
@@ -1095,7 +1115,6 @@ void WSLCSession::ImportImageImpl(DockerHTTPClient::HTTPRequestContext& Request,
}
else if (parsed.stream.has_value())
{
- // TODO: report progress to caller.
WSL_LOG("ImageImportProgress", TraceLoggingValue(parsed.stream->c_str(), "Content"));
}
else
@@ -1420,6 +1439,15 @@ try
*Count = static_cast(deletedImages.size());
*DeletedImages = output.release();
+ // Notify plugin manager of all deleted image IDs.
+ for (const auto& image : deletedImages)
+ {
+ if (!image.Deleted.empty())
+ {
+ OnImageDeleted(image.Deleted);
+ }
+ }
+
return S_OK;
}
CATCH_RETURN();
@@ -1497,10 +1525,18 @@ try
auto lock = m_lock.lock_shared();
RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_dockerClient.has_value());
+ *Output = wil::make_unique_ansistring(InspectImageLockHeld(ImageNameOrId).c_str()).release();
+
+ return S_OK;
+}
+CATCH_RETURN();
+
+std::string WSLCSession::InspectImageLockHeld(const std::string& NameOrId)
+{
docker_schema::InspectImage dockerInspect;
try
{
- dockerInspect = m_dockerClient->InspectImage(ImageNameOrId);
+ dockerInspect = m_dockerClient->InspectImage(NameOrId);
}
catch (const DockerHTTPException& e)
{
@@ -1519,12 +1555,8 @@ try
auto wslcInspect = ConvertInspectImage(dockerInspect);
// Serialize to JSON
- std::string wslcJson = wsl::shared::ToJson(wslcInspect);
- *Output = wil::make_unique_ansistring(wslcJson.c_str()).release();
-
- return S_OK;
+ return wsl::shared::ToJson(wslcInspect);
}
-CATCH_RETURN();
HRESULT WSLCSession::Authenticate(_In_ LPCSTR ServerAddress, _In_ LPCSTR Username, _In_ LPCSTR Password, _Out_ LPSTR* IdentityToken)
try
@@ -1724,6 +1756,7 @@ try
containerName,
*this,
m_virtualMachine.value(),
+ m_pluginNotifier.get(),
m_networks,
std::bind(&WSLCSession::OnContainerDeleted, this, std::placeholders::_1),
m_eventTracker.value(),
@@ -2766,6 +2799,7 @@ void WSLCSession::RecoverExistingContainers()
dockerContainer,
*this,
m_virtualMachine.value(),
+ m_pluginNotifier.get(),
*m_volumes,
std::bind(&WSLCSession::OnContainerDeleted, this, std::placeholders::_1),
m_eventTracker.value(),
diff --git a/src/windows/wslcsession/WSLCSession.h b/src/windows/wslcsession/WSLCSession.h
index 572e721c8..c5db04808 100644
--- a/src/windows/wslcsession/WSLCSession.h
+++ b/src/windows/wslcsession/WSLCSession.h
@@ -85,7 +85,7 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLCSession
// IWSLCSession - initialization methods
IFACEMETHOD(GetProcessHandle)(_Out_ HANDLE* ProcessHandle) override;
- IFACEMETHOD(Initialize)(_In_ const WSLCSessionInitSettings* Settings, _In_ IWSLCVirtualMachine* Vm) override;
+ IFACEMETHOD(Initialize)(_In_ const WSLCSessionInitSettings* Settings, _In_ IWSLCVirtualMachine* Vm, _In_ IWSLCPluginNotifier* PluginNotifier) override;
IFACEMETHOD(GetId)(_Out_ ULONG* Id) override;
IFACEMETHOD(GetState)(_Out_ WSLCSessionState* State) override;
@@ -177,7 +177,16 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLCSession
__requires_lock_held(m_userCOMCallbacksLock) void CancelUserCOMCallbacks();
void ConfigureStorage(const WSLCSessionInitSettings& Settings, PSID UserSid);
void Ext4Format(const std::string& Device);
+ _Requires_shared_lock_held_(m_lock)
+ std::string InspectImageLockHeld(const std::string& Id);
void OnContainerDeleted(const WSLCContainerImpl* Container);
+
+ _Requires_shared_lock_held_(m_lock)
+ void OnImageCreated(const std::string& ImageNameOrId) noexcept;
+
+ _Requires_shared_lock_held_(m_lock)
+ void OnImageDeleted(const std::string& ImageId) noexcept;
+
void OnProcessLog(const gsl::span& Data, PCSTR Source);
void OnContainerdExited();
void OnDockerdExited();
@@ -222,6 +231,8 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLCSession
std::atomic m_terminating{false};
std::atomic m_terminated{false};
+ wil::com_ptr m_pluginNotifier;
+
// User-provided handles that the session is currently doing IO on.
std::mutex m_userHandlesLock;
__guarded_by(m_userHandlesLock) std::vector m_userHandles;
diff --git a/src/windows/wslcsession/WSLCSessionFactory.cpp b/src/windows/wslcsession/WSLCSessionFactory.cpp
index 9b2e287b7..e08f72090 100644
--- a/src/windows/wslcsession/WSLCSessionFactory.cpp
+++ b/src/windows/wslcsession/WSLCSessionFactory.cpp
@@ -30,7 +30,11 @@ void wslc::WSLCSessionFactory::SetDestructionCallback(std::function&& ca
}
HRESULT wslc::WSLCSessionFactory::CreateSession(
- _In_ const WSLCSessionInitSettings* Settings, _In_ IWSLCVirtualMachine* Vm, _Out_ IWSLCSession** Session, _Out_ IWSLCSessionReference** ServiceRef)
+ _In_ const WSLCSessionInitSettings* Settings,
+ _In_ IWSLCVirtualMachine* Vm,
+ _In_ IWSLCPluginNotifier* PluginNotifier,
+ _Out_ IWSLCSession** Session,
+ _Out_ IWSLCSessionReference** ServiceRef)
try
{
*Session = nullptr;
@@ -44,7 +48,7 @@ try
session->SetDestructionCallback(std::move(m_destructionCallback));
// Initialize the session with the VM.
- RETURN_IF_FAILED(session->Initialize(Settings, Vm));
+ RETURN_IF_FAILED(session->Initialize(Settings, Vm, PluginNotifier));
// Create the service session ref. It extracts metadata and a weak reference from the session.
auto serviceRef = Microsoft::WRL::Make(session.Get());
diff --git a/src/windows/wslcsession/WSLCSessionFactory.h b/src/windows/wslcsession/WSLCSessionFactory.h
index 17382c967..266517d28 100644
--- a/src/windows/wslcsession/WSLCSessionFactory.h
+++ b/src/windows/wslcsession/WSLCSessionFactory.h
@@ -44,8 +44,11 @@ class DECLSPEC_UUID("9FCD2067-9FC6-4EFA-9EB0-698169EBF7D3") WSLCSessionFactory
// IWSLCSessionFactory
IFACEMETHOD(CreateSession)
- (_In_ const WSLCSessionInitSettings* Settings, _In_ IWSLCVirtualMachine* Vm, _Out_ IWSLCSession** Session, _Out_ IWSLCSessionReference** ServiceRef)
- override;
+ (_In_ const WSLCSessionInitSettings* Settings,
+ _In_ IWSLCVirtualMachine* Vm,
+ _In_ IWSLCPluginNotifier* PluginNotifier,
+ _Out_ IWSLCSession** Session,
+ _Out_ IWSLCSessionReference** ServiceRef) override;
IFACEMETHOD(GetProcessHandle)(_Out_ HANDLE* ProcessHandle) override;
diff --git a/test/windows/Common.cpp b/test/windows/Common.cpp
index 5fb028f84..d18622dd6 100644
--- a/test/windows/Common.cpp
+++ b/test/windows/Common.cpp
@@ -2904,6 +2904,19 @@ std::filesystem::path GetTestImagePath(std::string_view imageName)
return result;
}
+void LoadTestImage(IWSLCSession& session, std::string_view imageName)
+{
+ std::filesystem::path imagePath = GetTestImagePath(imageName);
+ wil::unique_hfile imageFile{
+ CreateFileW(imagePath.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr)};
+ THROW_LAST_ERROR_IF(!imageFile);
+
+ LARGE_INTEGER fileSize{};
+ THROW_LAST_ERROR_IF(!GetFileSizeEx(imageFile.get(), &fileSize));
+
+ THROW_IF_FAILED(session.LoadImage(wsl::windows::common::wslutil::ToCOMInputHandle(imageFile.get()), nullptr, fileSize.QuadPart));
+}
+
void ExpectHttpResponse(LPCWSTR Url, std::optional expectedCode, bool retry)
{
const winrt::Windows::Web::Http::Filters::HttpBaseProtocolFilter filter;
@@ -2991,3 +3004,52 @@ void WriteSocket(SOCKET Socket, const void* data, size_t size)
data = static_cast(data) + result;
}
}
+
+void ValidateCOMErrorMessage(const std::optional& Expected, const std::source_location& Source)
+{
+ auto comError = wsl::windows::common::wslutil::GetCOMErrorInfo();
+
+ if (comError.has_value())
+ {
+ if (!Expected.has_value())
+ {
+ LogError("Unexpected COM error: '%ls'. Source: %hs", comError->Message.get(), std::format("{}", Source).c_str());
+ VERIFY_FAIL();
+ }
+
+ VERIFY_ARE_EQUAL(Expected.value(), comError->Message.get());
+ }
+ else
+ {
+ if (Expected.has_value())
+ {
+ LogError("Expected COM error: '%ls' but none was set. Source: %hs", Expected->c_str(), std::format("{}", Source).c_str());
+ VERIFY_FAIL();
+ }
+ }
+}
+
+void ValidateCOMErrorMessageContains(const std::wstring& ExpectedSubstring)
+{
+ auto comError = wsl::windows::common::wslutil::GetCOMErrorInfo();
+
+ if (comError.has_value())
+ {
+ if (!comError->Message)
+ {
+ LogError("Expected COM error containing: '%ls', but COM error message was null", ExpectedSubstring.c_str());
+ VERIFY_FAIL();
+ }
+
+ if (wcsstr(comError->Message.get(), ExpectedSubstring.c_str()) == nullptr)
+ {
+ LogError("Expected COM error containing: '%ls', but got: '%ls'", ExpectedSubstring.c_str(), comError->Message.get());
+ VERIFY_FAIL();
+ }
+ }
+ else
+ {
+ LogError("Expected COM error containing: '%ls' but none was set", ExpectedSubstring.c_str());
+ VERIFY_FAIL();
+ }
+}
diff --git a/test/windows/Common.h b/test/windows/Common.h
index 43cf3a3b7..ad2706ebf 100644
--- a/test/windows/Common.h
+++ b/test/windows/Common.h
@@ -617,6 +617,8 @@ void VerifyPatternMatch(const std::string& Content, const std::string& Pattern);
std::filesystem::path GetTestImagePath(std::string_view imageName);
+void LoadTestImage(IWSLCSession& session, std::string_view imageName);
+
void ExpectHttpResponse(LPCWSTR Url, std::optional expectedCode, bool retry = false);
template
@@ -678,3 +680,7 @@ void VerifyAreEqualUnordered(const std::vector& expected, const std::vector& Expected, const std::source_location& Source = std::source_location::current());
+
+void ValidateCOMErrorMessageContains(const std::wstring& ExpectedSubstring);
diff --git a/test/windows/PluginTests.cpp b/test/windows/PluginTests.cpp
index 8b42f941e..3f2be8665 100644
--- a/test/windows/PluginTests.cpp
+++ b/test/windows/PluginTests.cpp
@@ -16,10 +16,16 @@ Module Name:
#include "Common.h"
#include "registry.hpp"
#include "PluginTests.h"
+#include "wslc.h"
+#include "WSLCContainerLauncher.h"
+#include "WSLCProcessLauncher.h"
+#include "wslc/e2e/WSLCE2EHelpers.h"
using namespace wsl::windows::common::registry;
+using WSLCE2ETests::StartLocalRegistry;
extern std::wstring g_testDistroPath;
+extern std::wstring g_testDataPath;
class PluginTests
{
@@ -592,6 +598,186 @@ class PluginTests
StartWsl(0);
ValidateLogFile(ExpectedOutput);
}
+ static wil::com_ptr OpenWslcSessionManager()
+ {
+ wil::com_ptr sessionManager;
+ VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLCSessionManager), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&sessionManager)));
+ wsl::windows::common::security::ConfigureForCOMImpersonation(sessionManager.get());
+ return sessionManager;
+ }
+
+ static wil::com_ptr CreateWslcSession(LPCWSTR Name, WSLCNetworkingMode NetworkingMode = WSLCNetworkingModeNone)
+ {
+ WSLCSessionSettings settings{};
+ settings.DisplayName = Name;
+ settings.CpuCount = 4;
+ settings.MemoryMb = 4096;
+ settings.BootTimeoutMs = 30 * 1000;
+ settings.NetworkingMode = NetworkingMode;
+
+ auto manager = OpenWslcSessionManager();
+ wil::com_ptr session;
+ VERIFY_SUCCEEDED(manager->CreateSession(&settings, WSLCSessionFlagsNone, &session));
+ wsl::windows::common::security::ConfigureForCOMImpersonation(session.get());
+
+ WSLCSessionState state{};
+ VERIFY_SUCCEEDED(session->GetState(&state));
+ VERIFY_ARE_EQUAL(state, WSLCSessionStateRunning);
+
+ return session;
+ }
+
+ WSL2_TEST_METHOD(WslcSuccess)
+ {
+ ConfigurePlugin(PluginTestType::WslcSuccess);
+
+ {
+ auto session = CreateWslcSession(L"plugin-wslc-test");
+
+ LoadTestImage(*session, "debian:latest");
+
+ // Create a container that will have a stuck process so it's still in a running state when the callback is made.
+ wsl::windows::common::WSLCContainerLauncher launcher(
+ "debian:latest", "wslc-plugin-container", {"/bin/sh", "-c", "sleep 120"});
+
+ auto container = launcher.Launch(*session, WSLCContainerStartFlagsAttach);
+ VERIFY_SUCCEEDED(container.Get().Stop(WSLCSignalSIGKILL, 0));
+
+ // Delete the image so we get an ImageDeleted notification before the session goes away.
+ WSLCDeleteImageOptions options{.Image = "debian:latest", .Flags = WSLCDeleteImageFlagsForce};
+ wil::unique_cotaskmem_array_ptr deletedImages;
+ VERIFY_SUCCEEDED(session->DeleteImage(&options, deletedImages.addressof(), deletedImages.size_address()));
+ }
+
+ const auto ExpectedOutput = std::format(
+ LR"(Plugin loaded. TestMode=18
+ WSLC Session created, name=plugin-wslc-test, id=*, pid=*, token=set, sid=set
+ Command: 'echo -n stdout-ok && echo -n stderr-ok >&2', status=0, stdout: stdout-ok, stderr: stderr-ok
+ Command: 'cat', status=0, stdout: stdin-ok, stderr:
+ Command: 'exit 12', status=12, stdout: , stderr:
+ Command: 'echo -n $ENV', status=0, stdout: env-ok, stderr:
+ WSLCCreateProcess(does-not-exist): {:x}, errno=2
+ WSLCProcessGetFd(999): {}
+ WSLCProcessGetExitCode(): {}
+ WSLC RW folder mounted at: /mnt/wsl-plugin/plugin-rw-test
+ Command: 'cat /mnt/wsl-plugin/plugin-rw-test/plugin-test.txt', status=0, stdout: Windows-content, stderr:
+ WSLC RO folder mounted at: /mnt/wsl-plugin/plugin-ro-test
+ Command: 'echo fail > /mnt/wsl-plugin/plugin-ro-test/should-not-exist.txt', status=1, stdout: , stderr: *
+ WSLCMountFolder(nonexistent): {}
+ WSLCMountFolder(../escape): {}
+ WSLCMountFolder(): {}
+ Test completed
+ WSLC Container started, session=*, id=*, name=wslc-plugin-container, image=debian:latest, state=*
+ WSLC Container stopping, session=*, id=*
+ WSLC Image deleted, session=*, id=*
+ WSLC Session stopping, name=plugin-wslc-test, id=*)",
+ static_cast(E_FAIL),
+ E_INVALIDARG,
+ HRESULT_FROM_WIN32(ERROR_INVALID_STATE),
+ HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND),
+ E_INVALIDARG,
+ E_INVALIDARG);
+
+ ValidateLogFile(ExpectedOutput.c_str());
+ }
+
+ WSL2_TEST_METHOD(WslcPullImageNotification)
+ {
+ ConfigurePlugin(PluginTestType::WslcImagePull);
+
+ {
+ auto session = CreateWslcSession(L"plugin-wslc-pull-test", WSLCNetworkingModeVirtioProxy);
+
+ // Load the registry and debian images.
+ LoadTestImage(*session, "debian:latest");
+
+ // Start a local registry container.
+ auto [registryContainer, registryAddress] = StartLocalRegistry(*session);
+
+ // Tag debian:latest for the local registry and push it.
+ auto registryImage = std::format("{}/debian:latest", registryAddress);
+ auto registryRepo = std::format("{}/debian", registryAddress);
+ WSLCTagImageOptions tagOptions{};
+ tagOptions.Image = "debian:latest";
+ tagOptions.Repo = registryRepo.c_str();
+ tagOptions.Tag = "latest";
+ VERIFY_SUCCEEDED(session->TagImage(&tagOptions));
+
+ auto emptyAuth = wsl::windows::common::wslutil::BuildRegistryAuthHeader("", "");
+ VERIFY_SUCCEEDED(session->PushImage(registryImage.c_str(), emptyAuth.c_str(), nullptr));
+
+ // Delete the local tagged copy so PullImage actually downloads it.
+ WSLCDeleteImageOptions deleteOpts{.Image = registryImage.c_str(), .Flags = WSLCDeleteImageFlagsNone};
+ wil::unique_cotaskmem_array_ptr deletedImages;
+ VERIFY_SUCCEEDED(session->DeleteImage(&deleteOpts, deletedImages.addressof(), deletedImages.size_address()));
+
+ // Pull the image back — this should trigger the ImageCreated plugin callback.
+ VERIFY_SUCCEEDED(session->PullImage(registryImage.c_str(), nullptr, nullptr));
+ }
+
+ constexpr auto ExpectedOutput =
+ LR"(Plugin loaded. TestMode=21
+ WSLC Session created, name=plugin-wslc-pull-test, id=*, pid=*, token=set, sid=set
+ WSLC Container started, session=*, id=*, name=*, image=wslc-registry:latest, state=running
+ WSLC Image created, session=*, id=sha256:*, name=127.0.0.1:5000/debian:latest
+ WSLC Session stopping, name=plugin-wslc-pull-test, id=*)";
+
+ ValidateLogFile(ExpectedOutput);
+ }
+
+ WSL2_TEST_METHOD(WslcSessionRejected)
+ {
+ ConfigurePlugin(PluginTestType::WslcSessionRejected);
+
+ WSLCSessionSettings settings{};
+ settings.DisplayName = L"plugin-wslc-rejected";
+ settings.CpuCount = 4;
+ settings.MemoryMb = 2048;
+ settings.BootTimeoutMs = 30 * 1000;
+ settings.MaximumStorageSizeMb = 1024 * 20;
+ settings.NetworkingMode = WSLCNetworkingModeNone;
+
+ auto manager = OpenWslcSessionManager();
+ wil::com_ptr session;
+ const auto hr = manager->CreateSession(&settings, WSLCSessionFlagsNone, &session);
+ ValidateCOMErrorMessageContains(L"A fatal error was returned by plugin 'TestPlugin'");
+ VERIFY_ARE_EQUAL(hr, HRESULT_FROM_WIN32(ERROR_ACCESS_DENIED));
+
+ constexpr auto ExpectedOutput =
+ LR"(Plugin loaded. TestMode=19
+ WSLC Session created, name=plugin-wslc-rejected, id=*, pid=*, token=set, sid=set
+ OnWslcSessionCreated: ERROR_ACCESS_DENIED)";
+
+ ValidateLogFile(ExpectedOutput);
+ }
+
+ WSL2_TEST_METHOD(WslcContainerRejected)
+ {
+ ConfigurePlugin(PluginTestType::WslcContainerRejected);
+
+ {
+ auto session = CreateWslcSession(L"plugin-wslc-container-rejected");
+
+ LoadTestImage(*session, "debian:latest");
+
+ wsl::windows::common::WSLCContainerLauncher launcher(
+ "debian:latest", "wslc-plugin-rejected-container", {"/bin/sh", "-c", "echo nope"});
+
+ auto [hr, container] = launcher.LaunchNoThrow(*session, WSLCContainerStartFlagsAttach);
+ ValidateCOMErrorMessageContains(L"A fatal error was returned by plugin 'TestPlugin'");
+ VERIFY_ARE_EQUAL(hr, HRESULT_FROM_WIN32(ERROR_ACCESS_DENIED));
+ }
+
+ constexpr auto ExpectedOutput =
+ LR"(Plugin loaded. TestMode=20
+ WSLC Session created, name=plugin-wslc-container-rejected, id=*, pid=*, token=set, sid=set
+ WSLC Container started, session=*, id=*, name=*, image=debian:latest, state=*
+ OnWslcContainerStarted: ERROR_ACCESS_DENIED
+ WSLC Session stopping, name=plugin-wslc-container-rejected, id=*)";
+
+ ValidateLogFile(ExpectedOutput);
+ }
+
// This test must run last so it doesn't break test cases that depends on plugin signature.
WSL2_TEST_METHOD(InvalidPluginSignature)
{
diff --git a/test/windows/PluginTests.h b/test/windows/PluginTests.h
index d29d84e39..c9a779031 100644
--- a/test/windows/PluginTests.h
+++ b/test/windows/PluginTests.h
@@ -37,7 +37,11 @@ enum class PluginTestType
InitPidIsDifferent,
FailToRegisterUnregisterDistro,
RunDistroCommand,
- GetUsername
+ GetUsername,
+ WslcSuccess,
+ WslcSessionRejected,
+ WslcContainerRejected,
+ WslcImagePull
};
constexpr auto c_testType = L"TestType";
@@ -46,4 +50,4 @@ constexpr auto c_logFile = L"LogFile";
inline wil::unique_hkey OpenTestRegistryKey(REGSAM AccessMask)
{
return wsl::windows::common::registry::CreateKey(HKEY_LOCAL_MACHINE, c_configKey, AccessMask, nullptr, REG_OPTION_VOLATILE);
-}
\ No newline at end of file
+}
diff --git a/test/windows/WSLCTests.cpp b/test/windows/WSLCTests.cpp
index cf7511e57..029a20a80 100644
--- a/test/windows/WSLCTests.cpp
+++ b/test/windows/WSLCTests.cpp
@@ -20,6 +20,7 @@ Module Name:
#include "WslCoreFilesystem.h"
#include "hcs.hpp"
#include "ContainerNameGenerator.h"
+#include "wslc/e2e/WSLCE2EHelpers.h"
#include
using namespace std::literals::chrono_literals;
@@ -31,6 +32,7 @@ using wsl::windows::common::WSLCProcessLauncher;
using wsl::windows::common::io::OverlappedIOHandle;
using wsl::windows::common::io::WriteHandle;
using namespace wsl::windows::common::wslutil;
+using WSLCE2ETests::StartLocalRegistry;
extern std::wstring g_testDataPath;
extern bool g_fastTestRun;
@@ -45,20 +47,6 @@ class WSLCTests
wil::com_ptr m_defaultSession;
static inline auto c_testSessionName = L"wslc-test";
- void LoadTestImage(std::string_view imageName, IWSLCSession* session = nullptr)
- {
- std::filesystem::path imagePath = GetTestImagePath(imageName);
- wil::unique_hfile imageFile{
- CreateFileW(imagePath.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr)};
- THROW_LAST_ERROR_IF(!imageFile);
-
- LARGE_INTEGER fileSize{};
- THROW_LAST_ERROR_IF(!GetFileSizeEx(imageFile.get(), &fileSize));
-
- THROW_IF_FAILED(
- (session ? session : m_defaultSession.get())->LoadImage(ToCOMInputHandle(imageFile.get()), nullptr, fileSize.QuadPart));
- }
-
TEST_CLASS_SETUP(TestClassSetup)
{
THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &m_wsadata));
@@ -78,27 +66,27 @@ class WSLCTests
if (!hasImage("debian:latest"))
{
- LoadTestImage("debian:latest");
+ LoadTestImage(*m_defaultSession, "debian:latest");
}
if (!hasImage("python:3.12-alpine"))
{
- LoadTestImage("python:3.12-alpine");
+ LoadTestImage(*m_defaultSession, "python:3.12-alpine");
}
if (!hasImage("hello-world:latest"))
{
- LoadTestImage("hello-world:latest");
+ LoadTestImage(*m_defaultSession, "hello-world:latest");
}
if (!hasImage("alpine:latest"))
{
- LoadTestImage("alpine:latest");
+ LoadTestImage(*m_defaultSession, "alpine:latest");
}
if (!hasImage("wslc-registry:latest"))
{
- LoadTestImage("wslc-registry:latest");
+ LoadTestImage(*m_defaultSession, "wslc-registry:latest");
}
PruneResult result;
@@ -207,28 +195,6 @@ class WSLCTests
return result;
}
- std::pair StartLocalRegistry(const std::string& username = {}, const std::string& password = {}, USHORT port = 5000)
- {
- std::vector env = {std::format("REGISTRY_HTTP_ADDR=0.0.0.0:{}", port)};
- if (!username.empty())
- {
- env.push_back(std::format("USERNAME={}", username));
- env.push_back(std::format("PASSWORD={}", password));
- }
-
- WSLCContainerLauncher launcher("wslc-registry:latest", {}, {}, env);
- launcher.SetEntrypoint({"/entrypoint.sh"});
- launcher.AddPort(port, port, AF_INET);
-
- auto container = launcher.Launch(*m_defaultSession, WSLCContainerStartFlagsNone);
-
- auto registryAddress = std::format("127.0.0.1:{}", port);
- auto registryUrl = std::format(L"http://{}", registryAddress);
- ExpectHttpResponse(registryUrl.c_str(), 200, true);
-
- return {std::move(container), std::move(registryAddress)};
- }
-
std::string PushImageToRegistry(const std::string& imageName, const std::string& registryAddress, const std::string& registryAuth)
{
auto [repo, tag] = ParseImage(imageName);
@@ -394,7 +360,7 @@ class WSLCTests
// Act: list sessions
{
- wil::unique_cotaskmem_array_ptr sessions;
+ wil::unique_cotaskmem_array_ptr sessions;
VERIFY_SUCCEEDED(sessionManager->ListSessions(&sessions, sessions.size_address()));
// Assert
@@ -409,7 +375,7 @@ class WSLCTests
{
auto session2 = CreateSession(GetDefaultSessionSettings(L"wslc-test-list-2"));
- wil::unique_cotaskmem_array_ptr sessions;
+ wil::unique_cotaskmem_array_ptr sessions;
VERIFY_SUCCEEDED(sessionManager->ListSessions(&sessions, sessions.size_address()));
VERIFY_ARE_EQUAL(sessions.size(), 2);
@@ -455,7 +421,7 @@ class WSLCTests
// Reject DisplayName at exact boundary (no room for null terminator).
{
- std::wstring boundaryName(std::size(WSLCSessionInformation{}.DisplayName), L'x');
+ std::wstring boundaryName(std::size(WSLCSessionListEntry{}.DisplayName), L'x');
auto settings = GetDefaultSessionSettings(boundaryName.c_str());
wil::com_ptr session;
VERIFY_ARE_EQUAL(sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session), WSLC_E_INVALID_SESSION_NAME);
@@ -463,7 +429,7 @@ class WSLCTests
// Reject too long DisplayName.
{
- std::wstring longName(std::size(WSLCSessionInformation{}.DisplayName) + 1, L'x');
+ std::wstring longName(std::size(WSLCSessionListEntry{}.DisplayName) + 1, L'x');
auto settings = GetDefaultSessionSettings(longName.c_str());
wil::com_ptr session;
VERIFY_ARE_EQUAL(sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session), WSLC_E_INVALID_SESSION_NAME);
@@ -587,7 +553,7 @@ class WSLCTests
{
{
// Start a local registry without auth and push hello-world:latest to it.
- auto [registryContainer, registryAddress] = StartLocalRegistry();
+ auto [registryContainer, registryAddress] = StartLocalRegistry(*m_defaultSession);
auto image = PushImageToRegistry("hello-world:latest", registryAddress, BuildRegistryAuthHeader("", ""));
ExpectImagePresent(*m_defaultSession, image.c_str(), false);
@@ -630,7 +596,7 @@ class WSLCTests
WSLC_TEST_METHOD(PullImageAdvanced)
{
// Start a local registry without auth to avoid Docker Hub rate limits.
- auto [registryContainer, registryAddress] = StartLocalRegistry();
+ auto [registryContainer, registryAddress] = StartLocalRegistry(*m_defaultSession);
auto auth = BuildRegistryAuthHeader("", "");
auto validatePull = [&](const std::string& sourceImage) {
@@ -740,7 +706,7 @@ class WSLCTests
constexpr auto c_username = "wslctest";
constexpr auto c_password = "password";
- auto [registryContainer, registryAddress] = StartLocalRegistry(c_username, c_password);
+ auto [registryContainer, registryAddress] = StartLocalRegistry(*m_defaultSession, c_username, c_password);
wil::unique_cotaskmem_ansistring token;
VERIFY_ARE_EQUAL(m_defaultSession->Authenticate(registryAddress.c_str(), c_username, "wrong-password", &token), E_FAIL);
@@ -1017,7 +983,7 @@ class WSLCTests
LogInfo("Test: Dangling filter");
{
// Setup a dangling image
- LoadTestImage("alpine:latest");
+ LoadTestImage(*m_defaultSession, "alpine:latest");
WSLCTagImageOptions tagOptions{};
tagOptions.Image = "debian:latest";
tagOptions.Repo = "alpine";
@@ -1301,7 +1267,7 @@ class WSLCTests
WSLC_TEST_METHOD(DeleteImage)
{
// Prepare alpine image to delete.
- LoadTestImage("alpine:latest");
+ LoadTestImage(*m_defaultSession, "alpine:latest");
// Verify that the image is in the list of images.
ExpectImagePresent(*m_defaultSession, "alpine:latest");
@@ -1338,55 +1304,6 @@ class WSLCTests
}
}
- void ValidateCOMErrorMessage(const std::optional& Expected, const std::source_location& Source = std::source_location::current())
- {
- auto comError = wsl::windows::common::wslutil::GetCOMErrorInfo();
-
- if (comError.has_value())
- {
- if (!Expected.has_value())
- {
- LogError("Unexpected COM error: '%ls'. Source: %hs", comError->Message.get(), std::format("{}", Source).c_str());
- VERIFY_FAIL();
- }
-
- VERIFY_ARE_EQUAL(Expected.value(), comError->Message.get());
- }
- else
- {
- if (Expected.has_value())
- {
- LogError("Expected COM error: '%ls' but none was set. Source: %hs", Expected->c_str(), std::format("{}", Source).c_str());
- VERIFY_FAIL();
- }
- }
- }
-
- void ValidateCOMErrorMessageContains(const std::wstring& ExpectedSubstring)
- {
- auto comError = wsl::windows::common::wslutil::GetCOMErrorInfo();
-
- if (comError.has_value())
- {
- if (!comError->Message)
- {
- LogError("Expected COM error containing: '%ls', but COM error message was null", ExpectedSubstring.c_str());
- VERIFY_FAIL();
- }
-
- if (wcsstr(comError->Message.get(), ExpectedSubstring.c_str()) == nullptr)
- {
- LogError("Expected COM error containing: '%ls', but got: '%ls'", ExpectedSubstring.c_str(), comError->Message.get());
- VERIFY_FAIL();
- }
- }
- else
- {
- LogError("Expected COM error containing: '%ls' but none was set", ExpectedSubstring.c_str());
- VERIFY_FAIL();
- }
- }
-
class CapturingProgressCallback
: public Microsoft::WRL::RuntimeClass, IProgressCallback>
{
@@ -7701,7 +7618,7 @@ class WSLCTests
auto manager = OpenSessionManager();
auto expectSessions = [&](const std::vector& expectedSessions) {
- wil::unique_cotaskmem_array_ptr sessions;
+ wil::unique_cotaskmem_array_ptr sessions;
VERIFY_SUCCEEDED(manager->ListSessions(&sessions, sessions.size_address()));
std::set displayNames;
@@ -9104,12 +9021,12 @@ class WSLCTests
// Helper to create a dangling image using only test-local tags:
// Load alpine and hello-world under unique tags, then overwrite one with the other.
auto createDanglingImage = [this]() {
- LoadTestImage("alpine:latest");
+ LoadTestImage(*m_defaultSession, "alpine:latest");
WSLCTagImageOptions tagA{.Image = "alpine:latest", .Repo = "prune-test-a", .Tag = "v1"};
VERIFY_SUCCEEDED(m_defaultSession->TagImage(&tagA));
DeleteImage("alpine:latest", WSLCDeleteImageFlagsNone);
- LoadTestImage("hello-world:latest");
+ LoadTestImage(*m_defaultSession, "hello-world:latest");
WSLCTagImageOptions tagB{.Image = "hello-world:latest", .Repo = "prune-test-b", .Tag = "v1"};
VERIFY_SUCCEEDED(m_defaultSession->TagImage(&tagB));
DeleteImage("hello-world:latest", WSLCDeleteImageFlagsNone);
@@ -9182,7 +9099,7 @@ class WSLCTests
// Validate null Options uses defaults (dangling-only prune).
{
- LoadTestImage("alpine:latest");
+ LoadTestImage(*m_defaultSession, "alpine:latest");
WSLCTagImageOptions renameOptions{.Image = "alpine:latest", .Repo = "prune-test-a", .Tag = "v1"};
VERIFY_SUCCEEDED(m_defaultSession->TagImage(&renameOptions));
DeleteImage("alpine:latest", WSLCDeleteImageFlagsNone);
@@ -9319,7 +9236,7 @@ class WSLCTests
auto revert = wil::impersonate_token(nonElevatedToken.get());
nonElevatedSession = CreateSession(GetDefaultSessionSettings(L"non-elevated-session"), WSLCSessionFlagsNone);
- LoadTestImage("debian:latest", nonElevatedSession.get());
+ LoadTestImage(*nonElevatedSession, "debian:latest");
WSLCContainerLauncher launcher("debian:latest", "test-non-elevated-handles-1", {"echo", "OK"});
auto container = launcher.Launch(*nonElevatedSession);
diff --git a/test/windows/testplugin/Plugin.cpp b/test/windows/testplugin/Plugin.cpp
index b7cdc3815..af5fc5f9c 100644
--- a/test/windows/testplugin/Plugin.cpp
+++ b/test/windows/testplugin/Plugin.cpp
@@ -14,10 +14,14 @@ Module Name:
#include "precomp.h"
#include "WslPluginApi.h"
+#include "wslc_schema.h"
#include "PluginTests.h"
using namespace wsl::windows::common::registry;
+using namespace wsl::windows::common::relay;
+using namespace wsl::shared::string;
+using namespace std::chrono_literals;
std::ofstream g_logfile;
std::optional g_distroGuid;
@@ -338,6 +342,215 @@ HRESULT OnDistributionUnregistered(const WSLSessionInformation* Session, const W
return S_OK;
}
+HRESULT OnWslcSessionCreated(const WSLCSessionInformation* Session)
+try
+{
+ g_logfile << "WSLC Session created, name=" << wsl::shared::string::WideToMultiByte(Session->DisplayName) << ", id=" << Session->SessionId
+ << ", pid=" << Session->ApplicationPid << ", token=" << (Session->UserToken != nullptr ? "set" : "null")
+ << ", sid=" << (Session->UserSid != nullptr ? "set" : "null") << std::endl;
+
+ if (g_testType == PluginTestType::WslcSessionRejected)
+ {
+ g_logfile << "OnWslcSessionCreated: ERROR_ACCESS_DENIED" << std::endl;
+ return HRESULT_FROM_WIN32(ERROR_ACCESS_DENIED);
+ }
+
+ if (g_testType == PluginTestType::WslcSuccess)
+ {
+ // Helper: run a command in the root namespace and return (status, stdout, stderr).
+ auto runCommand = [&](const char* cmd,
+ const std::optional& input = {},
+ std::vector env = {}) -> std::tuple {
+ std::vector arguments = {"/bin/sh", "-c", cmd, nullptr};
+ WSLCProcessHandle process = nullptr;
+ THROW_IF_FAILED(g_api->WSLCCreateProcess(
+ Session->SessionId, arguments[0], arguments.data(), env.empty() ? nullptr : env.data(), &process, nullptr));
+ auto releaseProcess = wil::scope_exit([&]() { g_api->WSLCReleaseProcess(process); });
+
+ wil::unique_handle stdinHandle;
+ wil::unique_handle stdoutHandle;
+ wil::unique_handle stderrHandle;
+ wil::unique_handle exitEvent;
+ THROW_IF_FAILED(g_api->WSLCProcessGetFd(process, WSLCProcessFdStdin, &stdinHandle));
+ THROW_IF_FAILED(g_api->WSLCProcessGetFd(process, WSLCProcessFdStdout, &stdoutHandle));
+ THROW_IF_FAILED(g_api->WSLCProcessGetFd(process, WSLCProcessFdStderr, &stderrHandle));
+ THROW_IF_FAILED(g_api->WSLCProcessGetExitEvent(process, &exitEvent));
+
+ std::string out;
+ std::string err;
+
+ MultiHandleWait io;
+ io.AddHandle(std::make_unique(
+ std::move(stdoutHandle), [&out](const auto& span) { out.append(span.begin(), span.end()); }));
+
+ io.AddHandle(std::make_unique(
+ std::move(stderrHandle), [&err](const auto& span) { err.append(span.begin(), span.end()); }));
+
+ io.AddHandle(std::make_unique(std::move(exitEvent)));
+
+ if (input.has_value())
+ {
+ io.AddHandle(std::make_unique(std::move(stdinHandle), std::vector(input->begin(), input->end())));
+ }
+ else
+ {
+ stdinHandle.reset();
+ }
+
+ io.Run(60000ms);
+
+ int status = 0;
+ THROW_IF_FAILED(g_api->WSLCProcessGetExitCode(process, &status));
+ g_logfile << "Command: '" << cmd << "', status=" << status << ", stdout: " << out << ", stderr: " << err << std::endl;
+
+ return {status, out, err};
+ };
+
+ // Test process creation (output & exit code validated by the test code).
+ {
+ runCommand("echo -n stdout-ok && echo -n stderr-ok >&2");
+ runCommand("cat", "stdin-ok");
+ runCommand("exit 12");
+ runCommand("echo -n $ENV", {}, {"ENV=env-ok", nullptr});
+ }
+
+ // Validate that trying to execute a non-existent file fails with the expected error code.
+ {
+ WSLCProcessHandle process = nullptr;
+ int errnoValue = 0;
+ std::vector args = {"does-not-exist", nullptr};
+
+ auto hr = g_api->WSLCCreateProcess(Session->SessionId, args[0], args.data(), nullptr, &process, &errnoValue);
+ g_logfile << "WSLCCreateProcess(does-not-exist): " << std::hex << hr << ", errno=" << std::dec << errnoValue << std::endl;
+ }
+
+ // Validate various error paths
+ {
+ std::vector args = {"/bin/sh", "-c", "sleep 9999", nullptr};
+ WSLCProcessHandle process = nullptr;
+ THROW_IF_FAILED(g_api->WSLCCreateProcess(Session->SessionId, args[0], args.data(), nullptr, &process, nullptr));
+ auto releaseProcess = wil::scope_exit([&]() { g_api->WSLCReleaseProcess(process); });
+
+ // Validate that getting an fd that doesn't exist fails with the expected error code.
+ HANDLE dummy = nullptr;
+ g_logfile << "WSLCProcessGetFd(999): " << g_api->WSLCProcessGetFd(process, static_cast(999), &dummy) << std::endl;
+ int exitCode = -1;
+
+ g_logfile << "WSLCProcessGetExitCode(): " << g_api->WSLCProcessGetExitCode(process, &exitCode) << std::endl;
+ }
+
+ const auto testFolder = L"C:\\";
+ constexpr auto testFileName = L"plugin-test.txt";
+
+ // Validate rw mounts.
+ {
+ auto rwCleanup = wil::scope_exit_log(
+ WI_DIAGNOSTICS_INFO, [&]() { std::filesystem::remove(std::wstring(testFolder) + testFileName); });
+
+ {
+ std::ofstream file(std::wstring(testFolder) + testFileName);
+ file << "Windows-content";
+ }
+
+ // Mount read-write and verify the file can be read from Linux.
+ char rwMountpoint[WSLC_MOUNTPOINT_LENGTH] = {};
+ THROW_IF_FAILED(g_api->WSLCMountFolder(Session->SessionId, testFolder, false, L"plugin-rw-test", rwMountpoint));
+
+ g_logfile << "WSLC RW folder mounted at: " << rwMountpoint << std::endl;
+
+ auto readCmd = std::format("cat {}/{}", rwMountpoint, testFileName);
+ runCommand(readCmd.c_str());
+
+ THROW_IF_FAILED(g_api->WSLCUnmountFolder(Session->SessionId, rwMountpoint));
+ }
+
+ // Validate ro mounts.
+ {
+ char roMountpoint[WSLC_MOUNTPOINT_LENGTH] = {};
+ THROW_IF_FAILED(g_api->WSLCMountFolder(Session->SessionId, L"C:\\", TRUE, L"plugin-ro-test", roMountpoint));
+
+ g_logfile << "WSLC RO folder mounted at: " << roMountpoint << std::endl;
+
+ // Attempt to write from Linux — should fail on a read-only mount.
+ auto writeCmd = std::format("echo fail > {}/should-not-exist.txt", roMountpoint);
+ runCommand(writeCmd.c_str());
+
+ THROW_IF_FAILED(g_api->WSLCUnmountFolder(Session->SessionId, roMountpoint));
+ }
+
+ // Validate that trying to mount a folder that doesn't exist fails with the expected error code.
+ {
+ char mountpoint[WSLC_MOUNTPOINT_LENGTH] = {};
+ g_logfile << "WSLCMountFolder(nonexistent): "
+ << g_api->WSLCMountFolder(Session->SessionId, L"C:\\nonexistent", TRUE, L"plugin-ro-test", mountpoint) << std::endl;
+ }
+
+ // Validate that trying to escape the /mnt folder fails.
+ {
+ char mountpoint[WSLC_MOUNTPOINT_LENGTH] = {};
+ g_logfile << "WSLCMountFolder(../escape): " << g_api->WSLCMountFolder(Session->SessionId, L"C:\\", TRUE, L"../escape", mountpoint)
+ << std::endl;
+ }
+
+ // Validate that empty names are rejected.
+ {
+ char mountpoint[WSLC_MOUNTPOINT_LENGTH] = {};
+ g_logfile << "WSLCMountFolder(): " << g_api->WSLCMountFolder(Session->SessionId, L"C:\\", TRUE, L"", mountpoint) << std::endl;
+ }
+
+ g_logfile << "Test completed" << std::endl;
+ }
+
+ return S_OK;
+}
+CATCH_RETURN();
+
+HRESULT OnWslcSessionStopping(const WSLCSessionInformation* Session)
+{
+ g_logfile << "WSLC Session stopping, name=" << wsl::shared::string::WideToMultiByte(Session->DisplayName)
+ << ", id=" << Session->SessionId << std::endl;
+
+ return S_OK;
+}
+
+HRESULT OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson)
+try
+{
+ auto container = wsl::shared::FromJson(InspectJson);
+
+ g_logfile << "WSLC Container started, session=" << Session->SessionId << ", id=" << container.Id
+ << ", name=" << container.Name << ", image=" << container.Image << ", state=" << container.State.Status << std::endl;
+
+ if (g_testType == PluginTestType::WslcContainerRejected)
+ {
+ g_logfile << "OnWslcContainerStarted: ERROR_ACCESS_DENIED" << std::endl;
+ return HRESULT_FROM_WIN32(ERROR_ACCESS_DENIED);
+ }
+
+ return S_OK;
+}
+CATCH_RETURN();
+
+HRESULT OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId)
+{
+ g_logfile << "WSLC Container stopping, session=" << Session->SessionId << ", id=" << ContainerId << std::endl;
+ return S_OK;
+}
+
+HRESULT OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson)
+{
+ auto image = wsl::shared::FromJson(InspectJson);
+ auto name = (image.RepoTags.has_value() && !image.RepoTags->empty()) ? image.RepoTags->front() : "";
+ g_logfile << "WSLC Image created, session=" << Session->SessionId << ", id=" << image.Id << ", name=" << name << std::endl;
+ return S_OK;
+}
+
+HRESULT OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId)
+{
+ g_logfile << "WSLC Image deleted, session=" << Session->SessionId << ", id=" << ImageId << std::endl;
+ return S_OK;
+}
+
EXTERN_C __declspec(dllexport) HRESULT WSLPLUGINAPI_ENTRYPOINTV1(const WSLPluginAPIV1* Api, WSLPluginHooksV1* Hooks)
{
try
@@ -349,7 +562,7 @@ EXTERN_C __declspec(dllexport) HRESULT WSLPLUGINAPI_ENTRYPOINTV1(const WSLPlugin
THROW_HR_IF(E_UNEXPECTED, !g_logfile);
g_testType = static_cast(ReadDword(key.get(), nullptr, c_testType, static_cast(PluginTestType::Invalid)));
- THROW_HR_IF(E_INVALIDARG, static_cast(g_testType) <= 0 || static_cast(g_testType) > static_cast(PluginTestType::GetUsername));
+ THROW_HR_IF(E_INVALIDARG, static_cast(g_testType) <= 0 || static_cast(g_testType) > static_cast(PluginTestType::WslcImagePull));
g_logfile << "Plugin loaded. TestMode=" << static_cast(g_testType) << std::endl;
g_api = Api;
@@ -359,6 +572,12 @@ EXTERN_C __declspec(dllexport) HRESULT WSLPLUGINAPI_ENTRYPOINTV1(const WSLPlugin
Hooks->OnDistributionStopping = &OnDistroStopping;
Hooks->OnDistributionRegistered = &OnDistributionRegistered;
Hooks->OnDistributionUnregistered = &OnDistributionUnregistered;
+ Hooks->OnSessionCreated = &OnWslcSessionCreated;
+ Hooks->OnSessionStopping = &OnWslcSessionStopping;
+ Hooks->ContainerStarted = &OnWslcContainerStarted;
+ Hooks->ContainerStopping = &OnWslcContainerStopping;
+ Hooks->ImageCreated = &OnWslcImageCreated;
+ Hooks->ImageDeleted = &OnWslcImageDeleted;
if (g_testType == PluginTestType::FailToLoad)
{
@@ -383,4 +602,4 @@ EXTERN_C __declspec(dllexport) HRESULT WSLPLUGINAPI_ENTRYPOINTV1(const WSLPlugin
return error;
}
return S_OK;
-}
\ No newline at end of file
+}
diff --git a/test/windows/wslc/e2e/WSLCE2EHelpers.cpp b/test/windows/wslc/e2e/WSLCE2EHelpers.cpp
index dfb844cac..efa53fb4d 100644
--- a/test/windows/wslc/e2e/WSLCE2EHelpers.cpp
+++ b/test/windows/wslc/e2e/WSLCE2EHelpers.cpp
@@ -440,7 +440,17 @@ wil::com_ptr OpenDefaultElevatedSession()
std::pair StartLocalRegistry(IWSLCSession& session, const std::string& username, const std::string& password, USHORT port)
{
- EnsureImageIsLoaded({L"wslc-registry", L"latest", GetTestImagePath("wslc-registry:latest")});
+ // Check if the registry image is already loaded on this session.
+ wil::unique_cotaskmem_array_ptr images;
+ THROW_IF_FAILED(session.ListImages(nullptr, &images, images.size_address()));
+
+ bool found = std::ranges::any_of(
+ std::span{images.get(), images.size()}, [](const auto& e) { return std::strcmp(e.Image, "wslc-registry:latest") == 0; });
+
+ if (!found)
+ {
+ LoadTestImage(session, "wslc-registry:latest");
+ }
std::vector env = {std::format("REGISTRY_HTTP_ADDR=0.0.0.0:{}", port)};