Extensions.cs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Mono.CecilX;
  5. namespace Mirror.Weaver
  6. {
  7. public static class Extensions
  8. {
  9. public static bool Is(this TypeReference td, Type type) =>
  10. type.IsGenericType
  11. ? td.GetElementType().FullName == type.FullName
  12. : td.FullName == type.FullName;
  13. // check if 'td' is exactly of type T.
  14. // it does not check if any base type is of <T>, only the specific type.
  15. // for example:
  16. // NetworkConnection Is NetworkConnection: true
  17. // NetworkConnectionToClient Is NetworkConnection: false
  18. public static bool Is<T>(this TypeReference td) => Is(td, typeof(T));
  19. // check if 'tr' is derived from T.
  20. // it does not check if 'tr' is exactly T.
  21. // for example:
  22. // NetworkConnection IsDerivedFrom<NetworkConnection>: false
  23. // NetworkConnectionToClient IsDerivedFrom<NetworkConnection>: true
  24. public static bool IsDerivedFrom<T>(this TypeReference tr) => IsDerivedFrom(tr, typeof(T));
  25. public static bool IsDerivedFrom(this TypeReference tr, Type baseClass)
  26. {
  27. TypeDefinition td = tr.Resolve();
  28. if (!td.IsClass)
  29. return false;
  30. // are ANY parent classes of baseClass?
  31. TypeReference parent = td.BaseType;
  32. if (parent == null)
  33. return false;
  34. if (parent.Is(baseClass))
  35. return true;
  36. if (parent.CanBeResolved())
  37. return IsDerivedFrom(parent.Resolve(), baseClass);
  38. return false;
  39. }
  40. public static TypeReference GetEnumUnderlyingType(this TypeDefinition td)
  41. {
  42. foreach (FieldDefinition field in td.Fields)
  43. {
  44. if (!field.IsStatic)
  45. return field.FieldType;
  46. }
  47. throw new ArgumentException($"Invalid enum {td.FullName}");
  48. }
  49. public static bool ImplementsInterface<TInterface>(this TypeDefinition td)
  50. {
  51. TypeDefinition typedef = td;
  52. while (typedef != null)
  53. {
  54. if (typedef.Interfaces.Any(iface => iface.InterfaceType.Is<TInterface>()))
  55. return true;
  56. try
  57. {
  58. TypeReference parent = typedef.BaseType;
  59. typedef = parent?.Resolve();
  60. }
  61. catch (AssemblyResolutionException)
  62. {
  63. // this can happen for plugins.
  64. //Console.WriteLine("AssemblyResolutionException: "+ ex.ToString());
  65. break;
  66. }
  67. }
  68. return false;
  69. }
  70. public static bool IsMultidimensionalArray(this TypeReference tr) =>
  71. tr is ArrayType arrayType && arrayType.Rank > 1;
  72. // Does type use netId as backing field
  73. public static bool IsNetworkIdentityField(this TypeReference tr) =>
  74. tr.Is<UnityEngine.GameObject>() ||
  75. tr.Is<NetworkIdentity>() ||
  76. // handle both NetworkBehaviour and inheritors.
  77. // fixes: https://github.com/MirrorNetworking/Mirror/issues/2939
  78. tr.IsDerivedFrom<NetworkBehaviour>() ||
  79. tr.Is<NetworkBehaviour>();
  80. public static bool CanBeResolved(this TypeReference parent)
  81. {
  82. while (parent != null)
  83. {
  84. if (parent.Scope.Name == "Windows")
  85. {
  86. return false;
  87. }
  88. if (parent.Scope.Name == "mscorlib")
  89. {
  90. TypeDefinition resolved = parent.Resolve();
  91. return resolved != null;
  92. }
  93. try
  94. {
  95. parent = parent.Resolve().BaseType;
  96. }
  97. catch
  98. {
  99. return false;
  100. }
  101. }
  102. return true;
  103. }
  104. // Makes T => Variable and imports function
  105. public static MethodReference MakeGeneric(this MethodReference generic, ModuleDefinition module, TypeReference variableReference)
  106. {
  107. GenericInstanceMethod instance = new GenericInstanceMethod(generic);
  108. instance.GenericArguments.Add(variableReference);
  109. MethodReference readFunc = module.ImportReference(instance);
  110. return readFunc;
  111. }
  112. // Given a method of a generic class such as ArraySegment`T.get_Count,
  113. // and a generic instance such as ArraySegment`int
  114. // Creates a reference to the specialized method ArraySegment`int`.get_Count
  115. // Note that calling ArraySegment`T.get_Count directly gives an invalid IL error
  116. public static MethodReference MakeHostInstanceGeneric(this MethodReference self, ModuleDefinition module, GenericInstanceType instanceType)
  117. {
  118. MethodReference reference = new MethodReference(self.Name, self.ReturnType, instanceType)
  119. {
  120. CallingConvention = self.CallingConvention,
  121. HasThis = self.HasThis,
  122. ExplicitThis = self.ExplicitThis
  123. };
  124. foreach (ParameterDefinition parameter in self.Parameters)
  125. reference.Parameters.Add(new ParameterDefinition(parameter.ParameterType));
  126. foreach (GenericParameter generic_parameter in self.GenericParameters)
  127. reference.GenericParameters.Add(new GenericParameter(generic_parameter.Name, reference));
  128. return module.ImportReference(reference);
  129. }
  130. // needed for NetworkBehaviour<T> support
  131. // https://github.com/vis2k/Mirror/pull/3073/
  132. public static FieldReference MakeHostInstanceGeneric(this FieldReference self)
  133. {
  134. var declaringType = new GenericInstanceType(self.DeclaringType);
  135. foreach (var parameter in self.DeclaringType.GenericParameters)
  136. {
  137. declaringType.GenericArguments.Add(parameter);
  138. }
  139. return new FieldReference(self.Name, self.FieldType, declaringType);
  140. }
  141. // Given a field of a generic class such as Writer<T>.write,
  142. // and a generic instance such as ArraySegment`int
  143. // Creates a reference to the specialized method ArraySegment`int`.get_Count
  144. // Note that calling ArraySegment`T.get_Count directly gives an invalid IL error
  145. public static FieldReference SpecializeField(this FieldReference self, ModuleDefinition module, GenericInstanceType instanceType)
  146. {
  147. FieldReference reference = new FieldReference(self.Name, self.FieldType, instanceType);
  148. return module.ImportReference(reference);
  149. }
  150. public static CustomAttribute GetCustomAttribute<TAttribute>(this ICustomAttributeProvider method)
  151. {
  152. return method.CustomAttributes.FirstOrDefault(ca => ca.AttributeType.Is<TAttribute>());
  153. }
  154. public static bool HasCustomAttribute<TAttribute>(this ICustomAttributeProvider attributeProvider)
  155. {
  156. return attributeProvider.CustomAttributes.Any(attr => attr.AttributeType.Is<TAttribute>());
  157. }
  158. public static T GetField<T>(this CustomAttribute ca, string field, T defaultValue)
  159. {
  160. foreach (CustomAttributeNamedArgument customField in ca.Fields)
  161. if (customField.Name == field)
  162. return (T)customField.Argument.Value;
  163. return defaultValue;
  164. }
  165. public static MethodDefinition GetMethod(this TypeDefinition td, string methodName)
  166. {
  167. return td.Methods.FirstOrDefault(method => method.Name == methodName);
  168. }
  169. public static List<MethodDefinition> GetMethods(this TypeDefinition td, string methodName)
  170. {
  171. return td.Methods.Where(method => method.Name == methodName).ToList();
  172. }
  173. public static MethodDefinition GetMethodInBaseType(this TypeDefinition td, string methodName)
  174. {
  175. TypeDefinition typedef = td;
  176. while (typedef != null)
  177. {
  178. foreach (MethodDefinition md in typedef.Methods)
  179. {
  180. if (md.Name == methodName)
  181. return md;
  182. }
  183. try
  184. {
  185. TypeReference parent = typedef.BaseType;
  186. typedef = parent?.Resolve();
  187. }
  188. catch (AssemblyResolutionException)
  189. {
  190. // this can happen for plugins.
  191. break;
  192. }
  193. }
  194. return null;
  195. }
  196. // Finds public fields in type and base type
  197. public static IEnumerable<FieldDefinition> FindAllPublicFields(this TypeReference variable)
  198. {
  199. return FindAllPublicFields(variable.Resolve());
  200. }
  201. // Finds public fields in type and base type
  202. public static IEnumerable<FieldDefinition> FindAllPublicFields(this TypeDefinition typeDefinition)
  203. {
  204. while (typeDefinition != null)
  205. {
  206. foreach (FieldDefinition field in typeDefinition.Fields)
  207. {
  208. // ignore static, private, protected fields
  209. // fixes: https://github.com/MirrorNetworking/Mirror/issues/3485
  210. // credit: James Frowen
  211. if (field.IsStatic || field.IsPrivate || field.IsFamily)
  212. continue;
  213. // also ignore internal fields
  214. // we dont want to create different writers for this type if they are in current dll or another dll
  215. // so we have to ignore internal in all cases
  216. if (field.IsAssembly)
  217. continue;
  218. if (field.IsNotSerialized)
  219. continue;
  220. yield return field;
  221. }
  222. try
  223. {
  224. typeDefinition = typeDefinition.BaseType?.Resolve();
  225. }
  226. catch (AssemblyResolutionException)
  227. {
  228. break;
  229. }
  230. }
  231. }
  232. public static bool ContainsClass(this ModuleDefinition module, string nameSpace, string className) =>
  233. module.GetTypes().Any(td => td.Namespace == nameSpace &&
  234. td.Name == className);
  235. public static AssemblyNameReference FindReference(this ModuleDefinition module, string referenceName)
  236. {
  237. foreach (AssemblyNameReference reference in module.AssemblyReferences)
  238. {
  239. if (reference.Name == referenceName)
  240. return reference;
  241. }
  242. return null;
  243. }
  244. // Takes generic arguments from child class and applies them to parent reference, if possible
  245. // eg makes `Base<T>` in Child<int> : Base<int> have `int` instead of `T`
  246. // Originally by James-Frowen under MIT
  247. // https://github.com/MirageNet/Mirage/commit/cf91e1d54796866d2cf87f8e919bb5c681977e45
  248. public static TypeReference ApplyGenericParameters(this TypeReference parentReference,
  249. TypeReference childReference)
  250. {
  251. // If the parent is not generic, we got nothing to apply
  252. if (!parentReference.IsGenericInstance)
  253. return parentReference;
  254. GenericInstanceType parentGeneric = (GenericInstanceType)parentReference;
  255. // make new type so we can replace the args on it
  256. // resolve it so we have non-generic instance (eg just instance with <T> instead of <int>)
  257. // if we don't cecil will make it double generic (eg INVALID IL)
  258. GenericInstanceType generic = new GenericInstanceType(parentReference.Resolve());
  259. foreach (TypeReference arg in parentGeneric.GenericArguments)
  260. generic.GenericArguments.Add(arg);
  261. for (int i = 0; i < generic.GenericArguments.Count; i++)
  262. {
  263. // if arg is not generic
  264. // eg List<int> would be int so not generic.
  265. // But List<T> would be T so is generic
  266. if (!generic.GenericArguments[i].IsGenericParameter)
  267. continue;
  268. // get the generic name, eg T
  269. string name = generic.GenericArguments[i].Name;
  270. // find what type T is, eg turn it into `int` if `List<int>`
  271. TypeReference arg = FindMatchingGenericArgument(childReference, name);
  272. // import just to be safe
  273. TypeReference imported = parentReference.Module.ImportReference(arg);
  274. // set arg on generic, parent ref will be Base<int> instead of just Base<T>
  275. generic.GenericArguments[i] = imported;
  276. }
  277. return generic;
  278. }
  279. // Finds the type reference for a generic parameter with the provided name in the child reference
  280. // Originally by James-Frowen under MIT
  281. // https://github.com/MirageNet/Mirage/commit/cf91e1d54796866d2cf87f8e919bb5c681977e45
  282. static TypeReference FindMatchingGenericArgument(TypeReference childReference, string paramName)
  283. {
  284. TypeDefinition def = childReference.Resolve();
  285. // child class must be generic if we are in this part of the code
  286. // eg Child<T> : Base<T> <--- child must have generic if Base has T
  287. // vs Child : Base<int> <--- wont be here if Base has int (we check if T exists before calling this)
  288. if (!def.HasGenericParameters)
  289. throw new InvalidOperationException(
  290. "Base class had generic parameters, but could not find them in child class");
  291. // go through parameters in child class, and find the generic that matches the name
  292. for (int i = 0; i < def.GenericParameters.Count; i++)
  293. {
  294. GenericParameter param = def.GenericParameters[i];
  295. if (param.Name == paramName)
  296. {
  297. GenericInstanceType generic = (GenericInstanceType)childReference;
  298. // return generic arg with same index
  299. return generic.GenericArguments[i];
  300. }
  301. }
  302. // this should never happen, if it does it means that this code is bugged
  303. throw new InvalidOperationException("Did not find matching generic");
  304. }
  305. }
  306. }