Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 90 additions & 19 deletions snowflake_utils/models/table.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from collections import defaultdict
from functools import partial
from typing import ClassVar

from pydantic import BaseModel, Field
from snowflake.connector.cursor import SnowflakeCursor
Expand All @@ -26,6 +27,12 @@ class Table(BaseModel):
_file_format: FileFormat | None = None
_stage: str | None = None

# Max column tag actions per batched ALTER TABLE statement. Chunking keeps
# each statement well under Snowflake's ~1 MB statement-size limit even for
# very wide tables, while collapsing dozens of per-column statements into a
# handful of multi-action ones.
_MAX_TAG_ACTIONS_PER_STATEMENT: ClassVar[int] = 100

@property
def file_format(self) -> str:
if self._file_format:
Expand Down Expand Up @@ -531,12 +538,25 @@ def current_table_tags(self, cursor: SnowflakeCursor) -> dict[str, str]:
def sync_tags_table(self, cursor: SnowflakeCursor) -> None:
tags = self.current_table_tags(cursor=cursor)
desired_tags = {k.casefold(): v for k, v in self.table_structure.tags.items()}
for tag_name in desired_tags:
if tag_name not in tags:
self._set_table_tag(cursor, desired_tags, tag_name)
for tag_name in tags:
if tag_name not in desired_tags:
self._unset_table_tag(cursor, tag_name)

to_set = {
tag_name: desired_tags[tag_name]
for tag_name in desired_tags
if tag_name not in tags
}
to_unset = [tag_name for tag_name in tags if tag_name not in desired_tags]

if to_unset:
actions = ", ".join(
governance_settings.fqn(tag_name) for tag_name in to_unset
)
cursor.execute(f"ALTER TABLE {self.fqn} UNSET TAG {actions}")
if to_set:
actions = ", ".join(
f"{governance_settings.fqn(tag_name)} = '{tag_value}'"
for tag_name, tag_value in to_set.items()
)
cursor.execute(f"ALTER TABLE {self.fqn} SET TAG {actions}")

def _unset_table_tag(self, cursor, tag_name):
cursor.execute(
Expand Down Expand Up @@ -572,26 +592,77 @@ def sync_tags_columns(self, cursor: SnowflakeCursor) -> None:
].tags.items()
}

for tag in existing_tags:
if tag not in desired_tags:
column, tag_name, _value = existing_tags[tag]
self._unset_column_tag(cursor, column, tag_name)
to_unset = [
(existing_tags[tag][0], existing_tags[tag][1])
for tag in existing_tags
if tag not in desired_tags
]
to_set = [desired_tags[tag] for tag in desired_tags if tag not in existing_tags]

self._unset_column_tags_batch(cursor, to_unset)
self._set_column_tags_batch(cursor, to_set)

def _set_column_tags_batch(
self,
cursor: SnowflakeCursor,
changes: list[tuple[str, str, str]],
) -> None:
"""Apply many column SET TAG actions in as few ALTER TABLE statements as possible.

``changes`` is a list of ``(column, tag_name, tag_value)`` tuples.
"""
actions = [
f""""{column.upper()}" SET TAG {governance_settings.fqn(tag_name)} = '{tag_value}'"""
for column, tag_name, tag_value in changes
]
self._execute_column_tag_actions(cursor, actions)

def _unset_column_tags_batch(
self,
cursor: SnowflakeCursor,
changes: list[tuple[str, str]],
) -> None:
"""Emit one ``ALTER TABLE ... MODIFY COLUMN`` UNSET statement per column.

``changes`` is a list of ``(column, tag_name)`` tuples. Snowflake rejects
multi-column UNSET TAG, so actions cannot be batched across columns the way
SET TAG can. Tags belonging to the same column are combined into that
column's single statement.
"""
tags_by_column: dict[str, list[str]] = {}
for column, tag_name in changes:
tags_by_column.setdefault(column.upper(), []).append(
governance_settings.fqn(tag_name)
)

