Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sentry.api.base import cell_silo_endpoint
from sentry.api.bases.organization_events import OrganizationEventsEndpointBase
from sentry.api.endpoints.organization_events_spans_performance import EventID, get_span_description
from sentry.api.serializers.rest_framework.project import ProjectField
from sentry.api.utils import handle_query_errors
from sentry.search.events.builder.discover import DiscoverQueryBuilder
from sentry.search.events.types import QueryBuilderConfig
Expand Down Expand Up @@ -35,7 +36,7 @@

class RootCauseAnalysisQuerySerializer(serializers.Serializer):
transaction = serializers.CharField(max_length=200)
project = serializers.IntegerField()
project = ProjectField(scope="project:read", id_allowed=True)
breakpoint = serializers.CharField()
per_page = serializers.IntegerField(min_value=1, max_value=MAX_LIMIT, default=DEFAULT_LIMIT)
span_score_threshold = serializers.IntegerField(
Expand Down Expand Up @@ -203,13 +204,16 @@ class OrganizationEventsRootCauseAnalysisEndpoint(OrganizationEventsEndpointBase
}

def get(self, request, organization):
serializer = RootCauseAnalysisQuerySerializer(data=request.GET)
serializer = RootCauseAnalysisQuerySerializer(
data=request.GET,
context={"access": request.access, "organization": organization},
)
if not serializer.is_valid():
return Response(serializer.errors, status=400)

validated = serializer.validated_data
transaction_name = validated["transaction"]
project_id = validated["project"]
project_id = validated["project"].id
regression_breakpoint = validated["breakpoint"]
limit = validated["per_page"]
span_score_threshold = validated["span_score_threshold"]
Expand Down
4 changes: 3 additions & 1 deletion src/sentry/replays/endpoints/organization_replay_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def get(self, request: Request, organization: Organization) -> Response[_ListRep
except NoProjects:
return Response({"data": []}, status=200)

result = ReplayValidator(data=request.GET)
query_params = self.get_query_params_with_project_slug_precedence(request)

result = ReplayValidator(data=query_params)
if not result.is_valid():
raise ParseError(result.errors)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def get(self, request: Request, organization: Organization) -> Response[ReplaySe
except NoProjects:
return Response({"data": []}, status=200)

result = ReplaySelectorValidator(data=request.GET)
query_params = self.get_query_params_with_project_slug_precedence(request)

result = ReplaySelectorValidator(data=query_params)
if not result.is_valid():
raise ParseError(result.errors)

Expand Down
10 changes: 6 additions & 4 deletions src/sentry/replays/validators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from rest_framework import serializers

from sentry.api.helpers.projects import ProjectIdOrSlugField

VALID_FIELD_SET = (
"activity",
"browser",
Expand Down Expand Up @@ -66,8 +68,8 @@ class ReplayValidator(serializers.Serializer):
)
project = serializers.ListField(
required=False,
help_text="The ID of the projects to filter by.",
child=serializers.IntegerField(),
help_text="A list of project IDs or slugs to filter by.",
child=ProjectIdOrSlugField(),
)
projectSlug = serializers.ListField(
required=False,
Expand Down Expand Up @@ -115,8 +117,8 @@ class ReplaySelectorValidator(serializers.Serializer):
)
project = serializers.ListField(
required=False,
help_text="The ID of the projects to filter by.",
child=serializers.IntegerField(),
help_text="A list of project IDs or slugs to filter by.",
child=ProjectIdOrSlugField(),
)
projectSlug = serializers.ListField(
required=False,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Any
from unittest.mock import MagicMock

from sentry.api.endpoints.organization_events_root_cause_analysis import (
RootCauseAnalysisQuerySerializer,
)
from sentry.testutils.cases import APITestCase


class RootCauseAnalysisQuerySerializerTest(APITestCase):
def setUp(self) -> None:
super().setUp()
self.project = self.create_project(organization=self.organization)
self.access = MagicMock()
self.access.has_any_project_scope.return_value = True

def _data(self, project: str) -> dict[str, str]:
return {
"transaction": "GET /api/0/issues/",
"project": project,
"breakpoint": "2024-01-01T00:00:00Z",
}

def _context(self) -> dict[str, Any]:
return {"access": self.access, "organization": self.organization}

def test_accepts_project_id(self) -> None:
serializer = RootCauseAnalysisQuerySerializer(
data=self._data(str(self.project.id)), context=self._context()
)

assert serializer.is_valid(), serializer.errors
assert serializer.validated_data["project"] == self.project

def test_accepts_project_slug(self) -> None:
serializer = RootCauseAnalysisQuerySerializer(
data=self._data(self.project.slug), context=self._context()
)

assert serializer.is_valid(), serializer.errors
assert serializer.validated_data["project"] == self.project

def test_rejects_project_id_from_another_organization(self) -> None:
other_project = self.create_project(organization=self.create_organization())
serializer = RootCauseAnalysisQuerySerializer(
data=self._data(str(other_project.id)), context=self._context()
)

assert not serializer.is_valid()
assert str(serializer.errors["project"][0]) == "Invalid project"

def test_rejects_project_id_without_scope(self) -> None:
self.access.has_any_project_scope.return_value = False
serializer = RootCauseAnalysisQuerySerializer(
data=self._data(str(self.project.id)), context=self._context()
)

assert not serializer.is_valid()
assert str(serializer.errors["project"][0]) == "Insufficient access to project"
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
mock_replay_tap,
mock_replay_viewed,
)
from sentry.replays.usecases.query import QueryResponse
from sentry.replays.validators import ReplaySelectorValidator, ReplayValidator
from sentry.testutils.cases import APITestCase, ReplaysSnubaTestCase
from sentry.utils.cursors import Cursor
from sentry.utils.snuba import QueryMemoryLimitExceeded
Expand All @@ -29,6 +31,29 @@ def setUp(self) -> None:
def features(self) -> dict[str, bool]:
return {"organizations:session-replay": True}

def test_replay_validators_accept_project_slugs(self) -> None:
replay_validator = ReplayValidator(data={"project": ["my-project"]})
selector_validator = ReplaySelectorValidator(data={"project": ["my-project"]})

assert replay_validator.is_valid(), replay_validator.errors
assert selector_validator.is_valid(), selector_validator.errors

def test_get_replays_empty_project_uses_project_slug_filter(self) -> None:
project = self.create_project(teams=[self.team], slug="replay-project")

with self.feature(self.features):
with mock.patch(
"sentry.replays.endpoints.organization_replay_index.query_replays_collection_paginated",
return_value=QueryResponse(response=[], has_more=False, source="mock"),
) as mock_query:
response = self.client.get(
self.url,
{"projectSlug": project.slug, "project": ""},
)

assert response.status_code == 200
assert mock_query.call_args.kwargs["project_ids"] == [project.id]

def test_feature_flag_disabled(self) -> None:
"""Test replays can be disabled."""
response = self.client.get(self.url)
Expand Down
Loading