using System; using System.Collections.Generic; using UnityEngine; namespace DunGen.Collision { /// /// A generic spatial hash grid that divides space into uniform cells for fast spatial queries. /// The grid operates on a 2D plane perpendicular to the specified up axis. /// /// The type of objects stored in the grid. public class SpatialHashGrid { private readonly Dictionary> cells; private readonly float cellSize; private readonly Func getBounds; private readonly AxisDirection upAxis; private readonly (int, int) primaryAxes; // Indices for the two axes that form the 2D plane /// /// Constructs a spatial hash grid. /// /// Size of each grid cell /// Delegate to extract bounds from objects /// The up axis direction (determines the 2D plane) public SpatialHashGrid(float cellSize, Func getBounds, AxisDirection upDirection = AxisDirection.PosY) { this.cells = new Dictionary>(); this.cellSize = cellSize; this.getBounds = getBounds; this.upAxis = upDirection; // Determine which axes to use for the 2D grid based on up direction this.primaryAxes = GetPrimaryAxes(upDirection); } private (int, int) GetPrimaryAxes(AxisDirection upDirection) { switch (upDirection) { case AxisDirection.PosY: case AxisDirection.NegY: return (0, 2); // Use X and Z axes case AxisDirection.PosX: case AxisDirection.NegX: return (1, 2); // Use Y and Z axes case AxisDirection.PosZ: case AxisDirection.NegZ: return (0, 1); // Use X and Y axes default: throw new ArgumentException("Invalid axis direction", nameof(upDirection)); } } private Vector2 GetGridPosition(Vector3 worldPos) { // Extract the two coordinates for our 2D plane based on the up axis float x = worldPos[primaryAxes.Item1]; float y = worldPos[primaryAxes.Item2]; return new Vector2(x, y); } private long GetCellKey(int x, int y) { return ((long)x << 32) | (y & 0xffffffffL); } /// /// Inserts an object into the spatial hash grid. /// public void Insert(T obj) { Bounds bounds = getBounds(obj); Vector2 min = GetGridPosition(bounds.min); Vector2 max = GetGridPosition(bounds.max); int minX = Mathf.FloorToInt(min.x / cellSize); int minY = Mathf.FloorToInt(min.y / cellSize); int maxX = Mathf.FloorToInt(max.x / cellSize); int maxY = Mathf.FloorToInt(max.y / cellSize); for (int y = minY; y <= maxY; y++) { for (int x = minX; x <= maxX; x++) { long key = GetCellKey(x, y); if (!cells.TryGetValue(key, out List cell)) { cell = new List(); cells[key] = cell; } cell.Add(obj); } } } /// /// Removes an object from the spatial hash grid. /// public bool Remove(T obj) { bool removed = false; Bounds bounds = getBounds(obj); Vector2 min = GetGridPosition(bounds.min); Vector2 max = GetGridPosition(bounds.max); int minX = Mathf.FloorToInt(min.x / cellSize); int minY = Mathf.FloorToInt(min.y / cellSize); int maxX = Mathf.FloorToInt(max.x / cellSize); int maxY = Mathf.FloorToInt(max.y / cellSize); for (int y = minY; y <= maxY; y++) { for (int x = minX; x <= maxX; x++) { Int64 key = GetCellKey(x, y); if (cells.TryGetValue(key, out List cell)) { if (cell.Remove(obj)) { removed = true; if (cell.Count == 0) { cells.Remove(key); } } } } } return removed; } /// /// Queries the spatial hash grid for objects that might intersect with the specified bounds. /// public void Query(Bounds queryBounds, ref List results) { Vector3 queryBoundsMin = queryBounds.min; Vector3 queryBoundsMax = queryBounds.max; Vector2 min = GetGridPosition(queryBoundsMin); Vector2 max = GetGridPosition(queryBoundsMax); int minX = Mathf.FloorToInt(min.x / cellSize); int minY = Mathf.FloorToInt(min.y / cellSize); int maxX = Mathf.FloorToInt(max.x / cellSize); int maxY = Mathf.FloorToInt(max.y / cellSize); for (int y = minY; y <= maxY; y++) { for (int x = minX; x <= maxX; x++) { long key = GetCellKey(x, y); if (cells.TryGetValue(key, out List cell)) { foreach (T obj in cell) { var objBounds = getBounds(obj); Vector3 objBoundsMin = objBounds.min; Vector3 objBoundsMax = objBounds.max; // Manual intersection test to avoid the overhead of Bounds.Intersects bool intersects = objBoundsMin.x <= queryBoundsMax.x && objBoundsMax.x >= queryBoundsMin.x && objBoundsMin.y <= queryBoundsMax.y && objBoundsMax.y >= queryBoundsMin.y && objBoundsMin.z <= queryBoundsMax.z && objBoundsMax.z >= queryBoundsMin.z; if (intersects) { if (!results.Contains(obj)) results.Add(obj); } } } } } } /// /// Clears all objects from the grid. /// public void Clear() { cells.Clear(); } /// /// Draws debug visualization of the grid and contained objects. /// /// How long the debug lines should remain visible public void DrawDebug(float duration = 0.0f) { // Get unique cell coordinates to draw grid var cellCoords = new HashSet<(int x, int y)>(); foreach (var key in cells.Keys) { int x = (int)(key >> 32); int y = (int)(key & 0xffffffffL); cellCoords.Add((x, y)); } // Draw grid cells foreach (var coord in cellCoords) { Vector3 min = Vector3.zero; Vector3 max = Vector3.zero; // Set the coordinates based on primary axes min[primaryAxes.Item1] = coord.x * cellSize; min[primaryAxes.Item2] = coord.y * cellSize; max[primaryAxes.Item1] = (coord.x + 1) * cellSize; max[primaryAxes.Item2] = (coord.y + 1) * cellSize; Vector3 p1 = min; Vector3 p2 = min; p2[primaryAxes.Item1] = max[primaryAxes.Item1]; Vector3 p3 = max; Vector3 p4 = max; p4[primaryAxes.Item1] = min[primaryAxes.Item1]; Debug.DrawLine(p1, p2, Color.white, duration); Debug.DrawLine(p2, p3, Color.white, duration); Debug.DrawLine(p3, p4, Color.white, duration); Debug.DrawLine(p4, p1, Color.white, duration); } // Draw object bounds var drawnObjects = new HashSet(); foreach (var cellObjects in cells.Values) { foreach (var obj in cellObjects) { if (drawnObjects.Add(obj)) // Only draw each object once { var bounds = getBounds(obj); Vector3 min = bounds.min; Vector3 max = bounds.max; // Create four corners all at the same height (using min for the up axis) Vector3 p1 = Vector3.zero; Vector3 p2 = Vector3.zero; Vector3 p3 = Vector3.zero; Vector3 p4 = Vector3.zero; // Set the coordinates for the primary axes p1[primaryAxes.Item1] = min[primaryAxes.Item1]; p1[primaryAxes.Item2] = min[primaryAxes.Item2]; p2[primaryAxes.Item1] = max[primaryAxes.Item1]; p2[primaryAxes.Item2] = min[primaryAxes.Item2]; p3[primaryAxes.Item1] = max[primaryAxes.Item1]; p3[primaryAxes.Item2] = max[primaryAxes.Item2]; p4[primaryAxes.Item1] = min[primaryAxes.Item1]; p4[primaryAxes.Item2] = max[primaryAxes.Item2]; // Set the up axis coordinate to min for all points int upAxisIndex = (int)(upAxis) / 2; // Convert AxisDirection to index (0=X, 1=Y, 2=Z) float upCoord = min[upAxisIndex]; p1[upAxisIndex] = upCoord; p2[upAxisIndex] = upCoord; p3[upAxisIndex] = upCoord; p4[upAxisIndex] = upCoord; Debug.DrawLine(p1, p2, Color.green, duration); Debug.DrawLine(p2, p3, Color.green, duration); Debug.DrawLine(p3, p4, Color.green, duration); Debug.DrawLine(p4, p1, Color.green, duration); } } } } } }