AssemblyPatcher.cs 9.1 KB


  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Reflection;
  5. using BepInEx.Harmony;
  6. using BepInEx.Logging;
  7. using Harmony;
  8. using Mono.Cecil;
  9. namespace BepInEx.Bootstrap
  10. {
  11. /// <summary>
  12. /// Delegate used in patching assemblies.
  13. /// </summary>
  14. /// <param name="assembly">The assembly that is being patched.</param>
  15. internal delegate void AssemblyPatcherDelegate(ref AssemblyDefinition assembly);
  16. /// <summary>
  17. /// A single assembly patcher.
  18. /// </summary>
  19. internal class PatcherPlugin
  20. {
  21. /// <summary>
  22. /// Target assemblies to patch.
  23. /// </summary>
  24. public IEnumerable<string> TargetDLLs { get; set; } = null;
  25. /// <summary>
  26. /// Initializer method that is run before any patching occurs.
  27. /// </summary>
  28. public Action Initializer { get; set; } = null;
  29. /// <summary>
  30. /// Finalizer method that is run after all patching is done.
  31. /// </summary>
  32. public Action Finalizer { get; set; } = null;
  33. /// <summary>
  34. /// The main patcher method that is called on every DLL defined in <see cref="TargetDLLs"/>.
  35. /// </summary>
  36. public AssemblyPatcherDelegate Patcher { get; set; } = null;
  37. /// <summary>
  38. /// Name of the patcher.
  39. /// </summary>
  40. public string Name { get; set; } = string.Empty;
  41. }
  42. /// <summary>
  43. /// Worker class which is used for loading and patching entire folders of assemblies, or alternatively patching and loading assemblies one at a time.
  44. /// </summary>
  45. internal static class AssemblyPatcher
  46. {
  47. private static List<PatcherPlugin> patchers = new List<PatcherPlugin>();
  48. /// <summary>
  49. /// Configuration value of whether assembly dumping is enabled or not.
  50. /// </summary>
  51. private static bool DumpingEnabled => Utility.SafeParseBool(Config.GetEntry("dump-assemblies", "false", "Preloader"));
  52. /// <summary>
  53. /// Adds a single assembly patcher to the pool of applicable patches.
  54. /// </summary>
  55. /// <param name="patcher">Patcher to apply.</param>
  56. public static void AddPatcher(PatcherPlugin patcher)
  57. {
  58. patchers.Add(patcher);
  59. }
  60. /// <summary>
  61. /// Adds all patchers from all managed assemblies specified in a directory.
  62. /// </summary>
  63. /// <param name="directory">Directory to search patcher DLLs from.</param>
  64. /// <param name="patcherLocator">A function that locates assembly patchers in a given managed assembly.</param>
  65. public static void AddPatchersFromDirectory(string directory, Func<Assembly, List<PatcherPlugin>> patcherLocator)
  66. {
  67. if (!Directory.Exists(directory))
  68. return;
  69. var sortedPatchers = new SortedDictionary<string, PatcherPlugin>();
  70. foreach (string assemblyPath in Directory.GetFiles(directory, "*.dll"))
  71. try
  72. {
  73. var assembly = Assembly.LoadFrom(assemblyPath);
  74. foreach (var patcher in patcherLocator(assembly))
  75. sortedPatchers.Add(patcher.Name, patcher);
  76. }
  77. catch (BadImageFormatException) { } //unmanaged DLL
  78. catch (ReflectionTypeLoadException) { } //invalid references
  79. foreach (var patcher in sortedPatchers)
  80. AddPatcher(patcher.Value);
  81. }
  82. private static void InitializePatchers()
  83. {
  84. foreach (var assemblyPatcher in patchers)
  85. assemblyPatcher.Initializer?.Invoke();
  86. }
  87. private static void FinalizePatching()
  88. {
  89. foreach (var assemblyPatcher in patchers)
  90. assemblyPatcher.Finalizer?.Invoke();
  91. }
  92. /// <summary>
  93. /// Releases all patchers to let them be collected by GC.
  94. /// </summary>
  95. public static void DisposePatchers()
  96. {
  97. patchers.Clear();
  98. }
  99. /// <summary>
  100. /// Applies patchers to all assemblies in the given directory and loads patched assemblies into memory.
  101. /// </summary>
  102. /// <param name="directory">Directory to load CLR assemblies from.</param>
  103. public static void PatchAndLoad(string directory)
  104. {
  105. // First, load patchable assemblies into Cecil
  106. Dictionary<string, AssemblyDefinition> assemblies = new Dictionary<string, AssemblyDefinition>();
  107. foreach (string assemblyPath in Directory.GetFiles(directory, "*.dll"))
  108. {
  109. var assembly = AssemblyDefinition.ReadAssembly(assemblyPath);
  110. //NOTE: this is special cased here because the dependency handling for System.dll is a bit wonky
  111. //System has an assembly reference to itself, and it also has a reference to Mono.Security causing a circular dependency
  112. //It's also generally dangerous to change system.dll since so many things rely on it,
  113. // and it's already loaded into the appdomain since this loader references it, so we might as well skip it
  114. if (assembly.Name.Name == "System"
  115. || assembly.Name.Name == "mscorlib") //mscorlib is already loaded into the appdomain so it can't be patched
  116. {
  117. assembly.Dispose();
  118. continue;
  119. }
  120. if (PatchedAssemblyResolver.AssemblyLocations.ContainsKey(assembly.FullName))
  121. {
  122. Logger.Log(LogLevel.Warning, $"Tried to load duplicate assembly {Path.GetFileName(assemblyPath)} from Managed folder! Skipping...");
  123. continue;
  124. }
  125. assemblies.Add(Path.GetFileName(assemblyPath), assembly);
  126. PatchedAssemblyResolver.AssemblyLocations.Add(assembly.FullName, Path.GetFullPath(assemblyPath));
  127. }
  128. // Next, initialize all the patchers
  129. InitializePatchers();
  130. // Then, perform the actual patching
  131. HashSet<string> patchedAssemblies = new HashSet<string>();
  132. foreach (var assemblyPatcher in patchers)
  133. {
  134. foreach (string targetDll in assemblyPatcher.TargetDLLs)
  135. {
  136. if (assemblies.TryGetValue(targetDll, out var assembly))
  137. {
  138. assemblyPatcher.Patcher?.Invoke(ref assembly);
  139. assemblies[targetDll] = assembly;
  140. patchedAssemblies.Add(targetDll);
  141. }
  142. }
  143. }
  144. // Finally, load all assemblies into memory
  145. foreach (var kv in assemblies)
  146. {
  147. string filename = kv.Key;
  148. var assembly = kv.Value;
  149. if (DumpingEnabled && patchedAssemblies.Contains(filename))
  150. {
  151. using (MemoryStream mem = new MemoryStream())
  152. {
  153. string dirPath = Path.Combine(Paths.PluginPath, "DumpedAssemblies");
  154. if (!Directory.Exists(dirPath))
  155. Directory.CreateDirectory(dirPath);
  156. assembly.Write(mem);
  157. File.WriteAllBytes(Path.Combine(dirPath, filename), mem.ToArray());
  158. }
  159. }
  160. Load(assembly);
  161. assembly.Dispose();
  162. }
  163. // Apply assembly location resolver patch
  164. PatchedAssemblyResolver.ApplyPatch();
  165. //run all finalizers
  166. FinalizePatching();
  167. }
  168. /// <summary>
  169. /// Loads an individual assembly defintion into the CLR.
  170. /// </summary>
  171. /// <param name="assembly">The assembly to load.</param>
  172. public static void Load(AssemblyDefinition assembly)
  173. {
  174. using (MemoryStream assemblyStream = new MemoryStream())
  175. {
  176. assembly.Write(assemblyStream);
  177. Assembly.Load(assemblyStream.ToArray());
  178. }
  179. }
  180. }
  181. internal static class PatchedAssemblyResolver
  182. {
  183. public static HarmonyInstance HarmonyInstance { get; } = HarmonyInstance.Create("com.bepis.bepinex.asmlocationfix");
  184. public static Dictionary<string, string> AssemblyLocations { get; } = new Dictionary<string, string>(StringComparer.InvariantCultureIgnoreCase);
  185. public static void ApplyPatch()
  186. {
  187. HarmonyWrapper.PatchAll(typeof(PatchedAssemblyResolver), HarmonyInstance);
  188. }
  189. [HarmonyPostfix, HarmonyPatch(typeof(Assembly), nameof(Assembly.Location), MethodType.Getter)]
  190. public static void GetLocation(ref string __result, Assembly __instance)
  191. {
  192. if (AssemblyLocations.TryGetValue(__instance.FullName, out string location))
  193. __result = location;
  194. }
  195. [HarmonyPostfix, HarmonyPatch(typeof(Assembly), nameof(Assembly.CodeBase), MethodType.Getter)]
  196. public static void GetCodeBase(ref string __result, Assembly __instance)
  197. {
  198. if (AssemblyLocations.TryGetValue(__instance.FullName, out string location))
  199. __result = $"file://{location.Replace('\\', '/')}";
  200. }
  201. }
  202. }