for column, tags in tags_by_column.items():
cursor.execute(
f'ALTER TABLE {self.fqn} MODIFY COLUMN "{column}" '
f"UNSET TAG {', '.join(tags)}"
)

def _execute_column_tag_actions(
self, cursor: SnowflakeCursor, actions: list[str]
) -> None:
"""Emit batched ``ALTER TABLE ... <action>, <action>, ...`` statements.

for tag in desired_tags:
if tag not in existing_tags:
self._set_column_tag(cursor, *desired_tags[tag])
Used for SET TAG actions, which Snowflake allows to span multiple columns
in one statement. No-op when ``actions`` is empty. Actions are chunked so
no single statement exceeds ``_MAX_TAG_ACTIONS_PER_STATEMENT``, keeping each
statement well under Snowflake's statement-size limit even for very
wide tables.
"""
for start in range(0, len(actions), self._MAX_TAG_ACTIONS_PER_STATEMENT):
chunk = actions[start : start + self._MAX_TAG_ACTIONS_PER_STATEMENT]
cursor.execute(f"ALTER TABLE {self.fqn} MODIFY COLUMN {', '.join(chunk)}")

def _set_column_tag(
self, cursor: SnowflakeCursor, column: str, tag_name: str, tag_value: str
) -> None:
cursor.execute(
f"""ALTER TABLE {self.fqn} MODIFY COLUMN "{column.upper()}" SET TAG {governance_settings.fqn(tag_name)} = '{tag_value}'"""
)
self._set_column_tags_batch(cursor, [(column, tag_name, tag_value)])

def _unset_column_tag(self, cursor: SnowflakeCursor, column: str, tag: str) -> None:
cursor.execute(
f'ALTER TABLE {self.fqn} MODIFY COLUMN "{column.upper()}" UNSET TAG {governance_settings.fqn(tag)}'
)
self._unset_column_tags_batch(cursor, [(column, tag)])

