Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions pkg/frontend/authenticate2.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,37 +331,70 @@ func canWriteProtectedDatabase(ses *Session) bool {
}

func normalizeProtectedDatabaseName(ses *Session, dbName string) string {
dbName = strings.TrimSpace(dbName)
if protectedDatabaseNamesAreLowerCased(ses) {
dbName = strings.ToLower(dbName)
}
return dbName
}

func resolveProtectedDatabaseRuntimeTarget(ses *Session, dbName string) string {
dbName = strings.TrimSpace(dbName)
if dbName == "" && ses != nil {
dbName = ses.GetDatabaseName()
}
if protectedDatabaseNamesAreLowerCased(ses) {
dbName = strings.ToLower(dbName)
}
return dbName
}

func getProtectedDatabaseSet(ses *Session) map[string]struct{} {
func protectedDatabaseNamesAreLowerCased(ses *Session) bool {
if ses == nil {
return nil
return false
}
value, err := ses.GetGlobalSysVar(ProtectedDatabases)
value, err := ses.GetSessionSysVar("lower_case_table_names")
if err != nil {
return nil
return true
}
raw, ok := value.(string)
if !ok || strings.TrimSpace(raw) == "" {
lowerCaseTableNames, ok := value.(int64)
return ok && lowerCaseTableNames != 0
}

func protectedDatabaseSetFromString(ses *Session, raw string) map[string]struct{} {
if strings.TrimSpace(raw) == "" {
return nil
}
protected := make(map[string]struct{})
for _, part := range strings.Split(raw, ",") {
dbName := strings.TrimSpace(part)
dbName := normalizeProtectedDatabaseName(ses, part)
if dbName != "" {
protected[dbName] = struct{}{}
}
}
if len(protected) == 0 {
return nil
}
return protected
}

func getProtectedDatabaseSet(ses *Session) map[string]struct{} {
if ses == nil {
return nil
}
value, err := ses.GetGlobalSysVar(ProtectedDatabases)
if err != nil {
return nil
}
raw, ok := value.(string)
if !ok || strings.TrimSpace(raw) == "" {
return nil
}
return protectedDatabaseSetFromString(ses, raw)
}

func isProtectedDatabase(ses *Session, dbName string) bool {
dbName = normalizeProtectedDatabaseName(ses, dbName)
dbName = resolveProtectedDatabaseRuntimeTarget(ses, dbName)
if dbName == "" {
return false
}
Expand Down Expand Up @@ -393,7 +426,7 @@ func checkProtectedDatabaseWriteWithSet(ctx context.Context, ses *Session, prote
return true
}
for _, dbName := range dbNames {
dbName = normalizeProtectedDatabaseName(ses, dbName)
dbName = resolveProtectedDatabaseRuntimeTarget(ses, dbName)
if dbName == "" {
continue
}
Expand Down
97 changes: 89 additions & 8 deletions pkg/frontend/protected_database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"context"
"testing"

"github.com/matrixorigin/matrixone/pkg/sql/parsers"
"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect"
"github.com/matrixorigin/matrixone/pkg/sql/parsers/tree"
"github.com/stretchr/testify/require"
)
Expand All @@ -30,6 +32,37 @@ func protectedTargetsFromStatement(stmt tree.Statement) []string {
return determinePrivilegeSetOfStatement(stmt).writeDatabaseTargets
}

func newProtectedDatabaseTestSessionWithCurrentDB(dbName string) *Session {
proto := &MysqlProtocolImpl{}
ses := &Session{feSessionImpl: feSessionImpl{respr: NewMysqlResp(proto)}}
ses.respr.SetStr(DBNAME, dbName)
return ses
}

func TestProtectedDatabaseSetFromString(t *testing.T) {
require.Nil(t, protectedDatabaseSetFromString(nil, ""))
require.Nil(t, protectedDatabaseSetFromString(nil, " , , "))
require.Equal(t, map[string]struct{}{"db1": {}, "CamelDB": {}}, protectedDatabaseSetFromString(nil, " db1, CamelDB "))
require.Equal(t, map[string]struct{}{"db1": {}, "cameldb": {}}, protectedDatabaseSetFromString(&Session{}, " db1, CamelDB "))

ses := newProtectedDatabaseTestSessionWithCurrentDB("current_db")
require.Equal(t, "", normalizeProtectedDatabaseName(ses, ""))
require.Nil(t, protectedDatabaseSetFromString(ses, ","))
require.Equal(t, map[string]struct{}{"db1": {}}, protectedDatabaseSetFromString(ses, "db1,"))
}

func TestProtectedDatabaseWriteTargetsFromDataBranchKeepsPrivileges(t *testing.T) {
tableStmt := &tree.DataBranchCreateTable{}
tableStmt.CreateTable.Table = testTableName("dst_db", "dst_tbl")
tablePriv := determinePrivilegeSetOfStatement(tableStmt)
require.Equal(t, []string{"dst_db"}, tablePriv.writeDatabaseTargets)

databaseStmt := tree.NewDataBranchCreateDatabase()
databaseStmt.DstDatabase = tree.Identifier("dst_db")
databasePriv := determinePrivilegeSetOfStatement(databaseStmt)
require.Equal(t, []string{"dst_db"}, databasePriv.writeDatabaseTargets)
}

func TestProtectedDatabaseWriteTargetsFromClone(t *testing.T) {
cloneTable := tree.NewCloneTable()
cloneTable.CreateTable.Table = testTableName("dst_db", "dst_tbl")
Expand Down Expand Up @@ -78,23 +111,71 @@ func TestPrivilegeTipWritesDatabase(t *testing.T) {
require.True(t, privilegeTipWritesDatabase(privilegeTips{typ: PrivilegeTypeDelete}))
}

func TestProtectedDatabaseNameKeepsOriginalCase(t *testing.T) {
protectedDatabases := map[string]struct{}{"CamelDB": {}}
func TestProtectedDatabaseNameFollowsLowerCaseTableNames(t *testing.T) {
ses := &Session{}
protectedDatabases := protectedDatabaseSetFromString(ses, "CamelDB")

require.False(t, checkProtectedDatabaseWriteWithSet(context.Background(), ses, protectedDatabases, "CamelDB"))
require.False(t, checkProtectedDatabaseWriteWithSet(context.Background(), ses, protectedDatabases, "cameldb"))

require.False(t, checkProtectedDatabaseWriteWithSet(context.Background(), nil, protectedDatabases, "CamelDB"))
require.True(t, checkProtectedDatabaseWriteWithSet(context.Background(), nil, protectedDatabases, "camedb"))
caseSensitiveSes := &Session{
feSessionImpl: feSessionImpl{
sesSysVars: &SystemVariables{mp: map[string]interface{}{"lower_case_table_names": int64(0)}},
},
}
caseSensitiveProtectedDatabases := protectedDatabaseSetFromString(caseSensitiveSes, "CamelDB")
require.False(t, checkProtectedDatabaseWriteWithSet(context.Background(), caseSensitiveSes, caseSensitiveProtectedDatabases, "CamelDB"))
require.True(t, checkProtectedDatabaseWriteWithSet(context.Background(), caseSensitiveSes, caseSensitiveProtectedDatabases, "cameldb"))

caseInsensitivePreserveNameSes := &Session{
feSessionImpl: feSessionImpl{
sesSysVars: &SystemVariables{mp: map[string]interface{}{"lower_case_table_names": int64(2)}},
},
}
caseInsensitivePreserveNameProtectedDatabases := protectedDatabaseSetFromString(caseInsensitivePreserveNameSes, "CamelDB")
require.False(t, checkProtectedDatabaseWriteWithSet(context.Background(), caseInsensitivePreserveNameSes, caseInsensitivePreserveNameProtectedDatabases, "CamelDB"))
require.False(t, checkProtectedDatabaseWriteWithSet(context.Background(), caseInsensitivePreserveNameSes, caseInsensitivePreserveNameProtectedDatabases, "cameldb"))
}

func TestCheckProtectedDatabaseWriteWithSet(t *testing.T) {
ctx := context.Background()
protectedDatabases := map[string]struct{}{"protected_db": {}, "CamelDB": {}}
ses := &Session{}
protectedDatabases := protectedDatabaseSetFromString(ses, "protected_db,CamelDB")

require.True(t, checkProtectedDatabaseWriteWithSet(ctx, nil, nil, "protected_db"))
require.True(t, checkProtectedDatabaseWriteWithSet(ctx, nil, protectedDatabases))
require.True(t, checkProtectedDatabaseWriteWithSet(ctx, nil, protectedDatabases, "normal_db"))
require.False(t, checkProtectedDatabaseWriteWithSet(ctx, nil, protectedDatabases, "normal_db", "protected_db"))
require.False(t, checkProtectedDatabaseWriteWithSet(ctx, nil, protectedDatabases, " CamelDB "))
require.True(t, checkProtectedDatabaseWriteWithSet(ctx, nil, protectedDatabases, "cameldb"))
require.False(t, checkProtectedDatabaseWriteWithSet(ctx, ses, protectedDatabases, "normal_db", "protected_db"))
require.False(t, checkProtectedDatabaseWriteWithSet(ctx, ses, protectedDatabases, " CamelDB "))
require.False(t, checkProtectedDatabaseWriteWithSet(ctx, ses, protectedDatabases, "cameldb"))
}

func TestCheckProtectedDatabaseWriteWithSetUsesCurrentDatabaseForEmptyTarget(t *testing.T) {
ctx := context.Background()
ses := newProtectedDatabaseTestSessionWithCurrentDB("protected_db")
protectedDatabases := protectedDatabaseSetFromString(ses, "protected_db")

require.False(t, checkProtectedDatabaseWriteWithSet(ctx, ses, protectedDatabases, ""))
}

func TestUnqualifiedProtectedDatabaseWritesUseCurrentDatabase(t *testing.T) {
ctx := context.Background()
ses := newProtectedDatabaseTestSessionWithCurrentDB("protected_db")
protectedDatabases := protectedDatabaseSetFromString(ses, "protected_db")

testCases := []string{
"create table t (a int)",
"drop table t",
"create table t as select 1",
}

for _, sql := range testCases {
t.Run(sql, func(t *testing.T) {
stmt, err := parsers.ParseOne(ctx, dialect.MYSQL, sql, 1)
require.NoError(t, err)
require.False(t, checkProtectedDatabaseWriteWithSet(ctx, ses, protectedDatabases, protectedTargetsFromStatement(stmt)...))
})
}
}

func TestCanWriteProtectedDatabase(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions pkg/frontend/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -1370,9 +1370,9 @@ func (ses *Session) SetGlobalSysVar(ctx context.Context, name string, val interf
}

if name == ProtectedDatabases {
if newValue, ok := val.(string); ok && strings.TrimSpace(newValue) == "" {
if newValue, ok := val.(string); ok && len(protectedDatabaseSetFromString(ses, newValue)) == 0 {
oldValue, _ := ses.GetGlobalSysVar(name)
if oldString, ok := oldValue.(string); ok && strings.TrimSpace(oldString) != "" {
if oldString, ok := oldValue.(string); ok && len(protectedDatabaseSetFromString(ses, oldString)) != 0 {
return moerr.NewInternalErrorNoCtx("protected_databases cannot be cleared directly")
}
}
Expand Down
Loading
Loading