Client LuaCsForBarotrauma
LuaCsHookCompat.cs
1 using System;
2 using System.Linq;
3 using System.Reflection;
4 using HarmonyLib;
5 using System.Collections.Generic;
6 using MoonSharp.Interpreter;
7 using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch;
8 
9 namespace Barotrauma
10 {
11  // XXX: this can't be renamed because of backward compatibility with C# mods
12  public delegate object LuaCsPatch(object self, Dictionary<string, object> args);
13 
14  partial class LuaCsHook
15  {
16  private Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>> compatHookPrefixMethods = new Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>>();
17  private Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>> compatHookPostfixMethods = new Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>>();
18 
19  private static void _hookLuaCsPatch(MethodBase __originalMethod, object[] __args, object __instance, out object result, HookMethodType hookType)
20  {
21  result = null;
22 
23  try
24  {
25  var funcAddr = ((long)__originalMethod.MethodHandle.GetFunctionPointer());
26  HashSet<(string, LuaCsCompatPatchFunc, ACsMod)> methodSet = null;
27  switch (hookType)
28  {
29  case HookMethodType.Before:
30  instance.compatHookPrefixMethods.TryGetValue(funcAddr, out methodSet);
31  break;
32  case HookMethodType.After:
33  instance.compatHookPostfixMethods.TryGetValue(funcAddr, out methodSet);
34  break;
35  default:
36  throw new ArgumentException($"Invalid {nameof(HookMethodType)} enum value.", nameof(hookType));
37  }
38 
39  if (methodSet != null)
40  {
41  var @params = __originalMethod.GetParameters();
42  var args = new Dictionary<string, object>();
43  for (int i = 0; i < @params.Length; i++)
44  {
45  args.Add(@params[i].Name, __args[i]);
46  }
47 
48  var outOfSocpe = new HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>();
49  foreach (var tuple in methodSet)
50  {
51  if (tuple.Item3 != null && tuple.Item3.IsDisposed)
52  {
53  outOfSocpe.Add(tuple);
54  }
55  else
56  {
57  var _result = tuple.Item2(__instance, args);
58  if (_result != null)
59  {
60  if (_result is DynValue res)
61  {
62  if (!res.IsNil())
63  {
64  if (__originalMethod is MethodInfo mi && mi.ReturnType != typeof(void))
65  {
66  result = res.ToObject(mi.ReturnType);
67  }
68  else
69  {
70  result = res.ToObject();
71  }
72  }
73  }
74  else
75  {
76  result = _result;
77  }
78  }
79  }
80  }
81  foreach (var tuple in outOfSocpe) { methodSet.Remove(tuple); }
82  }
83  }
84  catch (Exception ex)
85  {
86  LuaCsLogger.LogError($"Error in {__originalMethod.Name}:", LuaCsMessageOrigin.Unknown);
87  LuaCsLogger.HandleException(ex, LuaCsMessageOrigin.Unknown);
88  }
89  }
90 
91 
92  private static bool HookLuaCsPatchPrefix(MethodBase __originalMethod, object[] __args, object __instance)
93  {
94  _hookLuaCsPatch(__originalMethod, __args, __instance, out object result, HookMethodType.Before);
95  return result == null;
96  }
97 
98  private static void HookLuaCsPatchPostfix(MethodBase __originalMethod, object[] __args, object __instance) =>
99  _hookLuaCsPatch(__originalMethod, __args, __instance, out object _, HookMethodType.After);
100 
101  private static bool HookLuaCsPatchRetPrefix(MethodBase __originalMethod, object[] __args, ref object __result, object __instance)
102  {
103  _hookLuaCsPatch(__originalMethod, __args, __instance, out object result, HookMethodType.Before);
104  if (result != null)
105  {
106  __result = result;
107  return false;
108  }
109  else return true;
110  }
111 
112  private static void HookLuaCsPatchRetPostfix(MethodBase __originalMethod, object[] __args, ref object __result, object __instance)
113  {
114  _hookLuaCsPatch(__originalMethod, __args, __instance, out object result, HookMethodType.After);
115  if (result != null) __result = result;
116  }
117 
118  private static MethodInfo _miHookLuaCsPatchPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPrefix", BindingFlags.NonPublic | BindingFlags.Static);
119  private static MethodInfo _miHookLuaCsPatchPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static);
120  private static MethodInfo _miHookLuaCsPatchRetPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static);
121  private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static);
122 
123  // TODO: deprecate this
124  public void HookMethod(string identifier, MethodBase method, LuaCsCompatPatchFunc patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null)
125  {
126  if (identifier == null || method == null || patch == null)
127  {
128  LuaCsLogger.HandleException(new ArgumentNullException("Identifier, Method and Patch arguments must not be null."), LuaCsMessageOrigin.Unknown);
129  return;
130  }
131  ValidatePatchTarget(method);
132 
133  var funcAddr = ((long)method.MethodHandle.GetFunctionPointer());
134  var patches = Harmony.GetPatchInfo(method);
135 
136  if (hookType == HookMethodType.Before)
137  {
138  if (method is MethodInfo mi && mi.ReturnType != typeof(void))
139  {
140  if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchRetPrefix) == null)
141  {
142  harmony.Patch(method, prefix: new HarmonyMethod(_miHookLuaCsPatchRetPrefix));
143  }
144  }
145  else
146  {
147  if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchPrefix) == null)
148  {
149  harmony.Patch(method, prefix: new HarmonyMethod(_miHookLuaCsPatchPrefix));
150  }
151  }
152 
153  if (compatHookPrefixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsCompatPatchFunc, ACsMod)> methodSet))
154  {
155  if (identifier != "")
156  {
157  methodSet.RemoveWhere(tuple => tuple.Item1 == identifier);
158  }
159 
160  methodSet.Add((identifier, patch, owner));
161  }
162  else if (patch != null)
163  {
164  compatHookPrefixMethods.Add(funcAddr, new HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>() { (identifier, patch, owner) });
165  }
166 
167  }
168  else if (hookType == HookMethodType.After)
169  {
170  if (method is MethodInfo mi && mi.ReturnType != typeof(void))
171  {
172  if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchRetPostfix) == null)
173  {
174  harmony.Patch(method, postfix: new HarmonyMethod(_miHookLuaCsPatchRetPostfix));
175  }
176  }
177  else
178  {
179  if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchPostfix) == null)
180  {
181  harmony.Patch(method, postfix: new HarmonyMethod(_miHookLuaCsPatchPostfix));
182  }
183  }
184 
185  if (compatHookPostfixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsCompatPatchFunc, ACsMod)> methodSet))
186  {
187  if (identifier != "")
188  {
189  methodSet.RemoveWhere(tuple => tuple.Item1 == identifier);
190  }
191 
192  methodSet.Add((identifier, patch, owner));
193  }
194  else if (patch != null)
195  {
196  compatHookPostfixMethods.Add(funcAddr, new HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>() { (identifier, patch, owner) });
197  }
198  }
199  }
200  protected void HookMethod(string identifier, string className, string methodName, string[] parameterNames, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before)
201  {
202  var method = ResolveMethod(className, methodName, parameterNames);
203  if (method == null) return;
204  if (method.GetParameters().Any(x => x.ParameterType.IsByRef))
205  {
206  throw new InvalidOperationException($"{nameof(HookMethod)} doesn't support ByRef parameters; use {nameof(Patch)} instead.");
207  }
208  HookMethod(identifier, method, patch, hookMethodType);
209  }
210  protected void HookMethod(string identifier, string className, string methodName, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before) =>
211  HookMethod(identifier, className, methodName, null, patch, hookMethodType);
212  protected void HookMethod(string className, string methodName, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before) =>
213  HookMethod("", className, methodName, null, patch, hookMethodType);
214  protected void HookMethod(string className, string methodName, string[] parameterNames, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before) =>
215  HookMethod("", className, methodName, parameterNames, patch, hookMethodType);
216 
217 
218  public void UnhookMethod(string identifier, MethodBase method, HookMethodType hookType = HookMethodType.Before)
219  {
220  var funcAddr = (long)method.MethodHandle.GetFunctionPointer();
221 
222  Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>> methods;
223  if (hookType == HookMethodType.Before) methods = compatHookPrefixMethods;
224  else if (hookType == HookMethodType.After) methods = compatHookPostfixMethods;
225  else throw null;
226 
227  if (methods.ContainsKey(funcAddr)) methods[funcAddr]?.RemoveWhere(t => t.Item1 == identifier);
228  }
229  protected void UnhookMethod(string identifier, string className, string methodName, string[] parameterNames, HookMethodType hookType = HookMethodType.Before)
230  {
231  var method = ResolveMethod(className, methodName, parameterNames);
232  if (method == null) return;
233  UnhookMethod(identifier, method, hookType);
234  }
235  }
236 }
void HookMethod(string identifier, string className, string methodName, string[] parameterNames, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType=HookMethodType.Before)
void HookMethod(string identifier, MethodBase method, LuaCsCompatPatchFunc patch, HookMethodType hookType=HookMethodType.Before, ACsMod owner=null)
void HookMethod(string className, string methodName, string[] parameterNames, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType=HookMethodType.Before)
void HookMethod(string className, string methodName, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType=HookMethodType.Before)
void HookMethod(string identifier, string className, string methodName, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType=HookMethodType.Before)
void UnhookMethod(string identifier, string className, string methodName, string[] parameterNames, HookMethodType hookType=HookMethodType.Before)
void UnhookMethod(string identifier, MethodBase method, HookMethodType hookType=HookMethodType.Before)
static void HandleException(Exception ex, LuaCsMessageOrigin origin)
delegate object LuaCsPatch(object self, Dictionary< string, object > args)