Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ jobs:
python -m pip install --root-user-action=ignore -e . pytest pytest-cov coverage[toml]
mkdir -p /tmp/spine-sitecustomize
printf "import h5py\n" > /tmp/spine-sitecustomize/sitecustomize.py
PYTHONPATH="/tmp/spine-sitecustomize:${PYTHONPATH:-}" pytest test \
NUMBA_DISABLE_JIT=1 PYTHONPATH="/tmp/spine-sitecustomize:${PYTHONPATH:-}" pytest test \
-v \
--cov=spine \
--cov-report=xml:coverage.xml \
Expand Down
1 change: 1 addition & 0 deletions check_coverage.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ docker run --rm \
-w /workspace \
-e TEST_PATH="$test_path" \
-e COVERAGE_TARGET="$coverage_target" \
-e NUMBA_DISABLE_JIT=1 \
--platform linux/amd64 \
ghcr.io/deeplearnphysics/spine:latest \
bash -lc '
Expand Down
4 changes: 2 additions & 2 deletions src/spine/math/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Module with fast, Numba-accelerated, compiles math routines.
"""Fast, Numba-accelerated math routines.

This includes multiple submodules:
- `base.py` includes basic functions, as found in numpy or scipy.special
- `linalg.py` includes linear algebra routines, as found in numpy.linalg
- `distance.py` includes distance functions, as found in scipy.distance
- `graph.py` includes graph routines, as found in scipy.csgraph
- `cluster.py` includes cluster functions, as found in skleran.cluster
- `cluster.py` includes cluster functions, as found in sklearn.cluster
- `metrics.py` includes clustering evaluation metrics
"""

Expand Down
142 changes: 71 additions & 71 deletions src/spine/math/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,35 @@


@nb.njit(cache=True)
def seed(seed: nb.int64) -> None:
def seed(seed_value: int) -> None:
"""Sets the numpy random seed for all Numba jitted functions.

Note that setting the seed using `np.random.seed` outside a Numba jitted
function does *not* set the seed of Numba functions.

Parameters
----------
seed : int
seed_value : int
Random number generator seed
"""
np.random.seed(seed)
np.random.seed(seed_value)


