Browse Source

Perform assembly enumeration during patching

Bepis 6 years ago
parent
commit
4c15a6b563
2 changed files with 43 additions and 49 deletions
  1. 30 16
      BepInEx/Bootstrap/AssemblyPatcher.cs
  2. 13 33
      BepInEx/Bootstrap/Preloader.cs

+ 30 - 16
BepInEx/Bootstrap/AssemblyPatcher.cs

@@ -13,7 +13,7 @@ namespace BepInEx.Bootstrap
     {
         private static bool DumpingEnabled => bool.TryParse(Config.GetEntry("preloader-dumpassemblies", "false"), out bool result) ? result : false;
 
-        public static void PatchAll(string directory, Dictionary<string, IList<AssemblyPatcherDelegate>> patcherMethodDictionary)
+        public static void PatchAll(string directory, Dictionary<AssemblyPatcherDelegate, IEnumerable<string>> patcherMethodDictionary)
         {
             //load all the requested assemblies
             List<AssemblyDefinition> assemblies = new List<AssemblyDefinition>();
@@ -59,17 +59,27 @@ namespace BepInEx.Bootstrap
             //sort the assemblies so load the assemblies that are dependant upon first
             AssemblyDefinition[] sortedAssemblies = Utility.TopologicalSort(assemblies, x => assemblyDependencyDict[x]).ToArray();
 
+	        List<string> sortedAssemblyFilenames = sortedAssemblies.Select(x => assemblyFilenames[x]).ToList();
+
             //call the patchers on the assemblies
+	        foreach (var patcherMethod in patcherMethodDictionary)
+	        {
+		        foreach (string assemblyFilename in patcherMethod.Value)
+		        {
+			        int index = sortedAssemblyFilenames.FindIndex(x => x == assemblyFilename);
+
+			        if (index < 0)
+				        continue;
+
+					Patch(ref sortedAssemblies[index], patcherMethod.Key);
+		        }
+	        }
+
+
 			for (int i = 0; i < sortedAssemblies.Length; i++)
 			{
                 string filename = Path.GetFileName(assemblyFilenames[sortedAssemblies[i]]);
 
-                //skip if we aren't patching it
-                if (!patcherMethodDictionary.TryGetValue(filename, out IList<AssemblyPatcherDelegate> patcherMethods))
-                    continue;
-
-                Patch(ref sortedAssemblies[i], patcherMethods);
-
                 if (DumpingEnabled)
                 {
                     using (MemoryStream mem = new MemoryStream())
@@ -83,22 +93,26 @@ namespace BepInEx.Bootstrap
                         File.WriteAllBytes(Path.Combine(dirPath, filename), mem.ToArray());
                     }
                 }
+
+				Load(sortedAssemblies[i]);
 #if CECIL_10
 				sortedAssemblies[i].Dispose();
 #endif
             }
         }
 
-        public static void Patch(ref AssemblyDefinition assembly, IEnumerable<AssemblyPatcherDelegate> patcherMethods)
+        public static void Patch(ref AssemblyDefinition assembly, AssemblyPatcherDelegate patcherMethod)
         {
-            using (MemoryStream assemblyStream = new MemoryStream())
-            {
-                foreach (AssemblyPatcherDelegate method in patcherMethods)
-                    method.Invoke(ref assembly);
-
-                assembly.Write(assemblyStream);
-                Assembly.Load(assemblyStream.ToArray());
-            }
+	        patcherMethod.Invoke(ref assembly);
         }
+
+	    public static void Load(AssemblyDefinition assembly)
+	    {
+		    using (MemoryStream assemblyStream = new MemoryStream())
+		    {
+			    assembly.Write(assemblyStream);
+			    Assembly.Load(assemblyStream.ToArray());
+		    }
+	    }
     }
 }

+ 13 - 33
BepInEx/Bootstrap/Preloader.cs

@@ -38,24 +38,15 @@ namespace BepInEx.Bootstrap
 
         public static PreloaderLogWriter PreloaderLog { get; private set; }
 
