diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2d668e7c9..e1fb45eb8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -197,7 +197,7 @@ jobs: python -m pip install --root-user-action=ignore -e . pytest pytest-cov coverage[toml] mkdir -p /tmp/spine-sitecustomize printf "import h5py\n" > /tmp/spine-sitecustomize/sitecustomize.py - PYTHONPATH="/tmp/spine-sitecustomize:${PYTHONPATH:-}" pytest test \ + NUMBA_DISABLE_JIT=1 PYTHONPATH="/tmp/spine-sitecustomize:${PYTHONPATH:-}" pytest test \ -v \ --cov=spine \ --cov-report=xml:coverage.xml \ diff --git a/check_coverage.sh b/check_coverage.sh index ec849c875..b39ebdff8 100755 --- a/check_coverage.sh +++ b/check_coverage.sh @@ -10,6 +10,7 @@ docker run --rm \ -w /workspace \ -e TEST_PATH="$test_path" \ -e COVERAGE_TARGET="$coverage_target" \ + -e NUMBA_DISABLE_JIT=1 \ --platform linux/amd64 \ ghcr.io/deeplearnphysics/spine:latest \ bash -lc ' diff --git a/src/spine/math/__init__.py b/src/spine/math/__init__.py index 8b7cc5b51..93eebe48a 100644 --- a/src/spine/math/__init__.py +++ b/src/spine/math/__init__.py @@ -1,11 +1,11 @@ -"""Module with fast, Numba-accelerated, compiles math routines. +"""Fast, Numba-accelerated math routines. This includes multiple submodules: - `base.py` includes basic functions, as found in numpy or scipy.special - `linalg.py` includes linear algebra routines, as found in numpy.linalg - `distance.py` includes distance functions, as found in scipy.distance - `graph.py` includes graph routines, as found in scipy.csgraph -- `cluster.py` includes cluster functions, as found in skleran.cluster +- `cluster.py` includes cluster functions, as found in sklearn.cluster - `metrics.py` includes clustering evaluation metrics """ diff --git a/src/spine/math/base.py b/src/spine/math/base.py index bec73e60f..e4c214bf6 100644 --- a/src/spine/math/base.py +++ b/src/spine/math/base.py @@ -25,7 +25,7 @@ @nb.njit(cache=True) -def seed(seed: nb.int64) -> None: +def seed(seed_value: int) -> None: """Sets the numpy random seed for all Numba jitted functions. Note that setting the seed using `np.random.seed` outside a Numba jitted @@ -33,27 +33,27 @@ def seed(seed: nb.int64) -> None: Parameters ---------- - seed : int + seed_value : int Random number generator seed """ - np.random.seed(seed) + np.random.seed(seed_value) @nb.njit(cache=True) -def unique(x: nb.int64[:]) -> (nb.int64[:], nb.int64[:]): +def unique(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Numba implementation of `np.unique(x, return_counts=True)`. Parameters ---------- x : np.ndarray - (N) array of values + (N,) array of values Returns ------- np.ndarray - (U) array of unique values + (U,) array of unique values np.ndarray - (U) array of counts of each unique value in the original array + (U,) array of counts of each unique value in the original array """ # Nothing to do if the input is empty uniques = np.empty(len(x), dtype=x.dtype) @@ -84,74 +84,74 @@ def unique(x: nb.int64[:]) -> (nb.int64[:], nb.int64[:]): @nb.njit(cache=True) -def sum(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:]: +def sum(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.sum(x, axis)`. Parameters ---------- x : np.ndarray - (N,M) array of values + (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray - (N) or (M) array of `sum` values + (N,) or (M,) array of `sum` values """ assert axis == 0 or axis == 1 summ = np.empty(x.shape[1 - axis], dtype=x.dtype) if axis == 0: - for i in range(len(summ)): + for i in range(x.shape[1]): summ[i] = np.sum(x[:, i]) else: - for i in range(len(summ)): - summ[i] = np.sum(x[i]) + for i, xi in enumerate(x): + summ[i] = np.sum(xi) return summ @nb.njit(cache=True) -def mean(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:]: +def mean(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.mean(x, axis)`. Parameters ---------- x : np.ndarray - (N,M) array of values + (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray - (N) or (M) array of `mean` values + (N,) or (M,) array of `mean` values """ assert axis == 0 or axis == 1 - mean = np.empty(x.shape[1 - axis], dtype=x.dtype) + mean_values = np.empty(x.shape[1 - axis], dtype=x.dtype) if axis == 0: - for i in range(len(mean)): - mean[i] = np.mean(x[:, i]) + for i in range(x.shape[1]): + mean_values[i] = np.mean(x[:, i]) else: - for i in range(len(mean)): - mean[i] = np.mean(x[i]) + for i, xi in enumerate(x): + mean_values[i] = np.mean(xi) - return mean + return mean_values @nb.njit(cache=True) -def mode(x: nb.int64[:]) -> nb.int64: +def mode(x: np.ndarray) -> int: """Numba implementation of `scipy.stats.mode(x)`. Parameters ---------- x : np.ndarray - (N) array of values + (N,) array of values Returns ------- int - Most-propable value in the array + Most-probable value in the array """ values, counts = unique(x) @@ -159,122 +159,122 @@ def mode(x: nb.int64[:]) -> nb.int64: @nb.njit(cache=True) -def argmin(x: nb.float32[:, :], axis: nb.int32) -> nb.int32[:]: +def argmin(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.argmin(x, axis)`. Parameters ---------- x : np.ndarray - (N,M) array of values + (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray - (N) or (M) array of `argmin` values + (N,) or (M,) array of `argmin` values """ assert axis == 0 or axis == 1 - argmin = np.empty(x.shape[1 - axis], dtype=np.int32) + argmin_values = np.empty(x.shape[1 - axis], dtype=np.int32) if axis == 0: - for i in range(len(argmin)): - argmin[i] = np.argmin(x[:, i]) + for i in range(x.shape[1]): + argmin_values[i] = np.argmin(x[:, i]) else: - for i in range(len(argmin)): - argmin[i] = np.argmin(x[i]) + for i, xi in enumerate(x): + argmin_values[i] = np.argmin(xi) - return argmin + return argmin_values @nb.njit(cache=True) -def argmax(x: nb.float32[:, :], axis: nb.int32) -> nb.int32[:]: +def argmax(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.argmax(x, axis)`. Parameters ---------- x : np.ndarray - (N,M) array of values + (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray - (N) or (M) array of `argmax` values + (N,) or (M,) array of `argmax` values """ assert axis == 0 or axis == 1 - argmax = np.empty(x.shape[1 - axis], dtype=np.int32) + argmax_values = np.empty(x.shape[1 - axis], dtype=np.int32) if axis == 0: - for i in range(len(argmax)): - argmax[i] = np.argmax(x[:, i]) + for i in range(x.shape[1]): + argmax_values[i] = np.argmax(x[:, i]) else: - for i in range(len(argmax)): - argmax[i] = np.argmax(x[i]) + for i, xi in enumerate(x): + argmax_values[i] = np.argmax(xi) - return argmax + return argmax_values @nb.njit(cache=True) -def amin(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:]: +def amin(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.amin(x, axis)`. Parameters ---------- x : np.ndarray - (N,M) array of values + (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray - (N) or (M) array of `min` values + (N,) or (M,) array of `min` values """ assert axis == 0 or axis == 1 - xmin = np.empty(x.shape[1 - axis], dtype=np.int32) + xmin = np.empty(x.shape[1 - axis], dtype=x.dtype) if axis == 0: - for i in range(len(xmin)): + for i in range(x.shape[1]): xmin[i] = np.min(x[:, i]) else: - for i in range(len(xmin)): - xmin[i] = np.min(x[i]) + for i, xi in enumerate(x): + xmin[i] = np.min(xi) return xmin @nb.njit(cache=True) -def amax(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:]: +def amax(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.amax(x, axis)`. Parameters ---------- x : np.ndarray - (N,M) array of values + (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray - (N) or (M) array of `max` values + (N,) or (M,) array of `max` values """ assert axis == 0 or axis == 1 - xmax = np.empty(x.shape[1 - axis], dtype=np.int32) + xmax = np.empty(x.shape[1 - axis], dtype=x.dtype) if axis == 0: - for i in range(len(xmax)): + for i in range(x.shape[1]): xmax[i] = np.max(x[:, i]) else: - for i in range(len(xmax)): - xmax[i] = np.max(x[i]) + for i, xi in enumerate(x): + xmax[i] = np.max(xi) return xmax @nb.njit(cache=True) -def all(x: nb.float32[:, :], axis: nb.int32) -> nb.boolean[:]: +def all(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.all(x, axis)`. Parameters @@ -287,37 +287,37 @@ def all(x: nb.float32[:, :], axis: nb.int32) -> nb.boolean[:]: Returns ------- np.ndarray - (N) or (M) array of `all` outputs + (N,) or (M,) array of `all` outputs """ assert axis == 0 or axis == 1 - all = np.empty(x.shape[1 - axis], dtype=np.bool_) + all_values = np.empty(x.shape[1 - axis], dtype=np.bool_) if axis == 0: - for i in range(len(all)): - all[i] = np.all(x[:, i]) + for i in range(x.shape[1]): + all_values[i] = np.all(x[:, i]) else: - for i in range(len(all)): - all[i] = np.all(x[i]) + for i, xi in enumerate(x): + all_values[i] = np.all(xi) - return all + return all_values @nb.njit(cache=True) -def softmax(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:, :]: +def softmax(x: np.ndarray, axis: int) -> np.ndarray: """ Numba implementation of `scipy.special.softmax(x, axis)`. Parameters ---------- x : np.ndarray - (N,M) array of values + (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray - (N,M) Array of softmax scores + (N, M) array of softmax scores """ assert axis == 0 or axis == 1 if axis == 0: @@ -331,15 +331,15 @@ def softmax(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:, :]: @nb.njit(cache=True) -def log_loss(label: nb.boolean[:], pred: nb.float32[:]) -> nb.float32: +def log_loss(label: np.ndarray, pred: np.ndarray) -> float: """Numba implementation of cross-entropy loss. Parameters ---------- label : np.ndarray - (N) array of boolean labels (0 or 1) + (N,) array of boolean labels (0 or 1) pred : np.ndarray - (N) array of float scores (between 0 and 1) + (N,) array of float scores (between 0 and 1) Returns ------- diff --git a/src/spine/math/cluster.py b/src/spine/math/cluster.py index a0eaf15a2..755fdf9ea 100644 --- a/src/spine/math/cluster.py +++ b/src/spine/math/cluster.py @@ -16,7 +16,7 @@ ) -@nb.experimental.jitclass(DBSCAN_DTYPE) +@nb.experimental.jitclass(spec=DBSCAN_DTYPE) # type: ignore[call-arg] class DBSCAN: """Class-version of the Numba-accelerate :func:`dbscan` function. @@ -33,11 +33,11 @@ class DBSCAN: def __init__( self, - eps: nb.float32, - min_samples: nb.int64 = 1, - metric: nb.types.string = "euclidean", - p: nb.int64 = 2.0, - ): + eps: float, + min_samples: int = 1, + metric: str = "euclidean", + p: float = 2.0, + ) -> None: """Initialize the DBSCAN parameters. Parameters @@ -52,6 +52,11 @@ def __init__( p : float, default 2. p-norm factor for the Minkowski metric, if used """ + if eps < 0.0: + raise ValueError("Epsilon must be non-negative.") + if min_samples <= 0: + raise ValueError("Minimum number of samples must be positive.") + # For Euclidean, save time by using squared Euclidean if metric == "euclidean": metric = "sqeuclidean" @@ -80,7 +85,7 @@ def fit_predict(self, x): Returns ------- np.ndarray - (N) Group assignments + (N,) Group assignments """ # Produce a radius graph edge_index = radius_graph(x, self.eps, self.metric_id, self.p) @@ -93,12 +98,12 @@ def fit_predict(self, x): @nb.njit(cache=True) def dbscan( - x: nb.float32[:, :], - eps: nb.float32, - min_samples: nb.int64 = 1, - metric_id: nb.int64 = METRICS["euclidean"], - p: nb.float32 = 2.0, -) -> nb.float32[:]: + x: np.ndarray, + eps: float, + min_samples: int = 1, + metric_id: int = METRICS["euclidean"], + p: float = 2.0, +) -> np.ndarray: """Runs DBSCAN on 3D points and returns the group assignments. Parameters @@ -117,8 +122,13 @@ def dbscan( Returns ------- np.ndarray - (N) Group assignments + (N,) Group assignments """ + if eps < 0.0: + raise ValueError("Epsilon must be non-negative.") + if min_samples <= 0: + raise ValueError("Minimum number of samples must be positive.") + # Produce a radius graph edge_index = radius_graph(x, eps, metric_id, p) diff --git a/src/spine/math/decomposition.py b/src/spine/math/decomposition.py index 4e179c3ce..41661a686 100644 --- a/src/spine/math/decomposition.py +++ b/src/spine/math/decomposition.py @@ -8,7 +8,7 @@ PCA_DTYPE = (("n_components", nb.int64),) -@nb.experimental.jitclass(PCA_DTYPE) +@nb.experimental.jitclass(spec=PCA_DTYPE) # type: ignore[call-arg] class PCA: """Class-version of the Numba-accelerate :func:`principal_components` function. @@ -22,7 +22,7 @@ class PCA: (N_c) Variance along each of the principal axes """ - def __init__(self, n_components: nb.int64): + def __init__(self, n_components: int) -> None: """Initialize the PCA parameters. Parameters @@ -50,6 +50,7 @@ def fit(self, x): (N_c) Variance along each of the principal axes """ # Check input + assert len(x) > 1, "Must provide at least two samples." assert x.shape[1] >= self.n_components, ( f"The dimensionality of the data ({x.shape[1]}) is smaller " f"than the number of components ({self.n_components}." @@ -69,7 +70,7 @@ def fit(self, x): @nb.njit(cache=True) -def principal_components(x: nb.float32[:, :]) -> nb.float32[:, :]: +def principal_components(x: np.ndarray) -> np.ndarray: """Computes the principal components of a point cloud by computing the eigenvectors of the centered covariance matrix. @@ -83,6 +84,8 @@ def principal_components(x: nb.float32[:, :]) -> nb.float32[:, :]: np.ndarray (d, d) List of principal components (row-ordered) """ + assert len(x) > 1, "Must provide at least two samples." + # Get covariance matrix A = np.cov(x.T, ddof=len(x) - 1).astype(x.dtype) # Casting needed... diff --git a/src/spine/math/distance.py b/src/spine/math/distance.py index ffecc5ba4..5e418ebaf 100644 --- a/src/spine/math/distance.py +++ b/src/spine/math/distance.py @@ -1,7 +1,7 @@ """Numba JIT compiled implementation of distance computation routines. This module is entirely dedicated to 3D points, which is the core representation -of objects targetted by this software package. +of objects targeted by this software package. """ import numba as nb @@ -21,18 +21,25 @@ "closest_pair", ] -# Available distance metrics (casting is important for numba optimization) +MINKOWSKI = 0 +CITYBLOCK = 1 +EUCLIDEAN = 2 +SQEUCLIDEAN = 3 +CHEBYSHEV = 4 + +# Available distance metrics. Keep the public mapping for callers, while using +# named integer constants internally so Numba sees stable scalar IDs. METRICS = { - "minkowski": np.int64(0), - "cityblock": np.int64(1), - "euclidean": np.int64(2), - "sqeuclidean": np.int64(3), - "chebyshev": np.int64(4), + "minkowski": MINKOWSKI, + "cityblock": CITYBLOCK, + "euclidean": EUCLIDEAN, + "sqeuclidean": SQEUCLIDEAN, + "chebyshev": CHEBYSHEV, } @nb.njit(cache=True) -def get_metric_id(metric: nb.types.string, p: nb.float32) -> nb.int64: +def get_metric_id(metric: str, p: float) -> int: """Checks on the metric name, returns an enumerated form of the metric. Parameters @@ -49,33 +56,33 @@ def get_metric_id(metric: nb.types.string, p: nb.float32) -> nb.int64: """ if metric == "minkowski": if p == 1.0: - return np.int64(1) + return CITYBLOCK elif p == 2.0: - return np.int64(2) + return EUCLIDEAN else: - return np.int64(0) + return MINKOWSKI elif metric == "cityblock": - return np.int64(1) + return CITYBLOCK elif metric == "euclidean": - return np.int64(2) + return EUCLIDEAN elif metric == "sqeuclidean": - return np.int64(3) + return SQEUCLIDEAN elif metric == "chebyshev": - return np.int64(4) + return CHEBYSHEV else: raise ValueError(f"Distance metric not recognized: {metric}") @nb.njit(cache=True) -def cityblock(x: nb.float32[:], y: nb.float32[:]) -> nb.float32: - """Compute the cityblock distance (L1) between to 3D points. +def cityblock(x: np.ndarray, y: np.ndarray) -> float: + """Compute the cityblock distance (L1) between two 3D points. Parameters ---------- x : np.ndarray - (3) Coorinates of the first point + (3,) Coordinates of the first point y : np.ndarray - (3) Coorinates of the second point + (3,) Coordinates of the second point Returns ------- @@ -86,15 +93,15 @@ def cityblock(x: nb.float32[:], y: nb.float32[:]) -> nb.float32: @nb.njit(cache=True) -def euclidean(x: nb.float32[:], y: nb.float32[:]) -> nb.float32: +def euclidean(x: np.ndarray, y: np.ndarray) -> float: """Compute the Euclidean distance (L2) between two 3D points. Parameters ---------- x : np.ndarray - (3) Coorinates of the first point + (3,) Coordinates of the first point y : np.ndarray - (3) Coorinates of the second point + (3,) Coordinates of the second point Returns ------- @@ -105,15 +112,15 @@ def euclidean(x: nb.float32[:], y: nb.float32[:]) -> nb.float32: @nb.njit(cache=True) -def sqeuclidean(x: nb.float32[:], y: nb.float32[:]) -> nb.float32: +def sqeuclidean(x: np.ndarray, y: np.ndarray) -> float: """Compute the squared Euclidean distance (L2) between two 3D points. Parameters ---------- x : np.ndarray - (3) Coorinates of the first point + (3,) Coordinates of the first point y : np.ndarray - (3) Coorinates of the second point + (3,) Coordinates of the second point Returns ------- @@ -124,15 +131,15 @@ def sqeuclidean(x: nb.float32[:], y: nb.float32[:]) -> nb.float32: @nb.njit(cache=True) -def chebyshev(x: nb.float32[:], y: nb.float32[:]) -> nb.float32: - """Compute the Chebyshev distance (Linf) between to 3D points. +def chebyshev(x: np.ndarray, y: np.ndarray) -> float: + """Compute the Chebyshev distance (Linf) between two 3D points. Parameters ---------- x : np.ndarray - (3) Coorinates of the first point + (3,) Coordinates of the first point y : np.ndarray - (3) Coorinates of the second point + (3,) Coordinates of the second point Returns ------- @@ -143,15 +150,15 @@ def chebyshev(x: nb.float32[:], y: nb.float32[:]) -> nb.float32: @nb.njit(cache=True) -def minkowski(x: nb.float32[:], y: nb.float32[:], p: nb.float32) -> nb.float32: +def minkowski(x: np.ndarray, y: np.ndarray, p: float) -> float: """Compute the Minkowski distance (Lp) between two 3D points. Parameters ---------- x : np.ndarray - (3) Coorinates of the first point + (3,) Coordinates of the first point y : np.ndarray - (3) Coorinates of the second point + (3,) Coordinates of the second point Returns ------- @@ -165,8 +172,8 @@ def minkowski(x: nb.float32[:], y: nb.float32[:], p: nb.float32) -> nb.float32: @nb.njit(cache=True) def pdist( - x: nb.float32[:, :], metric_id: nb.int64 = METRICS["euclidean"], p: nb.float32 = 2.0 -) -> nb.float32[:, :]: + x: np.ndarray, metric_id: int = METRICS["euclidean"], p: float = 2.0 +) -> np.ndarray: """Numba implementation of `scipy.spatial.distance.pdist(x, metric=metric, p=p)` in 3D. @@ -187,83 +194,83 @@ def pdist( # Check on the input assert x.shape[1] == 3, "Only supports 3D points for now." - # Dispatch (faster this way than dipatching at each distance call) - if metric_id == np.int64(0): + # Dispatch (faster this way than dispatching at each distance call) + if metric_id == MINKOWSKI: return _pdist_minkowski(x, p) - elif metric_id == np.int64(1): + elif metric_id == CITYBLOCK: return _pdist_cityblock(x) - elif metric_id == np.int64(2): + elif metric_id == EUCLIDEAN: return _pdist_euclidean(x) - elif metric_id == np.int64(3): + elif metric_id == SQEUCLIDEAN: return _pdist_sqeuclidean(x) - elif metric_id == np.int64(4): + elif metric_id == CHEBYSHEV: return _pdist_chebyshev(x) else: raise ValueError("Distance metric not recognized.") @nb.njit(cache=True) -def _pdist_cityblock(x: nb.float32[:, :]) -> nb.float32[:, :]: +def _pdist_cityblock(x: np.ndarray) -> np.ndarray: res = np.empty((len(x), len(x)), dtype=x.dtype) - for i in range(len(x)): + for i, xi in enumerate(x): res[i, i] = 0.0 for j in range(i + 1, len(x)): - res[i, j] = res[j, i] = cityblock(x[i], x[j]) + res[i, j] = res[j, i] = cityblock(xi, x[j]) return res @nb.njit(cache=True) -def _pdist_euclidean(x: nb.float32[:, :]) -> nb.float32[:, :]: +def _pdist_euclidean(x: np.ndarray) -> np.ndarray: res = np.empty((len(x), len(x)), dtype=x.dtype) - for i in range(len(x)): + for i, xi in enumerate(x): res[i, i] = 0.0 for j in range(i + 1, len(x)): - res[i, j] = res[j, i] = euclidean(x[i], x[j]) + res[i, j] = res[j, i] = euclidean(xi, x[j]) return res @nb.njit(cache=True) -def _pdist_sqeuclidean(x: nb.float32[:, :]) -> nb.float32[:, :]: +def _pdist_sqeuclidean(x: np.ndarray) -> np.ndarray: res = np.empty((len(x), len(x)), dtype=x.dtype) - for i in range(len(x)): + for i, xi in enumerate(x): res[i, i] = 0.0 for j in range(i + 1, len(x)): - res[i, j] = res[j, i] = sqeuclidean(x[i], x[j]) + res[i, j] = res[j, i] = sqeuclidean(xi, x[j]) return res @nb.njit(cache=True) -def _pdist_chebyshev(x: nb.float32[:, :]) -> nb.float32[:, :]: +def _pdist_chebyshev(x: np.ndarray) -> np.ndarray: res = np.empty((len(x), len(x)), dtype=x.dtype) - for i in range(len(x)): + for i, xi in enumerate(x): res[i, i] = 0.0 for j in range(i + 1, len(x)): - res[i, j] = res[j, i] = chebyshev(x[i], x[j]) + res[i, j] = res[j, i] = chebyshev(xi, x[j]) return res @nb.njit(cache=True) -def _pdist_minkowski(x: nb.float32[:, :], p: nb.float32) -> nb.float32[:, :]: +def _pdist_minkowski(x: np.ndarray, p: float) -> np.ndarray: res = np.empty((len(x), len(x)), dtype=x.dtype) - for i in range(len(x)): + for i, xi in enumerate(x): res[i, i] = 0.0 for j in range(i + 1, len(x)): - res[i, j] = res[j, i] = minkowski(x[i], x[j], p) + res[i, j] = res[j, i] = minkowski(xi, x[j], p) return res @nb.njit(cache=True) def cdist( - x1: nb.float32[:, :], - x2: nb.float32[:, :], - metric_id: nb.int64 = METRICS["euclidean"], - p: nb.float32 = 2.0, -) -> nb.float32[:, :]: + x1: np.ndarray, + x2: np.ndarray, + metric_id: int = METRICS["euclidean"], + p: float = 2.0, +) -> np.ndarray: """Numba implementation of Euclidean `scipy.spatial.distance.cdist(x, metric=p=2)` in 3D. @@ -286,80 +293,78 @@ def cdist( # Check on the input assert x1.shape[1] == 3 and x2.shape[1] == 3, "Only supports 3D points for now." - # Dispatch (faster this way than dipatching at each distance call) - if metric_id == np.int64(0): + # Dispatch (faster this way than dispatching at each distance call) + if metric_id == MINKOWSKI: return _cdist_minkowski(x1, x2, p) - elif metric_id == np.int64(1): + elif metric_id == CITYBLOCK: return _cdist_cityblock(x1, x2) - elif metric_id == np.int64(2): + elif metric_id == EUCLIDEAN: return _cdist_euclidean(x1, x2) - elif metric_id == np.int64(3): + elif metric_id == SQEUCLIDEAN: return _cdist_sqeuclidean(x1, x2) - elif metric_id == np.int64(4): + elif metric_id == CHEBYSHEV: return _cdist_chebyshev(x1, x2) else: raise ValueError("Distance metric not recognized.") @nb.njit(cache=True) -def _cdist_cityblock(x1: nb.float32[:, :], x2: nb.float32[:, :]) -> nb.float32[:, :]: +def _cdist_cityblock(x1: np.ndarray, x2: np.ndarray) -> np.ndarray: res = np.empty((len(x1), len(x2)), dtype=x1.dtype) - for i1 in range(len(x1)): - for i2 in range(len(x2)): - res[i1, i2] = cityblock(x1[i1], x2[i2]) + for i1, x1i in enumerate(x1): + for i2, x2i in enumerate(x2): + res[i1, i2] = cityblock(x1i, x2i) return res @nb.njit(cache=True) -def _cdist_euclidean(x1: nb.float32[:, :], x2: nb.float32[:, :]) -> nb.float32[:, :]: +def _cdist_euclidean(x1: np.ndarray, x2: np.ndarray) -> np.ndarray: res = np.empty((len(x1), len(x2)), dtype=x1.dtype) - for i1 in range(len(x1)): - for i2 in range(len(x2)): - res[i1, i2] = euclidean(x1[i1], x2[i2]) + for i1, x1i in enumerate(x1): + for i2, x2i in enumerate(x2): + res[i1, i2] = euclidean(x1i, x2i) return res @nb.njit(cache=True) -def _cdist_sqeuclidean(x1: nb.float32[:, :], x2: nb.float32[:, :]) -> nb.float32[:, :]: +def _cdist_sqeuclidean(x1: np.ndarray, x2: np.ndarray) -> np.ndarray: res = np.empty((len(x1), len(x2)), dtype=x1.dtype) - for i1 in range(len(x1)): - for i2 in range(len(x2)): - res[i1, i2] = sqeuclidean(x1[i1], x2[i2]) + for i1, x1i in enumerate(x1): + for i2, x2i in enumerate(x2): + res[i1, i2] = sqeuclidean(x1i, x2i) return res @nb.njit(cache=True) -def _cdist_chebyshev(x1: nb.float32[:, :], x2: nb.float32[:, :]) -> nb.float32[:, :]: +def _cdist_chebyshev(x1: np.ndarray, x2: np.ndarray) -> np.ndarray: res = np.empty((len(x1), len(x2)), dtype=x1.dtype) - for i1 in range(len(x1)): - for i2 in range(len(x2)): - res[i1, i2] = chebyshev(x1[i1], x2[i2]) + for i1, x1i in enumerate(x1): + for i2, x2i in enumerate(x2): + res[i1, i2] = chebyshev(x1i, x2i) return res @nb.njit(cache=True) -def _cdist_minkowski( - x1: nb.float32[:, :], x2: nb.float32[:, :], p: nb.float32 -) -> nb.float32[:, :]: +def _cdist_minkowski(x1: np.ndarray, x2: np.ndarray, p: float) -> np.ndarray: res = np.empty((len(x1), len(x2)), dtype=x1.dtype) - for i1 in range(len(x1)): - for i2 in range(len(x2)): - res[i1, i2] = minkowski(x1[i1], x2[i2], p) + for i1, x1i in enumerate(x1): + for i2, x2i in enumerate(x2): + res[i1, i2] = minkowski(x1i, x2i, p) return res @nb.njit(cache=True) def farthest_pair( - x: nb.float32[:, :], - iterative: nb.boolean = False, - metric_id: nb.int64 = METRICS["euclidean"], - p: nb.float32 = 2.0, -) -> (nb.int64, nb.int64, nb.float32): + x: np.ndarray, + iterative: bool = False, + metric_id: int = METRICS["euclidean"], + p: float = 2.0, +) -> tuple[int, int, float]: """Algorithm which finds the two points which are farthest from each other in a set, in the Euclidean sense. @@ -390,10 +395,10 @@ def farthest_pair( Distance between the two points """ # To save time, if Euclidean distance is used, use its square - euclidean = False - if metric_id == np.int64(2): - euclidean = True - metric_id = np.int64(3) + is_euclidean = False + if metric_id == EUCLIDEAN: + is_euclidean = True + metric_id = SQEUCLIDEAN # Dispatch if not iterative: @@ -417,29 +422,30 @@ def farthest_pair( while dist > tempdist: tempdist = dist dists = cdist(x[idxs[subidx]][None, :], x, metric_id, p).flatten() - idxs[~subidx] = np.argmax(dists) - dist = dists[idxs[~subidx]] - subidx = ~subidx + other_idx = 1 - subidx + idxs[other_idx] = np.argmax(dists) + dist = dists[idxs[other_idx]] + subidx = other_idx # Unroll index i, j = idxs # If needed, take the square root of the distance - if euclidean: + if is_euclidean: dist = np.sqrt(dist) - return i, j, dist + return int(i), int(j), float(dist) @nb.njit(cache=True) def closest_pair( - x1: nb.float32[:, :], - x2: nb.float32[:, :], - iterative: nb.boolean = False, - seed: nb.boolean = True, - metric_id: nb.int64 = METRICS["euclidean"], - p: nb.float32 = 2.0, -) -> (nb.int64, nb.int64, nb.float32): + x1: np.ndarray, + x2: np.ndarray, + iterative: bool = False, + seed: bool = True, + metric_id: int = METRICS["euclidean"], + p: float = 2.0, +) -> tuple[int, int, float]: """Algorithm which finds the two points which are closest to each other from two separate sets. @@ -452,9 +458,9 @@ def closest_pair( Parameters ---------- x1 : np.ndarray - (Nx3) array of point coordinates in the first set + (N, 3) array of point coordinates in the first set x2 : np.ndarray - (Nx3) array of point coordinates in the second set + (M, 3) array of point coordinates in the second set iterative : bool If `True`, uses an iterative, fast approximation seed : bool @@ -475,10 +481,10 @@ def closest_pair( Distance between the two points """ # To save time, if Euclidean distance is used, use its square - euclidean = False - if metric_id == np.int64(2): - euclidean = True - metric_id = np.int64(3) + is_euclidean = False + if metric_id == EUCLIDEAN: + is_euclidean = True + metric_id = SQEUCLIDEAN # Find the two points in two sets of points that are closest to each other if not iterative: @@ -498,35 +504,38 @@ def closest_pair( idxs, set_id, dist, tempdist = [0, 0], 0, 1e9, 1e9 + 1.0 if seed: # Find the end points of the two sets - for i, x in enumerate(xarr): - seed_idxs = np.array(farthest_pair(xarr[i], True)[:2]) - seed_dists = cdist(xarr[i][seed_idxs], xarr[~i], metric_id, p) + for i, xi in enumerate(xarr): + other_id = 1 - i + seed_idxs = np.array(farthest_pair(xi, True)[:2]) + seed_dists = cdist(xi[seed_idxs], xarr[other_id], metric_id, p) seed_argmins = argmin(seed_dists, axis=1) seed_mins = np.array( [seed_dists[0][seed_argmins[0]], seed_dists[1][seed_argmins[1]]] ) if np.min(seed_mins) < dist: - set_id = ~i - seed_choice = np.argmin(seed_mins) - idxs[int(~set_id)] = seed_idxs[seed_choice] - idxs[int(set_id)] = seed_argmins[seed_choice] - dist = seed_mins[seed_choice] + set_id = other_id + seed_choice = int(np.argmin(seed_mins)) + idxs[i] = int(seed_idxs[seed_choice]) + idxs[set_id] = int(seed_argmins[seed_choice]) + dist = float(seed_mins[seed_choice]) # Find the closest point in the other set, repeat until convergence while dist < tempdist: tempdist = dist + other_id = 1 - set_id dists = cdist( - xarr[set_id][idxs[set_id]][None, :], xarr[~set_id], metric_id, p + xarr[set_id][idxs[set_id]][None, :], xarr[other_id], metric_id, p ).flatten() - idxs[~set_id] = np.argmin(dists) - dist = dists[idxs[~set_id]] - subidx = ~set_id + closest_idx = int(np.argmin(dists)) + idxs[other_id] = closest_idx + dist = float(dists[closest_idx]) + set_id = other_id # Unroll index i, j = idxs # If needed, take the square root of the distance - if euclidean: + if is_euclidean: dist = np.sqrt(dist) - return i, j, dist + return int(i), int(j), float(dist) diff --git a/src/spine/math/graph.py b/src/spine/math/graph.py index f7ff038f3..8b641d739 100644 --- a/src/spine/math/graph.py +++ b/src/spine/math/graph.py @@ -7,7 +7,18 @@ import numba as nb import numpy as np -from .distance import METRICS, cdist, chebyshev, cityblock, minkowski, sqeuclidean +from .distance import ( + CHEBYSHEV, + CITYBLOCK, + EUCLIDEAN, + METRICS, + MINKOWSKI, + SQEUCLIDEAN, + chebyshev, + cityblock, + minkowski, + sqeuclidean, +) CSR_DTYPE = ( ("num_nodes", nb.int64), @@ -16,32 +27,30 @@ ) -@nb.experimental.jitclass(CSR_DTYPE) +@nb.experimental.jitclass(spec=CSR_DTYPE) # type: ignore[call-arg] class CSRGraph: """Numba-enabled compressed Sparse Row (CSR) representation of a sparse matrix. Attributes ---------- neighbors : np.ndarray - (E) List of node neighbors in a compressed array + (E,) List of node neighbors in a compressed array offsets : np.ndarray - (N+1) Per-node slicing boundaries to query each node neighborhood + (N + 1,) Per-node slicing boundaries to query each node neighborhood num_nodes : int Number of nodes in the graph, N """ - def __init__( - self, neighbors: nb.int64[:], offsets: nb.int64[:], num_nodes: nb.int64 - ): + def __init__(self, neighbors: np.ndarray, offsets: np.ndarray, num_nodes: int): """Construct the Compressed Sparse Row (CSR) representation of a sparse matrix based on a list of nodes and edges. Parameters ---------- neighbors : np.ndarray - (E) List of node neighbors in a compressed array + (E,) List of node neighbors in a compressed array offsets : np.ndarray - (N+1) Per-node slicing boundaries to query each node neighborhood + (N + 1,) Per-node slicing boundaries to query each node neighborhood num_nodes : int Number of nodes in the graph, N """ @@ -49,7 +58,7 @@ def __init__( self.offsets = offsets self.num_nodes = num_nodes - def __getitem__(self, node_id: nb.int64): + def __getitem__(self, node_id: int) -> np.ndarray: """Get the list of neighbors associated with a node. Parameters @@ -65,7 +74,7 @@ def __getitem__(self, node_id: nb.int64): start, end = self.offsets[node_id], self.offsets[node_id + 1] return self.neighbors[start:end] - def num_neighbors(self, node_id: nb.int64): + def num_neighbors(self, node_id: int) -> int: """Returns the number of neighbors of a node. Parameters @@ -84,8 +93,8 @@ def num_neighbors(self, node_id: nb.int64): @nb.njit def csr_graph( - edge_index: nb.int64[:, :], num_nodes: nb.int64, directed: nb.boolean = True -) -> CSR_DTYPE: + edge_index: np.ndarray, num_nodes: int, directed: bool = True +) -> CSRGraph: """Construct the Compressed Sparse Row (CSR) representation of a sparse matrix based on a list of nodes and edges. @@ -129,11 +138,11 @@ def csr_graph( @nb.njit(cache=True) def connected_components( - edge_index: nb.int64[:, :], - num_nodes: nb.int64, - min_samples: nb.int64 = 1, - directed: nb.boolean = True, -) -> nb.int64[:]: + edge_index: np.ndarray, + num_nodes: int, + min_samples: int = 1, + directed: bool = True, +) -> np.ndarray: """Find connected components. Parameters @@ -148,7 +157,7 @@ def connected_components( Returns ------- np.ndarray - (N) Cluster label associated with each node + (N,) Cluster label associated with each node """ # Initialize the CSR data structure graph = csr_graph(edge_index, num_nodes, directed) @@ -164,7 +173,7 @@ def connected_components( min_neighbors = min_samples - 1 for node in range(graph.num_nodes): if not visited[node]: - if graph.num_neighbors(node) > min_neighbors: + if graph.num_neighbors(node) >= min_neighbors: # Perform DFS and collect all nodes in this connected component comp_idx[0] = 0 dfs_iterative(graph, visited, node, component, comp_idx) @@ -185,34 +194,34 @@ def connected_components( @nb.njit(cache=True) def dfs( - graph: CSR_DTYPE, - visited: nb.boolean[:], - node: nb.int64, - component: nb.int64[:], - comp_idx: nb.int64[:], -): + graph: CSRGraph, + visited: np.ndarray, + node: int, + component: np.ndarray, + comp_idx: np.ndarray, +) -> None: """Does a depth-first search and builds a connected component. Parameters ---------- graph : CSRGraph CSR representation of a graph - visitied : np.ndarray - (N) Boolean array which specified weather a node has been visited or not. + visited : np.ndarray + (N,) Boolean array which specifies whether a node has been visited. node : int Current node index component : np.ndarray - (N) Current component (padded) + (N,) Current component (padded) comp_idx : np.ndarray Current component index (pointer) Notes ----- This implementation is recursive, which is the fastest implementation but - silently throws segementation faults if the maximum recursion depth is + silently throws segmentation faults if the maximum recursion depth is reached. The :func:`dfs_iterative` function is safer, but slightly slower. """ - # Mark the node as visited, incremant pointer + # Mark the node as visited, increment pointer visited[node] = True component[comp_idx[0]] = node comp_idx[0] += 1 @@ -225,24 +234,24 @@ def dfs( @nb.njit(cache=True) def dfs_iterative( - graph: CSR_DTYPE, - visited: nb.boolean[:], - start_node: nb.int64, - component: nb.int64[:], - comp_idx: nb.int64[:], -): + graph: CSRGraph, + visited: np.ndarray, + start_node: int, + component: np.ndarray, + comp_idx: np.ndarray, +) -> None: """Does a depth-first search and builds a connected component. Parameters ---------- graph : CSRGraph CSR representation of a graph - visitied : np.ndarray - (N) Boolean array which specified weather a node has been visited or not. - node : int - Current node index + visited : np.ndarray + (N,) Boolean array which specifies whether a node has been visited. + start_node : int + Starting node index component : np.ndarray - (N) Current component (padded) + (N,) Current component (padded) comp_idx : np.ndarray Current component index (pointer) @@ -276,11 +285,11 @@ def dfs_iterative( @nb.njit(cache=True) def radius_graph( - x: nb.float32[:, :], - radius: nb.float32, - metric_id: nb.int64 = METRICS["euclidean"], - p: nb.float32 = 2.0, -) -> nb.int64[:, :]: + x: np.ndarray, + radius: float, + metric_id: int = METRICS["euclidean"], + p: float = 2.0, +) -> np.ndarray: """Builds an undirected radius graph. This function generates a list of edges in a graph which connects all nodes @@ -304,31 +313,29 @@ def radius_graph( """ # Determine the distance function to use. If the metric is Euclidean, it # is cheaper to square the radius and use the squared Euclidean metric - if metric_id == np.int64(0): + if metric_id == MINKOWSKI: return _radius_graph_minkowski(x, radius, p) - elif metric_id == np.int64(1): + elif metric_id == CITYBLOCK: return _radius_graph_cityblock(x, radius) - elif metric_id == np.int64(2): + elif metric_id == EUCLIDEAN: radius = radius * radius return _radius_graph_sqeuclidean(x, radius) - elif metric_id == np.int64(3): + elif metric_id == SQEUCLIDEAN: return _radius_graph_sqeuclidean(x, radius) - elif metric_id == np.int64(4): + elif metric_id == CHEBYSHEV: return _radius_graph_chebyshev(x, radius) else: raise ValueError("Distance metric not recognized.") @nb.njit(cache=True) -def _radius_graph_minkowski( - x: nb.float32[:, :], radius: nb.float32, p: nb.float32 -) -> nb.float32[:, :]: +def _radius_graph_minkowski(x: np.ndarray, radius: float, p: float) -> np.ndarray: # Initialize a data structure to hold edges num_nodes = len(x) max_edges = num_nodes * (num_nodes - 1) // 2 edge_index = np.empty((max_edges, 2), dtype=np.int64) - # Loop over pairs of nodes, ass edges if the distance fits the bill + # Loop over pairs of nodes, add edges if the distance fits the bill edge_count = 0 for i in range(num_nodes): for j in range(i + 1, num_nodes): @@ -340,15 +347,13 @@ def _radius_graph_minkowski( @nb.njit(cache=True) -def _radius_graph_cityblock( - x: nb.float32[:, :], radius: nb.float32 -) -> nb.float32[:, :]: +def _radius_graph_cityblock(x: np.ndarray, radius: float) -> np.ndarray: # Initialize a data structure to hold edges num_nodes = len(x) max_edges = num_nodes * (num_nodes - 1) // 2 edge_index = np.empty((max_edges, 2), dtype=np.int64) - # Loop over pairs of nodes, ass edges if the distance fits the bill + # Loop over pairs of nodes, add edges if the distance fits the bill edge_count = 0 for i in range(num_nodes): for j in range(i + 1, num_nodes): @@ -360,15 +365,13 @@ def _radius_graph_cityblock( @nb.njit(cache=True) -def _radius_graph_sqeuclidean( - x: nb.float32[:, :], radius: nb.float32 -) -> nb.float32[:, :]: +def _radius_graph_sqeuclidean(x: np.ndarray, radius: float) -> np.ndarray: # Initialize a data structure to hold edges num_nodes = len(x) max_edges = num_nodes * (num_nodes - 1) // 2 edge_index = np.empty((max_edges, 2), dtype=np.int64) - # Loop over pairs of nodes, ass edges if the distance fits the bill + # Loop over pairs of nodes, add edges if the distance fits the bill edge_count = 0 for i in range(num_nodes): for j in range(i + 1, num_nodes): @@ -380,15 +383,13 @@ def _radius_graph_sqeuclidean( @nb.njit(cache=True) -def _radius_graph_chebyshev( - x: nb.float32[:, :], radius: nb.float32 -) -> nb.float32[:, :]: +def _radius_graph_chebyshev(x: np.ndarray, radius: float) -> np.ndarray: # Initialize a data structure to hold edges num_nodes = len(x) max_edges = num_nodes * (num_nodes - 1) // 2 edge_index = np.empty((max_edges, 2), dtype=np.int64) - # Loop over pairs of nodes, ass edges if the distance fits the bill + # Loop over pairs of nodes, add edges if the distance fits the bill edge_count = 0 for i in range(num_nodes): for j in range(i + 1, num_nodes): @@ -399,10 +400,25 @@ def _radius_graph_chebyshev( return edge_index[:edge_count] +@nb.njit(cache=True) +def _find_root(parents: np.ndarray, node: int) -> int: + """Find the root parent of a node with path compression.""" + root = node + while parents[root] != root: + root = parents[root] + + while parents[node] != node: + parent = parents[node] + parents[node] = root + node = parent + + return root + + @nb.njit(cache=True) def union_find( - edge_index: nb.int64[:, :], count: nb.int64, return_inverse: bool = True -) -> nb.int64[:]: + edge_index: np.ndarray, count: int, return_inverse: bool = True +) -> tuple[np.ndarray, dict[int, np.ndarray]]: """Numba implementation of the Union-Find algorithm. This function assigns a group to each node in a graph, provided @@ -420,18 +436,29 @@ def union_find( Returns ------- np.ndarray - (C) Group assignments for each of the nodes in the graph + (C,) Group assignments for each of the nodes in the graph Dict[int, np.ndarray] Dictionary which maps groups to indexes """ - labels = np.arange(count) - groups = {i: np.array([i]) for i in labels} - for e in edge_index: - li, lj = labels[e[0]], labels[e[1]] - if li != lj: - labels[groups[lj]] = li - groups[li] = np.concatenate((groups[li], groups[lj])) - del groups[lj] + if count == 0: + labels = np.empty(0, dtype=np.int64) + groups = {0: labels} + del groups[0] + return labels, groups + + parents = np.arange(count) + for src, dst in edge_index: + src_root = _find_root(parents, int(src)) + dst_root = _find_root(parents, int(dst)) + if src_root != dst_root: + if src_root < dst_root: + parents[dst_root] = src_root + else: + parents[src_root] = dst_root + + labels = np.empty(count, dtype=np.int64) + for node in range(count): + labels[node] = _find_root(parents, node) if return_inverse: mask = np.zeros(count, dtype=np.bool_) @@ -440,4 +467,13 @@ def union_find( mapping[mask] = np.arange(np.sum(mask)) labels = mapping[labels] + groups = {labels[0]: np.array([0])} + for node in range(1, count): + label = labels[node] + node_arr = np.array([node]) + if label in groups: + groups[label] = np.concatenate((groups[label], node_arr)) + else: + groups[label] = node_arr + return labels, groups diff --git a/src/spine/math/linalg.py b/src/spine/math/linalg.py index 4059b9b76..04456f883 100644 --- a/src/spine/math/linalg.py +++ b/src/spine/math/linalg.py @@ -7,7 +7,7 @@ @nb.njit(cache=True) -def norm(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:]: +def norm(x: np.ndarray, axis: int) -> np.ndarray: """Compute vector norms along specified axis. This is a Numba-compiled implementation of `np.linalg.norm(x, axis=axis)` @@ -44,19 +44,17 @@ def norm(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:]: assert axis == 0 or axis == 1 xnorm = np.empty(x.shape[1 - axis], dtype=x.dtype) if axis == 0: - for i in range(len(xnorm)): + for i in range(x.shape[1]): xnorm[i] = np.linalg.norm(x[:, i]) else: - for i in range(len(xnorm)): - xnorm[i] = np.linalg.norm(x[i]) + for i, xi in enumerate(x): + xnorm[i] = np.linalg.norm(xi) return xnorm @nb.njit(cache=True) -def submatrix( - x: nb.float32[:, :], index1: nb.int32[:], index2: nb.int32[:] -) -> nb.float32[:, :]: +def submatrix(x: np.ndarray, index1: np.ndarray, index2: np.ndarray) -> np.ndarray: """Extract submatrix using row and column indices. This function creates a submatrix by selecting specific rows and columns @@ -104,8 +102,8 @@ def submatrix( @nb.njit(cache=True) def contingency_table( - x: nb.int32[:], y: nb.int32[:], nx: nb.int32 = None, ny: nb.int32 = None -) -> nb.int64[:, :]: + x: np.ndarray, y: np.ndarray, nx: int | None = None, ny: int | None = None +) -> np.ndarray: """Build a contingency table for two sets of labels. A contingency table (also known as a cross-tabulation or crosstab) shows @@ -156,14 +154,22 @@ def contingency_table( [0, 2], [0, 0]]) """ + if len(x) != len(y): + raise ValueError("Label arrays must have the same length.") + # If not provided, assume that the max label is the max of the label array - if not nx: - nx = np.max(x) + 1 if len(x) > 0 else 1 - if not ny: - ny = np.max(y) + 1 if len(y) > 0 else 1 + if nx is None: + nx_val = np.max(x) + 1 if len(x) > 0 else 1 + else: + nx_val = nx + + if ny is None: + ny_val = np.max(y) + 1 if len(y) > 0 else 1 + else: + ny_val = ny # Bin the table - table = np.zeros((nx, ny), dtype=np.int64) + table = np.zeros((nx_val, ny_val), dtype=np.int64) for i, j in zip(x, y): table[i, j] += 1 diff --git a/src/spine/math/metrics.py b/src/spine/math/metrics.py index 2b85928bb..d146bd5fb 100644 --- a/src/spine/math/metrics.py +++ b/src/spine/math/metrics.py @@ -85,6 +85,9 @@ def adjusted_rand_score(labels_true, labels_pred): >>> adjusted_rand_score(labels_true, labels_pred) # doctest: +ELLIPSIS 0.0 """ + if len(labels_true) != len(labels_pred): + raise ValueError("Labels must have the same length") + # Get dimensions for contingency table nx = labels_true.max() + 1 if len(labels_true) > 0 else 1 ny = labels_pred.max() + 1 if len(labels_pred) > 0 else 1 diff --git a/src/spine/math/neighbors.py b/src/spine/math/neighbors.py index 3e97a33a8..98fd43c4b 100644 --- a/src/spine/math/neighbors.py +++ b/src/spine/math/neighbors.py @@ -9,7 +9,7 @@ import numpy as np from .base import mode -from .distance import METRICS, cdist, get_metric_id +from .distance import cdist, get_metric_id __all__ = ["RadiusNeighborsClassifier", "KNeighborsClassifier"] @@ -25,7 +25,7 @@ KNC_DTYPE = (("k", nb.int64), ("metric_id", nb.int64), ("p", nb.float32)) -@nb.experimental.jitclass(RNC_DTYPE) +@nb.experimental.jitclass(spec=RNC_DTYPE) # type: ignore[call-arg] class RadiusNeighborsClassifier: """Class which assigns labels to points based on radial neighborhood majority vote. @@ -54,11 +54,11 @@ class RadiusNeighborsClassifier: def __init__( self, - radius: nb.float32, - metric: nb.types.string = "euclidean", - p: nb.float32 = 2.0, - iterate: nb.boolean = True, - ): + radius: float, + metric: str = "euclidean", + p: float = 2.0, + iterate: bool = True, + ) -> None: """Initialize the RadiusNeighborsClassifier parameters. Parameters @@ -72,6 +72,9 @@ def __init__( iterate : bool, default True Whether to recurse the search until no new labels are assigned """ + if radius < 0.0: + raise ValueError("Radius must be non-negative.") + # For Euclidean, save time by using squared Euclidean if metric == "euclidean": metric = "sqeuclidean" @@ -83,7 +86,9 @@ def __init__( self.p = p self.iterate = iterate - def fit_predict(self, X: nb.float32[:, :], y: nb.float32[:], Xq: nb.float32[:, :]): + def fit_predict( + self, X: np.ndarray, y: np.ndarray, Xq: np.ndarray + ) -> tuple[np.ndarray, np.ndarray]: """Assign labels to a set of points given a set of reference points. Parameters @@ -91,16 +96,16 @@ def fit_predict(self, X: nb.float32[:, :], y: nb.float32[:], Xq: nb.float32[:, : X : np.ndarray (N, 3) Set of reference points y : np.ndarray - (N) Labels of reference points + (N,) Labels of reference points Xq : nb.ndarray (M, 3) Set of query points Returns ------- np.ndarray - (M) Labels assigned to the query points + (M,) Labels assigned to the query points np.ndarray - Index of points which have not been sucessfully assigned + Index of points which have not been successfully assigned """ # Loop over query points until no new labels can be assigned num_query = len(Xq) @@ -150,7 +155,7 @@ def fit_predict(self, X: nb.float32[:, :], y: nb.float32[:], Xq: nb.float32[:, : return labels, orphan_index -@nb.experimental.jitclass(KNC_DTYPE) +@nb.experimental.jitclass(spec=KNC_DTYPE) # type: ignore[call-arg] class KNeighborsClassifier: """Class which assigns labels to points based on a nearest neighbor majority vote. @@ -175,9 +180,7 @@ class KNeighborsClassifier: p-norm factor for the Minkowski metric, if used """ - def __init__( - self, k: nb.int64, metric: nb.types.string = "euclidean", p: nb.float32 = 2.0 - ): + def __init__(self, k: int, metric: str = "euclidean", p: float = 2.0) -> None: """Initialize the RadiusNeighborsClassifier parameters. Parameters @@ -189,6 +192,9 @@ def __init__( p : float, default 2. p-norm factor for the Minkowski metric, if used """ + if k <= 0: + raise ValueError("Number of neighbors must be positive.") + # For Euclidean, save time by using squared Euclidean if metric == "euclidean": metric = "sqeuclidean" @@ -198,7 +204,9 @@ def __init__( self.metric_id = get_metric_id(metric, p) self.p = p - def fit_predict(self, X: nb.float32[:, :], y: nb.float32[:], Xq: nb.float32[:, :]): + def fit_predict( + self, X: np.ndarray, y: np.ndarray, Xq: np.ndarray + ) -> tuple[np.ndarray, np.ndarray]: """Assign labels to a set of points given a set of reference points. Parameters @@ -206,16 +214,16 @@ def fit_predict(self, X: nb.float32[:, :], y: nb.float32[:], Xq: nb.float32[:, : X : np.ndarray (N, 3) Set of reference points y : np.ndarray - (N) Labels of reference points + (N,) Labels of reference points Xq : nb.ndarray (M, 3) Set of query points Returns ------- np.ndarray - (M) Labels assigned to the query points + (M,) Labels assigned to the query points np.ndarray - Index of points which have not been sucessfully assigned + Index of points which have not been successfully assigned """ # If there are no labeled points provided, nothing to do if len(X) == 0: @@ -227,7 +235,7 @@ def fit_predict(self, X: nb.float32[:, :], y: nb.float32[:], Xq: nb.float32[:, : # Start by computing the distance between the query and reference dists = cdist(Xq, X, metric_id=self.metric_id, p=self.p) - # Loop over query poins + # Loop over query points labels = np.empty(len(Xq), dtype=np.int64) for i in range(len(Xq)): # Find the list k closest labels diff --git a/src/spine/post/reco/cathode_cross.py b/src/spine/post/reco/cathode_cross.py index d0afe587b..257e349e3 100644 --- a/src/spine/post/reco/cathode_cross.py +++ b/src/spine/post/reco/cathode_cross.py @@ -392,7 +392,7 @@ def get_cathode_offset(self, particle): # Get the end points of the track segment index = self.geo.get_volume_index(self.get_sources(particle), module, tpc) points = self.get_points(particle)[index] - idx0, idx1, _ = farthest_pair(points, "recursive") + idx0, idx1, _ = farthest_pair(points, iterative=True) end_points = points[[idx0, idx1]] # Find the point closest to the cathode diff --git a/src/spine/vis/network.py b/src/spine/vis/network.py index 16d4965e8..b59d84a8c 100644 --- a/src/spine/vis/network.py +++ b/src/spine/vis/network.py @@ -96,7 +96,7 @@ def network_topology( # For scatter and hull, join closest point to closest point for i, j in edge_index: vi, vj = points[clusts[i]], points[clusts[j]] - i1, i2, _ = closest_pair(vi, vj, "recursive") + i1, i2, _ = closest_pair(vi, vj, iterative=True) edge_vertices.extend([vi[i1], vj[i2], [None, None, None]]) else: diff --git a/test/test_math/test_base_regression.py b/test/test_math/test_base_regression.py new file mode 100644 index 000000000..04cda936d --- /dev/null +++ b/test/test_math/test_base_regression.py @@ -0,0 +1,66 @@ +"""Regression tests for base math helpers.""" + +import numpy as np +import pytest + +from spine.math.base import ( + all, + amax, + amin, + argmax, + log_loss, + mean, + softmax, + sum, + unique, +) + + +def test_min_max_preserve_float_dtype_and_values(): + """Min/max reductions should not truncate floating point values.""" + x = np.array([[1.25, -2.5], [3.75, 4.5]], dtype=np.float32) + + mins = amin(x, 0) + maxs = amax(x, 1) + + assert mins.dtype == x.dtype + assert maxs.dtype == x.dtype + np.testing.assert_allclose(mins, [1.25, -2.5]) + np.testing.assert_allclose(maxs, [1.25, 4.5]) + + +def test_axis_one_reductions_match_numpy(): + """Axis-one reduction branches should match numpy.""" + x = np.array([[1.25, -2.5], [3.75, 4.5]], dtype=np.float32) + + np.testing.assert_allclose(sum(x, 1), np.sum(x, axis=1)) + np.testing.assert_allclose(mean(x, 1), np.mean(x, axis=1)) + np.testing.assert_array_equal(argmax(x, 1), np.argmax(x, axis=1)) + np.testing.assert_allclose(amin(x, 1), np.min(x, axis=1)) + + +def test_axis_zero_all_matches_numpy(): + """Axis-zero all branch should match numpy.""" + x = np.array([[True, True], [True, False]]) + + np.testing.assert_array_equal(all(x, 0), np.all(x, axis=0)) + + +def test_axis_reductions_cover_axis_one_and_invalid_axis(): + """Base reductions should support axis 1 and reject other axes.""" + x = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + + np.testing.assert_allclose(softmax(x, 1).sum(axis=1), [1.0, 1.0], atol=1e-6) + np.testing.assert_array_equal(all(x > 0.0, 1), [True, True]) + + with pytest.raises(AssertionError): + softmax(x, 2) + + +def test_unique_empty_and_log_loss_empty_inputs(): + """Helpers should handle empty arrays.""" + values, counts = unique(np.empty(0, dtype=np.int64)) + + assert len(values) == 0 + assert len(counts) == 0 + assert log_loss(np.empty(0, dtype=np.bool_), np.empty(0, dtype=np.float32)) == 0.0 diff --git a/test/test_math/test_cluster.py b/test/test_math/test_cluster.py index 50de17393..665b60332 100644 --- a/test/test_math/test_cluster.py +++ b/test/test_math/test_cluster.py @@ -76,10 +76,10 @@ def test_dbscan_class(self): [[0.0, 0.0, 0.0], [0.1, 0.1, 0.1], [5.0, 5.0, 5.0]], dtype=np.float32 ) - # Test fit method if available - if hasattr(clusterer, "fit"): - labels = clusterer.fit(points) - assert len(labels) == len(points) + labels = clusterer.fit_predict(points) + assert len(labels) == len(points) + assert labels[0] == labels[1] + assert labels[2] != labels[0] except (ImportError, TypeError, AttributeError): pytest.skip("DBSCAN class not available") @@ -118,6 +118,28 @@ def test_dbscan_parameters(self): except (ImportError, TypeError, AttributeError): pytest.skip("DBSCAN parameters test not available") + def test_dbscan_rejects_invalid_parameters(self): + """DBSCAN should reject invalid density parameters.""" + try: + from spine.math.cluster import DBSCAN, dbscan + + points = np.zeros((1, 3), dtype=np.float32) + + with pytest.raises(ValueError, match="non-negative"): + DBSCAN(eps=-1.0) + + with pytest.raises(ValueError, match="positive"): + DBSCAN(eps=1.0, min_samples=0) + + with pytest.raises(ValueError, match="non-negative"): + dbscan(points, eps=-1.0) + + with pytest.raises(ValueError, match="positive"): + dbscan(points, eps=1.0, min_samples=0) + + except (ImportError, TypeError, AttributeError): + pytest.skip("DBSCAN invalid parameter test not available") + def test_dbscan_noise_detection(self): """Test DBSCAN noise detection.""" try: diff --git a/test/test_math/test_decomposition.py b/test/test_math/test_decomposition.py new file mode 100644 index 000000000..0e2610de7 --- /dev/null +++ b/test/test_math/test_decomposition.py @@ -0,0 +1,56 @@ +"""Tests for decomposition helpers.""" + +import numpy as np +import pytest + +from spine.math.decomposition import PCA, principal_components + + +def test_principal_components_are_orthonormal(): + """Principal component vectors should form an orthonormal basis.""" + x = np.array( + [[-2.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]], + dtype=np.float32, + ) + + components = principal_components(x) + + np.testing.assert_allclose(components @ components.T, np.eye(3), atol=1e-6) + np.testing.assert_allclose(np.abs(components[0]), [1.0, 0.0, 0.0], atol=1e-6) + + +def test_pca_fit_returns_requested_components(): + """PCA jitclass should return the requested number of components.""" + x = np.array( + [[-2.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]], + dtype=np.float32, + ) + pca = PCA(2) + + components, variance = pca.fit(x) + + assert components.shape == (2, 3) + assert variance.shape == (2,) + np.testing.assert_allclose(np.abs(components[0]), [1.0, 0.0, 0.0], atol=1e-6) + assert variance[0] > variance[1] + + +def test_pca_rejects_invalid_component_count(): + """PCA should reject empty or overcomplete component requests.""" + with pytest.raises(AssertionError, match="one component"): + PCA(0) + + pca = PCA(4) + with pytest.raises(AssertionError, match="dimensionality"): + pca.fit(np.ones((3, 3), dtype=np.float32)) + + +def test_pca_rejects_undersampled_inputs(): + """PCA requires at least two samples to produce meaningful variance.""" + x = np.ones((1, 3), dtype=np.float32) + + with pytest.raises(AssertionError, match="two samples"): + principal_components(x) + + with pytest.raises(AssertionError, match="two samples"): + PCA(1).fit(x) diff --git a/test/test_math/test_distance.py b/test/test_math/test_distance.py index 9ad43255b..e1ab6b1c6 100644 --- a/test/test_math/test_distance.py +++ b/test/test_math/test_distance.py @@ -231,7 +231,7 @@ def test_farthest_pair(self): dtype=np.float32, ) - i, j, distance = farthest_pair(points, METRICS["euclidean"]) + i, j, distance = farthest_pair(points, metric_id=METRICS["euclidean"]) # Should find the farthest pair (points 0 and 2) assert set([i, j]) == {0, 2} diff --git a/test/test_math/test_distance_regression.py b/test/test_math/test_distance_regression.py new file mode 100644 index 000000000..c864471dc --- /dev/null +++ b/test/test_math/test_distance_regression.py @@ -0,0 +1,100 @@ +"""Regression tests for distance helpers.""" + +import numpy as np +import pytest + +from spine.math.distance import ( + METRICS, + cdist, + closest_pair, + farthest_pair, + get_metric_id, + pdist, +) + +POINTS = np.array( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 2.0, 0.0]], + dtype=np.float32, +) + + +def test_get_metric_id_dispatches_minkowski_aliases_and_errors(): + """Metric dispatch should special-case p=1 and p=2 Minkowski.""" + assert get_metric_id("minkowski", 1.0) == METRICS["cityblock"] + assert get_metric_id("minkowski", 2.0) == METRICS["euclidean"] + assert get_metric_id("minkowski", 3.0) == METRICS["minkowski"] + assert get_metric_id("cityblock", 2.0) == METRICS["cityblock"] + assert get_metric_id("euclidean", 2.0) == METRICS["euclidean"] + assert get_metric_id("sqeuclidean", 2.0) == METRICS["sqeuclidean"] + assert get_metric_id("chebyshev", 2.0) == METRICS["chebyshev"] + + with pytest.raises(ValueError, match="not recognized"): + get_metric_id("bad", 2.0) + + +def test_pdist_dispatches_all_metrics_and_errors(): + """Pairwise distance matrix should support every metric enumerator.""" + expected = { + METRICS["minkowski"]: [ + [0.0, 1.0, 2.0], + [1.0, 0.0, 2.0800838], + [2.0, 2.0800838, 0.0], + ], + METRICS["cityblock"]: [[0.0, 1.0, 2.0], [1.0, 0.0, 3.0], [2.0, 3.0, 0.0]], + METRICS["euclidean"]: [ + [0.0, 1.0, 2.0], + [1.0, 0.0, np.sqrt(5.0)], + [2.0, np.sqrt(5.0), 0.0], + ], + METRICS["sqeuclidean"]: [[0.0, 1.0, 4.0], [1.0, 0.0, 5.0], [4.0, 5.0, 0.0]], + METRICS["chebyshev"]: [[0.0, 1.0, 2.0], [1.0, 0.0, 2.0], [2.0, 2.0, 0.0]], + } + + for metric, matrix in expected.items(): + np.testing.assert_allclose(pdist(POINTS, metric, p=3.0), matrix, atol=1e-5) + + with pytest.raises(ValueError, match="Distance metric"): + pdist(POINTS, np.int64(99)) + + +def test_cdist_dispatches_all_metrics_and_errors(): + """Cross-distance matrix should support every metric enumerator.""" + x1 = POINTS[:2] + x2 = POINTS[1:] + + for metric in ( + METRICS["minkowski"], + METRICS["cityblock"], + METRICS["euclidean"], + METRICS["sqeuclidean"], + METRICS["chebyshev"], + ): + distances = cdist(x1, x2, metric, p=3.0) + assert distances.shape == (2, 2) + assert distances[1, 0] == 0.0 + + with pytest.raises(ValueError, match="Distance metric"): + cdist(x1, x2, np.int64(99)) + + +def test_pair_helpers_support_iterative_paths(): + """Closest/farthest pair helpers should cover iterative variants.""" + i, j, dist = farthest_pair(POINTS, iterative=True) + assert {i, j} == {1, 2} + assert np.isclose(dist, np.sqrt(5.0)) + + x2 = np.array([[10.0, 0.0, 0.0], [0.1, 0.0, 0.0]], dtype=np.float32) + i, j, dist = closest_pair(POINTS, x2, iterative=True, seed=False) + assert (i, j) == (0, 1) + assert np.isclose(dist, 0.1) + + i, j, dist = closest_pair(POINTS, x2, iterative=True, seed=True) + assert np.isclose(dist, 0.1) + + +def test_farthest_pair_brute_with_non_euclidean_metric(): + """Brute farthest pair should support non-Euclidean metric branches.""" + i, j, dist = farthest_pair(POINTS, iterative=False, metric_id=METRICS["cityblock"]) + + assert {i, j} == {1, 2} + assert np.isclose(dist, 3.0) diff --git a/test/test_math/test_graph.py b/test/test_math/test_graph.py new file mode 100644 index 000000000..b728542a3 --- /dev/null +++ b/test/test_math/test_graph.py @@ -0,0 +1,124 @@ +"""Tests for graph helpers.""" + +import numpy as np +import pytest + +from spine.math.distance import METRICS +from spine.math.graph import ( + connected_components, + csr_graph, + dfs, + dfs_iterative, + radius_graph, + union_find, +) + + +def sorted_edges(edge_index): + """Return lexicographically sorted edge tuples.""" + return sorted(map(tuple, np.asarray(edge_index))) + + +def test_csr_graph_directed_and_undirected_neighbors(): + """CSR graph should expose directed and undirected neighborhoods.""" + edges = np.array([[0, 1], [1, 2]], dtype=np.int64) + + directed = csr_graph(edges, 3, directed=True) + np.testing.assert_array_equal(directed[0], [1]) + np.testing.assert_array_equal(directed[2], []) + assert directed.num_neighbors(1) == 1 + + undirected = csr_graph(edges, 3, directed=False) + np.testing.assert_array_equal(np.sort(undirected[1]), [0, 2]) + assert undirected.num_neighbors(1) == 2 + + +def test_connected_components_and_dfs_variants(): + """Connected-component helpers should traverse equivalent components.""" + edges = np.array([[0, 1], [1, 2], [3, 4]], dtype=np.int64) + + labels = connected_components(edges, 6, directed=False) + np.testing.assert_array_equal(labels, [0, 0, 0, 1, 1, 2]) + + graph = csr_graph(edges, 6, directed=False) + for search in (dfs, dfs_iterative): + visited = np.zeros(6, dtype=np.bool_) + component = np.empty(6, dtype=np.int64) + comp_idx = np.zeros(1, dtype=np.int64) + search(graph, visited, 0, component, comp_idx) + np.testing.assert_array_equal(np.sort(component[: comp_idx[0]]), [0, 1, 2]) + + +def test_connected_components_respects_min_samples(): + """Nodes below the neighbor threshold should not expand components.""" + edges = np.array([[0, 1], [1, 2]], dtype=np.int64) + + labels = connected_components(edges, 3, min_samples=4, directed=False) + + np.testing.assert_array_equal(labels, [0, 1, 2]) + + +def test_radius_graph_supports_all_metrics(): + """Radius graph should dispatch all supported distance metrics.""" + points = np.array( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [3.0, 0.0, 0.0]], + dtype=np.float32, + ) + + for metric in ( + METRICS["minkowski"], + METRICS["cityblock"], + METRICS["euclidean"], + METRICS["sqeuclidean"], + METRICS["chebyshev"], + ): + edges = radius_graph(points, 1.1, metric_id=metric, p=2.0) + assert sorted_edges(edges) == [(0, 1)] + + with pytest.raises(ValueError, match="Distance metric"): + radius_graph(points, 1.0, metric_id=np.int64(99)) + + +def test_union_find_returns_labels_and_groups(): + """Union-find should merge connected nodes and optionally keep raw labels.""" + edges = np.array([[0, 1], [2, 3]], dtype=np.int64) + + labels, groups = union_find(edges, 5) + np.testing.assert_array_equal(labels, [0, 0, 1, 1, 2]) + np.testing.assert_array_equal(np.sort(groups[0]), [0, 1]) + np.testing.assert_array_equal(np.sort(groups[1]), [2, 3]) + np.testing.assert_array_equal(np.sort(groups[2]), [4]) + + raw_labels, _ = union_find(edges, 5, return_inverse=False) + np.testing.assert_array_equal(raw_labels, [0, 0, 2, 2, 4]) + + +def test_union_find_group_keys_match_returned_labels(): + """Group dictionary keys should use the same label space as labels.""" + edges = np.array([[1, 2], [0, 1]], dtype=np.int64) + + labels, groups = union_find(edges, 5) + + np.testing.assert_array_equal(labels, [0, 0, 0, 1, 2]) + assert set(groups.keys()) == set(labels) + np.testing.assert_array_equal(np.sort(groups[0]), [0, 1, 2]) + np.testing.assert_array_equal(groups[1], [3]) + np.testing.assert_array_equal(groups[2], [4]) + + +def test_union_find_merges_into_lower_root(): + """Union-find should use a stable low-root representative.""" + edges = np.array([[2, 1]], dtype=np.int64) + + labels, groups = union_find(edges, 3, return_inverse=False) + + np.testing.assert_array_equal(labels, [0, 1, 1]) + np.testing.assert_array_equal(groups[1], [1, 2]) + + +def test_union_find_handles_empty_graph(): + """Union-find should handle a graph with no nodes.""" + labels, groups = union_find(np.empty((0, 2), dtype=np.int64), 0) + + np.testing.assert_array_equal(labels, []) + assert len(groups) == 0 diff --git a/test/test_math/test_linalg.py b/test/test_math/test_linalg.py new file mode 100644 index 000000000..53871d522 --- /dev/null +++ b/test/test_math/test_linalg.py @@ -0,0 +1,57 @@ +"""Tests for linear algebra helpers.""" + +import numpy as np +import pytest + +from spine.math.linalg import contingency_table, norm, submatrix + + +def test_norm_matches_numpy_by_axis(): + """Norm should match numpy along both supported axes.""" + x = np.array([[3.0, 4.0], [0.0, 5.0]], dtype=np.float32) + + np.testing.assert_allclose(norm(x, 0), np.linalg.norm(x, axis=0)) + np.testing.assert_allclose(norm(x, 1), np.linalg.norm(x, axis=1)) + + +def test_submatrix_extracts_row_column_product(): + """Submatrix should select the Cartesian product of row/column indexes.""" + x = np.arange(12, dtype=np.float32).reshape(3, 4) + rows = np.array([2, 0], dtype=np.int32) + cols = np.array([3, 1], dtype=np.int32) + + np.testing.assert_array_equal(submatrix(x, rows, cols), x[np.ix_(rows, cols)]) + + +def test_contingency_table_infers_and_accepts_shape(): + """Contingency table should infer dimensions or use explicit dimensions.""" + x = np.array([0, 0, 1, 2], dtype=np.int32) + y = np.array([1, 1, 0, 1], dtype=np.int32) + + np.testing.assert_array_equal( + contingency_table(x, y), + [[0, 2], [1, 0], [0, 1]], + ) + np.testing.assert_array_equal( + contingency_table(x, y, nx=4, ny=3), + [[0, 2, 0], [1, 0, 0], [0, 1, 0], [0, 0, 0]], + ) + + +def test_contingency_table_handles_empty_inputs(): + """Empty label inputs should produce a one-cell empty table.""" + table = contingency_table( + np.empty(0, dtype=np.int32), + np.empty(0, dtype=np.int32), + ) + + np.testing.assert_array_equal(table, [[0]]) + + +def test_contingency_table_rejects_length_mismatch(): + """Label arrays must describe the same samples.""" + with pytest.raises(ValueError, match="same length"): + contingency_table( + np.array([0, 1], dtype=np.int32), + np.array([0], dtype=np.int32), + ) diff --git a/test/test_math/test_metrics.py b/test/test_math/test_metrics.py new file mode 100644 index 000000000..b9bf9811e --- /dev/null +++ b/test/test_math/test_metrics.py @@ -0,0 +1,77 @@ +"""Tests for clustering metrics.""" + +import numpy as np +import pytest + +from spine.math.metrics import _entropy, adjusted_mutual_info_score, adjusted_rand_score + + +def test_adjusted_rand_score_handles_perfect_random_and_empty_cases(): + """ARI should cover perfect, random-like and degenerate inputs.""" + perfect = np.array([0, 0, 1, 1], dtype=np.int32) + crossed = np.array([0, 1, 0, 1], dtype=np.int32) + one_cluster = np.zeros(4, dtype=np.int32) + + assert adjusted_rand_score(perfect, perfect) == 1.0 + assert adjusted_rand_score(crossed, perfect) <= 0.0 + assert adjusted_rand_score(one_cluster, one_cluster) == 1.0 + assert ( + adjusted_rand_score(np.empty(0, dtype=np.int32), np.empty(0, dtype=np.int32)) + == 1.0 + ) + + +def test_adjusted_mutual_info_score_handles_common_cases(): + """AMI should cover perfect, one-cluster and empty inputs.""" + perfect = np.array([0, 0, 1, 1], dtype=np.int32) + crossed = np.array([0, 1, 0, 1], dtype=np.int32) + one_cluster = np.zeros(4, dtype=np.int32) + + assert adjusted_mutual_info_score(perfect, perfect) == 1.0 + assert adjusted_mutual_info_score(one_cluster, one_cluster) == 1.0 + assert adjusted_mutual_info_score(perfect, one_cluster) == 0.0 + assert ( + adjusted_mutual_info_score( + np.array([0], dtype=np.int32), + np.array([0], dtype=np.int32), + ) + == 1.0 + ) + assert adjusted_mutual_info_score(crossed, perfect) <= 1.0 + assert ( + adjusted_mutual_info_score( + np.array([0, 1], dtype=np.int32), + np.array([0, 1], dtype=np.int32), + ) + == 1.0 + ) + assert ( + adjusted_mutual_info_score( + np.empty(0, dtype=np.int32), + np.empty(0, dtype=np.int32), + ) + == 1.0 + ) + + +def test_entropy_handles_singleton_input(): + """Private entropy helper should handle singleton labels.""" + assert _entropy(np.array([0], dtype=np.int32)) == 0.0 + + +def test_adjusted_mutual_info_score_rejects_length_mismatch(): + """AMI inputs must have matching lengths.""" + with pytest.raises(ValueError, match="same length"): + adjusted_mutual_info_score( + np.array([0, 1], dtype=np.int32), + np.array([0], dtype=np.int32), + ) + + +def test_adjusted_rand_score_rejects_length_mismatch(): + """ARI inputs must have matching lengths.""" + with pytest.raises(ValueError, match="same length"): + adjusted_rand_score( + np.array([0, 1], dtype=np.int32), + np.array([0], dtype=np.int32), + ) diff --git a/test/test_math/test_neighbors.py b/test/test_math/test_neighbors.py new file mode 100644 index 000000000..69b34adf3 --- /dev/null +++ b/test/test_math/test_neighbors.py @@ -0,0 +1,85 @@ +"""Tests for neighbor classifiers.""" + +import numpy as np +import pytest + +from spine.math.neighbors import KNeighborsClassifier, RadiusNeighborsClassifier + + +def test_radius_neighbors_assigns_and_reports_orphans(): + """Radius classifier should assign nearby labels and report unassigned queries.""" + x = np.array([[0.0, 0.0, 0.0], [10.0, 0.0, 0.0]], dtype=np.float32) + y = np.array([1, 2], dtype=np.int64) + xq = np.array([[0.2, 0.0, 0.0], [50.0, 0.0, 0.0]], dtype=np.float32) + + clf = RadiusNeighborsClassifier(radius=0.5, iterate=False) + labels, orphan_index = clf.fit_predict(x, y, xq) + + np.testing.assert_array_equal(labels, [1, -1]) + np.testing.assert_array_equal(orphan_index, [1]) + + +def test_radius_neighbors_iterates_over_new_labels(): + """Iterative radius classifier should use newly assigned labels.""" + x = np.array([[0.0, 0.0, 0.0]], dtype=np.float32) + y = np.array([7], dtype=np.int64) + xq = np.array([[0.4, 0.0, 0.0], [0.8, 0.0, 0.0]], dtype=np.float32) + + clf = RadiusNeighborsClassifier(radius=0.5, iterate=True) + labels, orphan_index = clf.fit_predict(x, y, xq) + + np.testing.assert_array_equal(labels, [7, 7]) + np.testing.assert_array_equal(orphan_index, []) + + +def test_radius_neighbors_stops_when_no_assignments_change(): + """Radius classifier should stop when every query remains orphaned.""" + x = np.array([[0.0, 0.0, 0.0]], dtype=np.float32) + y = np.array([1], dtype=np.int64) + xq = np.array([[10.0, 0.0, 0.0], [20.0, 0.0, 0.0]], dtype=np.float32) + + clf = RadiusNeighborsClassifier(radius=0.5, iterate=True) + labels, orphan_index = clf.fit_predict(x, y, xq) + + np.testing.assert_array_equal(labels, [-1, -1]) + np.testing.assert_array_equal(orphan_index, [0, 1]) + + +def test_k_neighbors_assigns_mode_labels(): + """KNN classifier should assign majority labels.""" + x = np.array( + [[0.0, 0.0, 0.0], [0.1, 0.0, 0.0], [5.0, 0.0, 0.0]], + dtype=np.float32, + ) + y = np.array([1, 1, 2], dtype=np.int64) + xq = np.array([[0.05, 0.0, 0.0]], dtype=np.float32) + + clf = KNeighborsClassifier(k=2) + labels, orphan_index = clf.fit_predict(x, y, xq) + + np.testing.assert_array_equal(labels, [1]) + np.testing.assert_array_equal(orphan_index, []) + + +def test_k_neighbors_handles_empty_reference_set(): + """KNN classifier should mark every query orphaned without references.""" + xq = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=np.float32) + clf = KNeighborsClassifier(k=1) + + labels, orphan_index = clf.fit_predict( + np.empty((0, 3), dtype=np.float32), + np.empty(0, dtype=np.int64), + xq, + ) + + np.testing.assert_array_equal(labels, [-1, -1]) + np.testing.assert_array_equal(orphan_index, [0, 1]) + + +def test_neighbor_classifiers_reject_invalid_configuration(): + """Invalid neighborhood parameters should fail at construction.""" + with pytest.raises(ValueError, match="non-negative"): + RadiusNeighborsClassifier(radius=-1.0) + + with pytest.raises(ValueError, match="positive"): + KNeighborsClassifier(k=0)