From 54deded929203a64555d97424d5bb4b884fff69f Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Thu, 5 May 2022 14:58:59 -0300
Subject: [PATCH] Fix shared memory leak on Windows (#3319)

* Fix shared memory leak on Windows

* Fix memory leak caused by RO session disposal not decrementing the memory manager ref count

* Fix UnmapViewInternal deadlock

* Was not supposed to add those back
---
 Ryujinx.HLE/HOS/Services/Ro/IRoInterface.cs   |  12 +-
 Ryujinx.Memory/MemoryBlock.cs                 |   4 +-
 Ryujinx.Memory/MemoryManagement.cs            |  44 ++---
 Ryujinx.Memory/MemoryManagementWindows.cs     |  40 +++--
 .../WindowsShared/PlaceholderManager.cs       |  15 +-
 .../WindowsShared/PlaceholderManager4KB.cs    | 170 ++++++++++++++++++
 6 files changed, 226 insertions(+), 59 deletions(-)
 create mode 100644 Ryujinx.Memory/WindowsShared/PlaceholderManager4KB.cs

diff --git a/Ryujinx.HLE/HOS/Services/Ro/IRoInterface.cs b/Ryujinx.HLE/HOS/Services/Ro/IRoInterface.cs
index 0ce65e3a7..d986bc41f 100644
--- a/Ryujinx.HLE/HOS/Services/Ro/IRoInterface.cs
+++ b/Ryujinx.HLE/HOS/Services/Ro/IRoInterface.cs
@@ -30,6 +30,7 @@ namespace Ryujinx.HLE.HOS.Services.Ro
         private List<NroInfo> _nroInfos;
 
         private KProcess _owner;
+        private IVirtualMemoryManager _ownerMm;
 
         private static Random _random = new Random();
 
@@ -38,6 +39,7 @@ namespace Ryujinx.HLE.HOS.Services.Ro
             _nrrInfos = new List<NrrInfo>(MaxNrr);
             _nroInfos = new List<NroInfo>(MaxNro);
             _owner    = null;
+            _ownerMm  = null;
         }
 
         private ResultCode ParseNrr(out NrrInfo nrrInfo, ServiceCtx context, ulong nrrAddress, ulong nrrSize)
@@ -564,10 +566,12 @@ namespace Ryujinx.HLE.HOS.Services.Ro
                 return ResultCode.InvalidSession;
             }
 
-            _owner = context.Process.HandleTable.GetKProcess(context.Request.HandleDesc.ToCopy[0]);
-            context.Device.System.KernelContext.Syscall.CloseHandle(context.Request.HandleDesc.ToCopy[0]);
+            int processHandle = context.Request.HandleDesc.ToCopy[0];
+            _owner = context.Process.HandleTable.GetKProcess(processHandle);
+            _ownerMm = _owner?.CpuMemory;
+            context.Device.System.KernelContext.Syscall.CloseHandle(processHandle);
 
-            if (_owner?.CpuMemory is IRefCounted rc)
+            if (_ownerMm is IRefCounted rc)
             {
                 rc.IncrementReferenceCount();
             }
@@ -586,7 +590,7 @@ namespace Ryujinx.HLE.HOS.Services.Ro
 
                 _nroInfos.Clear();
 
-                if (_owner?.CpuMemory is IRefCounted rc)
+                if (_ownerMm is IRefCounted rc)
                 {
                     rc.DecrementReferenceCount();
                 }
diff --git a/Ryujinx.Memory/MemoryBlock.cs b/Ryujinx.Memory/MemoryBlock.cs
index 82a7d882e..c6b85b582 100644
--- a/Ryujinx.Memory/MemoryBlock.cs
+++ b/Ryujinx.Memory/MemoryBlock.cs
@@ -48,7 +48,7 @@ namespace Ryujinx.Memory
             {
                 _viewCompatible = flags.HasFlag(MemoryAllocationFlags.ViewCompatible);
                 _forceWindows4KBView = flags.HasFlag(MemoryAllocationFlags.ForceWindows4KBViewMapping);
-                _pointer = MemoryManagement.Reserve(size, _viewCompatible);
+                _pointer = MemoryManagement.Reserve(size, _viewCompatible, _forceWindows4KBView);
             }
             else
             {
@@ -404,7 +404,7 @@ namespace Ryujinx.Memory
                 }
                 else
                 {
-                    MemoryManagement.Free(ptr);
+                    MemoryManagement.Free(ptr, Size, _forceWindows4KBView);
                 }
 
                 foreach (MemoryBlock viewStorage in _viewStorages.Keys)
