Extensions.cs 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Mono.CecilX;
  5. namespace Mirror.Weaver
  6. {
  7. public static class Extensions
  8. {
  9. public static bool Is(this TypeReference td, Type t)
  10. {
  11. if (t.IsGenericType)
  12. {
  13. return td.GetElementType().FullName == t.FullName;
  14. }
  15. return td.FullName == t.FullName;
  16. }
  17. public static bool Is<T>(this TypeReference td) => Is(td, typeof(T));
  18. public static bool IsDerivedFrom<T>(this TypeReference tr) => IsDerivedFrom(tr, typeof(T));
  19. public static bool IsDerivedFrom(this TypeReference tr, Type baseClass)
  20. {
  21. TypeDefinition td = tr.Resolve();
  22. if (!td.IsClass)
  23. return false;
  24. // are ANY parent classes of baseClass?
  25. TypeReference parent = td.BaseType;
  26. if (parent == null)
  27. return false;
  28. if (parent.Is(baseClass))
  29. return true;
  30. if (parent.CanBeResolved())
  31. return IsDerivedFrom(parent.Resolve(), baseClass);
  32. return false;
  33. }
  34. public static TypeReference GetEnumUnderlyingType(this TypeDefinition td)
  35. {
  36. foreach (FieldDefinition field in td.Fields)
  37. {
  38. if (!field.IsStatic)
  39. return field.FieldType;
  40. }
  41. throw new ArgumentException($"Invalid enum {td.FullName}");
  42. }
  43. public static bool ImplementsInterface<TInterface>(this TypeDefinition td)
  44. {
  45. TypeDefinition typedef = td;
  46. while (typedef != null)
  47. {
  48. if (typedef.Interfaces.Any(iface => iface.InterfaceType.Is<TInterface>()))
  49. return true;
  50. try
  51. {
  52. TypeReference parent = typedef.BaseType;
  53. typedef = parent?.Resolve();
  54. }
  55. catch (AssemblyResolutionException)
  56. {
  57. // this can happen for plugins.
  58. //Console.WriteLine("AssemblyResolutionException: "+ ex.ToString());
  59. break;
  60. }
  61. }
  62. return false;
  63. }
  64. public static bool IsMultidimensionalArray(this TypeReference tr)
  65. {
  66. return tr is ArrayType arrayType && arrayType.Rank > 1;
  67. }
  68. /// <summary>
  69. /// Does type use netId as backing field
  70. /// </summary>
  71. public static bool IsNetworkIdentityField(this TypeReference tr)
  72. {
  73. return tr.Is<UnityEngine.GameObject>()
  74. || tr.Is<NetworkIdentity>()
  75. || tr.IsDerivedFrom<NetworkBehaviour>();
  76. }
  77. public static bool CanBeResolved(this TypeReference parent)
  78. {
  79. while (parent != null)
  80. {
  81. if (parent.Scope.Name == "Windows")
  82. {
  83. return false;
  84. }
  85. if (parent.Scope.Name == "mscorlib")
  86. {
  87. TypeDefinition resolved = parent.Resolve();
  88. return resolved != null;
  89. }
  90. try
  91. {
  92. parent = parent.Resolve().BaseType;
  93. }
  94. catch
  95. {
  96. return false;
  97. }
  98. }
  99. return true;
  100. }
  101. /// <summary>
  102. /// Makes T => Variable and imports function
  103. /// </summary>
  104. /// <param name="generic"></param>
  105. /// <param name="variableReference"></param>
  106. /// <returns></returns>
  107. public static MethodReference MakeGeneric(this MethodReference generic, TypeReference variableReference)
  108. {
  109. GenericInstanceMethod instance = new GenericInstanceMethod(generic);
  110. instance.GenericArguments.Add(variableReference);
  111. MethodReference readFunc = Weaver.CurrentAssembly.MainModule.ImportReference(instance);
  112. return readFunc;
  113. }
  114. /// <summary>
  115. /// Given a method of a generic class such as ArraySegment`T.get_Count,
  116. /// and a generic instance such as ArraySegment`int
  117. /// Creates a reference to the specialized method ArraySegment`int`.get_Count
  118. /// <para> Note that calling ArraySegment`T.get_Count directly gives an invalid IL error </para>
  119. /// </summary>
  120. /// <param name="self"></param>
  121. /// <param name="instanceType"></param>
  122. /// <returns></returns>
  123. public static MethodReference MakeHostInstanceGeneric(this MethodReference self, GenericInstanceType instanceType)
  124. {
  125. MethodReference reference = new MethodReference(self.Name, self.ReturnType, instanceType)
  126. {
  127. CallingConvention = self.CallingConvention,
  128. HasThis = self.HasThis,
  129. ExplicitThis = self.ExplicitThis
  130. };
  131. foreach (ParameterDefinition parameter in self.Parameters)
  132. reference.Parameters.Add(new ParameterDefinition(parameter.ParameterType));
  133. foreach (GenericParameter generic_parameter in self.GenericParameters)
  134. reference.GenericParameters.Add(new GenericParameter(generic_parameter.Name, reference));
  135. return Weaver.CurrentAssembly.MainModule.ImportReference(reference);
  136. }
  137. /// <summary>
  138. /// Given a field of a generic class such as Writer<T>.write,
  139. /// and a generic instance such as ArraySegment`int
  140. /// Creates a reference to the specialized method ArraySegment`int`.get_Count
  141. /// <para> Note that calling ArraySegment`T.get_Count directly gives an invalid IL error </para>
  142. /// </summary>
  143. /// <param name="self"></param>
  144. /// <param name="instanceType">Generic Instance e.g. Writer<int></param>
  145. /// <returns></returns>
  146. public static FieldReference SpecializeField(this FieldReference self, GenericInstanceType instanceType)
  147. {
  148. FieldReference reference = new FieldReference(self.Name, self.FieldType, instanceType);
  149. return Weaver.CurrentAssembly.MainModule.ImportReference(reference);
  150. }
  151. public static CustomAttribute GetCustomAttribute<TAttribute>(this ICustomAttributeProvider method)
  152. {
  153. return method.CustomAttributes.FirstOrDefault(ca => ca.AttributeType.Is<TAttribute>());
  154. }
  155. public static bool HasCustomAttribute<TAttribute>(this ICustomAttributeProvider attributeProvider)
  156. {
  157. return attributeProvider.CustomAttributes.Any(attr => attr.AttributeType.Is<TAttribute>());
  158. }
  159. public static T GetField<T>(this CustomAttribute ca, string field, T defaultValue)
  160. {
  161. foreach (CustomAttributeNamedArgument customField in ca.Fields)
  162. if (customField.Name == field)
  163. return (T)customField.Argument.Value;
  164. return defaultValue;
  165. }
  166. public static MethodDefinition GetMethod(this TypeDefinition td, string methodName)
  167. {
  168. return td.Methods.FirstOrDefault(method => method.Name == methodName);
  169. }
  170. public static List<MethodDefinition> GetMethods(this TypeDefinition td, string methodName)
  171. {
  172. return td.Methods.Where(method => method.Name == methodName).ToList();
  173. }
  174. public static MethodDefinition GetMethodInBaseType(this TypeDefinition td, string methodName)
  175. {
  176. TypeDefinition typedef = td;
  177. while (typedef != null)
  178. {
  179. foreach (MethodDefinition md in typedef.Methods)
  180. {
  181. if (md.Name == methodName)
  182. return md;
  183. }
  184. try
  185. {
  186. TypeReference parent = typedef.BaseType;
  187. typedef = parent?.Resolve();
  188. }
  189. catch (AssemblyResolutionException)
  190. {
  191. // this can happen for plugins.
  192. break;
  193. }
  194. }
  195. return null;
  196. }
  197. /// <summary>
  198. /// Finds public fields in type and base type
  199. /// </summary>
  200. /// <param name="variable"></param>
  201. /// <returns></returns>
  202. public static IEnumerable<FieldDefinition> FindAllPublicFields(this TypeReference variable)
  203. {
  204. return FindAllPublicFields(variable.Resolve());
  205. }
  206. /// <summary>
  207. /// Finds public fields in type and base type
  208. /// </summary>
  209. /// <param name="variable"></param>
  210. /// <returns></returns>
  211. public static IEnumerable<FieldDefinition> FindAllPublicFields(this TypeDefinition typeDefinition)
  212. {
  213. while (typeDefinition != null)
  214. {
  215. foreach (FieldDefinition field in typeDefinition.Fields)
  216. {
  217. if (field.IsStatic || field.IsPrivate)
  218. continue;
  219. if (field.IsNotSerialized)
  220. continue;
  221. yield return field;
  222. }
  223. try
  224. {
  225. typeDefinition = typeDefinition.BaseType?.Resolve();
  226. }
  227. catch (AssemblyResolutionException)
  228. {
  229. break;
  230. }
  231. }
  232. }
  233. }
  234. }