You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
401 lines
14 KiB
401 lines
14 KiB
""" |
|
Optimized collision detection system using NumPy for vectorized operations. |
|
|
|
This module provides efficient collision detection for games with many entities (200+). |
|
Uses AABB (Axis-Aligned Bounding Box) collision detection with numpy vectorization. |
|
|
|
HYBRID APPROACH: |
|
- For < 50 units: Uses simple dictionary-based approach (low overhead) |
|
- For >= 50 units: Uses NumPy vectorization (scales better) |
|
|
|
Performance improvements: |
|
- O(n²) → O(n) for spatial queries using grid-based hashing |
|
- Vectorized AABB checks for large unit counts |
|
- Minimal overhead for small unit counts |
|
""" |
|
|
|
import numpy as np |
|
from typing import Dict, List, Tuple, Set |
|
from dataclasses import dataclass |
|
|
|
# Threshold for switching to NumPy mode |
|
NUMPY_THRESHOLD = 50 |
|
|
|
|
|
@dataclass |
|
class CollisionLayer: |
|
"""Define which types of units can collide with each other.""" |
|
RAT = 0 |
|
BOMB = 1 |
|
GAS = 2 |
|
MINE = 3 |
|
POINT = 4 |
|
EXPLOSION = 5 |
|
|
|
|
|
class CollisionSystem: |
|
""" |
|
Manages collision detection for all game units using NumPy vectorization. |
|
|
|
Attributes |
|
---------- |
|
cell_size : int |
|
Size of each grid cell in pixels |
|
grid_width : int |
|
Number of cells in grid width |
|
grid_height : int |
|
Number of cells in grid height |
|
""" |
|
|
|
def __init__(self, cell_size: int, grid_width: int, grid_height: int): |
|
self.cell_size = cell_size |
|
self.grid_width = grid_width |
|
self.grid_height = grid_height |
|
|
|
# Spatial grid for fast lookups |
|
self.spatial_grid: Dict[Tuple[int, int], List] = {} |
|
self.spatial_grid_before: Dict[Tuple[int, int], List] = {} |
|
|
|
# Arrays for vectorized operations |
|
self.unit_ids = [] |
|
self.bboxes = np.array([], dtype=np.float32).reshape(0, 4) # (x1, y1, x2, y2) |
|
self.positions = np.array([], dtype=np.int32).reshape(0, 2) # (x, y) |
|
self.positions_before = np.array([], dtype=np.int32).reshape(0, 2) |
|
self.layers = np.array([], dtype=np.int8) |
|
|
|
# Pre-allocation tracking |
|
self._capacity = 0 |
|
self._size = 0 |
|
|
|
# Collision matrix: which layers collide with which |
|
self.collision_matrix = np.zeros((6, 6), dtype=bool) |
|
self._setup_collision_matrix() |
|
|
|
def _setup_collision_matrix(self): |
|
"""Define which collision layers interact with each other.""" |
|
L = CollisionLayer |
|
|
|
# Rats collide with: Rats, Bombs, Gas, Mines, Points |
|
self.collision_matrix[L.RAT, L.RAT] = True |
|
self.collision_matrix[L.RAT, L.BOMB] = False # Bombs don't kill on contact |
|
self.collision_matrix[L.RAT, L.GAS] = True |
|
self.collision_matrix[L.RAT, L.MINE] = True |
|
self.collision_matrix[L.RAT, L.POINT] = True |
|
self.collision_matrix[L.RAT, L.EXPLOSION] = True |
|
|
|
# Gas affects rats |
|
self.collision_matrix[L.GAS, L.RAT] = True |
|
|
|
# Mines trigger on rats |
|
self.collision_matrix[L.MINE, L.RAT] = True |
|
|
|
# Points collected by rats (handled in point logic) |
|
self.collision_matrix[L.POINT, L.RAT] = True |
|
|
|
# Explosions kill rats |
|
self.collision_matrix[L.EXPLOSION, L.RAT] = True |
|
|
|
# Make matrix symmetric |
|
self.collision_matrix = np.logical_or(self.collision_matrix, |
|
self.collision_matrix.T) |
|
|
|
def clear(self): |
|
"""Clear all collision data for new frame.""" |
|
self.spatial_grid.clear() |
|
self.spatial_grid_before.clear() |
|
self.unit_ids = [] |
|
self.bboxes = np.array([], dtype=np.float32).reshape(0, 4) |
|
self.positions = np.array([], dtype=np.int32).reshape(0, 2) |
|
self.positions_before = np.array([], dtype=np.int32).reshape(0, 2) |
|
self.layers = np.array([], dtype=np.int8) |
|
|
|
def register_unit(self, unit_id, bbox: Tuple[float, float, float, float], |
|
position: Tuple[int, int], position_before: Tuple[int, int], |
|
layer: int): |
|
""" |
|
Register a unit for collision detection this frame. |
|
|
|
Parameters |
|
---------- |
|
unit_id : UUID |
|
Unique identifier for the unit |
|
bbox : tuple |
|
Bounding box (x1, y1, x2, y2) |
|
position : tuple |
|
Current grid position (x, y) |
|
position_before : tuple |
|
Previous grid position (x, y) |
|
layer : int |
|
Collision layer (from CollisionLayer enum) |
|
""" |
|
idx = len(self.unit_ids) |
|
self.unit_ids.append(unit_id) |
|
|
|
# Pre-allocate arrays in batches to reduce overhead |
|
if len(self.bboxes) == 0: |
|
# Initialize with reasonable capacity |
|
self.bboxes = np.empty((100, 4), dtype=np.float32) |
|
self.positions = np.empty((100, 2), dtype=np.int32) |
|
self.positions_before = np.empty((100, 2), dtype=np.int32) |
|
self.layers = np.empty(100, dtype=np.int8) |
|
self._capacity = 100 |
|
self._size = 0 |
|
elif self._size >= self._capacity: |
|
# Expand capacity |
|
new_capacity = self._capacity * 2 |
|
self.bboxes = np.resize(self.bboxes, (new_capacity, 4)) |
|
self.positions = np.resize(self.positions, (new_capacity, 2)) |
|
self.positions_before = np.resize(self.positions_before, (new_capacity, 2)) |
|
self.layers = np.resize(self.layers, new_capacity) |
|
self._capacity = new_capacity |
|
|
|
# Add data |
|
self.bboxes[self._size] = bbox |
|
self.positions[self._size] = position |
|
self.positions_before[self._size] = position_before |
|
self.layers[self._size] = layer |
|
self._size += 1 |
|
|
|
# Add to spatial grids |
|
self.spatial_grid.setdefault(position, []).append(idx) |
|
self.spatial_grid_before.setdefault(position_before, []).append(idx) |
|
|
|
def check_aabb_collision(self, idx1: int, idx2: int, tolerance: int = 0) -> bool: |
|
""" |
|
Check AABB collision between two units. |
|
|
|
Parameters |
|
---------- |
|
idx1, idx2 : int |
|
Indices in the arrays |
|
tolerance : int |
|
Overlap tolerance in pixels (reduces detection zone) |
|
|
|
Returns |
|
------- |
|
bool |
|
True if bounding boxes overlap |
|
""" |
|
bbox1 = self.bboxes[idx1] |
|
bbox2 = self.bboxes[idx2] |
|
|
|
return (bbox1[0] < bbox2[2] - tolerance and |
|
bbox1[2] > bbox2[0] + tolerance and |
|
bbox1[1] < bbox2[3] - tolerance and |
|
bbox1[3] > bbox2[1] + tolerance) |
|
|
|
def check_aabb_collision_vectorized(self, idx: int, indices: np.ndarray, |
|
tolerance: int = 0) -> np.ndarray: |
|
""" |
|
Vectorized AABB collision check between one unit and many others. |
|
|
|
Parameters |
|
---------- |
|
idx : int |
|
Index of the unit to check |
|
indices : ndarray |
|
Array of indices to check against |
|
tolerance : int |
|
Overlap tolerance in pixels |
|
|
|
Returns |
|
------- |
|
ndarray |
|
Boolean array indicating collisions |
|
""" |
|
if len(indices) == 0: |
|
return np.array([], dtype=bool) |
|
|
|
# Slice actual data size, not full capacity |
|
bbox = self.bboxes[idx] |
|
other_bboxes = self.bboxes[indices] |
|
|
|
# Vectorized AABB check |
|
collisions = ( |
|
(bbox[0] < other_bboxes[:, 2] - tolerance) & |
|
(bbox[2] > other_bboxes[:, 0] + tolerance) & |
|
(bbox[1] < other_bboxes[:, 3] - tolerance) & |
|
(bbox[3] > other_bboxes[:, 1] + tolerance) |
|
) |
|
|
|
return collisions |
|
|
|
def get_collisions_for_unit(self, unit_id, layer: int, |
|
tolerance: int = 0) -> List[Tuple[int, any]]: |
|
""" |
|
Get all units colliding with the specified unit. |
|
Uses hybrid approach: simple method for few units, numpy for many. |
|
|
|
Parameters |
|
---------- |
|
unit_id : UUID |
|
ID of the unit to check |
|
layer : int |
|
Collision layer of the unit |
|
tolerance : int |
|
Overlap tolerance |
|
|
|
Returns |
|
------- |
|
list |
|
List of tuples (index, unit_id) for colliding units |
|
""" |
|
if unit_id not in self.unit_ids: |
|
return [] |
|
|
|
idx = self.unit_ids.index(unit_id) |
|
position = tuple(self.positions[idx]) |
|
position_before = tuple(self.positions_before[idx]) |
|
|
|
# Get candidate indices from spatial grid |
|
candidates = set() |
|
for pos in [position, position_before]: |
|
candidates.update(self.spatial_grid.get(pos, [])) |
|
candidates.update(self.spatial_grid_before.get(pos, [])) |
|
|
|
# Remove self and out-of-bounds indices |
|
candidates.discard(idx) |
|
candidates = {c for c in candidates if c < self._size} |
|
|
|
if not candidates: |
|
return [] |
|
|
|
# HYBRID APPROACH: Use simple method for few candidates |
|
if len(candidates) < 10: |
|
return self._simple_collision_check(idx, candidates, layer, tolerance) |
|
|
|
# NumPy vectorized approach for many candidates |
|
candidates_array = np.array(list(candidates), dtype=np.int32) |
|
candidate_layers = self.layers[candidates_array] |
|
|
|
# Check collision matrix |
|
can_collide = self.collision_matrix[layer, candidate_layers] |
|
valid_candidates = candidates_array[can_collide] |
|
|
|
if len(valid_candidates) == 0: |
|
return [] |
|
|
|
# Vectorized AABB check |
|
collisions = self.check_aabb_collision_vectorized(idx, valid_candidates, tolerance) |
|
colliding_indices = valid_candidates[collisions] |
|
|
|
# Return list of (index, unit_id) pairs |
|
return [(int(i), self.unit_ids[i]) for i in colliding_indices] |
|
|
|
def _simple_collision_check(self, idx: int, candidates: set, layer: int, |
|
tolerance: int) -> List[Tuple[int, any]]: |
|
""" |
|
Simple collision check without numpy overhead. |
|
Used when there are few candidates. |
|
""" |
|
results = [] |
|
bbox = self.bboxes[idx] |
|
|
|
for other_idx in candidates: |
|
# Check collision layer |
|
if not self.collision_matrix[layer, self.layers[other_idx]]: |
|
continue |
|
|
|
# AABB check |
|
other_bbox = self.bboxes[other_idx] |
|
if (bbox[0] < other_bbox[2] - tolerance and |
|
bbox[2] > other_bbox[0] + tolerance and |
|
bbox[1] < other_bbox[3] - tolerance and |
|
bbox[3] > other_bbox[1] + tolerance): |
|
results.append((int(other_idx), self.unit_ids[other_idx])) |
|
|
|
return results |
|
|
|
def get_units_in_cell(self, position: Tuple[int, int], |
|
use_before: bool = False) -> List[any]: |
|
""" |
|
Get all unit IDs in a specific grid cell. |
|
|
|
Parameters |
|
---------- |
|
position : tuple |
|
Grid position (x, y) |
|
use_before : bool |
|
If True, use position_before instead of position |
|
|
|
Returns |
|
------- |
|
list |
|
List of unit IDs in that cell |
|
""" |
|
grid = self.spatial_grid_before if use_before else self.spatial_grid |
|
indices = grid.get(position, []) |
|
return [self.unit_ids[i] for i in indices] |
|
|
|
def get_units_in_area(self, positions: List[Tuple[int, int]], |
|
layer_filter: int = None) -> Set[any]: |
|
""" |
|
Get all units in multiple grid cells (useful for explosions). |
|
|
|
Parameters |
|
---------- |
|
positions : list |
|
List of grid positions to check |
|
layer_filter : int, optional |
|
If provided, only return units of this layer |
|
|
|
Returns |
|
------- |
|
set |
|
Set of unique unit IDs in the area |
|
""" |
|
unit_set = set() |
|
|
|
for pos in positions: |
|
# Check both current and previous positions |
|
for grid in [self.spatial_grid, self.spatial_grid_before]: |
|
indices = grid.get(pos, []) |
|
for idx in indices: |
|
if layer_filter is None or self.layers[idx] == layer_filter: |
|
unit_set.add(self.unit_ids[idx]) |
|
|
|
return unit_set |
|
|
|
def check_partial_move_collision(self, unit_id, partial_move: float, |
|
threshold: float = 0.5) -> List[any]: |
|
""" |
|
Check collisions considering partial movement progress. |
|
|
|
For units moving between cells, checks if they should be considered |
|
in current or previous cell based on movement progress. |
|
|
|
Parameters |
|
---------- |
|
unit_id : UUID |
|
Unit to check |
|
partial_move : float |
|
Movement progress (0.0 to 1.0) |
|
threshold : float |
|
Movement threshold for position consideration |
|
|
|
Returns |
|
------- |
|
list |
|
List of unit IDs in collision |
|
""" |
|
if unit_id not in self.unit_ids: |
|
return [] |
|
|
|
idx = self.unit_ids.index(unit_id) |
|
|
|
# Choose position based on partial move |
|
if partial_move >= threshold: |
|
position = tuple(self.positions[idx]) |
|
else: |
|
position = tuple(self.positions_before[idx]) |
|
|
|
# Get units in that position |
|
indices = self.spatial_grid.get(position, []) + \ |
|
self.spatial_grid_before.get(position, []) |
|
|
|
# Remove duplicates and self |
|
indices = list(set(indices)) |
|
if idx in indices: |
|
indices.remove(idx) |
|
|
|
return [self.unit_ids[i] for i in indices]
|
|
|