-        public static Dictionary<string, IList<AssemblyPatcherDelegate>> PatcherDictionary = new Dictionary<string, IList<AssemblyPatcherDelegate>>(StringComparer.OrdinalIgnoreCase);
+        public static Dictionary<AssemblyPatcherDelegate, IEnumerable<string>> PatcherDictionary = new Dictionary<AssemblyPatcherDelegate, IEnumerable<string>>();
 
 
-        public static void AddPatcher(string dllName, AssemblyPatcherDelegate patcher)
+        public static void AddPatcher(IEnumerable<string> dllNames, AssemblyPatcherDelegate patcher)
         {
-            if (PatcherDictionary.TryGetValue(dllName, out IList<AssemblyPatcherDelegate> patcherList))
-                patcherList.Add(patcher);
-            else
-            {
-                patcherList = new List<AssemblyPatcherDelegate>();
-
-                patcherList.Add(patcher);
-
-                PatcherDictionary[dllName] = patcherList;
-            }
+	        PatcherDictionary[patcher] = dllNames;
         }
 
-        private static bool TryGetConfigBool(string key, string defaultValue)
+        private static bool SafeGetConfigBool(string key, string defaultValue)
         {
             try
             {
@@ -71,8 +62,8 @@ namespace BepInEx.Bootstrap
 
         internal static void AllocateConsole()
         {
-            bool console = TryGetConfigBool("console", "false");
-            bool shiftjis = TryGetConfigBool("console-shiftjis", "false");
+            bool console = SafeGetConfigBool("console", "false");
+            bool shiftjis = SafeGetConfigBool("console-shiftjis", "false");
 
             if (console)
             {
@@ -105,7 +96,7 @@ namespace BepInEx.Bootstrap
 
                 AllocateConsole();
 
-                PreloaderLog = new PreloaderLogWriter(TryGetConfigBool("preloader-logconsole", "false"));
+                PreloaderLog = new PreloaderLogWriter(SafeGetConfigBool("preloader-logconsole", "false"));
                 PreloaderLog.Enabled = true;
 
                 string consoleTile = $"BepInEx {Assembly.GetExecutingAssembly().GetName().Version} - {Process.GetCurrentProcess().ProcessName}";
@@ -118,7 +109,7 @@ namespace BepInEx.Bootstrap
                 Logger.Log(LogLevel.Message, "Preloader started");
 
 
-                AddPatcher("UnityEngine.dll", PatchEntrypoint);
+                AddPatcher(new [] { "UnityEngine.dll" }, PatchEntrypoint);
 
                 if (Directory.Exists(PatcherPluginPath))
                     foreach (string assemblyPath in Directory.GetFiles(PatcherPluginPath, "*.dll"))
@@ -128,8 +119,7 @@ namespace BepInEx.Bootstrap
                             var assembly = Assembly.LoadFrom(assemblyPath);
 
                             foreach (var kv in GetPatcherMethods(assembly))
-                                foreach (var patcher in kv.Value)
-                                    AddPatcher(kv.Key, patcher);
+                                AddPatcher(kv.Value, kv.Key);
                         }
                         catch (BadImageFormatException) { } //unmanaged DLL
                         catch (ReflectionTypeLoadException) { } //invalid references
@@ -159,9 +149,9 @@ namespace BepInEx.Bootstrap
             }
         }
 
-        internal static IDictionary<string, IList<AssemblyPatcherDelegate>> GetPatcherMethods(Assembly assembly)
+        internal static Dictionary<AssemblyPatcherDelegate, IEnumerable<string>> GetPatcherMethods(Assembly assembly)
         {
-            var patcherMethods = new Dictionary<string, IList<AssemblyPatcherDelegate>>(StringComparer.OrdinalIgnoreCase);
+            var patcherMethods = new Dictionary<AssemblyPatcherDelegate, IEnumerable<string>>();
 
             foreach (var type in assembly.GetExportedTypes())
             {
@@ -213,18 +203,8 @@ namespace BepInEx.Bootstrap
 
                     IEnumerable<string> targets = (IEnumerable<string>)targetsProperty.GetValue(null, null);
 
-                    foreach (string target in targets)
-                    {
-                        if (patcherMethods.TryGetValue(target, out IList<AssemblyPatcherDelegate> patchers))
-                            patchers.Add(patchDelegate);
-                        else
-                        {
-                            patchers = new List<AssemblyPatcherDelegate>{ patchDelegate };
-
-                            patcherMethods[target] = patchers;
-                        }
-                    }
-                }
+		            patcherMethods[patchDelegate] = targets;
+	            }
                 catch (Exception ex)
                 {
                     Logger.Log(LogLevel.Warning, $"Could not load patcher methods from {assembly.GetName().Name}");