def copy_custom(
self,
Expand Down
152 changes: 152 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,3 +1238,155 @@ def test_merge_statement_without_database():
assert "merge into PUBLIC.MAIN as dest" in result
assert "using PUBLIC.TEMP tmp" in result
assert 'ON dest."ID" = tmp."ID"' in result


def _alter_calls(mock_cursor):
return [
call.args[0]
for call in mock_cursor.execute.call_args_list
if call.args and call.args[0].lstrip().upper().startswith("ALTER TABLE")
]


def _make_tag_table(columns: dict[str, dict[str, str]], table_tags=None) -> Table:
return Table(
name="PYTEST",
schema_name="PUBLIC",
table_structure=TableStructure(
columns={
col: Column(name=col, data_type="text", tags=tags)
for col, tags in columns.items()
},
tags=table_tags or {},
),
)


def test_sync_tags_columns_batches_sets_into_single_alter():
"""Multiple column SET TAG changes collapse into one ALTER TABLE."""
table = _make_tag_table(
{
"id": {"pii": "personal"},
"name": {"pii": "personal"},
"email": {"pii": "contact"},
}
)
cursor = make_mock_cursor(fetchall_return=[]) # no existing tags
table.sync_tags_columns(cursor)

alters = _alter_calls(cursor)
assert len(alters) == 1, alters
statement = alters[0]
assert statement.count("SET TAG") == 3
# MODIFY COLUMN appears exactly once, then comma-separated bare column clauses
assert statement.count("MODIFY COLUMN") == 1
assert "MODIFY COLUMN " in statement
assert "\"ID\" SET TAG governance.public.pii = 'personal'" in statement
assert "\"EMAIL\" SET TAG governance.public.pii = 'contact'" in statement


def test_sync_tags_columns_separates_set_and_unset():
"""SET and UNSET go into separate ALTER statements, one each."""
table = _make_tag_table({"id": {"pii": "personal"}}) # desired: id only
cursor = make_mock_cursor(
fetchall_return=[("name", "pii", "personal")] # existing: name only -> unset
)
table.sync_tags_columns(cursor)

alters = _alter_calls(cursor)
assert len(alters) == 2, alters
unset_stmt = next(s for s in alters if "UNSET TAG" in s)
set_stmt = next(s for s in alters if "SET TAG" in s and "UNSET" not in s)
assert unset_stmt.count("MODIFY COLUMN") == 1
assert 'MODIFY COLUMN "NAME" UNSET TAG governance.public.pii' in unset_stmt
assert set_stmt.count("MODIFY COLUMN") == 1
assert "MODIFY COLUMN \"ID\" SET TAG governance.public.pii = 'personal'" in set_stmt


def test_sync_tags_columns_noop_when_no_changes():
"""No ALTER statements emitted when desired == existing."""
table = _make_tag_table({"id": {"pii": "personal"}})
cursor = make_mock_cursor(fetchall_return=[("id", "pii", "personal")])
table.sync_tags_columns(cursor)
assert _alter_calls(cursor) == []


def test_sync_tags_columns_chunks_wide_tables(monkeypatch):
"""Action count beyond the per-statement cap splits into multiple ALTERs."""
monkeypatch.setattr(Table, "_MAX_TAG_ACTIONS_PER_STATEMENT", 2)
table = _make_tag_table(
{
"c1": {"pii": "personal"},
"c2": {"pii": "personal"},
"c3": {"pii": "personal"},
"c4": {"pii": "personal"},
"c5": {"pii": "personal"},
}
)
cursor = make_mock_cursor(fetchall_return=[])
table.sync_tags_columns(cursor)

alters = _alter_calls(cursor)
assert len(alters) == 3 # 5 actions / cap 2 -> 2 + 2 + 1
assert sum(s.count("SET TAG") for s in alters) == 5


def test_sync_tags_columns_unset_one_statement_per_column():
"""Multi-column UNSET emits one ALTER per column (Snowflake rejects multi-column UNSET)."""
table = _make_tag_table({"id": {}, "name": {}, "email": {}}) # no desired tags
cursor = make_mock_cursor(
fetchall_return=[
("id", "pii", "personal"),
("name", "pii", "personal"),
("email", "pii", "contact"),
]
)
table.sync_tags_columns(cursor)

alters = _alter_calls(cursor)
assert len(alters) == 3, alters
assert all("UNSET TAG" in s for s in alters)
assert all(s.count("MODIFY COLUMN") == 1 for s in alters)
assert any(
'MODIFY COLUMN "ID" UNSET TAG governance.public.pii' in s for s in alters
)
assert any(
'MODIFY COLUMN "NAME" UNSET TAG governance.public.pii' in s for s in alters
)
assert any(
'MODIFY COLUMN "EMAIL" UNSET TAG governance.public.pii' in s for s in alters
)


def test_sync_tags_columns_unset_combines_multiple_tags_per_column():
"""Multiple tags on one column are combined into that column's single UNSET statement."""
table = _make_tag_table({"id": {}}) # no desired tags
cursor = make_mock_cursor(
fetchall_return=[
("id", "pii", "personal"),
("id", "domain", "finance"),
]
)
table.sync_tags_columns(cursor)

alters = _alter_calls(cursor)
assert len(alters) == 1, alters
stmt = alters[0]
assert stmt.count("MODIFY COLUMN") == 1
assert stmt.count("UNSET TAG") == 1 # single keyword, comma-separated tags
assert 'MODIFY COLUMN "ID" UNSET TAG ' in stmt
assert "governance.public.pii" in stmt
assert "governance.public.domain" in stmt


def test_sync_tags_table_batches_multiple_tags():
"""Multiple table-level SET TAG changes collapse into one ALTER TABLE."""
table = _make_tag_table({}, table_tags={"pii": "foo", "domain": "finance"})
cursor = make_mock_cursor(fetchall_return=[]) # no existing table tags
table.sync_tags_table(cursor)

alters = _alter_calls(cursor)
assert len(alters) == 1, alters
assert alters[0].count("SET TAG") == 1 # single SET TAG keyword, comma-separated
assert "governance.public.pii = 'foo'" in alters[0]
assert "governance.public.domain = 'finance'" in alters[0]
Loading