Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
# We aim to support the versions on pytorch.org
# as well as selected previous versions on
# https://pytorch.org/get-started/previous-versions/
torch-version: ["2.6.0", "2.9.1"]
torch-version: ["2.6.0", "2.10.0"]
sklearn-version: ["latest"]
include:
# windows test with standard config
Expand Down
3 changes: 2 additions & 1 deletion cebra/integrations/deeplabcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def load_data(self, pcutoff: float = 0.6) -> npt.NDArray:
pred_xy = []
for i, _ in enumerate(self.dlc_df.index):
data = (self.dlc_df.iloc[i].loc[self.scorer].loc[
self.keypoints_list].to_numpy().reshape(-1, len(dlc_df_coords)))
self.keypoints_list].to_numpy().copy().reshape(
-1, len(dlc_df_coords)))

# Handles nan values with interpolation
if i > 0 and i < len(self.dlc_df) - 1:
Expand Down
2 changes: 2 additions & 0 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,8 @@ def transform(self,

X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))

X = cebra_sklearn_dataset._ensure_writable(X)

if isinstance(X, np.ndarray):
X = torch.from_numpy(X)

Expand Down
27 changes: 27 additions & 0 deletions cebra/integrations/sklearn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#
"""Datasets to be used as part of the sklearn framework."""

import traceback
import warnings
from typing import Iterable, Optional

import numpy as np
Expand All @@ -34,6 +36,28 @@
import cebra.solver


def _ensure_writable(array: npt.NDArray) -> npt.NDArray:
if not array.flags.writeable:
stack = traceback.extract_stack()[-5:-1]
stack_str = ''.join(traceback.format_list(stack[-4:]))

warnings.warn(
("You passed a non-writable Numpy array to CEBRA. Pytorch does currently "
"not support non-writable tensors. As a result, CEBRA needs to copy the "
"contents of the array, which might yield unnecessary memory overhead. "
"Ideally, adapt the code such that the array you pass to CEBRA is writable "
"to make your code memory efficient. "
"You can find more context and the rationale for this fix here: "
"https://github.com/AdaptiveMotorControlLab/CEBRA/pull/289."
"\n\n"
"Trace:\n" + stack_str),
UserWarning,
stacklevel=2,
)
array = array.copy()
return array


class SklearnDataset(cebra.data.SingleSessionDataset):
"""Dataset for wrapping array-like input/index pairs.

Expand Down Expand Up @@ -110,6 +134,7 @@ def _parse_data(self, X: npt.NDArray):
# one sample is a conservative default here to ensure that sklearn tests
# passes with the correct error messages.
X = cebra_sklearn_utils.check_input_array(X, min_samples=2)
X = _ensure_writable(X)
self.neural = torch.from_numpy(X).float().to(self.device)

def _parse_labels(self, labels: Optional[tuple]):
Expand Down Expand Up @@ -143,6 +168,8 @@ def _parse_labels(self, labels: Optional[tuple]):
f"or lists that can be converted to arrays, but got {type(y)}"
)

y = _ensure_writable(y)

# Define the index as either continuous or discrete indices, depending
# on the dtype in the index array.
if cebra.helper._is_floating(y):
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ markers = [
"cuda",
]
addopts = "--ignore=cebra/integrations/threejs --ignore=cebra/integrations/streamlit.py --ignore=cebra/datasets"
# NOTE(stes): See https://github.com/AdaptiveMotorControlLab/CEBRA/pull/289.
filterwarnings = [
"error:The given NumPy array is not writable.*PyTorch does not support non-writable tensors:UserWarning",
]



Expand Down
17 changes: 17 additions & 0 deletions tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,3 +1544,20 @@ def test_last_incomplete_batch_smaller_than_offset():
model.fit(train.neural, train.continuous)

_ = model.transform(train.neural, batch_size=300)


def test_non_writable_array():
X = np.random.randn(100, 10)
y = np.random.randn(100, 2)
X.setflags(write=False)
y.setflags(write=False)
with pytest.raises(ValueError, match="assignment destination is read-only"):
X[:] = 0
y[:] = 0

cebra_model = cebra.CEBRA(max_iterations=2, batch_size=32, device="cpu")

cebra_model.fit(X, y)
embedding = cebra_model.transform(X)
assert isinstance(embedding, np.ndarray)
assert embedding.shape[0] == X.shape[0]