Extensions.cs 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Mono.CecilX;
  5. namespace Mirror.Weaver
  6. {
  7. public static class Extensions
  8. {
  9. public static bool Is(this TypeReference td, Type t)
  10. {
  11. if (t.IsGenericType)
  12. {
  13. return td.GetElementType().FullName == t.FullName;
  14. }
  15. return td.FullName == t.FullName;
  16. }
  17. public static bool Is<T>(this TypeReference td) => Is(td, typeof(T));
  18. public static bool IsDerivedFrom<T>(this TypeReference tr) => IsDerivedFrom(tr, typeof(T));
  19. public static bool IsDerivedFrom(this TypeReference tr, Type baseClass)
  20. {
  21. TypeDefinition td = tr.Resolve();
  22. if (!td.IsClass)
  23. return false;
  24. // are ANY parent classes of baseClass?
  25. TypeReference parent = td.BaseType;
  26. if (parent == null)
  27. return false;
  28. if (parent.Is(baseClass))
  29. return true;
  30. if (parent.CanBeResolved())
  31. return IsDerivedFrom(parent.Resolve(), baseClass);
  32. return false;
  33. }
  34. public static TypeReference GetEnumUnderlyingType(this TypeDefinition td)
  35. {
  36. foreach (FieldDefinition field in td.Fields)
  37. {
  38. if (!field.IsStatic)
  39. return field.FieldType;
  40. }
  41. throw new ArgumentException($"Invalid enum {td.FullName}");
  42. }
  43. public static bool ImplementsInterface<TInterface>(this TypeDefinition td)
  44. {
  45. TypeDefinition typedef = td;
  46. while (typedef != null)
  47. {
  48. if (typedef.Interfaces.Any(iface => iface.InterfaceType.Is<TInterface>()))
  49. return true;
  50. try
  51. {
  52. TypeReference parent = typedef.BaseType;
  53. typedef = parent?.Resolve();
  54. }
  55. catch (AssemblyResolutionException)
  56. {
  57. // this can happen for plugins.
  58. //Console.WriteLine("AssemblyResolutionException: "+ ex.ToString());
  59. break;
  60. }
  61. }
  62. return false;
  63. }
  64. public static bool IsMultidimensionalArray(this TypeReference tr) =>
  65. tr is ArrayType arrayType && arrayType.Rank > 1;
  66. // Does type use netId as backing field
  67. public static bool IsNetworkIdentityField(this TypeReference tr) =>
  68. tr.Is<UnityEngine.GameObject>() ||
  69. tr.Is<NetworkIdentity>() ||
  70. tr.IsDerivedFrom<NetworkBehaviour>();
  71. public static bool CanBeResolved(this TypeReference parent)
  72. {
  73. while (parent != null)
  74. {
  75. if (parent.Scope.Name == "Windows")
  76. {
  77. return false;
  78. }
  79. if (parent.Scope.Name == "mscorlib")
  80. {
  81. TypeDefinition resolved = parent.Resolve();
  82. return resolved != null;
  83. }
  84. try
  85. {
  86. parent = parent.Resolve().BaseType;
  87. }
  88. catch
  89. {
  90. return false;
  91. }
  92. }
  93. return true;
  94. }
  95. // Makes T => Variable and imports function
  96. public static MethodReference MakeGeneric(this MethodReference generic, ModuleDefinition module, TypeReference variableReference)
  97. {
  98. GenericInstanceMethod instance = new GenericInstanceMethod(generic);
  99. instance.GenericArguments.Add(variableReference);
  100. MethodReference readFunc = module.ImportReference(instance);
  101. return readFunc;
  102. }
  103. // Given a method of a generic class such as ArraySegment`T.get_Count,
  104. // and a generic instance such as ArraySegment`int
  105. // Creates a reference to the specialized method ArraySegment`int`.get_Count
  106. // Note that calling ArraySegment`T.get_Count directly gives an invalid IL error
  107. public static MethodReference MakeHostInstanceGeneric(this MethodReference self, ModuleDefinition module, GenericInstanceType instanceType)
  108. {
  109. MethodReference reference = new MethodReference(self.Name, self.ReturnType, instanceType)
  110. {
  111. CallingConvention = self.CallingConvention,
  112. HasThis = self.HasThis,
  113. ExplicitThis = self.ExplicitThis
  114. };
  115. foreach (ParameterDefinition parameter in self.Parameters)
  116. reference.Parameters.Add(new ParameterDefinition(parameter.ParameterType));
  117. foreach (GenericParameter generic_parameter in self.GenericParameters)
  118. reference.GenericParameters.Add(new GenericParameter(generic_parameter.Name, reference));
  119. return module.ImportReference(reference);
  120. }
  121. // Given a field of a generic class such as Writer<T>.write,
  122. // and a generic instance such as ArraySegment`int
  123. // Creates a reference to the specialized method ArraySegment`int`.get_Count
  124. // Note that calling ArraySegment`T.get_Count directly gives an invalid IL error
  125. public static FieldReference SpecializeField(this FieldReference self, ModuleDefinition module, GenericInstanceType instanceType)
  126. {
  127. FieldReference reference = new FieldReference(self.Name, self.FieldType, instanceType);
  128. return module.ImportReference(reference);
  129. }
  130. public static CustomAttribute GetCustomAttribute<TAttribute>(this ICustomAttributeProvider method)
  131. {
  132. return method.CustomAttributes.FirstOrDefault(ca => ca.AttributeType.Is<TAttribute>());
  133. }
  134. public static bool HasCustomAttribute<TAttribute>(this ICustomAttributeProvider attributeProvider)
  135. {
  136. return attributeProvider.CustomAttributes.Any(attr => attr.AttributeType.Is<TAttribute>());
  137. }
  138. public static T GetField<T>(this CustomAttribute ca, string field, T defaultValue)
  139. {
  140. foreach (CustomAttributeNamedArgument customField in ca.Fields)
  141. if (customField.Name == field)
  142. return (T)customField.Argument.Value;
  143. return defaultValue;
  144. }
  145. public static MethodDefinition GetMethod(this TypeDefinition td, string methodName)
  146. {
  147. return td.Methods.FirstOrDefault(method => method.Name == methodName);
  148. }
  149. public static List<MethodDefinition> GetMethods(this TypeDefinition td, string methodName)
  150. {
  151. return td.Methods.Where(method => method.Name == methodName).ToList();
  152. }
  153. public static MethodDefinition GetMethodInBaseType(this TypeDefinition td, string methodName)
  154. {
  155. TypeDefinition typedef = td;
  156. while (typedef != null)
  157. {
  158. foreach (MethodDefinition md in typedef.Methods)
  159. {
  160. if (md.Name == methodName)
  161. return md;
  162. }
  163. try
  164. {
  165. TypeReference parent = typedef.BaseType;
  166. typedef = parent?.Resolve();
  167. }
  168. catch (AssemblyResolutionException)
  169. {
  170. // this can happen for plugins.
  171. break;
  172. }
  173. }
  174. return null;
  175. }
  176. // Finds public fields in type and base type
  177. public static IEnumerable<FieldDefinition> FindAllPublicFields(this TypeReference variable)
  178. {
  179. return FindAllPublicFields(variable.Resolve());
  180. }
  181. // Finds public fields in type and base type
  182. public static IEnumerable<FieldDefinition> FindAllPublicFields(this TypeDefinition typeDefinition)
  183. {
  184. while (typeDefinition != null)
  185. {
  186. foreach (FieldDefinition field in typeDefinition.Fields)
  187. {
  188. if (field.IsStatic || field.IsPrivate)
  189. continue;
  190. if (field.IsNotSerialized)
  191. continue;
  192. yield return field;
  193. }
  194. try
  195. {
  196. typeDefinition = typeDefinition.BaseType?.Resolve();
  197. }
  198. catch (AssemblyResolutionException)
  199. {
  200. break;
  201. }
  202. }
  203. }
  204. public static bool ContainsClass(this ModuleDefinition module, string nameSpace, string className) =>
  205. module.GetTypes().Any(td => td.Namespace == nameSpace &&
  206. td.Name == className);
  207. public static AssemblyNameReference FindReference(this ModuleDefinition module, string referenceName)
  208. {
  209. foreach (AssemblyNameReference reference in module.AssemblyReferences)
  210. {
  211. if (reference.Name == referenceName)
  212. return reference;
  213. }
  214. return null;
  215. }
  216. }
  217. }