Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 59 additions & 77 deletions lance_ray/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,33 @@ def read_lance(
)


def _open_lance_version(
uri: str,
storage_options: Optional[dict[str, Any]],
base_store_params_kwargs: dict[str, Any],
) -> Optional[int]:
"""Return the current version of the dataset at ``uri``, or None if it cannot be opened."""
try:
dataset = LanceDataset(
uri, storage_options=storage_options, **base_store_params_kwargs
)
return dataset.version
except Exception:
return None


def _refresh_version(
uri: str,
storage_options: Optional[dict[str, Any]],
base_store_params_kwargs: dict[str, Any],
*,
fallback: Optional[int],
) -> Optional[int]:
"""Re-open ``uri`` to read its version, keeping ``fallback`` if it cannot be opened."""
version = _open_lance_version(uri, storage_options, base_store_params_kwargs)
return version if version is not None else fallback


def write_lance(
ds: Dataset,
uri: Optional[str] = None,
Expand Down Expand Up @@ -274,8 +301,6 @@ def write_lance(
return

# Streaming path: commit one fragment per batch to minimize memory usage.
import lance

if (namespace_impl is not None or namespace_properties is not None) and table_id:
raise ValueError(
"Streaming write with 'namespace_impl' + 'table_id' is not supported; "
Expand All @@ -288,23 +313,14 @@ def write_lance(
)

dest_uri: str = uri
dest_exists = False
dest_version: Optional[int] = None
base_store_params_kwargs = {}
if base_store_params:
base_store_params_kwargs = {"base_store_params": base_store_params}

try:
_dest = lance.LanceDataset(
dest_uri,
storage_options=storage_options,
**base_store_params_kwargs,
)
dest_exists = True
dest_version = _dest.version
except Exception:
dest_exists = False
dest_version = None
dest_version = _open_lance_version(
dest_uri, storage_options, base_store_params_kwargs
)
dest_exists = dest_version is not None

# Enforce mode semantics.
if mode == "create" and dest_exists:
Expand All @@ -315,6 +331,9 @@ def write_lance(
from .fragment import LanceFragmentWriter

effective_batch_size = batch_size if batch_size is not None else 1024
overwrite_initial_bases = (
materialize_initial_bases(initial_bases) if mode == "create" else None
)

rows_seen = 0
first_commit_done = False
Expand Down Expand Up @@ -370,92 +389,55 @@ def write_lance(

# Commit after each batch.
if not first_commit_done:
# First commit: respect mode.
if mode in ("create", "overwrite") or not dest_exists:
op = LanceOperation.Overwrite(
schema_obj,
fragments,
initial_bases=(
materialize_initial_bases(initial_bases)
if mode == "create"
else None
),
)
if mode == "append" and dest_exists:
# First commit appends onto the existing dataset.
LanceDataset.commit(
dest_uri,
op,
read_version=None,
LanceOperation.Append(fragments),
read_version=dest_version,
storage_options=storage_options,
**base_store_params_kwargs,
)
first_commit_done = True
dest_exists = True
try:
_dest = lance.LanceDataset(
dest_uri,
storage_options=storage_options,
**base_store_params_kwargs,
)
dest_version = _dest.version
except Exception:
dest_version = None
elif mode == "append":
op = LanceOperation.Append(fragments)
LanceDataset.commit(
dest_version = _refresh_version(
dest_uri,
op,
read_version=dest_version,
storage_options=storage_options,
**base_store_params_kwargs,
storage_options,
base_store_params_kwargs,
fallback=dest_version,
)
first_commit_done = True
try:
_dest = lance.LanceDataset(
dest_uri,
storage_options=storage_options,
**base_store_params_kwargs,
)
dest_version = _dest.version
except Exception:
pass
else:
# Fallback: overwrite.
op = LanceOperation.Overwrite(
schema_obj,
fragments,
initial_bases=(
materialize_initial_bases(initial_bases)
if mode == "create"
else None
),
)
# create / overwrite, or a destination that does not yet exist.
LanceDataset.commit(
dest_uri,
op,
LanceOperation.Overwrite(
schema_obj,
fragments,
initial_bases=overwrite_initial_bases,
),
read_version=None,
storage_options=storage_options,
**base_store_params_kwargs,
)
first_commit_done = True
dest_exists = True
dest_version = _open_lance_version(
dest_uri, storage_options, base_store_params_kwargs
)
else:
# Subsequent commits always append.
op = LanceOperation.Append(fragments)
LanceDataset.commit(
dest_uri,
op,
LanceOperation.Append(fragments),
read_version=dest_version,
storage_options=storage_options,
**base_store_params_kwargs,
)
try:
_dest = lance.LanceDataset(
dest_uri,
storage_options=storage_options,
**base_store_params_kwargs,
)
dest_version = _dest.version
except Exception:
pass
dest_version = _refresh_version(
dest_uri,
storage_options,
base_store_params_kwargs,
fallback=dest_version,
)

rows_seen += tbl.num_rows

Expand Down
15 changes: 15 additions & 0 deletions tests/test_basic_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,21 @@ def test_write_lance_invalid_input(self, temp_dir):
with pytest.raises((ValueError, AttributeError, TypeError)):
lr.write_lance(None, str(path)) # type: ignore

def test_write_lance_stream_create_then_append(self, temp_dir):
"""Streaming create followed by streaming append over multiple batches."""
path = Path(temp_dir) / "stream_append.lance"

ds1 = ray.data.from_pandas(pd.DataFrame({"id": [1, 2, 3]}))
lr.write_lance(ds1, str(path), mode="create", stream=True, batch_size=2)
assert lance.dataset(str(path)).count_rows() == 3

ds2 = ray.data.from_pandas(pd.DataFrame({"id": [4, 5, 6]}))
lr.write_lance(ds2, str(path), mode="append", stream=True, batch_size=2)

result = lance.dataset(str(path))
assert result.count_rows() == 6
assert sorted(result.to_table()["id"].to_pylist()) == [1, 2, 3, 4, 5, 6]

def test_write_with_pandas_map_batches(self, temp_dir):
def map_fn(row):
return {
Expand Down
Loading