using System;
using System.Collections.Generic;
using UnityEngine;
namespace DunGen.Collision
{
///
/// A generic spatial hash grid that divides space into uniform cells for fast spatial queries.
/// The grid operates on a 2D plane perpendicular to the specified up axis.
///
/// The type of objects stored in the grid.
public class SpatialHashGrid
{
private readonly Dictionary> cells;
private readonly float cellSize;
private readonly Func getBounds;
private readonly AxisDirection upAxis;
private readonly (int, int) primaryAxes; // Indices for the two axes that form the 2D plane
///
/// Constructs a spatial hash grid.
///
/// Size of each grid cell
/// Delegate to extract bounds from objects
/// The up axis direction (determines the 2D plane)
public SpatialHashGrid(float cellSize, Func getBounds, AxisDirection upDirection = AxisDirection.PosY)
{
this.cells = new Dictionary>();
this.cellSize = cellSize;
this.getBounds = getBounds;
this.upAxis = upDirection;
// Determine which axes to use for the 2D grid based on up direction
this.primaryAxes = GetPrimaryAxes(upDirection);
}
private (int, int) GetPrimaryAxes(AxisDirection upDirection)
{
switch (upDirection)
{
case AxisDirection.PosY:
case AxisDirection.NegY:
return (0, 2); // Use X and Z axes
case AxisDirection.PosX:
case AxisDirection.NegX:
return (1, 2); // Use Y and Z axes
case AxisDirection.PosZ:
case AxisDirection.NegZ:
return (0, 1); // Use X and Y axes
default:
throw new ArgumentException("Invalid axis direction", nameof(upDirection));
}
}
private Vector2 GetGridPosition(Vector3 worldPos)
{
// Extract the two coordinates for our 2D plane based on the up axis
float x = worldPos[primaryAxes.Item1];
float y = worldPos[primaryAxes.Item2];
return new Vector2(x, y);
}
private long GetCellKey(int x, int y)
{
return ((long)x << 32) | (y & 0xffffffffL);
}
///
/// Inserts an object into the spatial hash grid.
///
public void Insert(T obj)
{
Bounds bounds = getBounds(obj);
Vector2 min = GetGridPosition(bounds.min);
Vector2 max = GetGridPosition(bounds.max);
int minX = Mathf.FloorToInt(min.x / cellSize);
int minY = Mathf.FloorToInt(min.y / cellSize);
int maxX = Mathf.FloorToInt(max.x / cellSize);
int maxY = Mathf.FloorToInt(max.y / cellSize);
for (int y = minY; y <= maxY; y++)
{
for (int x = minX; x <= maxX; x++)
{
long key = GetCellKey(x, y);
if (!cells.TryGetValue(key, out List cell))
{
cell = new List();
cells[key] = cell;
}
cell.Add(obj);
}
}
}
///
/// Removes an object from the spatial hash grid.
///
public bool Remove(T obj)
{
bool removed = false;
Bounds bounds = getBounds(obj);
Vector2 min = GetGridPosition(bounds.min);
Vector2 max = GetGridPosition(bounds.max);
int minX = Mathf.FloorToInt(min.x / cellSize);
int minY = Mathf.FloorToInt(min.y / cellSize);
int maxX = Mathf.FloorToInt(max.x / cellSize);
int maxY = Mathf.FloorToInt(max.y / cellSize);
for (int y = minY; y <= maxY; y++)
{
for (int x = minX; x <= maxX; x++)
{
Int64 key = GetCellKey(x, y);
if (cells.TryGetValue(key, out List cell))
{
if (cell.Remove(obj))
{
removed = true;
if (cell.Count == 0)
{
cells.Remove(key);
}
}
}
}
}
return removed;
}
///
/// Queries the spatial hash grid for objects that might intersect with the specified bounds.
///
public void Query(Bounds queryBounds, ref List results)
{
Vector3 queryBoundsMin = queryBounds.min;
Vector3 queryBoundsMax = queryBounds.max;
Vector2 min = GetGridPosition(queryBoundsMin);
Vector2 max = GetGridPosition(queryBoundsMax);
int minX = Mathf.FloorToInt(min.x / cellSize);
int minY = Mathf.FloorToInt(min.y / cellSize);
int maxX = Mathf.FloorToInt(max.x / cellSize);
int maxY = Mathf.FloorToInt(max.y / cellSize);
for (int y = minY; y <= maxY; y++)
{
for (int x = minX; x <= maxX; x++)
{
long key = GetCellKey(x, y);
if (cells.TryGetValue(key, out List cell))
{
foreach (T obj in cell)
{
var objBounds = getBounds(obj);
Vector3 objBoundsMin = objBounds.min;
Vector3 objBoundsMax = objBounds.max;
// Manual intersection test to avoid the overhead of Bounds.Intersects
bool intersects = objBoundsMin.x <= queryBoundsMax.x &&
objBoundsMax.x >= queryBoundsMin.x &&
objBoundsMin.y <= queryBoundsMax.y &&
objBoundsMax.y >= queryBoundsMin.y &&
objBoundsMin.z <= queryBoundsMax.z &&
objBoundsMax.z >= queryBoundsMin.z;
if (intersects)
{
if (!results.Contains(obj))
results.Add(obj);
}
}
}
}
}
}
///
/// Clears all objects from the grid.
///
public void Clear()
{
cells.Clear();
}
///
/// Draws debug visualization of the grid and contained objects.
///
/// How long the debug lines should remain visible
public void DrawDebug(float duration = 0.0f)
{
// Get unique cell coordinates to draw grid
var cellCoords = new HashSet<(int x, int y)>();
foreach (var key in cells.Keys)
{
int x = (int)(key >> 32);
int y = (int)(key & 0xffffffffL);
cellCoords.Add((x, y));
}
// Draw grid cells
foreach (var coord in cellCoords)
{
Vector3 min = Vector3.zero;
Vector3 max = Vector3.zero;
// Set the coordinates based on primary axes
min[primaryAxes.Item1] = coord.x * cellSize;
min[primaryAxes.Item2] = coord.y * cellSize;
max[primaryAxes.Item1] = (coord.x + 1) * cellSize;
max[primaryAxes.Item2] = (coord.y + 1) * cellSize;
Vector3 p1 = min;
Vector3 p2 = min;
p2[primaryAxes.Item1] = max[primaryAxes.Item1];
Vector3 p3 = max;
Vector3 p4 = max;
p4[primaryAxes.Item1] = min[primaryAxes.Item1];
Debug.DrawLine(p1, p2, Color.white, duration);
Debug.DrawLine(p2, p3, Color.white, duration);
Debug.DrawLine(p3, p4, Color.white, duration);
Debug.DrawLine(p4, p1, Color.white, duration);
}
// Draw object bounds
var drawnObjects = new HashSet();
foreach (var cellObjects in cells.Values)
{
foreach (var obj in cellObjects)
{
if (drawnObjects.Add(obj)) // Only draw each object once
{
var bounds = getBounds(obj);
Vector3 min = bounds.min;
Vector3 max = bounds.max;
// Create four corners all at the same height (using min for the up axis)
Vector3 p1 = Vector3.zero;
Vector3 p2 = Vector3.zero;
Vector3 p3 = Vector3.zero;
Vector3 p4 = Vector3.zero;
// Set the coordinates for the primary axes
p1[primaryAxes.Item1] = min[primaryAxes.Item1];
p1[primaryAxes.Item2] = min[primaryAxes.Item2];
p2[primaryAxes.Item1] = max[primaryAxes.Item1];
p2[primaryAxes.Item2] = min[primaryAxes.Item2];
p3[primaryAxes.Item1] = max[primaryAxes.Item1];
p3[primaryAxes.Item2] = max[primaryAxes.Item2];
p4[primaryAxes.Item1] = min[primaryAxes.Item1];
p4[primaryAxes.Item2] = max[primaryAxes.Item2];
// Set the up axis coordinate to min for all points
int upAxisIndex = (int)(upAxis) / 2; // Convert AxisDirection to index (0=X, 1=Y, 2=Z)
float upCoord = min[upAxisIndex];
p1[upAxisIndex] = upCoord;
p2[upAxisIndex] = upCoord;
p3[upAxisIndex] = upCoord;
p4[upAxisIndex] = upCoord;
Debug.DrawLine(p1, p2, Color.green, duration);
Debug.DrawLine(p2, p3, Color.green, duration);
Debug.DrawLine(p3, p4, Color.green, duration);
Debug.DrawLine(p4, p1, Color.green, duration);
}
}
}
}
}
}