diff --git a/Ryujinx.Memory/MemoryManagement.cs b/Ryujinx.Memory/MemoryManagement.cs
index 81262152b..3b8a96649 100644
--- a/Ryujinx.Memory/MemoryManagement.cs
+++ b/Ryujinx.Memory/MemoryManagement.cs
@@ -8,9 +8,7 @@ namespace Ryujinx.Memory
         {
             if (OperatingSystem.IsWindows())
             {
-                IntPtr sizeNint = new IntPtr((long)size);
-
-                return MemoryManagementWindows.Allocate(sizeNint);
+                return MemoryManagementWindows.Allocate((IntPtr)size);
             }
             else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
             {
@@ -22,13 +20,11 @@ namespace Ryujinx.Memory
             }
         }
 
-        public static IntPtr Reserve(ulong size, bool viewCompatible)
+        public static IntPtr Reserve(ulong size, bool viewCompatible, bool force4KBMap)
         {
             if (OperatingSystem.IsWindows())
             {
-                IntPtr sizeNint = new IntPtr((long)size);
-
-                return MemoryManagementWindows.Reserve(sizeNint, viewCompatible);
+                return MemoryManagementWindows.Reserve((IntPtr)size, viewCompatible, force4KBMap);
             }
             else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
             {
@@ -44,9 +40,7 @@ namespace Ryujinx.Memory
         {
             if (OperatingSystem.IsWindows())
             {
-                IntPtr sizeNint = new IntPtr((long)size);
-
-                return MemoryManagementWindows.Commit(address, sizeNint);
+                return MemoryManagementWindows.Commit(address, (IntPtr)size);
             }
             else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
             {
@@ -62,9 +56,7 @@ namespace Ryujinx.Memory
         {
             if (OperatingSystem.IsWindows())
             {
-                IntPtr sizeNint = new IntPtr((long)size);
-
-                return MemoryManagementWindows.Decommit(address, sizeNint);
+                return MemoryManagementWindows.Decommit(address, (IntPtr)size);
             }
             else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
             {
@@ -80,15 +72,13 @@ namespace Ryujinx.Memory
         {
             if (OperatingSystem.IsWindows())
             {
-                IntPtr sizeNint = new IntPtr((long)size);
-
                 if (force4KBMap)
                 {
-                    MemoryManagementWindows.MapView4KB(sharedMemory, srcOffset, address, sizeNint);
+                    MemoryManagementWindows.MapView4KB(sharedMemory, srcOffset, address, (IntPtr)size);
                 }
                 else
                 {
-                    MemoryManagementWindows.MapView(sharedMemory, srcOffset, address, sizeNint);
+                    MemoryManagementWindows.MapView(sharedMemory, srcOffset, address, (IntPtr)size);
                 }
             }
             else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
@@ -105,15 +95,13 @@ namespace Ryujinx.Memory
         {
             if (OperatingSystem.IsWindows())
             {
-                IntPtr sizeNint = new IntPtr((long)size);
-
                 if (force4KBMap)
                 {
-                    MemoryManagementWindows.UnmapView4KB(address, sizeNint);
+                    MemoryManagementWindows.UnmapView4KB(address, (IntPtr)size);
                 }
                 else
                 {
-                    MemoryManagementWindows.UnmapView(sharedMemory, address, sizeNint);
+                    MemoryManagementWindows.UnmapView(sharedMemory, address, (IntPtr)size);
                 }
             }
             else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
@@ -132,15 +120,13 @@ namespace Ryujinx.Memory
 
             if (OperatingSystem.IsWindows())
             {
-                IntPtr sizeNint = new IntPtr((long)size);
-
                 if (forView && force4KBMap)
                 {
-                    result = MemoryManagementWindows.Reprotect4KB(address, sizeNint, permission, forView);
+                    result = MemoryManagementWindows.Reprotect4KB(address, (IntPtr)size, permission, forView);
                 }
                 else
                 {
-                    result = MemoryManagementWindows.Reprotect(address, sizeNint, permission, forView);
+                    result = MemoryManagementWindows.Reprotect(address, (IntPtr)size, permission, forView);
                 }
             }
             else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
@@ -158,11 +144,11 @@ namespace Ryujinx.Memory
             }
         }
 
-        public static bool Free(IntPtr address)
+        public static bool Free(IntPtr address, ulong size, bool force4KBMap)
         {
             if (OperatingSystem.IsWindows())
             {
-                return MemoryManagementWindows.Free(address);
+                return MemoryManagementWindows.Free(address, (IntPtr)size, force4KBMap);
             }
             else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
             {
@@ -178,9 +164,7 @@ namespace Ryujinx.Memory
         {
             if (OperatingSystem.IsWindows())
             {
-                IntPtr sizeNint = new IntPtr((long)size);
-
-                return MemoryManagementWindows.CreateSharedMemory(sizeNint, reserve);
+                return MemoryManagementWindows.CreateSharedMemory((IntPtr)size, reserve);
             }
             else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
             {
diff --git a/Ryujinx.Memory/MemoryManagementWindows.cs b/Ryujinx.Memory/MemoryManagementWindows.cs
index 1885376ca..3d51b64f7 100644
--- a/Ryujinx.Memory/MemoryManagementWindows.cs
+++ b/Ryujinx.Memory/MemoryManagementWindows.cs
@@ -7,21 +7,27 @@ namespace Ryujinx.Memory
     [SupportedOSPlatform("windows")]
     static class MemoryManagementWindows
     {
-        private const int PageSize = 0x1000;
+        public const int PageSize = 0x1000;
 
         private static readonly PlaceholderManager _placeholders = new PlaceholderManager();
+        private static readonly PlaceholderManager4KB _placeholders4KB = new PlaceholderManager4KB();
 
         public static IntPtr Allocate(IntPtr size)
         {
             return AllocateInternal(size, AllocationType.Reserve | AllocationType.Commit);
         }
 
-        public static IntPtr Reserve(IntPtr size, bool viewCompatible)
+        public static IntPtr Reserve(IntPtr size, bool viewCompatible, bool force4KBMap)
         {
             if (viewCompatible)
             {
                 IntPtr baseAddress = AllocateInternal2(size, AllocationType.Reserve | AllocationType.ReservePlaceholder);
-                _placeholders.ReserveRange((ulong)baseAddress, (ulong)size);
+
+                if (!force4KBMap)
+                {
+                    _placeholders.ReserveRange((ulong)baseAddress, (ulong)size);
+                }
+
                 return baseAddress;
             }
 
@@ -69,6 +75,8 @@ namespace Ryujinx.Memory
 
         public static void MapView4KB(IntPtr sharedMemory, ulong srcOffset, IntPtr location, IntPtr size)
         {
+            _placeholders4KB.UnmapAndMarkRangeAsMapped(location, size);
+
             ulong uaddress = (ulong)location;
             ulong usize = (ulong)size;
             IntPtr endLocation = (IntPtr)(uaddress + usize);
@@ -105,20 +113,7 @@ namespace Ryujinx.Memory
 
         public static void UnmapView4KB(IntPtr location, IntPtr size)
         {
-            ulong uaddress = (ulong)location;
-            ulong usize = (ulong)size;
-            IntPtr endLocation = (IntPtr)(uaddress + usize);
-
-            while (location != endLocation)
-            {
-                bool result = WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, location, 2);
-                if (!result)
-                {
-                    throw new WindowsApiException("UnmapViewOfFile2");
-                }
-
-                location += PageSize;
-            }
+            _placeholders4KB.UnmapView(location, size);
         }
 
         public static bool Reprotect(IntPtr address, IntPtr size, MemoryPermission permission, bool forView)
@@ -151,8 +146,17 @@ namespace Ryujinx.Memory
             return true;
         }
 
-        public static bool Free(IntPtr address)
+        public static bool Free(IntPtr address, IntPtr size, bool force4KBMap)
         {
+            if (force4KBMap)
+            {
+                _placeholders4KB.UnmapRange(address, size);
+            }
+            else
+            {
+                _placeholders.UnmapView(IntPtr.Zero, address, size);
+            }
+
             return WindowsApi.VirtualFree(address, IntPtr.Zero, AllocationType.Release);
         }
 
diff --git a/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs b/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs
index b0b3bf050..d465f3416 100644
--- a/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs
+++ b/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs
@@ -1,5 +1,6 @@
 using System;
 using System.Diagnostics;
+using System.Runtime.Versioning;
 using System.Threading;
 
 namespace Ryujinx.Memory.WindowsShared
@@ -7,6 +8,7 @@ namespace Ryujinx.Memory.WindowsShared
     /// <summary>
     /// Windows memory placeholder manager.
     /// </summary>
+    [SupportedOSPlatform("windows")]
     class PlaceholderManager
     {
         private const ulong MinimumPageSize = 0x1000;
@@ -203,7 +205,7 @@ namespace Ryujinx.Memory.WindowsShared
             ulong endAddress = startAddress + unmapSize;
 
             var overlaps = Array.Empty<IntervalTreeNode<ulong, ulong>>();
-            int count = 0;
+            int count;
 
             lock (_mappings)
             {
@@ -226,8 +228,11 @@ namespace Ryujinx.Memory.WindowsShared
                     ulong overlapEnd = overlap.End;
                     ulong overlapValue = overlap.Value;
 
-                    _mappings.Remove(overlap);
-                    _mappings.Add(overlapStart, overlapEnd, ulong.MaxValue);
+                    lock (_mappings)
+                    {
+                        _mappings.Remove(overlap);
+                        _mappings.Add(overlapStart, overlapEnd, ulong.MaxValue);
+                    }
 
                     bool overlapStartsBefore = overlapStart < startAddress;
                     bool overlapEndsAfter = overlapEnd > endAddress;
@@ -364,7 +369,7 @@ namespace Ryujinx.Memory.WindowsShared
             ulong endAddress = reprotectAddress + reprotectSize;
 
             var overlaps = Array.Empty<IntervalTreeNode<ulong, ulong>>();
-            int count = 0;
+            int count;
 
             lock (_mappings)
             {
@@ -534,7 +539,7 @@ namespace Ryujinx.Memory.WindowsShared
         {
             ulong endAddress = address + size;
             var overlaps = Array.Empty<IntervalTreeNode<ulong, MemoryPermission>>();
-            int count = 0;
+            int count;
 
             lock (_protections)
             {
diff --git a/Ryujinx.Memory/WindowsShared/PlaceholderManager4KB.cs b/Ryujinx.Memory/WindowsShared/PlaceholderManager4KB.cs
new file mode 100644
index 000000000..fc056a2f7
--- /dev/null
+++ b/Ryujinx.Memory/WindowsShared/PlaceholderManager4KB.cs
@@ -0,0 +1,170 @@
+using System;
+using System.Runtime.Versioning;
+
+namespace Ryujinx.Memory.WindowsShared
+{
+    /// <summary>
+    /// Windows 4KB memory placeholder manager.
+    /// </summary>
+    [SupportedOSPlatform("windows")]
+    class PlaceholderManager4KB
+    {
+        private const int PageSize = MemoryManagementWindows.PageSize;
+
+        private readonly IntervalTree<ulong, byte> _mappings;
+
+        /// <summary>
+        /// Creates a new instance of the Windows 4KB memory placeholder manager.
+        /// </summary>
+        public PlaceholderManager4KB()
+        {
+            _mappings = new IntervalTree<ulong, byte>();
+        }
+
+        /// <summary>
+        /// Unmaps the specified range of memory and marks it as mapped internally.
+        /// </summary>
+        /// <remarks>
+        /// Since this marks the range as mapped, the expectation is that the range will be mapped after calling this method.
+        /// </remarks>
+        /// <param name="location">Memory address to unmap and mark as mapped</param>
+        /// <param name="size">Size of the range in bytes</param>
+        public void UnmapAndMarkRangeAsMapped(IntPtr location, IntPtr size)
+        {
+            ulong startAddress = (ulong)location;
+            ulong unmapSize = (ulong)size;
+            ulong endAddress = startAddress + unmapSize;
+
+            var overlaps = Array.Empty<IntervalTreeNode<ulong, byte>>();
+            int count = 0;
+
+            lock (_mappings)
+            {
+                count = _mappings.Get(startAddress, endAddress, ref overlaps);
+            }
+
+            for (int index = 0; index < count; index++)
+            {
+                var overlap = overlaps[index];
+
+                // Tree operations might modify the node start/end values, so save a copy before we modify the tree.
+                ulong overlapStart = overlap.Start;
+                ulong overlapEnd = overlap.End;
+                ulong overlapValue = overlap.Value;
+
+                _mappings.Remove(overlap);
+
+                ulong unmapStart = Math.Max(overlapStart, startAddress);
+                ulong unmapEnd = Math.Min(overlapEnd, endAddress);
+
+                if (overlapStart < startAddress)
+                {
+                    startAddress = overlapStart;
+                }
+
+                if (overlapEnd > endAddress)
+                {
+                    endAddress = overlapEnd;
+                }
+
+                ulong currentAddress = unmapStart;
+                while (currentAddress < unmapEnd)
+                {
+                    WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)currentAddress, 2);
+                    currentAddress += PageSize;
+                }
+            }
+
+            _mappings.Add(startAddress, endAddress, 0);
+        }
+
+        /// <summary>
+        /// Unmaps views at the specified memory range.
+        /// </summary>
+        /// <param name="location">Address of the range</param>
+        /// <param name="size">Size of the range in bytes</param>
+        public void UnmapView(IntPtr location, IntPtr size)
+        {
+            ulong startAddress = (ulong)location;
+            ulong unmapSize = (ulong)size;
+            ulong endAddress = startAddress + unmapSize;
+
+            var overlaps = Array.Empty<IntervalTreeNode<ulong, byte>>();
+            int count = 0;
+
+            lock (_mappings)
+            {
+                count = _mappings.Get(startAddress, endAddress, ref overlaps);
+            }
+
+            for (int index = 0; index < count; index++)
+            {
+                var overlap = overlaps[index];
+
+                // Tree operations might modify the node start/end values, so save a copy before we modify the tree.
+                ulong overlapStart = overlap.Start;
+                ulong overlapEnd = overlap.End;
+
+                _mappings.Remove(overlap);
+
+                if (overlapStart < startAddress)
+                {
+                    _mappings.Add(overlapStart, startAddress, 0);
+                }
+
+                if (overlapEnd > endAddress)
+                {
+                    _mappings.Add(endAddress, overlapEnd, 0);
+                }
+
+                ulong unmapStart = Math.Max(overlapStart, startAddress);
+                ulong unmapEnd = Math.Min(overlapEnd, endAddress);
+
+                ulong currentAddress = unmapStart;
+                while (currentAddress < unmapEnd)
+                {
+                    WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)currentAddress, 2);
+                    currentAddress += PageSize;
+                }
+            }
+        }
+
+        /// <summary>
+        /// Unmaps mapped memory at a given range.
+        /// </summary>
+        /// <param name="location">Address of the range</param>
+        /// <param name="size">Size of the range in bytes</param>
+        public void UnmapRange(IntPtr location, IntPtr size)
+        {
+            ulong startAddress = (ulong)location;
+            ulong unmapSize = (ulong)size;
+            ulong endAddress = startAddress + unmapSize;
+
+            var overlaps = Array.Empty<IntervalTreeNode<ulong, byte>>();
+            int count = 0;
+
+            lock (_mappings)
+            {
+                count = _mappings.Get(startAddress, endAddress, ref overlaps);
+            }
+
+            for (int index = 0; index < count; index++)
+            {
+                var overlap = overlaps[index];
+
+                // Tree operations might modify the node start/end values, so save a copy before we modify the tree.
+                ulong unmapStart = Math.Max(overlap.Start, startAddress);
+                ulong unmapEnd = Math.Min(overlap.End, endAddress);
+
+                _mappings.Remove(overlap);
+
+                ulong currentAddress = unmapStart;
+                while (currentAddress < unmapEnd)
+                {
+                    WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)currentAddress, 2);
+                    currentAddress += PageSize;
+                }
+            }
+        }
+    }
+}
\ No newline at end of file