SpatialHashGrid.cs 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. using System;
  2. using System.Collections.Generic;
  3. using UnityEngine;
  4. namespace DunGen.Collision
  5. {
  6. /// <summary>
  7. /// A generic spatial hash grid that divides space into uniform cells for fast spatial queries.
  8. /// The grid operates on a 2D plane perpendicular to the specified up axis.
  9. /// </summary>
  10. /// <typeparam name="T">The type of objects stored in the grid.</typeparam>
  11. public class SpatialHashGrid<T>
  12. {
  13. private readonly Dictionary<long, List<T>> cells;
  14. private readonly float cellSize;
  15. private readonly Func<T, Bounds> getBounds;
  16. private readonly AxisDirection upAxis;
  17. private readonly (int, int) primaryAxes; // Indices for the two axes that form the 2D plane
  18. /// <summary>
  19. /// Constructs a spatial hash grid.
  20. /// </summary>
  21. /// <param name="cellSize">Size of each grid cell</param>
  22. /// <param name="getBounds">Delegate to extract bounds from objects</param>
  23. /// <param name="upDirection">The up axis direction (determines the 2D plane)</param>
  24. public SpatialHashGrid(float cellSize, Func<T, Bounds> getBounds, AxisDirection upDirection = AxisDirection.PosY)
  25. {
  26. this.cells = new Dictionary<long, List<T>>();
  27. this.cellSize = cellSize;
  28. this.getBounds = getBounds;
  29. this.upAxis = upDirection;
  30. // Determine which axes to use for the 2D grid based on up direction
  31. this.primaryAxes = GetPrimaryAxes(upDirection);
  32. }
  33. private (int, int) GetPrimaryAxes(AxisDirection upDirection)
  34. {
  35. switch (upDirection)
  36. {
  37. case AxisDirection.PosY:
  38. case AxisDirection.NegY:
  39. return (0, 2); // Use X and Z axes
  40. case AxisDirection.PosX:
  41. case AxisDirection.NegX:
  42. return (1, 2); // Use Y and Z axes
  43. case AxisDirection.PosZ:
  44. case AxisDirection.NegZ:
  45. return (0, 1); // Use X and Y axes
  46. default:
  47. throw new ArgumentException("Invalid axis direction", nameof(upDirection));
  48. }
  49. }
  50. private Vector2 GetGridPosition(Vector3 worldPos)
  51. {
  52. // Extract the two coordinates for our 2D plane based on the up axis
  53. float x = worldPos[primaryAxes.Item1];
  54. float y = worldPos[primaryAxes.Item2];
  55. return new Vector2(x, y);
  56. }
  57. private long GetCellKey(int x, int y)
  58. {
  59. return ((long)x << 32) | (y & 0xffffffffL);
  60. }
  61. /// <summary>
  62. /// Inserts an object into the spatial hash grid.
  63. /// </summary>
  64. public void Insert(T obj)
  65. {
  66. Bounds bounds = getBounds(obj);
  67. Vector2 min = GetGridPosition(bounds.min);
  68. Vector2 max = GetGridPosition(bounds.max);
  69. int minX = Mathf.FloorToInt(min.x / cellSize);
  70. int minY = Mathf.FloorToInt(min.y / cellSize);
  71. int maxX = Mathf.FloorToInt(max.x / cellSize);
  72. int maxY = Mathf.FloorToInt(max.y / cellSize);
  73. for (int y = minY; y <= maxY; y++)
  74. {
  75. for (int x = minX; x <= maxX; x++)
  76. {
  77. long key = GetCellKey(x, y);
  78. if (!cells.TryGetValue(key, out List<T> cell))
  79. {
  80. cell = new List<T>();
  81. cells[key] = cell;
  82. }
  83. cell.Add(obj);
  84. }
  85. }
  86. }
  87. /// <summary>
  88. /// Removes an object from the spatial hash grid.
  89. /// </summary>
  90. public bool Remove(T obj)
  91. {
  92. bool removed = false;
  93. Bounds bounds = getBounds(obj);
  94. Vector2 min = GetGridPosition(bounds.min);
  95. Vector2 max = GetGridPosition(bounds.max);
  96. int minX = Mathf.FloorToInt(min.x / cellSize);
  97. int minY = Mathf.FloorToInt(min.y / cellSize);
  98. int maxX = Mathf.FloorToInt(max.x / cellSize);
  99. int maxY = Mathf.FloorToInt(max.y / cellSize);
  100. for (int y = minY; y <= maxY; y++)
  101. {
  102. for (int x = minX; x <= maxX; x++)
  103. {
  104. Int64 key = GetCellKey(x, y);
  105. if (cells.TryGetValue(key, out List<T> cell))
  106. {
  107. if (cell.Remove(obj))
  108. {
  109. removed = true;
  110. if (cell.Count == 0)
  111. {
  112. cells.Remove(key);
  113. }
  114. }
  115. }
  116. }
  117. }
  118. return removed;
  119. }
  120. /// <summary>
  121. /// Queries the spatial hash grid for objects that might intersect with the specified bounds.
  122. /// </summary>
  123. public void Query(Bounds queryBounds, ref List<T> results)
  124. {
  125. Vector3 queryBoundsMin = queryBounds.min;
  126. Vector3 queryBoundsMax = queryBounds.max;
  127. Vector2 min = GetGridPosition(queryBoundsMin);
  128. Vector2 max = GetGridPosition(queryBoundsMax);
  129. int minX = Mathf.FloorToInt(min.x / cellSize);
  130. int minY = Mathf.FloorToInt(min.y / cellSize);
  131. int maxX = Mathf.FloorToInt(max.x / cellSize);
  132. int maxY = Mathf.FloorToInt(max.y / cellSize);
  133. for (int y = minY; y <= maxY; y++)
  134. {
  135. for (int x = minX; x <= maxX; x++)
  136. {
  137. long key = GetCellKey(x, y);
  138. if (cells.TryGetValue(key, out List<T> cell))
  139. {
  140. foreach (T obj in cell)
  141. {
  142. var objBounds = getBounds(obj);
  143. Vector3 objBoundsMin = objBounds.min;
  144. Vector3 objBoundsMax = objBounds.max;
  145. // Manual intersection test to avoid the overhead of Bounds.Intersects
  146. bool intersects = objBoundsMin.x <= queryBoundsMax.x &&
  147. objBoundsMax.x >= queryBoundsMin.x &&
  148. objBoundsMin.y <= queryBoundsMax.y &&
  149. objBoundsMax.y >= queryBoundsMin.y &&
  150. objBoundsMin.z <= queryBoundsMax.z &&
  151. objBoundsMax.z >= queryBoundsMin.z;
  152. if (intersects)
  153. {
  154. if (!results.Contains(obj))
  155. results.Add(obj);
  156. }
  157. }
  158. }
  159. }
  160. }
  161. }
  162. /// <summary>
  163. /// Clears all objects from the grid.
  164. /// </summary>
  165. public void Clear()
  166. {
  167. cells.Clear();
  168. }
  169. /// <summary>
  170. /// Draws debug visualization of the grid and contained objects.
  171. /// </summary>
  172. /// <param name="duration">How long the debug lines should remain visible</param>
  173. public void DrawDebug(float duration = 0.0f)
  174. {
  175. // Get unique cell coordinates to draw grid
  176. var cellCoords = new HashSet<(int x, int y)>();
  177. foreach (var key in cells.Keys)
  178. {
  179. int x = (int)(key >> 32);
  180. int y = (int)(key & 0xffffffffL);
  181. cellCoords.Add((x, y));
  182. }
  183. // Draw grid cells
  184. foreach (var coord in cellCoords)
  185. {
  186. Vector3 min = Vector3.zero;
  187. Vector3 max = Vector3.zero;
  188. // Set the coordinates based on primary axes
  189. min[primaryAxes.Item1] = coord.x * cellSize;
  190. min[primaryAxes.Item2] = coord.y * cellSize;
  191. max[primaryAxes.Item1] = (coord.x + 1) * cellSize;
  192. max[primaryAxes.Item2] = (coord.y + 1) * cellSize;
  193. Vector3 p1 = min;
  194. Vector3 p2 = min;
  195. p2[primaryAxes.Item1] = max[primaryAxes.Item1];
  196. Vector3 p3 = max;
  197. Vector3 p4 = max;
  198. p4[primaryAxes.Item1] = min[primaryAxes.Item1];
  199. Debug.DrawLine(p1, p2, Color.white, duration);
  200. Debug.DrawLine(p2, p3, Color.white, duration);
  201. Debug.DrawLine(p3, p4, Color.white, duration);
  202. Debug.DrawLine(p4, p1, Color.white, duration);
  203. }
  204. // Draw object bounds
  205. var drawnObjects = new HashSet<T>();
  206. foreach (var cellObjects in cells.Values)
  207. {
  208. foreach (var obj in cellObjects)
  209. {
  210. if (drawnObjects.Add(obj)) // Only draw each object once
  211. {
  212. var bounds = getBounds(obj);
  213. Vector3 min = bounds.min;
  214. Vector3 max = bounds.max;
  215. // Create four corners all at the same height (using min for the up axis)
  216. Vector3 p1 = Vector3.zero;
  217. Vector3 p2 = Vector3.zero;
  218. Vector3 p3 = Vector3.zero;
  219. Vector3 p4 = Vector3.zero;
  220. // Set the coordinates for the primary axes
  221. p1[primaryAxes.Item1] = min[primaryAxes.Item1];
  222. p1[primaryAxes.Item2] = min[primaryAxes.Item2];
  223. p2[primaryAxes.Item1] = max[primaryAxes.Item1];
  224. p2[primaryAxes.Item2] = min[primaryAxes.Item2];
  225. p3[primaryAxes.Item1] = max[primaryAxes.Item1];
  226. p3[primaryAxes.Item2] = max[primaryAxes.Item2];
  227. p4[primaryAxes.Item1] = min[primaryAxes.Item1];
  228. p4[primaryAxes.Item2] = max[primaryAxes.Item2];
  229. // Set the up axis coordinate to min for all points
  230. int upAxisIndex = (int)(upAxis) / 2; // Convert AxisDirection to index (0=X, 1=Y, 2=Z)
  231. float upCoord = min[upAxisIndex];
  232. p1[upAxisIndex] = upCoord;
  233. p2[upAxisIndex] = upCoord;
  234. p3[upAxisIndex] = upCoord;
  235. p4[upAxisIndex] = upCoord;
  236. Debug.DrawLine(p1, p2, Color.green, duration);
  237. Debug.DrawLine(p2, p3, Color.green, duration);
  238. Debug.DrawLine(p3, p4, Color.green, duration);
  239. Debug.DrawLine(p4, p1, Color.green, duration);
  240. }
  241. }
  242. }
  243. }
  244. }
  245. }