Browse Source

Use Cecil to find preloader patches

ghorsington 5 years ago
parent
commit
3cdf6cc642

+ 63 - 10
BepInEx.Preloader/Patching/AssemblyPatcher.cs

@@ -4,6 +4,7 @@ using System.Diagnostics;
 using System.IO;
 using System.Linq;
 using System.Reflection;
+using BepInEx.Bootstrap;
 using BepInEx.Configuration;
 using BepInEx.Logging;
 using BepInEx.Preloader.RuntimeFixes;
@@ -23,6 +24,8 @@ namespace BepInEx.Preloader.Patching
 	/// </summary>
 	internal static class AssemblyPatcher
 	{
+		private const BindingFlags ALL = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.IgnoreCase;
+
 		public static List<PatcherPlugin> PatcherPlugins { get; } = new List<PatcherPlugin>();
 
 		private static readonly string DumpedAssembliesPath = Path.Combine(Paths.BepInExRootPath, "DumpedAssemblies");
@@ -36,32 +39,83 @@ namespace BepInEx.Preloader.Patching
 			PatcherPlugins.Add(patcher);
 		}
 
+		private static T CreateDelegate<T>(MethodInfo method) where T : class => method != null ? Delegate.CreateDelegate(typeof(T), method) as T : null;
+
 		/// <summary>
 		///     Adds all patchers from all managed assemblies specified in a directory.
 		/// </summary>
 		/// <param name="directory">Directory to search patcher DLLs from.</param>
 		/// <param name="patcherLocator">A function that locates assembly patchers in a given managed assembly.</param>
 		public static void AddPatchersFromDirectory(string directory,
-			Func<Assembly, List<PatcherPlugin>> patcherLocator)
+			Func<TypeDefinition, PatcherPlugin> patcherLocator)
 		{
 			if (!Directory.Exists(directory))
 				return;
 
 			var sortedPatchers = new SortedDictionary<string, PatcherPlugin>();
 
-			foreach (string assemblyPath in Directory.GetFiles(directory, "*.dll", SearchOption.AllDirectories))
-				try
+			var patchers = TypeLoader.FindPluginTypes(directory, patcherLocator);
+
+			foreach (var keyValuePair in patchers)
+			{
+				var assembly = keyValuePair.Key;
+				var patcherCollection = keyValuePair.Value;
+
+				var ass = Assembly.LoadFile(assembly.MainModule.FileName);
+
+				foreach (var patcherPlugin in patcherCollection)
 				{
-					var assembly = Assembly.LoadFrom(assemblyPath);
+					try
+					{
+						var type = ass.GetType(patcherPlugin.Type.FullName);
+
+						var methods = type.GetMethods(ALL);
+
+						patcherPlugin.Initializer = CreateDelegate<Action>(methods.FirstOrDefault(m => m.Name.Equals("Initialize", StringComparison.InvariantCultureIgnoreCase) &&
+																									   m.GetParameters().Length == 0 &&
+																									   m.ReturnType == typeof(void)));
+
+						patcherPlugin.Finalizer = CreateDelegate<Action>(methods.FirstOrDefault(m => m.Name.Equals("Finish", StringComparison.InvariantCultureIgnoreCase) &&
+																									 m.GetParameters().Length == 0 &&
+																									 m.ReturnType == typeof(void)));
+
+						patcherPlugin.TargetDLLs = CreateDelegate<Func<IEnumerable<string>>>(type.GetProperty("TargetDLLs", ALL).GetGetMethod());
 
-					foreach (var patcher in patcherLocator(assembly))
-						sortedPatchers.Add(patcher.Name, patcher);
+						var patcher = methods.FirstOrDefault(m => m.Name.Equals("Patch", StringComparison.CurrentCultureIgnoreCase) &&
+																  m.ReturnType == typeof(void) &&
+																  m.GetParameters().Length == 0 &&
+																  (m.GetParameters()[0].ParameterType == typeof(AssemblyDefinition) ||
+																   m.GetParameters()[0].ParameterType == typeof(AssemblyDefinition).MakeByRefType()));
+
+						patcherPlugin.Patcher = (ref AssemblyDefinition pAss) =>
+						{
+							//we do the array fuckery here to get the ref result out
+							object[] args = { pAss };
+
+							patcher.Invoke(null, args);
+
+							pAss = (AssemblyDefinition)args[0];
+						};
+
+						sortedPatchers.Add($"{ass.GetName().Name}/{type.FullName}", patcherPlugin);
+						patcherPlugin.Type = null;
+					}
+					catch (Exception e)
+					{
+						Logger.LogError($"Failed to load patcher [{patcherPlugin.Type.FullName}]: {e.Message}");
+						if (e is ReflectionTypeLoadException re)
+							Logger.LogDebug(TypeLoader.TypeLoadExceptionToString(re));
+						else
+							Logger.LogDebug(e.ToString());
+					}
 				}
-				catch (BadImageFormatException) { } //unmanaged DLL
-				catch (ReflectionTypeLoadException) { } //invalid references
+			}
 
 			foreach (KeyValuePair<string, PatcherPlugin> patcher in sortedPatchers)
 				AddPatcher(patcher.Value);
+
+			foreach (var assemblyDefinition in patchers.Keys)
+				assemblyDefinition.Dispose();
 		}
 
 		private static void InitializePatchers()