@nb.njit(cache=True)
def unique(x: nb.int64[:]) -> (nb.int64[:], nb.int64[:]):
def unique(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Numba implementation of `np.unique(x, return_counts=True)`.

Parameters
----------
x : np.ndarray
(N) array of values
(N,) array of values

Returns
-------
np.ndarray
(U) array of unique values
(U,) array of unique values
np.ndarray
(U) array of counts of each unique value in the original array
(U,) array of counts of each unique value in the original array
"""
# Nothing to do if the input is empty
uniques = np.empty(len(x), dtype=x.dtype)
Expand Down Expand Up @@ -84,197 +84,197 @@ def unique(x: nb.int64[:]) -> (nb.int64[:], nb.int64[:]):


@nb.njit(cache=True)
def sum(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:]:
def sum(x: np.ndarray, axis: int) -> np.ndarray:
"""Numba implementation of `np.sum(x, axis)`.

Parameters
----------
x : np.ndarray
(N,M) array of values
(N, M) array of values
axis : int
Array axis ID

Returns
-------
np.ndarray
(N) or (M) array of `sum` values
(N,) or (M,) array of `sum` values
"""
assert axis == 0 or axis == 1
summ = np.empty(x.shape[1 - axis], dtype=x.dtype)
if axis == 0:
for i in range(len(summ)):
for i in range(x.shape[1]):
summ[i] = np.sum(x[:, i])
else:
for i in range(len(summ)):
summ[i] = np.sum(x[i])
for i, xi in enumerate(x):
summ[i] = np.sum(xi)

return summ


@nb.njit(cache=True)
def mean(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:]:
def mean(x: np.ndarray, axis: int) -> np.ndarray:
"""Numba implementation of `np.mean(x, axis)`.

Parameters
----------
x : np.ndarray
(N,M) array of values
(N, M) array of values
axis : int
Array axis ID

Returns
-------
np.ndarray
(N) or (M) array of `mean` values
(N,) or (M,) array of `mean` values
"""
assert axis == 0 or axis == 1
mean = np.empty(x.shape[1 - axis], dtype=x.dtype)
mean_values = np.empty(x.shape[1 - axis], dtype=x.dtype)
if axis == 0:
for i in range(len(mean)):
mean[i] = np.mean(x[:, i])
for i in range(x.shape[1]):
mean_values[i] = np.mean(x[:, i])
else:
for i in range(len(mean)):
mean[i] = np.mean(x[i])
for i, xi in enumerate(x):
mean_values[i] = np.mean(xi)

return mean
return mean_values


@nb.njit(cache=True)
def mode(x: nb.int64[:]) -> nb.int64:
def mode(x: np.ndarray) -> int:
"""Numba implementation of `scipy.stats.mode(x)`.

Parameters
----------
x : np.ndarray
(N) array of values
(N,) array of values

Returns
-------
int
Most-propable value in the array
Most-probable value in the array
"""
values, counts = unique(x)

return values[np.argmax(counts)]


@nb.njit(cache=True)
def argmin(x: nb.float32[:, :], axis: nb.int32) -> nb.int32[:]:
def argmin(x: np.ndarray, axis: int) -> np.ndarray:
"""Numba implementation of `np.argmin(x, axis)`.

Parameters
----------
x : np.ndarray
(N,M) array of values
(N, M) array of values
axis : int
Array axis ID

Returns
-------
np.ndarray
(N) or (M) array of `argmin` values
(N,) or (M,) array of `argmin` values
"""
assert axis == 0 or axis == 1
argmin = np.empty(x.shape[1 - axis], dtype=np.int32)
argmin_values = np.empty(x.shape[1 - axis], dtype=np.int32)
if axis == 0:
for i in range(len(argmin)):
argmin[i] = np.argmin(x[:, i])
for i in range(x.shape[1]):
argmin_values[i] = np.argmin(x[:, i])
else:
for i in range(len(argmin)):
argmin[i] = np.argmin(x[i])
for i, xi in enumerate(x):
argmin_values[i] = np.argmin(xi)

return argmin
return argmin_values


@nb.njit(cache=True)
def argmax(x: nb.float32[:, :], axis: nb.int32) -> nb.int32[:]:
def argmax(x: np.ndarray, axis: int) -> np.ndarray:
"""Numba implementation of `np.argmax(x, axis)`.

Parameters
----------
x : np.ndarray
(N,M) array of values
(N, M) array of values
axis : int
Array axis ID

Returns
-------
np.ndarray
(N) or (M) array of `argmax` values
(N,) or (M,) array of `argmax` values
"""
assert axis == 0 or axis == 1
argmax = np.empty(x.shape[1 - axis], dtype=np.int32)
argmax_values = np.empty(x.shape[1 - axis], dtype=np.int32)
if axis == 0:
for i in range(len(argmax)):
argmax[i] = np.argmax(x[:, i])
for i in range(x.shape[1]):
argmax_values[i] = np.argmax(x[:, i])

else:
for i in range(len(argmax)):
argmax[i] = np.argmax(x[i])
for i, xi in enumerate(x):
argmax_values[i] = np.argmax(xi)

return argmax
return argmax_values


@nb.njit(cache=True)
def amin(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:]:
def amin(x: np.ndarray, axis: int) -> np.ndarray:
"""Numba implementation of `np.amin(x, axis)`.

Parameters
----------
x : np.ndarray
(N,M) array of values
(N, M) array of values
axis : int
Array axis ID

Returns
-------
np.ndarray
(N) or (M) array of `min` values
(N,) or (M,) array of `min` values
"""
assert axis == 0 or axis == 1
xmin = np.empty(x.shape[1 - axis], dtype=np.int32)
xmin = np.empty(x.shape[1 - axis], dtype=x.dtype)
if axis == 0:
for i in range(len(xmin)):
for i in range(x.shape[1]):
xmin[i] = np.min(x[:, i])

else:
for i in range(len(xmin)):
xmin[i] = np.min(x[i])
for i, xi in enumerate(x):
xmin[i] = np.min(xi)

return xmin


@nb.njit(cache=True)
def amax(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:]:
def amax(x: np.ndarray, axis: int) -> np.ndarray:
"""Numba implementation of `np.amax(x, axis)`.

Parameters
----------
x : np.ndarray
(N,M) array of values
(N, M) array of values
axis : int
Array axis ID

Returns
-------
np.ndarray
(N) or (M) array of `max` values
(N,) or (M,) array of `max` values
"""
assert axis == 0 or axis == 1
xmax = np.empty(x.shape[1 - axis], dtype=np.int32)
xmax = np.empty(x.shape[1 - axis], dtype=x.dtype)
if axis == 0:
for i in range(len(xmax)):
for i in range(x.shape[1]):
xmax[i] = np.max(x[:, i])

else:
for i in range(len(xmax)):
xmax[i] = np.max(x[i])
for i, xi in enumerate(x):
xmax[i] = np.max(xi)

return xmax


@nb.njit(cache=True)
def all(x: nb.float32[:, :], axis: nb.int32) -> nb.boolean[:]:
def all(x: np.ndarray, axis: int) -> np.ndarray:
"""Numba implementation of `np.all(x, axis)`.

Parameters
Expand All @@ -287,37 +287,37 @@ def all(x: nb.float32[:, :], axis: nb.int32) -> nb.boolean[:]:
Returns
-------
np.ndarray
(N) or (M) array of `all` outputs
(N,) or (M,) array of `all` outputs
"""
assert axis == 0 or axis == 1
all = np.empty(x.shape[1 - axis], dtype=np.bool_)
all_values = np.empty(x.shape[1 - axis], dtype=np.bool_)
if axis == 0:
for i in range(len(all)):
all[i] = np.all(x[:, i])
for i in range(x.shape[1]):
all_values[i] = np.all(x[:, i])

else:
for i in range(len(all)):
all[i] = np.all(x[i])
for i, xi in enumerate(x):
all_values[i] = np.all(xi)

return all
return all_values


@nb.njit(cache=True)
def softmax(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:, :]:
def softmax(x: np.ndarray, axis: int) -> np.ndarray:
"""
Numba implementation of `scipy.special.softmax(x, axis)`.

Parameters
----------
x : np.ndarray
(N,M) array of values
(N, M) array of values
axis : int
Array axis ID

Returns
-------
np.ndarray
(N,M) Array of softmax scores
(N, M) array of softmax scores
"""
assert axis == 0 or axis == 1
if axis == 0:
Expand All @@ -331,15 +331,15 @@ def softmax(x: nb.float32[:, :], axis: nb.int32) -> nb.float32[:, :]:


@nb.njit(cache=True)
def log_loss(label: nb.boolean[:], pred: nb.float32[:]) -> nb.float32:
def log_loss(label: np.ndarray, pred: np.ndarray) -> float:
"""Numba implementation of cross-entropy loss.

Parameters
----------
label : np.ndarray
(N) array of boolean labels (0 or 1)
(N,) array of boolean labels (0 or 1)
pred : np.ndarray
(N) array of float scores (between 0 and 1)
(N,) array of float scores (between 0 and 1)

Returns
-------
Expand Down
Loading
Loading