Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/firebase_functions/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ def _required_apis(self) -> list[_manifest.ManifestRequiredApi]:


@_dataclasses.dataclass(frozen=True, kw_only=True)
class StorageOptions(RuntimeOptions):
class StorageOptions(EventHandlerOptions):
"""
Options specific to Cloud Storage function types.
Internal use only.
Expand Down Expand Up @@ -937,19 +937,19 @@ def _endpoint(
}
event_trigger = _manifest.EventTrigger(
eventType=kwargs["event_type"],
retry=False,
retry=self.retry if self.retry is not None else False,
eventFilters=event_filters,
)

kwargs_merged = {
**_dataclasses.asdict(super()._endpoint(**kwargs)),
**_dataclasses.asdict(RuntimeOptions._endpoint(self, **kwargs)),
"eventTrigger": event_trigger,
}
return _manifest.ManifestEndpoint(**_typing.cast(dict, kwargs_merged))
Comment on lines 938 to 948

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since StorageOptions now inherits from EventHandlerOptions, you can simplify _endpoint by calling super()._endpoint directly. This avoids manually constructing the EventTrigger and merging dictionaries, leveraging the base class implementation as intended.

        return super()._endpoint(**kwargs, event_filters=event_filters)



@_dataclasses.dataclass(frozen=True, kw_only=True)
class DatabaseOptions(RuntimeOptions):
class DatabaseOptions(EventHandlerOptions):
"""
Options specific to Realtime Database function types.
Internal use only.
Expand Down Expand Up @@ -990,13 +990,13 @@ def _endpoint(

event_trigger = _manifest.EventTrigger(
eventType=kwargs["event_type"],
retry=False,
retry=self.retry if self.retry is not None else False,
eventFilters=event_filters,
eventFilterPathPatterns=event_filters_path_patterns,
)

kwargs_merged = {
**_dataclasses.asdict(super()._endpoint(**kwargs)),
**_dataclasses.asdict(RuntimeOptions._endpoint(self, **kwargs)),
"eventTrigger": event_trigger,
}
return _manifest.ManifestEndpoint(**_typing.cast(dict, kwargs_merged))
Expand Down Expand Up @@ -1055,7 +1055,7 @@ def _required_apis(self) -> list[_manifest.ManifestRequiredApi]:


@_dataclasses.dataclass(frozen=True, kw_only=True)
class FirestoreOptions(RuntimeOptions):
class FirestoreOptions(EventHandlerOptions):
"""
Options specific to Firestore function types.
Internal use only.
Expand Down Expand Up @@ -1097,13 +1097,13 @@ def _endpoint(
event_filters["document"] = event_filter_document
event_trigger = _manifest.EventTrigger(
eventType=kwargs["event_type"],
retry=False,
retry=self.retry if self.retry is not None else False,
eventFilters=event_filters,
eventFilterPathPatterns=event_filters_path_patterns,
)

kwargs_merged = {
**_dataclasses.asdict(super()._endpoint(**kwargs)),
**_dataclasses.asdict(RuntimeOptions._endpoint(self, **kwargs)),
"eventTrigger": event_trigger,
}
return _manifest.ManifestEndpoint(**_typing.cast(dict, kwargs_merged))
Expand Down
18 changes: 18 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ class TestDb(unittest.TestCase):
Tests for the db module.
"""

def test_database_decorator_retry_option(self):
func = mock.Mock(__name__="example_func")
decorated_func = db_fn.on_value_written(reference="/items/{itemId}", retry=True)(func)

endpoint = decorated_func.__firebase_endpoint__

self.assertIsNotNone(endpoint.eventTrigger)
self.assertTrue(endpoint.eventTrigger["retry"])

def test_database_decorator_retry_defaults_false(self):
func = mock.Mock(__name__="example_func")
decorated_func = db_fn.on_value_written(reference="/items/{itemId}")(func)

endpoint = decorated_func.__firebase_endpoint__

self.assertIsNotNone(endpoint.eventTrigger)
self.assertFalse(endpoint.eventTrigger["retry"])

def test_calls_init_function(self):
hello = None

Expand Down
27 changes: 27 additions & 0 deletions tests/test_firestore_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,33 @@ class TestFirestore(TestCase):
firestore_fn tests.
"""

def test_firestore_decorator_retry_option(self):
with patch.dict("sys.modules", mocked_modules):
from firebase_functions import firestore_fn

func = Mock(__name__="example_func")
decorated_func = firestore_fn.on_document_created(
document="/foo/{bar}",
retry=True,
)(func)

endpoint = decorated_func.__firebase_endpoint__

self.assertIsNotNone(endpoint.eventTrigger)
self.assertTrue(endpoint.eventTrigger["retry"])

def test_firestore_decorator_retry_defaults_false(self):
with patch.dict("sys.modules", mocked_modules):
from firebase_functions import firestore_fn

func = Mock(__name__="example_func")
decorated_func = firestore_fn.on_document_created(document="/foo/{bar}")(func)

endpoint = decorated_func.__firebase_endpoint__

self.assertIsNotNone(endpoint.eventTrigger)
self.assertFalse(endpoint.eventTrigger["retry"])

def test_firestore_endpoint_handler_calls_function_with_correct_args(self):
with patch.dict("sys.modules", mocked_modules):
from cloudevents.http import CloudEvent
Expand Down
18 changes: 18 additions & 0 deletions tests/test_storage_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@ class TestStorage(unittest.TestCase):
Storage function tests.
"""

def test_storage_decorator_retry_option(self):
func = Mock(__name__="example_func")
decorated_func = storage_fn.on_object_finalized(bucket="bucket", retry=True)(func)

endpoint = decorated_func.__firebase_endpoint__

self.assertIsNotNone(endpoint.eventTrigger)
self.assertTrue(endpoint.eventTrigger["retry"])

def test_storage_decorator_retry_defaults_false(self):
func = Mock(__name__="example_func")
decorated_func = storage_fn.on_object_finalized(bucket="bucket")(func)

endpoint = decorated_func.__firebase_endpoint__

self.assertIsNotNone(endpoint.eventTrigger)
self.assertFalse(endpoint.eventTrigger["retry"])

def test_calls_init(self):
hello = None

Expand Down
Loading