@@ -123,7 +177,7 @@ namespace BepInEx.Preloader.Patching
 			// Then, perform the actual patching
 			var patchedAssemblies = new HashSet<string>();
 			foreach (var assemblyPatcher in PatcherPlugins)
-				foreach (string targetDll in assemblyPatcher.TargetDLLs)
+				foreach (string targetDll in assemblyPatcher.TargetDLLs())
 					if (assemblies.TryGetValue(targetDll, out var assembly))
 					{
 						Logger.LogInfo($"Patching [{assembly.Name.Name}] with [{assemblyPatcher.Name}]");
@@ -135,7 +189,6 @@ namespace BepInEx.Preloader.Patching
 
 
 			// Finally, load patched assemblies into memory
-
 			if (ConfigDumpAssemblies.Value || ConfigLoadDumpedAssemblies.Value)
 			{
 				if (!Directory.Exists(DumpedAssembliesPath))

+ 4 - 1
BepInEx.Preloader/Patching/PatcherPlugin.cs

@@ -1,5 +1,6 @@
 using System;
 using System.Collections.Generic;
+using Mono.Cecil;
 
 namespace BepInEx.Preloader.Patching
 {
@@ -8,10 +9,12 @@ namespace BepInEx.Preloader.Patching
 	/// </summary>
 	internal class PatcherPlugin
 	{
+		public TypeDefinition Type { get; set; }
+
 		/// <summary>
 		///     Target assemblies to patch.
 		/// </summary>
-		public IEnumerable<string> TargetDLLs { get; set; } = null;
+		public Func<IEnumerable<string>> TargetDLLs { get; set; } = null;
 
 		/// <summary>
 		///     Initializer method that is run before any patching occurs.

+ 29 - 89
BepInEx.Preloader/Preloader.cs

@@ -41,7 +41,7 @@ namespace BepInEx.Preloader
 
 				Logger.Listeners.Add(PreloaderLog);
 
-				
+
 				string consoleTile = $"BepInEx {typeof(Paths).Assembly.GetName().Version} - {Process.GetCurrentProcess().ProcessName}";
 
 				ConsoleWindow.Title = consoleTile;
@@ -69,12 +69,13 @@ namespace BepInEx.Preloader
 
 
 				AssemblyPatcher.AddPatcher(new PatcherPlugin
-					{ TargetDLLs = new[] { ConfigEntrypointAssembly.Value },
-						Patcher = PatchEntrypoint,
-						Name = "BepInEx.Chainloader"
-					});
+				{
+					TargetDLLs = () => new[] { ConfigEntrypointAssembly.Value },
+					Patcher = PatchEntrypoint,
+					Name = "BepInEx.Chainloader"
+				});
 
-				AssemblyPatcher.AddPatchersFromDirectory(Paths.PatcherPluginPath, GetPatcherMethods);
+				AssemblyPatcher.AddPatchersFromDirectory(Paths.PatcherPluginPath, ToPatcherPlugin);
 
 				Logger.LogInfo($"{AssemblyPatcher.PatcherPlugins.Count} patcher plugin(s) loaded");
 
@@ -122,95 +123,34 @@ namespace BepInEx.Preloader
 			}
 		}
 
-		/// <summary>
-		///     Scans the assembly for classes that use the patcher contract, and returns a list of valid patchers.
-		/// </summary>
-		/// <param name="assembly">The assembly to scan.</param>
-		/// <returns>A list of assembly patchers that were found in the assembly.</returns>
-		public static List<PatcherPlugin> GetPatcherMethods(Assembly assembly)
+		public static PatcherPlugin ToPatcherPlugin(TypeDefinition type)
 		{
-			var patcherMethods = new List<PatcherPlugin>();
-			var flags = BindingFlags.Public | BindingFlags.Static | BindingFlags.IgnoreCase;
-
-			foreach (var type in assembly.GetExportedTypes())
-				try
-				{
-					if (type.IsInterface)
-						continue;
-
-					var targetsProperty = type.GetProperty("TargetDLLs",
-						flags,
-						null,
-						typeof(IEnumerable<string>),
-						Type.EmptyTypes,
-						null);
-
-					//first try get the ref patcher method
-					var patcher = type.GetMethod("Patch",
-						flags,
-						null,
-						CallingConventions.Any,
-						new[] { typeof(AssemblyDefinition).MakeByRefType() },
-						null);
-
-					if (patcher == null) //otherwise try getting the non-ref patcher method
-						patcher = type.GetMethod("Patch",
-							flags,
-							null,
-							CallingConventions.Any,
-							new[] { typeof(AssemblyDefinition) },
-							null);
-
-					if (targetsProperty == null || !targetsProperty.CanRead || patcher == null)
-						continue;
-
-					var assemblyPatcher = new PatcherPlugin();
-
-					assemblyPatcher.Name = $"{assembly.GetName().Name}/{type.FullName}";
-					assemblyPatcher.Patcher = (ref AssemblyDefinition ass) =>
-					{
-						//we do the array fuckery here to get the ref result out
-						object[] args = { ass };
-
-						patcher.Invoke(null, args);
+			if (type.IsInterface || type.IsAbstract)
+				return null;
 
-						ass = (AssemblyDefinition)args[0];
-					};
+			var targetDlls = type.Methods.FirstOrDefault(m => m.Name.Equals("get_TargetDLLs", StringComparison.InvariantCultureIgnoreCase) &&
+															  m.IsPublic &&
+															  m.IsStatic);
 
-					assemblyPatcher.TargetDLLs = (IEnumerable<string>)targetsProperty.GetValue(null, null);
+			if (targetDlls == null ||
+				targetDlls.ReturnType.FullName != "System.Collections.Generic.IEnumerable`1<System.String>")
+				return null;
 
-					var initMethod = type.GetMethod("Initialize",
-						flags,
-						null,
-						CallingConventions.Any,
-						Type.EmptyTypes,
-						null);
+			var patch = type.Methods.FirstOrDefault(m => m.Name.Equals("Patch") &&
+														 m.IsPublic &&
+														 m.IsStatic &&
+														 m.ReturnType.FullName == "System.Void" &&
+														 m.Parameters.Count == 1 &&
+														 (m.Parameters[0].ParameterType.FullName == "Mono.Cecil.AssemblyDefinition&" ||
+														  m.Parameters[0].ParameterType.FullName == "Mono.Cecil.AssemblyDefinition"));
 
-					if (initMethod != null)
-						assemblyPatcher.Initializer = () => initMethod.Invoke(null, null);
+			if (patch == null)
+				return null;
 
-					var finalizeMethod = type.GetMethod("Finish",
-						flags,
-						null,
-						CallingConventions.Any,
-						Type.EmptyTypes,
-						null);
-
-					if (finalizeMethod != null)
-						assemblyPatcher.Finalizer = () => finalizeMethod.Invoke(null, null);
-
-					patcherMethods.Add(assemblyPatcher);
-				}
-				catch (Exception ex)
-				{
-					Logger.LogWarning($"Could not load patcher methods from {assembly.GetName().Name}");
-					Logger.LogWarning(ex);
-				}
-
-			Logger.Log(patcherMethods.Count > 0 ? LogLevel.Info : LogLevel.Debug,
-				$"Loaded {patcherMethods.Count} patcher methods from {assembly.GetName().Name}");
-
-			return patcherMethods;
+			return new PatcherPlugin
+			{
+				Type = type
+			};
 		}
 
 		/// <summary>