Browse Source

Add ref patcher method

Bepis 6 years ago
parent
commit
910aece254
2 changed files with 65 additions and 47 deletions
  1. 22 24
      BepInEx/Bootstrap/AssemblyPatcher.cs
  2. 43 23
      BepInEx/Bootstrap/Preloader.cs

+ 22 - 24
BepInEx/Bootstrap/AssemblyPatcher.cs

@@ -7,7 +7,7 @@ using Mono.Cecil;
 
 namespace BepInEx.Bootstrap
 {
-    public delegate void AssemblyPatcherDelegate(AssemblyDefinition assembly);
+    public delegate void AssemblyPatcherDelegate(ref AssemblyDefinition assembly);
 
     public static class AssemblyPatcher
     {
@@ -57,46 +57,44 @@ namespace BepInEx.Bootstrap
             }
 
             //sort the assemblies so load the assemblies that are dependant upon first
-            IEnumerable<AssemblyDefinition> sortedAssemblies = Utility.TopologicalSort(assemblies, x => assemblyDependencyDict[x]);
+            AssemblyDefinition[] sortedAssemblies = Utility.TopologicalSort(assemblies, x => assemblyDependencyDict[x]).ToArray();
 
             //call the patchers on the assemblies
-            foreach (var assembly in sortedAssemblies)
-            {
-#if CECIL_10
-                using (assembly)
-#endif
-                {
-                    string filename = Path.GetFileName(assemblyFilenames[assembly]);
+			for (int i = 0; i < sortedAssemblies.Length; i++)
+			{
+                string filename = Path.GetFileName(assemblyFilenames[sortedAssemblies[i]]);
 
-                    //skip if we aren't patching it
-                    if (!patcherMethodDictionary.TryGetValue(filename, out IList<AssemblyPatcherDelegate> patcherMethods))
-                        continue;
+                //skip if we aren't patching it
+                if (!patcherMethodDictionary.TryGetValue(filename, out IList<AssemblyPatcherDelegate> patcherMethods))
+                    continue;
 
-                    Patch(assembly, patcherMethods);
+                Patch(ref sortedAssemblies[i], patcherMethods);
 
-                    if (DumpingEnabled)
+                if (DumpingEnabled)
+                {
+                    using (MemoryStream mem = new MemoryStream())
                     {
-                        using (MemoryStream mem = new MemoryStream())
-                        {
-                            string dirPath = Path.Combine(Preloader.PluginPath, "DumpedAssemblies");
+                        string dirPath = Path.Combine(Preloader.PluginPath, "DumpedAssemblies");
 
-                            if (!Directory.Exists(dirPath))
-                                Directory.CreateDirectory(dirPath);
+                        if (!Directory.Exists(dirPath))
+                            Directory.CreateDirectory(dirPath);
                             
-                            assembly.Write(mem);
-                            File.WriteAllBytes(Path.Combine(dirPath, filename), mem.ToArray());
-                        }
+	                    sortedAssemblies[i].Write(mem);
+                        File.WriteAllBytes(Path.Combine(dirPath, filename), mem.ToArray());
                     }
                 }
+#if CECIL_10
+				sortedAssemblies[i].Dispose();
+#endif
             }
         }
 
-        public static void Patch(AssemblyDefinition assembly, IEnumerable<AssemblyPatcherDelegate> patcherMethods)
+        public static void Patch(ref AssemblyDefinition assembly, IEnumerable<AssemblyPatcherDelegate> patcherMethods)
         {
             using (MemoryStream assemblyStream = new MemoryStream())
             {
                 foreach (AssemblyPatcherDelegate method in patcherMethods)
-                    method.Invoke(assembly);
+                    method.Invoke(ref assembly);
 
                 assembly.Write(assemblyStream);
                 Assembly.Load(assemblyStream.ToArray());

+ 43 - 23
BepInEx/Bootstrap/Preloader.cs

@@ -165,31 +165,51 @@ namespace BepInEx.Bootstrap
 
             foreach (var type in assembly.GetExportedTypes())
             {
-                try
-                {
-                    if (type.IsInterface)
+	            try
+	            {
+		            if (type.IsInterface)
+			            continue;
+
+		            PropertyInfo targetsProperty = type.GetProperty(
+			            "TargetDLLs",
+			            BindingFlags.Public | BindingFlags.Static | BindingFlags.IgnoreCase,
+			            null,
+			            typeof(IEnumerable<string>),
+			            Type.EmptyTypes,
+			            null);
+
+					//first try get the ref patcher method
+		            MethodInfo patcher = type.GetMethod(
+			            "Patch",
+			            BindingFlags.Public | BindingFlags.Static | BindingFlags.IgnoreCase,
+			            null,
+			            CallingConventions.Any,
+			            new[] {typeof(AssemblyDefinition).MakeByRefType()},
+			            null);
+
+		            if (patcher == null) //otherwise try getting the non-ref patcher method
+		            {
+			            patcher = type.GetMethod(
+				            "Patch",
+				            BindingFlags.Public | BindingFlags.Static | BindingFlags.IgnoreCase,
+				            null,
+				            CallingConventions.Any,
+				            new[] {typeof(AssemblyDefinition)},
+				            null);
+		            }
+
+		            if (targetsProperty == null || !targetsProperty.CanRead || patcher == null)
                         continue;
 
-                    PropertyInfo targetsProperty = type.GetProperty(
-                        "TargetDLLs", 
-                        BindingFlags.Public | BindingFlags.Static | BindingFlags.IgnoreCase,
-                        null,
-                        typeof(IEnumerable<string>),
-                        Type.EmptyTypes,
-                        null);
-
-                    MethodInfo patcher = type.GetMethod(
-                        "Patch", 
-                        BindingFlags.Public | BindingFlags.Static | BindingFlags.IgnoreCase,
-                        null,
-                        CallingConventions.Any,
-                        new[] { typeof(AssemblyDefinition) },
-                        null);
-
-                    if (targetsProperty == null || !targetsProperty.CanRead || patcher == null)
-                        continue;
+                    AssemblyPatcherDelegate patchDelegate = (ref AssemblyDefinition ass) =>
+                    {
+						//we do the array fuckery here to get the ref result out
+	                    object[] args = { ass };
+
+	                    patcher.Invoke(null, args);
 
-                    AssemblyPatcherDelegate patchDelegate = (ass) => { patcher.Invoke(null, new object[] {ass}); };
+	                    ass = (AssemblyDefinition)args[0];
+                    };
 
                     IEnumerable<string> targets = (IEnumerable<string>)targetsProperty.GetValue(null, null);
 
@@ -217,7 +237,7 @@ namespace BepInEx.Bootstrap
             return patcherMethods;
         }
 
-        internal static void PatchEntrypoint(AssemblyDefinition assembly)
+        internal static void PatchEntrypoint(ref AssemblyDefinition assembly)
         {
             if (assembly.Name.Name == "UnityEngine")
             {