diff --git a/query/delete.go b/query/delete.go new file mode 100644 index 0000000..a11a211 --- /dev/null +++ b/query/delete.go @@ -0,0 +1,42 @@ +package query + +import "strings" + +// DeleteQuery represents an immutable DELETE query builder. +// Each method returns a new instance without modifying the original. +type DeleteQuery struct { + table string + where whereClauses +} + +// Delete creates a new [DeleteQuery] for the given table. +func Delete(table string) DeleteQuery { + return DeleteQuery{table: table} +} + +// Where appends an AND WHERE condition. +func (query DeleteQuery) Where(sql string, args ...any) DeleteQuery { + query.where = query.where.add("AND", sql, args) + + return query +} + +// OrWhere appends an OR WHERE condition. +func (query DeleteQuery) OrWhere(sql string, args ...any) DeleteQuery { + query.where = query.where.add("OR", sql, args) + + return query +} + +// Build produces the SQL query string and its positional arguments. +func (query DeleteQuery) Build() (string, []any) { + var builder strings.Builder + var args []any + + builder.WriteString("DELETE FROM ") + builder.WriteString(query.table) + + query.where.render(&builder, &args) + + return builder.String(), args +} diff --git a/query/delete_test.go b/query/delete_test.go new file mode 100644 index 0000000..563ff23 --- /dev/null +++ b/query/delete_test.go @@ -0,0 +1,75 @@ +package query_test + +import ( + "testing" + + "github.com/studiolambda/cosmos/query" +) + +func TestDeleteBasic(t *testing.T) { + t.Parallel() + + sql, args := query.Delete("users"). + Where("id = ?", 1). + Build() + + if sql != "DELETE FROM users WHERE id = ?" { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 1 || args[0] != 1 { + t.Errorf("unexpected args: %v", args) + } +} + +func TestDeleteOrWhere(t *testing.T) { + t.Parallel() + + sql, args := query.Delete("sessions"). + Where("expired = ?", true). + OrWhere("revoked = ?", true). + Build() + + expected := "DELETE FROM sessions WHERE expired = ? OR revoked = ?" + + if sql != expected { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 2 || args[0] != true || args[1] != true { + t.Errorf("unexpected args: %v", args) + } +} + +func TestDeleteNoWhere(t *testing.T) { + t.Parallel() + + sql, args := query.Delete("logs").Build() + + if sql != "DELETE FROM logs" { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 0 { + t.Errorf("unexpected args: %v", args) + } +} + +func TestDeleteImmutability(t *testing.T) { + t.Parallel() + + base := query.Delete("users") + first := base.Where("id = ?", 1) + second := base.Where("id = ?", 2) + + _, argsFirst := first.Build() + _, argsSecond := second.Build() + + if argsFirst[0] != 1 { + t.Errorf("unexpected first args: %v", argsFirst) + } + + if argsSecond[0] != 2 { + t.Errorf("unexpected second args: %v", argsSecond) + } +} diff --git a/query/doc.go b/query/doc.go new file mode 100644 index 0000000..d92594c --- /dev/null +++ b/query/doc.go @@ -0,0 +1,17 @@ +// Package query provides an immutable, copy-on-write SQL query builder +// that produces parameterized query strings with positional placeholders (?). +// +// The builder supports SELECT, INSERT, UPDATE, and DELETE statements +// with support for raw SQL expressions via [Raw] and [Expr]. +// +// Each method returns a new value without mutating the original, +// making it safe to derive queries from shared base instances: +// +// base := query.Select("users").Columns("id", "name").Where("active = ?", true) +// admins := base.Where("role = ?", "admin") +// editors := base.Where("role = ?", "editor") +// +// The produced queries use ? as the placeholder character. Use your +// database driver's rebind functionality if your database requires +// a different placeholder style (e.g., $1 for PostgreSQL). +package query diff --git a/query/go.mod b/query/go.mod new file mode 100644 index 0000000..f3999f9 --- /dev/null +++ b/query/go.mod @@ -0,0 +1,3 @@ +module github.com/studiolambda/cosmos/query + +go 1.25.0 diff --git a/query/insert.go b/query/insert.go new file mode 100644 index 0000000..0aadcf4 --- /dev/null +++ b/query/insert.go @@ -0,0 +1,65 @@ +package query + +import ( + "slices" + "strings" +) + +// insertPair holds a column name and its associated value. +type insertPair struct { + column string + value any +} + +// InsertQuery represents an immutable INSERT query builder. +// Each method returns a new instance without modifying the original. +type InsertQuery struct { + table string + pairs []insertPair +} + +// Insert creates a new [InsertQuery] for the given table. +func Insert(table string) InsertQuery { + return InsertQuery{table: table} +} + +// Set adds a column-value pair to the INSERT statement. The value can +// be a plain value (emits ?), [Raw] (emits literal SQL), or [Expr] +// (emits SQL with embedded placeholders). +func (query InsertQuery) Set(column string, value any) InsertQuery { + query.pairs = append(slices.Clone(query.pairs), insertPair{column: column, value: value}) + + return query +} + +// Build produces the SQL query string and its positional arguments. +func (query InsertQuery) Build() (string, []any) { + var builder strings.Builder + var args []any + + builder.WriteString("INSERT INTO ") + builder.WriteString(query.table) + builder.WriteString(" (") + + for i, pair := range query.pairs { + if i > 0 { + builder.WriteString(", ") + } + + builder.WriteString(pair.column) + } + + builder.WriteString(") VALUES (") + + for i, pair := range query.pairs { + if i > 0 { + builder.WriteString(", ") + } + + renderValue(&builder, &args, pair.value) + } + + builder.WriteByte(')') + + return builder.String(), args +} diff --git a/query/insert_test.go b/query/insert_test.go new file mode 100644 index 0000000..db88229 --- /dev/null +++ b/query/insert_test.go @@ -0,0 +1,77 @@ +package query_test + +import ( + "testing" + + "github.com/studiolambda/cosmos/query" +) + +func TestInsertBasic(t *testing.T) { + t.Parallel() + + sql, args := query.Insert("users"). + Set("name", "Erik"). + Set("email", "erik@example.com"). + Build() + + if sql != "INSERT INTO users (name, email) VALUES (?, ?)" { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 2 || args[0] != "Erik" || args[1] != "erik@example.com" { + t.Errorf("unexpected args: %v", args) + } +} + +func TestInsertRaw(t *testing.T) { + t.Parallel() + + sql, args := query.Insert("users"). + Set("name", "Erik"). + Set("created_at", query.Raw("NOW()")). + Build() + + if sql != "INSERT INTO users (name, created_at) VALUES (?, NOW())" { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 1 || args[0] != "Erik" { + t.Errorf("unexpected args: %v", args) + } +} + +func TestInsertExpr(t *testing.T) { + t.Parallel() + + sql, args := query.Insert("events"). + Set("name", "deploy"). + Set("scheduled_at", query.Expr{SQL: "NOW() + INTERVAL ? HOUR", Args: []any{2}}). + Build() + + if sql != "INSERT INTO events (name, scheduled_at) VALUES (?, NOW() + INTERVAL ? HOUR)" { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 2 || args[0] != "deploy" || args[1] != 2 { + t.Errorf("unexpected args: %v", args) + } +} + +func TestInsertImmutability(t *testing.T) { + t.Parallel() + + base := query.Insert("users").Set("role", "user") + admin := base.Set("name", "Admin") + editor := base.Set("name", "Editor") + + sqlAdmin, argsAdmin := admin.Build() + sqlEditor, argsEditor := editor.Build() + + if argsAdmin[1] != "Admin" { + t.Errorf("unexpected admin args: %v (sql: %s)", argsAdmin, sqlAdmin) + } + + if argsEditor[1] != "Editor" { + t.Errorf("unexpected editor args: %v (sql: %s)", argsEditor, sqlEditor) + } +} diff --git a/query/raw.go b/query/raw.go new file mode 100644 index 0000000..38ea5f8 --- /dev/null +++ b/query/raw.go @@ -0,0 +1,49 @@ +package query + +import "strings" + +// Raw represents a raw SQL fragment that is embedded directly into +// the query without a placeholder. Use this for SQL functions and +// expressions like NOW(), COUNT(*), etc. +type Raw string + +// Expr represents a raw SQL fragment with embedded placeholders +// and their corresponding arguments. Use this for expressions that +// require parameterized values within raw SQL. +type Expr struct { + SQL string + Args []any +} + +// renderValue writes a single value to the builder. If the value is +// a [Raw], its SQL is written directly. If it is an [Expr], its SQL +// is written and its args are appended. Otherwise, a ? placeholder +// is written and the value is appended to args. +func renderValue(builder *strings.Builder, args *[]any, value any) { + switch v := value.(type) { + case Raw: + builder.WriteString(string(v)) + case Expr: + builder.WriteString(v.SQL) + *args = append(*args, v.Args...) + default: + builder.WriteByte('?') + *args = append(*args, value) + } +} + +// renderColumn writes a column reference to the builder. If the value +// is a [Raw], its SQL is written directly. If it is an [Expr], its SQL +// is written and its args are appended. Otherwise, the value is written +// as a plain string column name. +func renderColumn(builder *strings.Builder, args *[]any, column any) { + switch v := column.(type) { + case Raw: + builder.WriteString(string(v)) + case Expr: + builder.WriteString(v.SQL) + *args = append(*args, v.Args...) + case string: + builder.WriteString(v) + } +} diff --git a/query/select.go b/query/select.go new file mode 100644 index 0000000..22c9a6c --- /dev/null +++ b/query/select.go @@ -0,0 +1,145 @@ +package query + +import ( + "slices" + "strconv" + "strings" +) + +// SelectQuery represents an immutable SELECT query builder. +// Each method returns a new instance without modifying the original. +type SelectQuery struct { + table string + columns []any + joins []string + where whereClauses + groupBy []string + orderBy []string + limit int + offset int +} + +// Select creates a new [SelectQuery] for the given table. +func Select(table string) SelectQuery { + return SelectQuery{table: table} +} + +// Columns sets the columns to select. Values can be strings or [Raw]/[Expr] +// for raw SQL expressions. +func (query SelectQuery) Columns(columns ...any) SelectQuery { + query.columns = slices.Clone(columns) + + return query +} + +// Column appends a single column to the select list. The value can be +// a string or [Raw]/[Expr] for raw SQL expressions. +func (query SelectQuery) Column(column any) SelectQuery { + query.columns = append(slices.Clone(query.columns), column) + + return query +} + +// Join appends a JOIN clause to the query. The clause should be a +// complete join expression (e.g., "posts ON posts.user_id = users.id"). +func (query SelectQuery) Join(join string) SelectQuery { + query.joins = append(slices.Clone(query.joins), join) + + return query +} + +// Where appends an AND WHERE condition. +func (query SelectQuery) Where(sql string, args ...any) SelectQuery { + query.where = query.where.add("AND", sql, args) + + return query +} + +// OrWhere appends an OR WHERE condition. +func (query SelectQuery) OrWhere(sql string, args ...any) SelectQuery { + query.where = query.where.add("OR", sql, args) + + return query +} + +// GroupBy sets the GROUP BY columns. +func (query SelectQuery) GroupBy(columns ...string) SelectQuery { + query.groupBy = slices.Clone(columns) + + return query +} + +// OrderBy sets the ORDER BY clauses. +func (query SelectQuery) OrderBy(clauses ...string) SelectQuery { + query.orderBy = slices.Clone(clauses) + + return query +} + +// Limit sets the maximum number of rows to return. +// A value of 0 means no limit. +func (query SelectQuery) Limit(limit int) SelectQuery { + query.limit = limit + + return query +} + +// Offset sets the number of rows to skip. +// A value of 0 means no offset. +func (query SelectQuery) Offset(offset int) SelectQuery { + query.offset = offset + + return query +} + +// Build produces the SQL query string and its positional arguments. +func (query SelectQuery) Build() (string, []any) { + var builder strings.Builder + var args []any + + builder.WriteString("SELECT ") + + if len(query.columns) == 0 { + builder.WriteByte('*') + } else { + for i, col := range query.columns { + if i > 0 { + builder.WriteString(", ") + } + + renderColumn(&builder, &args, col) + } + } + + builder.WriteString(" FROM ") + builder.WriteString(query.table) + + for _, join := range query.joins { + builder.WriteString(" JOIN ") + builder.WriteString(join) + } + + query.where.render(&builder, &args) + + if len(query.groupBy) > 0 { + builder.WriteString(" GROUP BY ") + builder.WriteString(strings.Join(query.groupBy, ", ")) + } + + if len(query.orderBy) > 0 { + builder.WriteString(" ORDER BY ") + builder.WriteString(strings.Join(query.orderBy, ", ")) + } + + if query.limit > 0 { + builder.WriteString(" LIMIT ") + builder.WriteString(strconv.Itoa(query.limit)) + } + + if query.offset > 0 { + builder.WriteString(" OFFSET ") + builder.WriteString(strconv.Itoa(query.offset)) + } + + return builder.String(), args +} diff --git a/query/select_test.go b/query/select_test.go new file mode 100644 index 0000000..3ed3140 --- /dev/null +++ b/query/select_test.go @@ -0,0 +1,224 @@ +package query_test + +import ( + "testing" + + "github.com/studiolambda/cosmos/query" +) + +func TestSelectAll(t *testing.T) { + t.Parallel() + + sql, args := query.Select("users").Build() + + if sql != "SELECT * FROM users" { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 0 { + t.Errorf("unexpected args: %v", args) + } +} + +func TestSelectColumns(t *testing.T) { + t.Parallel() + + sql, args := query.Select("users"). + Columns("id", "name", "email"). + Build() + + if sql != "SELECT id, name, email FROM users" { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 0 { + t.Errorf("unexpected args: %v", args) + } +} + +func TestSelectColumnAppend(t *testing.T) { + t.Parallel() + + sql, _ := query.Select("users"). + Columns("id"). + Column("name"). + Build() + + if sql != "SELECT id, name FROM users" { + t.Errorf("unexpected sql: %s", sql) + } +} + +func TestSelectColumnRaw(t *testing.T) { + t.Parallel() + + sql, args := query.Select("orders"). + Columns("user_id", query.Raw("COUNT(*) AS total")). + Build() + + if sql != "SELECT user_id, COUNT(*) AS total FROM orders" { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 0 { + t.Errorf("unexpected args: %v", args) + } +} + +func TestSelectColumnExpr(t *testing.T) { + t.Parallel() + + sql, args := query.Select("orders"). + Column(query.Expr{SQL: "DATE_TRUNC(?, created_at) AS period", Args: []any{"month"}}). + Build() + + if sql != "SELECT DATE_TRUNC(?, created_at) AS period FROM orders" { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 1 || args[0] != "month" { + t.Errorf("unexpected args: %v", args) + } +} + +func TestSelectWhere(t *testing.T) { + t.Parallel() + + sql, args := query.Select("users"). + Columns("id", "name"). + Where("active = ?", true). + Where("age > ?", 18). + Build() + + if sql != "SELECT id, name FROM users WHERE active = ? AND age > ?" { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 2 || args[0] != true || args[1] != 18 { + t.Errorf("unexpected args: %v", args) + } +} + +func TestSelectOrWhere(t *testing.T) { + t.Parallel() + + sql, args := query.Select("users"). + Where("role = ?", "admin"). + OrWhere("role = ?", "editor"). + Build() + + if sql != "SELECT * FROM users WHERE role = ? OR role = ?" { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 2 || args[0] != "admin" || args[1] != "editor" { + t.Errorf("unexpected args: %v", args) + } +} + +func TestSelectJoin(t *testing.T) { + t.Parallel() + + sql, _ := query.Select("users"). + Columns("users.id", "posts.title"). + Join("posts ON posts.user_id = users.id"). + Build() + + expected := "SELECT users.id, posts.title FROM users JOIN posts ON posts.user_id = users.id" + + if sql != expected { + t.Errorf("unexpected sql: %s", sql) + } +} + +func TestSelectGroupBy(t *testing.T) { + t.Parallel() + + sql, _ := query.Select("orders"). + Columns("user_id", query.Raw("SUM(amount) AS total")). + GroupBy("user_id"). + Build() + + expected := "SELECT user_id, SUM(amount) AS total FROM orders GROUP BY user_id" + + if sql != expected { + t.Errorf("unexpected sql: %s", sql) + } +} + +func TestSelectOrderBy(t *testing.T) { + t.Parallel() + + sql, _ := query.Select("users"). + OrderBy("name ASC", "id DESC"). + Build() + + expected := "SELECT * FROM users ORDER BY name ASC, id DESC" + + if sql != expected { + t.Errorf("unexpected sql: %s", sql) + } +} + +func TestSelectLimitOffset(t *testing.T) { + t.Parallel() + + sql, _ := query.Select("users"). + Limit(10). + Offset(20). + Build() + + if sql != "SELECT * FROM users LIMIT 10 OFFSET 20" { + t.Errorf("unexpected sql: %s", sql) + } +} + +func TestSelectImmutability(t *testing.T) { + t.Parallel() + + base := query.Select("users").Columns("id", "name").Where("active = ?", true) + admins := base.Where("role = ?", "admin") + editors := base.Where("role = ?", "editor") + + sqlAdmins, argsAdmins := admins.Build() + sqlEditors, argsEditors := editors.Build() + + if sqlAdmins != "SELECT id, name FROM users WHERE active = ? AND role = ?" { + t.Errorf("unexpected admin sql: %s", sqlAdmins) + } + + if sqlEditors != "SELECT id, name FROM users WHERE active = ? AND role = ?" { + t.Errorf("unexpected editor sql: %s", sqlEditors) + } + + if argsAdmins[1] != "admin" { + t.Errorf("unexpected admin args: %v", argsAdmins) + } + + if argsEditors[1] != "editor" { + t.Errorf("unexpected editor args: %v", argsEditors) + } +} + +func TestSelectFull(t *testing.T) { + t.Parallel() + + sql, args := query.Select("users"). + Columns("users.id", "users.name", query.Raw("COUNT(posts.id) AS post_count")). + Join("posts ON posts.user_id = users.id"). + Where("users.active = ?", true). + GroupBy("users.id", "users.name"). + OrderBy("post_count DESC"). + Limit(5). + Build() + + expected := "SELECT users.id, users.name, COUNT(posts.id) AS post_count FROM users JOIN posts ON posts.user_id = users.id WHERE users.active = ? GROUP BY users.id, users.name ORDER BY post_count DESC LIMIT 5" + + if sql != expected { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 1 || args[0] != true { + t.Errorf("unexpected args: %v", args) + } +} diff --git a/query/update.go b/query/update.go new file mode 100644 index 0000000..a15f001 --- /dev/null +++ b/query/update.go @@ -0,0 +1,72 @@ +package query + +import ( + "slices" + "strings" +) + +// updatePair holds a column name and its associated value. +type updatePair struct { + column string + value any +} + +// UpdateQuery represents an immutable UPDATE query builder. +// Each method returns a new instance without modifying the original. +type UpdateQuery struct { + table string + pairs []updatePair + where whereClauses +} + +// Update creates a new [UpdateQuery] for the given table. +func Update(table string) UpdateQuery { + return UpdateQuery{table: table} +} + +// Set adds a column-value pair to the SET clause. The value can +// be a plain value (emits ?), [Raw] (emits literal SQL), or [Expr] +// (emits SQL with embedded placeholders). +func (query UpdateQuery) Set(column string, value any) UpdateQuery { + query.pairs = append(slices.Clone(query.pairs), updatePair{column: column, value: value}) + + return query +} + +// Where appends an AND WHERE condition. +func (query UpdateQuery) Where(sql string, args ...any) UpdateQuery { + query.where = query.where.add("AND", sql, args) + + return query +} + +// OrWhere appends an OR WHERE condition. +func (query UpdateQuery) OrWhere(sql string, args ...any) UpdateQuery { + query.where = query.where.add("OR", sql, args) + + return query +} + +// Build produces the SQL query string and its positional arguments. +func (query UpdateQuery) Build() (string, []any) { + var builder strings.Builder + var args []any + + builder.WriteString("UPDATE ") + builder.WriteString(query.table) + builder.WriteString(" SET ") + + for i, pair := range query.pairs { + if i > 0 { + builder.WriteString(", ") + } + + builder.WriteString(pair.column) + builder.WriteString(" = ") + renderValue(&builder, &args, pair.value) + } + + query.where.render(&builder, &args) + + return builder.String(), args +} diff --git a/query/update_test.go b/query/update_test.go new file mode 100644 index 0000000..ee51dfc --- /dev/null +++ b/query/update_test.go @@ -0,0 +1,81 @@ +package query_test + +import ( + "testing" + + "github.com/studiolambda/cosmos/query" +) + +func TestUpdateBasic(t *testing.T) { + t.Parallel() + + sql, args := query.Update("users"). + Set("name", "Erik"). + Where("id = ?", 1). + Build() + + if sql != "UPDATE users SET name = ? WHERE id = ?" { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 2 || args[0] != "Erik" || args[1] != 1 { + t.Errorf("unexpected args: %v", args) + } +} + +func TestUpdateRaw(t *testing.T) { + t.Parallel() + + sql, args := query.Update("users"). + Set("name", "Erik"). + Set("updated_at", query.Raw("NOW()")). + Where("id = ?", 1). + Build() + + if sql != "UPDATE users SET name = ?, updated_at = NOW() WHERE id = ?" { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 2 || args[0] != "Erik" || args[1] != 1 { + t.Errorf("unexpected args: %v", args) + } +} + +func TestUpdateOrWhere(t *testing.T) { + t.Parallel() + + sql, args := query.Update("users"). + Set("active", false). + Where("banned = ?", true). + OrWhere("expired = ?", true). + Build() + + expected := "UPDATE users SET active = ? WHERE banned = ? OR expired = ?" + + if sql != expected { + t.Errorf("unexpected sql: %s", sql) + } + + if len(args) != 3 || args[0] != false || args[1] != true || args[2] != true { + t.Errorf("unexpected args: %v", args) + } +} + +func TestUpdateImmutability(t *testing.T) { + t.Parallel() + + base := query.Update("users").Set("updated_at", query.Raw("NOW()")) + first := base.Set("name", "A").Where("id = ?", 1) + second := base.Set("name", "B").Where("id = ?", 2) + + _, argsFirst := first.Build() + _, argsSecond := second.Build() + + if argsFirst[0] != "A" || argsFirst[1] != 1 { + t.Errorf("unexpected first args: %v", argsFirst) + } + + if argsSecond[0] != "B" || argsSecond[1] != 2 { + t.Errorf("unexpected second args: %v", argsSecond) + } +} diff --git a/query/where.go b/query/where.go new file mode 100644 index 0000000..4eea010 --- /dev/null +++ b/query/where.go @@ -0,0 +1,52 @@ +package query + +import ( + "slices" + "strings" +) + +// clause represents a single WHERE condition with its connector (AND/OR). +type clause struct { + connector string + sql string + args []any +} + +// whereClauses is the shared WHERE clause state used by SELECT, UPDATE, and DELETE. +type whereClauses []clause + +// clone returns a copy of the where clauses. +func (clauses whereClauses) clone() whereClauses { + return slices.Clone(clauses) +} + +// add appends a new clause with the given connector. +func (clauses whereClauses) add(connector string, sql string, args []any) whereClauses { + result := clauses.clone() + + return append(result, clause{ + connector: connector, + sql: sql, + args: args, + }) +} + +// render writes the WHERE clause to the builder and appends args. +func (clauses whereClauses) render(builder *strings.Builder, args *[]any) { + if len(clauses) == 0 { + return + } + + builder.WriteString(" WHERE ") + + for i, c := range clauses { + if i > 0 { + builder.WriteByte(' ') + builder.WriteString(c.connector) + builder.WriteByte(' ') + } + + builder.WriteString(c.sql) + *args = append(*args, c.args...) + } +}