diff --git a/.github/workflows/ci-main.yml b/.github/workflows/ci-main.yml index e3493de604e..b468c81c271 100644 --- a/.github/workflows/ci-main.yml +++ b/.github/workflows/ci-main.yml @@ -54,7 +54,7 @@ jobs: # Service containers to run with `code-test` services: # Etcd service. - # docker run -d --name etcd -p 2379:2379 -e ALLOW_NONE_AUTHENTICATION=yes bitnamilegacy/etcd:3.4.24 + # docker run -p 2379:2379 -e ALLOW_NONE_AUTHENTICATION=yes bitnamilegacy/etcd:3.4.24 etcd: image: bitnamilegacy/etcd:3.4.24 env: @@ -75,7 +75,7 @@ jobs: - 6379:6379 # MySQL backend server. - # docker run -d --name mysql \ + # docker run \ # -p 3306:3306 \ # -e MYSQL_DATABASE=test \ # -e MYSQL_ROOT_PASSWORD=12345678 \ @@ -89,7 +89,7 @@ jobs: - 3306:3306 # MariaDb backend server. - # docker run -d --name mariadb \ + # docker run \ # -p 3307:3306 \ # -e MYSQL_DATABASE=test \ # -e MYSQL_ROOT_PASSWORD=12345678 \ @@ -103,7 +103,7 @@ jobs: - 3307:3306 # PostgreSQL backend server. - # docker run -d --name postgres \ + # docker run \ # -p 5432:5432 \ # -e POSTGRES_PASSWORD=12345678 \ # -e POSTGRES_USER=postgres \ @@ -150,7 +150,7 @@ jobs: --health-retries 10 # ClickHouse backend server. - # docker run -d --name clickhouse \ + # docker run \ # -p 9000:9000 -p 8123:8123 -p 9001:9001 \ # clickhouse/clickhouse-server:24.11.1.2557-alpine clickhouse-server: @@ -161,7 +161,7 @@ jobs: - 9001:9001 # Polaris backend server. - # docker run -d --name polaris \ + # docker run \ # -p 8090:8090 -p 8091:8091 -p 8093:8093 -p 9090:9090 -p 9091:9091 \ # polarismesh/polaris-standalone:v1.17.2 polaris: diff --git a/cmd/gf/internal/cmd/gendao/gendao.go b/cmd/gf/internal/cmd/gendao/gendao.go index 54204755e0e..bc6ad943ac5 100644 --- a/cmd/gf/internal/cmd/gendao/gendao.go +++ b/cmd/gf/internal/cmd/gendao/gendao.go @@ -104,6 +104,10 @@ var ( "smallmoney": { Type: "float64", }, + "uuid": { + Type: "uuid.UUID", + Import: "github.com/google/uuid", + }, } // tablewriter Options diff --git a/contrib/drivers/README.MD b/contrib/drivers/README.MD index c10d58d66e1..1adc943c0a7 100644 --- a/contrib/drivers/README.MD +++ b/contrib/drivers/README.MD @@ -9,7 +9,7 @@ Let's take `mysql` for example. ```shell go get github.com/gogf/gf/contrib/drivers/mysql/v2@latest -# Easy to copy +# Easy for copying: go get github.com/gogf/gf/contrib/drivers/clickhouse/v2@latest go get github.com/gogf/gf/contrib/drivers/dm/v2@latest go get github.com/gogf/gf/contrib/drivers/mssql/v2@latest @@ -57,7 +57,7 @@ import _ "github.com/gogf/gf/contrib/drivers/sqlite/v2" #### cgo version -When the target is a 32-bit Windows system, the cgo version needs to be used. +When the target is a `32-bit` Windows system, the `cgo` version needs to be used. ```go import _ "github.com/gogf/gf/contrib/drivers/sqlitecgo/v2" @@ -69,10 +69,6 @@ import _ "github.com/gogf/gf/contrib/drivers/sqlitecgo/v2" import _ "github.com/gogf/gf/contrib/drivers/pgsql/v2" ``` -Note: - -- It does not support `Replace` features. - ### SQL Server ```go @@ -81,9 +77,10 @@ import _ "github.com/gogf/gf/contrib/drivers/mssql/v2" Note: -- It does not support `Replace` features. +- `InsertIgnore` returns error if there is no primary key or unique index submitted with record. - It supports server version >= `SQL Server2005` -- It ONLY supports datetime2 and datetimeoffset types for auto handling created_at/updated_at/deleted_at columns, because datetime type does not support microseconds precision when column value is passed as string. +- It ONLY supports `datetime2` and `datetimeoffset` types for auto handling created_at/updated_at/deleted_at columns, + because datetime type does not support microseconds precision when column value is passed as string. ### Oracle @@ -93,8 +90,8 @@ import _ "github.com/gogf/gf/contrib/drivers/oracle/v2" Note: -- It does not support `Replace` features. - It does not support `LastInsertId`. +- `InsertIgnore` returns error if there is no primary key or unique index submitted with record. ### ClickHouse @@ -104,7 +101,7 @@ import _ "github.com/gogf/gf/contrib/drivers/clickhouse/v2" Note: -- It does not support `InsertIgnore/InsertGetId` features. +- It does not support `InsertIgnore/InsertAndGetId` features. - It does not support `Save/Replace` features. - It does not support `Transaction` feature. - It does not support `RowsAffected` feature. @@ -115,6 +112,10 @@ Note: import _ "github.com/gogf/gf/contrib/drivers/dm/v2" ``` +Note: + +- `InsertIgnore` returns error if there is no primary key or unique index submitted with record. + ## Custom Drivers It's quick and easy, please refer to current driver source. diff --git a/contrib/drivers/clickhouse/clickhouse_do_insert.go b/contrib/drivers/clickhouse/clickhouse_do_insert.go index a6c397ae31f..a7127691395 100644 --- a/contrib/drivers/clickhouse/clickhouse_do_insert.go +++ b/contrib/drivers/clickhouse/clickhouse_do_insert.go @@ -16,6 +16,7 @@ import ( ) // DoInsert inserts or updates data for given table. +// The list parameter must contain at least one record, which was previously validated. func (d *Driver) DoInsert( ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, ) (result sql.Result, err error) { diff --git a/contrib/drivers/dm/dm_do_insert.go b/contrib/drivers/dm/dm_do_insert.go index 538340ffaaa..72fc540b4f9 100644 --- a/contrib/drivers/dm/dm_do_insert.go +++ b/contrib/drivers/dm/dm_do_insert.go @@ -20,6 +20,7 @@ import ( ) // DoInsert inserts or updates data for given table. +// The list parameter must contain at least one record, which was previously validated. func (d *Driver) DoInsert( ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, ) (result sql.Result, err error) { @@ -36,6 +37,12 @@ func (d *Driver) DoInsert( return d.doInsertIgnore(ctx, link, table, list, option) default: + // DM database supports IDENTITY auto-increment columns natively. + // The driver automatically returns LastInsertId through sql.Result. + // + // Note: DM IDENTITY columns cannot accept explicit ID values unless + // IDENTITY_INSERT is enabled. When using tables with IDENTITY columns, + // avoid providing explicit ID values in the data. return d.Core.DoInsert(ctx, link, table, list, option) } } @@ -60,16 +67,12 @@ func (d *Driver) doInsertIgnore(ctx context.Context, // When withUpdate is false, it performs insert ignore (insert only when no conflict). func (d *Driver) doMergeInsert( ctx context.Context, - link gdb.Link, - table string, - list gdb.List, - option gdb.DoInsertOption, - withUpdate bool, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, withUpdate bool, ) (result sql.Result, err error) { // If OnConflict is not specified, automatically get the primary key of the table conflictKeys := option.OnConflict if len(conflictKeys) == 0 { - conflictKeys, err = d.getPrimaryKeys(ctx, table) + primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table) if err != nil { return nil, gerror.WrapCode( gcode.CodeInternalError, @@ -77,29 +80,34 @@ func (d *Driver) doMergeInsert( `failed to get primary keys for table`, ) } - if len(conflictKeys) == 0 { - return nil, gerror.NewCode( + foundPrimaryKey := false + for _, primaryKey := range primaryKeys { + for dataKey := range list[0] { + if strings.EqualFold(dataKey, primaryKey) { + foundPrimaryKey = true + break + } + } + if foundPrimaryKey { + break + } + } + if !foundPrimaryKey { + return nil, gerror.NewCodef( gcode.CodeMissingParameter, - `Please specify conflict columns or ensure the table has a primary key`, + `Replace/Save/InsertIgnore operation requires conflict detection: `+ + `either specify OnConflict() columns or ensure table '%s' has a primary key in the data`, + table, ) } - } - - if len(list) == 0 { - opName := "Save" - if !withUpdate { - opName = "InsertIgnore" - } - return nil, gerror.NewCodef( - gcode.CodeInvalidRequest, `%s operation list is empty by dm driver`, opName, - ) + // TODO consider composite primary keys. + conflictKeys = primaryKeys } var ( - one = list[0] - oneLen = len(one) - charL, charR = d.GetChars() - + one = list[0] + oneLen = len(one) + charL, charR = d.GetChars() conflictKeySet = gset.New(false) // queryHolders: Handle data with Holder that need to be merged @@ -155,24 +163,6 @@ func (d *Driver) doMergeInsert( return batchResult, nil } -// getPrimaryKeys retrieves the primary key field names of the table as a slice of strings. -// This method extracts primary key information from TableFields. -func (d *Driver) getPrimaryKeys(ctx context.Context, table string) ([]string, error) { - tableFields, err := d.TableFields(ctx, table) - if err != nil { - return nil, err - } - - var primaryKeys []string - for _, field := range tableFields { - if field.Key == "PRI" { - primaryKeys = append(primaryKeys, field.Name) - } - } - - return primaryKeys, nil -} - // parseSqlForMerge generates MERGE statement for DM database. // When updateValues is empty, it only inserts (INSERT IGNORE behavior). // When updateValues is provided, it performs upsert (INSERT or UPDATE). diff --git a/contrib/drivers/dm/dm_z_unit_basic_test.go b/contrib/drivers/dm/dm_z_unit_basic_test.go index da9ab215b88..be8a5394578 100644 --- a/contrib/drivers/dm/dm_z_unit_basic_test.go +++ b/contrib/drivers/dm/dm_z_unit_basic_test.go @@ -7,7 +7,6 @@ package dm_test import ( - "database/sql" "fmt" "strings" "testing" @@ -509,124 +508,3 @@ func Test_Empty_Slice_Argument(t *testing.T) { t.Assert(len(result), 0) }) } - -func TestModelSave(t *testing.T) { - table := createTable() - defer dropTable(table) - gtest.C(t, func(t *gtest.T) { - type User struct { - Id int - AccountName string - AttrIndex int - } - var ( - user User - count int - result sql.Result - err error - ) - - result, err = db.Model(table).Data(g.Map{ - "id": 1, - "accountName": "ac1", - "attrIndex": 100, - }).OnConflict("id").Save() - - t.AssertNil(err) - n, _ := result.RowsAffected() - t.Assert(n, 1) - - err = db.Model(table).Scan(&user) - t.AssertNil(err) - t.Assert(user.Id, 1) - t.Assert(user.AccountName, "ac1") - t.Assert(user.AttrIndex, 100) - - _, err = db.Model(table).Data(g.Map{ - "id": 1, - "accountName": "ac2", - "attrIndex": 200, - }).OnConflict("id").Save() - t.AssertNil(err) - - err = db.Model(table).Scan(&user) - t.AssertNil(err) - t.Assert(user.AccountName, "ac2") - t.Assert(user.AttrIndex, 200) - - count, err = db.Model(table).Count() - t.AssertNil(err) - t.Assert(count, 1) - }) -} - -func TestModelInsert(t *testing.T) { - // g.Model.insert not lost default not null column - table := "A_tables" - createInitTable(table) - gtest.C(t, func(t *gtest.T) { - i := 200 - data := User{ - ID: int64(i), - AccountName: fmt.Sprintf(`A%dtwo`, i), - PwdReset: 0, - AttrIndex: 99, - CreatedTime: time.Now(), - UpdatedTime: time.Now(), - } - // _, err := db.Schema(TestDBName).Model(table).Data(data).Insert() - _, err := db.Model(table).Insert(&data) - gtest.AssertNil(err) - }) - - gtest.C(t, func(t *gtest.T) { - i := 201 - data := User{ - ID: int64(i), - AccountName: fmt.Sprintf(`A%dtwoONE`, i), - PwdReset: 1, - CreatedTime: time.Now(), - AttrIndex: 98, - UpdatedTime: time.Now(), - } - // _, err := db.Schema(TestDBName).Model(table).Data(data).Insert() - _, err := db.Model(table).Data(&data).Insert() - gtest.AssertNil(err) - }) -} - -func Test_Model_InsertIgnore(t *testing.T) { - table := createInitTable() - defer dropTable(table) - - // db.SetDebug(true) - - gtest.C(t, func(t *gtest.T) { - data := User{ - ID: int64(666), - AccountName: fmt.Sprintf(`name_%d`, 666), - PwdReset: 0, - AttrIndex: 99, - CreatedTime: time.Now(), - UpdatedTime: time.Now(), - } - _, err := db.Model(table).Data(data).Insert() - t.AssertNil(err) - }) - gtest.C(t, func(t *gtest.T) { - data := User{ - ID: int64(666), - AccountName: fmt.Sprintf(`name_%d`, 777), - PwdReset: 0, - AttrIndex: 99, - CreatedTime: time.Now(), - UpdatedTime: time.Now(), - } - _, err := db.Model(table).Data(data).InsertIgnore() - t.AssertNil(err) - - one, err := db.Model(table).Where("id", 666).One() - t.AssertNil(err) - t.Assert(one["ACCOUNT_NAME"].String(), "name_666") - }) -} diff --git a/contrib/drivers/dm/dm_z_unit_init_test.go b/contrib/drivers/dm/dm_z_unit_init_test.go index 100329a7044..30c8aca8589 100644 --- a/contrib/drivers/dm/dm_z_unit_init_test.go +++ b/contrib/drivers/dm/dm_z_unit_init_test.go @@ -220,3 +220,33 @@ func createInitTables(len int) []string { } return tables } + +// createTableWithIdentity creates a table with IDENTITY column for LastInsertId testing +func createTableWithIdentity(table ...string) (name string) { + if len(table) > 0 { + name = table[0] + } else { + name = fmt.Sprintf("random_%d", gtime.Timestamp()) + } + + dropTable(name) + + if _, err := db.Exec(ctx, fmt.Sprintf(` + CREATE TABLE "%s" +( +"ID" BIGINT IDENTITY(1, 1) NOT NULL, +"ACCOUNT_NAME" VARCHAR(128) DEFAULT '' NOT NULL COMMENT 'Account Name', +"PWD_RESET" TINYINT DEFAULT 0 NOT NULL, +"ENABLED" INT DEFAULT 1 NOT NULL, +"DELETED" INT DEFAULT 0 NOT NULL, +"ATTR_INDEX" INT DEFAULT 0 , +"CREATED_BY" VARCHAR(32) DEFAULT '' NOT NULL, +"CREATED_TIME" TIMESTAMP(6) DEFAULT CURRENT_TIMESTAMP() NOT NULL, +"UPDATED_BY" VARCHAR(32) DEFAULT '' NOT NULL, +"UPDATED_TIME" TIMESTAMP(6) DEFAULT CURRENT_TIMESTAMP() NOT NULL, +NOT CLUSTER PRIMARY KEY("ID")) STORAGE(ON "MAIN", CLUSTERBTR) ; + `, name)); err != nil { + gtest.Fatal(err) + } + return +} diff --git a/contrib/drivers/dm/dm_z_unit_model_test.go b/contrib/drivers/dm/dm_z_unit_model_test.go new file mode 100644 index 00000000000..80f04010b78 --- /dev/null +++ b/contrib/drivers/dm/dm_z_unit_model_test.go @@ -0,0 +1,185 @@ +// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package dm_test + +import ( + "database/sql" + "fmt" + "testing" + "time" + + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/os/gtime" + "github.com/gogf/gf/v2/test/gtest" +) + +func Test_Model_Save(t *testing.T) { + table := createTableWithIdentity() + defer dropTable(table) + gtest.C(t, func(t *gtest.T) { + type User struct { + Id int + AccountName string + AttrIndex int + } + var ( + user User + count int + result sql.Result + err error + ) + + // First insert: let IDENTITY auto-generate ID - use Insert() instead of Save() + // because Save() requires a primary key in the data for conflict detection + result, err = db.Model(table).Data(g.Map{ + "accountName": "ac1", + "attrIndex": 100, + }).Insert() + + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + err = db.Model(table).Scan(&user) + t.AssertNil(err) + t.AssertGT(user.Id, 0) // ID should be auto-generated + t.Assert(user.AccountName, "ac1") + t.Assert(user.AttrIndex, 100) + + // Second save: update the existing record using the generated ID + _, err = db.Model(table).Data(g.Map{ + "id": user.Id, + "accountName": "ac2", + "attrIndex": 200, + }).OnConflict("id").Save() + t.AssertNil(err) + + err = db.Model(table).Scan(&user) + t.AssertNil(err) + t.Assert(user.AccountName, "ac2") + t.Assert(user.AttrIndex, 200) + + _, err = db.Model(table).Data(g.Map{ + "id": user.Id, + "accountName": "ac2", + "attrIndex": 2000, + }).Save() + t.AssertNil(err) + + err = db.Model(table).Scan(&user) + t.AssertNil(err) + t.Assert(user.AccountName, "ac2") + t.Assert(user.AttrIndex, 2000) + + count, err = db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, 1) + }) +} + +func Test_Model_Insert(t *testing.T) { + // g.Model.insert not lost default not null column + table := "A_tables" + createInitTable(table) + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + i := 200 + data := User{ + ID: int64(i), + AccountName: fmt.Sprintf(`A%dtwo`, i), + PwdReset: 0, + AttrIndex: 99, + CreatedTime: time.Now(), + UpdatedTime: time.Now(), + } + result, err := db.Model(table).Insert(&data) + gtest.AssertNil(err) + n, err := result.RowsAffected() + gtest.AssertNil(err) + gtest.Assert(n, 1) + }) + + gtest.C(t, func(t *gtest.T) { + i := 201 + data := User{ + ID: int64(i), + AccountName: fmt.Sprintf(`A%dtwoONE`, i), + PwdReset: 1, + CreatedTime: time.Now(), + AttrIndex: 98, + UpdatedTime: time.Now(), + } + result, err := db.Model(table).Data(&data).Insert() + gtest.AssertNil(err) + n, err := result.RowsAffected() + gtest.AssertNil(err) + gtest.Assert(n, 1) + }) +} + +func Test_Model_InsertIgnore(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // db.SetDebug(true) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "account_name": fmt.Sprintf(`name_%d`, 777), + "pwd_reset": 0, + "attr_index": 777, + "created_time": gtime.Now(), + } + _, err := db.Model(table).Data(data).InsertIgnore() + t.AssertNil(err) + + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["ACCOUNT_NAME"].String(), "name_1") + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + // "id": 1, + "account_name": fmt.Sprintf(`name_%d`, 777), + "pwd_reset": 0, + "attr_index": 777, + "created_time": gtime.Now(), + } + _, err := db.Model(table).Data(data).InsertIgnore() + t.AssertNE(err, nil) + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) +} + +func Test_Model_InsertAndGetId(t *testing.T) { + table := createTableWithIdentity() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + // "id": 1, + "account_name": fmt.Sprintf(`name_%d`, 1), + "pwd_reset": 0, + "attr_index": 1, + "created_time": gtime.Now(), + } + lastId, err := db.Model(table).Data(data).InsertAndGetId() + t.AssertNil(err) + t.AssertGT(lastId, 0) + }) + +} diff --git a/contrib/drivers/mssql/mssql.go b/contrib/drivers/mssql/mssql.go index 5be217ce002..a8f443e7da8 100644 --- a/contrib/drivers/mssql/mssql.go +++ b/contrib/drivers/mssql/mssql.go @@ -4,11 +4,7 @@ // If a copy of the MIT was not distributed with this file, // You can obtain one at https://github.com/gogf/gf. -// Package mssql implements gdb.Driver, which supports operations for database MSSql. -// -// Note: -// 1. It does not support Replace features. -// 2. It does not support LastInsertId. +// Package mssql implements gdb.Driver, which supports operations for MSSQL. package mssql import ( diff --git a/contrib/drivers/mssql/mssql_do_exec.go b/contrib/drivers/mssql/mssql_do_exec.go index b8b0142447a..8cf90c4841b 100644 --- a/contrib/drivers/mssql/mssql_do_exec.go +++ b/contrib/drivers/mssql/mssql_do_exec.go @@ -1,3 +1,9 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + package mssql import ( @@ -87,12 +93,16 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sqlStr string, args IsTransaction: link.IsTransaction(), }) if err != nil { - return &InsertResult{lastInsertId: 0, rowsAffected: 0, err: err}, err + return &Result{lastInsertId: 0, rowsAffected: 0, err: err}, err } stdSqlResult := out.Records if len(stdSqlResult) == 0 { - err = gerror.WrapCode(gcode.CodeDbOperationError, gerror.New("affectcount is zero"), `sql.Result.RowsAffected failed`) - return &InsertResult{lastInsertId: 0, rowsAffected: 0, err: err}, err + err = gerror.WrapCode( + gcode.CodeDbOperationError, + gerror.New("affected count is zero"), + `sql.Result.RowsAffected failed`, + ) + return &Result{lastInsertId: 0, rowsAffected: 0, err: err}, err } // For batch insert, OUTPUT clause returns one row per inserted row. // So the rowsAffected should be the count of returned records. @@ -100,7 +110,7 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sqlStr string, args // get last_insert_id from the first returned row lastInsertId := stdSqlResult[0].GMap().GetVar(lastInsertIdFieldAlias).Int64() - return &InsertResult{lastInsertId: lastInsertId, rowsAffected: rowsAffected}, err + return &Result{lastInsertId: lastInsertId, rowsAffected: rowsAffected}, err } // GetTableNameFromSql get table name from sql statement @@ -111,17 +121,19 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sqlStr string, args // "user as u". func (d *Driver) GetTableNameFromSql(sqlStr string) (table string) { // INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?) - leftChars, rightChars := d.GetChars() - trimStr := leftChars + rightChars + "[] " - pattern := "INTO(.+?)\\(" - regCompile := regexp.MustCompile(pattern) - tableInfo := regCompile.FindStringSubmatch(sqlStr) + var ( + leftChars, rightChars = d.GetChars() + trimStr = leftChars + rightChars + "[] " + pattern = "INTO(.+?)\\(" + regCompile = regexp.MustCompile(pattern) + tableInfo = regCompile.FindStringSubmatch(sqlStr) + ) // get the first one. after the first it may be content of the value, it's not table name. table = tableInfo[1] table = strings.Trim(table, " ") if strings.Contains(table, ".") { tmpAry := strings.Split(table, ".") - // the last one is tablename + // the last one is table name table = tmpAry[len(tmpAry)-1] } else if strings.Contains(table, "as") || strings.Contains(table, " ") { tmpAry := strings.Split(table, "as") @@ -151,24 +163,9 @@ func (l *txLinkMssql) IsOnMaster() bool { return true } -// InsertResult instance of sql.Result -type InsertResult struct { - lastInsertId int64 - rowsAffected int64 - err error -} - -func (r *InsertResult) LastInsertId() (int64, error) { - return r.lastInsertId, r.err -} - -func (r *InsertResult) RowsAffected() (int64, error) { - return r.rowsAffected, r.err -} - // GetInsertOutputSql gen get last_insert_id code -func (m *Driver) GetInsertOutputSql(ctx context.Context, table string) string { - fds, errFd := m.GetDB().TableFields(ctx, table) +func (d *Driver) GetInsertOutputSql(ctx context.Context, table string) string { + fds, errFd := d.GetDB().TableFields(ctx, table) if errFd != nil { return "" } diff --git a/contrib/drivers/mssql/mssql_do_insert.go b/contrib/drivers/mssql/mssql_do_insert.go index 5e467d73069..93bc17cfa81 100644 --- a/contrib/drivers/mssql/mssql_do_insert.go +++ b/contrib/drivers/mssql/mssql_do_insert.go @@ -20,51 +20,95 @@ import ( ) // DoInsert inserts or updates data for given table. -func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption) (result sql.Result, err error) { +// The list parameter must contain at least one record, which was previously validated. +func (d *Driver) DoInsert( + ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, +) (result sql.Result, err error) { switch option.InsertOption { case gdb.InsertOptionSave: return d.doSave(ctx, link, table, list, option) case gdb.InsertOptionReplace: - return nil, gerror.NewCode( - gcode.CodeNotSupported, - `Replace operation is not supported by mssql driver`, - ) + // MSSQL does not support REPLACE INTO syntax, use SAVE instead. + return d.doSave(ctx, link, table, list, option) + + case gdb.InsertOptionIgnore: + // MSSQL does not support INSERT IGNORE syntax, use MERGE instead. + return d.doInsertIgnore(ctx, link, table, list, option) default: return d.Core.DoInsert(ctx, link, table, list, option) } } -// doSave support upsert for SQL server +// doSave support upsert for MSSQL func (d *Driver) doSave(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, ) (result sql.Result, err error) { - if len(option.OnConflict) == 0 { - return nil, gerror.NewCode( - gcode.CodeMissingParameter, `Please specify conflict columns`, - ) - } + return d.doMergeInsert(ctx, link, table, list, option, true) +} - if len(list) == 0 { - return nil, gerror.NewCode( - gcode.CodeInvalidRequest, `Save operation list is empty by mssql driver`, - ) +// doInsertIgnore implements INSERT IGNORE operation using MERGE statement for MSSQL database. +// It only inserts records when there's no conflict on primary/unique keys. +func (d *Driver) doInsertIgnore(ctx context.Context, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, +) (result sql.Result, err error) { + return d.doMergeInsert(ctx, link, table, list, option, false) +} + +// doMergeInsert implements MERGE-based insert operations for MSSQL database. +// When withUpdate is true, it performs upsert (insert or update). +// When withUpdate is false, it performs insert ignore (insert only when no conflict). +func (d *Driver) doMergeInsert( + ctx context.Context, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, withUpdate bool, +) (result sql.Result, err error) { + // If OnConflict is not specified, automatically get the primary key of the table + conflictKeys := option.OnConflict + if len(conflictKeys) == 0 { + primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table) + if err != nil { + return nil, gerror.WrapCode( + gcode.CodeInternalError, + err, + `failed to get primary keys for table`, + ) + } + foundPrimaryKey := false + for _, primaryKey := range primaryKeys { + for dataKey := range list[0] { + if strings.EqualFold(dataKey, primaryKey) { + foundPrimaryKey = true + break + } + } + if foundPrimaryKey { + break + } + } + if !foundPrimaryKey { + return nil, gerror.NewCodef( + gcode.CodeMissingParameter, + `Replace/Save/InsertIgnore operation requires conflict detection: `+ + `either specify OnConflict() columns or ensure table '%s' has a primary key in the data`, + table, + ) + } + // TODO consider composite primary keys. + conflictKeys = primaryKeys } var ( - one = list[0] - oneLen = len(one) - charL, charR = d.GetChars() - - conflictKeys = option.OnConflict + one = list[0] + oneLen = len(one) + charL, charR = d.GetChars() conflictKeySet = gset.New(false) - // queryHolders: Handle data with Holder that need to be upsert - // queryValues: Handle data that need to be upsert + // queryHolders: Handle data with Holder that need to be merged + // queryValues: Handle data that need to be merged // insertKeys: Handle valid keys that need to be inserted // insertValues: Handle values that need to be inserted - // updateValues: Handle values that need to be updated + // updateValues: Handle values that need to be updated (only when withUpdate=true) queryHolders = make([]string, oneLen) queryValues = make([]any, oneLen) insertKeys = make([]string, oneLen) @@ -84,9 +128,9 @@ func (d *Driver) doSave(ctx context.Context, insertKeys[index] = charL + key + charR insertValues[index] = "T2." + charL + key + charR - // filter conflict keys in updateValues. - // And the key is not a soft created field. - if !(conflictKeySet.Contains(key) || d.Core.IsSoftCreatedFieldName(key)) { + // Build updateValues only when withUpdate is true + // Filter conflict keys and soft created fields from updateValues + if withUpdate && !(conflictKeySet.Contains(key) || d.Core.IsSoftCreatedFieldName(key)) { updateValues = append( updateValues, fmt.Sprintf(`T1.%s = T2.%s`, charL+key+charR, charL+key+charR), @@ -95,8 +139,10 @@ func (d *Driver) doSave(ctx context.Context, index++ } - batchResult := new(gdb.SqlResult) - sqlStr := parseSqlForUpsert(table, queryHolders, insertKeys, insertValues, updateValues, conflictKeys) + var ( + batchResult = new(gdb.SqlResult) + sqlStr = parseSqlForMerge(table, queryHolders, insertKeys, insertValues, updateValues, conflictKeys) + ) r, err := d.DoExec(ctx, link, sqlStr, queryValues...) if err != nil { return r, err @@ -110,41 +156,48 @@ func (d *Driver) doSave(ctx context.Context, return batchResult, nil } -// parseSqlForUpsert -// MERGE INTO {{table}} T1 -// USING ( VALUES( {{queryHolders}}) T2 ({{insertKeyStr}}) -// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}} AND ...) -// WHEN NOT MATCHED THEN -// INSERT {{insertKeys}} VALUES {{insertValues}} -// WHEN MATCHED THEN -// UPDATE SET {{updateValues}} -func parseSqlForUpsert(table string, +// parseSqlForMerge generates MERGE statement for MSSQL database. +// When updateValues is empty, it only inserts (INSERT IGNORE behavior). +// When updateValues is provided, it performs upsert (INSERT or UPDATE). +// Examples: +// - INSERT IGNORE: MERGE INTO table T1 USING (...) T2 ON (...) WHEN NOT MATCHED THEN INSERT(...) VALUES (...) +// - UPSERT: MERGE INTO table T1 USING (...) T2 ON (...) WHEN NOT MATCHED THEN INSERT(...) VALUES (...) WHEN MATCHED THEN UPDATE SET ... +func parseSqlForMerge(table string, queryHolders, insertKeys, insertValues, updateValues, duplicateKey []string, ) (sqlStr string) { var ( queryHolderStr = strings.Join(queryHolders, ",") insertKeyStr = strings.Join(insertKeys, ",") insertValueStr = strings.Join(insertValues, ",") - updateValueStr = strings.Join(updateValues, ",") duplicateKeyStr string - pattern = gstr.Trim(`MERGE INTO %s T1 USING (VALUES(%s)) T2 (%s) ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s;`) ) + // Build ON condition for index, keys := range duplicateKey { if index != 0 { duplicateKeyStr += " AND " } - duplicateTmp := fmt.Sprintf("T1.%s = T2.%s", keys, keys) - duplicateKeyStr += duplicateTmp + duplicateKeyStr += fmt.Sprintf("T1.%s = T2.%s", keys, keys) } - return fmt.Sprintf(pattern, - table, - queryHolderStr, - insertKeyStr, - duplicateKeyStr, - insertKeyStr, - insertValueStr, - updateValueStr, + // Build SQL based on whether UPDATE is needed + pattern := gstr.Trim( + `MERGE INTO %s T1 USING (VALUES(%s)) T2 (%s) ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s)`, ) + if len(updateValues) > 0 { + // Upsert: INSERT or UPDATE + pattern += gstr.Trim(` WHEN MATCHED THEN UPDATE SET %s`) + return fmt.Sprintf( + pattern+";", + table, + queryHolderStr, + insertKeyStr, + duplicateKeyStr, + insertKeyStr, + insertValueStr, + strings.Join(updateValues, ","), + ) + } + // Insert Ignore: INSERT only + return fmt.Sprintf(pattern+";", table, queryHolderStr, insertKeyStr, duplicateKeyStr, insertKeyStr, insertValueStr) } diff --git a/contrib/drivers/mssql/mssql_result.go b/contrib/drivers/mssql/mssql_result.go new file mode 100644 index 00000000000..57f3f41a9ab --- /dev/null +++ b/contrib/drivers/mssql/mssql_result.go @@ -0,0 +1,22 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package mssql + +// Result instance of sql.Result +type Result struct { + lastInsertId int64 + rowsAffected int64 + err error +} + +func (r *Result) LastInsertId() (int64, error) { + return r.lastInsertId, r.err +} + +func (r *Result) RowsAffected() (int64, error) { + return r.rowsAffected, r.err +} diff --git a/contrib/drivers/mssql/mssql_z_unit_basic_test.go b/contrib/drivers/mssql/mssql_z_unit_basic_test.go index 0999a1eeee8..49635776df6 100644 --- a/contrib/drivers/mssql/mssql_z_unit_basic_test.go +++ b/contrib/drivers/mssql/mssql_z_unit_basic_test.go @@ -138,15 +138,17 @@ func TestDoInsert(t *testing.T) { i := 10 data := g.Map{ - "id": i, + // "id": i, "passport": fmt.Sprintf(`t%d`, i), "password": fmt.Sprintf(`p%d`, i), "nickname": fmt.Sprintf(`T%d`, i), "create_time": gtime.Now(), } + // Save without OnConflict should fail (missing conflict columns) _, err := db.Save(context.Background(), "t_user", data, 10) gtest.AssertNE(err, nil) + // Replace should fail because primary key 'id' is not in the data _, err = db.Replace(context.Background(), "t_user", data, 10) gtest.AssertNE(err, nil) }) diff --git a/contrib/drivers/mssql/mssql_z_unit_model_test.go b/contrib/drivers/mssql/mssql_z_unit_model_test.go index b3a1daa81db..1e49e5f5c76 100644 --- a/contrib/drivers/mssql/mssql_z_unit_model_test.go +++ b/contrib/drivers/mssql/mssql_z_unit_model_test.go @@ -117,6 +117,48 @@ func Test_Model_Insert(t *testing.T) { }) } +func Test_Model_InsertIgnore(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // db.SetDebug(true) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": fmt.Sprintf(`t%d`, 777), + "password": fmt.Sprintf(`p%d`, 777), + "nickname": fmt.Sprintf(`T%d`, 777), + "create_time": gtime.Now(), + } + _, err := db.Model(table).Data(data).InsertIgnore() + t.AssertNil(err) + + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "user_1") + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "passport": fmt.Sprintf(`t%d`, 777), + "password": fmt.Sprintf(`p%d`, 777), + "nickname": fmt.Sprintf(`T%d`, 777), + "create_time": gtime.Now(), + } + _, err := db.Model(table).Data(data).InsertIgnore() + t.AssertNE(err, nil) + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) +} + func Test_Model_Insert_KeyFieldNameMapping(t *testing.T) { table := createTable() defer dropTable(table) @@ -2658,14 +2700,53 @@ func Test_Model_Replace(t *testing.T) { defer dropTable(table) gtest.C(t, func(t *gtest.T) { - _, err := db.Model(table).Data(g.Map{ + // Insert initial record + result, err := db.Model(table).Data(g.Map{ + "id": 1, + "passport": "t1", + "password": "pass1", + "nickname": "T1", + "create_time": "2018-10-24 10:00:00", + }).Insert() + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + // Replace with new data (should update existing record using MERGE) + result, err = db.Model(table).Data(g.Map{ "id": 1, "passport": "t11", "password": "25d55ad283aa400af464c76d713c07ad", "nickname": "T11", "create_time": "2018-10-24 10:00:00", }).Replace() - t.Assert(err, "Replace operation is not supported by mssql driver") + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) + + // Verify the data was replaced + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "t11") + t.Assert(one["NICKNAME"].String(), "T11") + + // Replace with non-existing record (should insert new record) + result, err = db.Model(table).Data(g.Map{ + "id": 2, + "passport": "t222", + "password": "pass2", + "nickname": "T222", + "create_time": "2018-10-24 11:00:00", + }).Replace() + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) // MERGE reports: 1 for insert + + // Verify the new record was inserted + one, err = db.Model(table).WherePri(2).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "t222") + t.Assert(one["NICKNAME"].String(), "T222") }) } diff --git a/contrib/drivers/oracle/oracle.go b/contrib/drivers/oracle/oracle.go index ae6a1c5ddc1..51aa4df2515 100644 --- a/contrib/drivers/oracle/oracle.go +++ b/contrib/drivers/oracle/oracle.go @@ -5,10 +5,6 @@ // You can obtain one at https://github.com/gogf/gf. // Package oracle implements gdb.Driver, which supports operations for database Oracle. -// -// Note: -// 1. It does not support Save/Replace features. -// 2. It does not support LastInsertId. package oracle import ( diff --git a/contrib/drivers/oracle/oracle_do_exec.go b/contrib/drivers/oracle/oracle_do_exec.go new file mode 100644 index 00000000000..d7fe4a39dd4 --- /dev/null +++ b/contrib/drivers/oracle/oracle_do_exec.go @@ -0,0 +1,120 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package oracle + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" +) + +const ( + returningClause = " RETURNING %s INTO ?" +) + +// DoExec commits the sql string and its arguments to underlying driver +// through given link object and returns the execution result. +// It handles INSERT statements specially to support LastInsertId. +func (d *Driver) DoExec( + ctx context.Context, link gdb.Link, sql string, args ...interface{}, +) (result sql.Result, err error) { + var ( + isUseCoreDoExec = true + primaryKey string + pkField gdb.TableField + ) + + // Transaction checks. + if link == nil { + if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil { + link = tx + } else if link, err = d.MasterLink(); err != nil { + return nil, err + } + } else if !link.IsTransaction() { + if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil { + link = tx + } + } + + // Check if it is an insert operation with primary key from context. + if value := ctx.Value(internalPrimaryKeyInCtx); value != nil { + if field, ok := value.(gdb.TableField); ok { + pkField = field + isUseCoreDoExec = false + } + } + + // Check if it is an INSERT statement with primary key. + if !isUseCoreDoExec && pkField.Name != "" && strings.Contains(strings.ToUpper(sql), "INSERT INTO") { + primaryKey = pkField.Name + // Oracle supports RETURNING clause to get the last inserted id + sql += fmt.Sprintf(returningClause, d.QuoteWord(primaryKey)) + } else { + // Use default DoExec for non-INSERT or no primary key scenarios + return d.Core.DoExec(ctx, link, sql, args...) + } + + // Only the insert operation with primary key can execute the following code + + // SQL filtering. + sql, args = d.FormatSqlBeforeExecuting(sql, args) + sql, args, err = d.DoFilter(ctx, link, sql, args) + if err != nil { + return nil, err + } + + // Prepare output variable for RETURNING clause + var lastInsertId int64 + // Append the output parameter for the RETURNING clause + args = append(args, &lastInsertId) + + // Link execution. + _, err = d.DoCommit(ctx, gdb.DoCommitInput{ + Link: link, + Sql: sql, + Args: args, + Stmt: nil, + Type: gdb.SqlTypeExecContext, + IsTransaction: link.IsTransaction(), + }) + + if err != nil { + return &Result{ + lastInsertId: 0, + rowsAffected: 0, + lastInsertIdError: err, + }, err + } + + // Get rows affected from the result + // For single insert with RETURNING clause, affected is always 1 + var affected int64 = 1 + + // Check if the primary key field type supports LastInsertId + if !strings.Contains(strings.ToLower(pkField.Type), "int") { + return &Result{ + lastInsertId: 0, + rowsAffected: affected, + lastInsertIdError: gerror.NewCodef( + gcode.CodeNotSupported, + "LastInsertId is not supported by primary key type: %s", + pkField.Type, + ), + }, nil + } + + return &Result{ + lastInsertId: lastInsertId, + rowsAffected: affected, + }, nil +} diff --git a/contrib/drivers/oracle/oracle_do_insert.go b/contrib/drivers/oracle/oracle_do_insert.go index d59bdf95b26..82f8373d5e4 100644 --- a/contrib/drivers/oracle/oracle_do_insert.go +++ b/contrib/drivers/oracle/oracle_do_insert.go @@ -16,11 +16,17 @@ import ( "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/os/gctx" "github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/util/gconv" ) +const ( + internalPrimaryKeyInCtx gctx.StrKey = "primary_key_field" +) + // DoInsert inserts or updates data for given table. +// The list parameter must contain at least one record, which was previously validated. func (d *Driver) DoInsert( ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, ) (result sql.Result, err error) { @@ -29,10 +35,39 @@ func (d *Driver) DoInsert( return d.doSave(ctx, link, table, list, option) case gdb.InsertOptionReplace: - return nil, gerror.NewCode( - gcode.CodeNotSupported, - `Replace operation is not supported by oracle driver`, - ) + // Oracle does not support REPLACE INTO syntax, use SAVE instead. + return d.doSave(ctx, link, table, list, option) + + case gdb.InsertOptionIgnore: + // Oracle does not support INSERT IGNORE syntax, use MERGE instead. + return d.doInsertIgnore(ctx, link, table, list, option) + + case gdb.InsertOptionDefault: + // For default insert, set primary key field in context to support LastInsertId. + // Only set it when the primary key is not provided in the data, for performance reason. + tableFields, err := d.GetCore().GetDB().TableFields(ctx, table) + if err == nil && len(list) > 0 { + for _, field := range tableFields { + if strings.EqualFold(field.Key, "pri") { + // Check if primary key is provided in the data. + pkProvided := false + for key := range list[0] { + if strings.EqualFold(key, field.Name) { + pkProvided = true + break + } + } + // Only use RETURNING when primary key is not provided, for performance reason. + if !pkProvided { + pkField := *field + ctx = context.WithValue(ctx, internalPrimaryKeyInCtx, pkField) + } + break + } + } + } + + default: } var ( keys []string @@ -55,8 +90,8 @@ func (d *Driver) DoInsert( valueHolderStr = strings.Join(valueHolder, ",") ) // Format "INSERT...INTO..." statement. - intoStrArray := make([]string, 0) - for i := 0; i < len(list); i++ { + // Note: Use standard INSERT INTO syntax instead of INSERT ALL to ensure triggers fire + for i := 0; i < listLength; i++ { for _, k := range keys { if s, ok := list[i][k].(gdb.Raw); ok { params = append(params, gconv.String(s)) @@ -65,30 +100,22 @@ func (d *Driver) DoInsert( } } values = append(values, valueHolderStr) - intoStrArray = append( - intoStrArray, - fmt.Sprintf( - "INTO %s(%s) VALUES(%s)", - table, keyStr, valueHolderStr, - ), - ) - if len(intoStrArray) == option.BatchCount || (i == listLength-1 && len(valueHolder) > 0) { - r, err := d.DoExec(ctx, link, fmt.Sprintf( - "INSERT ALL %s SELECT * FROM DUAL", - strings.Join(intoStrArray, " "), - ), params...) - if err != nil { - return r, err - } - if n, err := r.RowsAffected(); err != nil { - return r, err - } else { - batchResult.Result = r - batchResult.Affected += n - } - params = params[:0] - intoStrArray = intoStrArray[:0] + + // Execute individual INSERT for each record to trigger row-level triggers + r, err := d.DoExec(ctx, link, fmt.Sprintf( + "INSERT INTO %s(%s) VALUES(%s)", + table, keyStr, valueHolderStr, + ), params...) + if err != nil { + return r, err + } + if n, err := r.RowsAffected(); err != nil { + return r, err + } else { + batchResult.Result = r + batchResult.Affected += n } + params = params[:0] } return batchResult, nil } @@ -97,24 +124,63 @@ func (d *Driver) DoInsert( func (d *Driver) doSave(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, ) (result sql.Result, err error) { - if len(option.OnConflict) == 0 { - return nil, gerror.NewCode( - gcode.CodeMissingParameter, `Please specify conflict columns`, - ) - } + return d.doMergeInsert(ctx, link, table, list, option, true) +} - if len(list) == 0 { - return nil, gerror.NewCode( - gcode.CodeInvalidRequest, `Save operation list is empty by oracle driver`, - ) +// doInsertIgnore implements INSERT IGNORE operation using MERGE statement for Oracle database. +// It only inserts records when there's no conflict on primary/unique keys. +func (d *Driver) doInsertIgnore(ctx context.Context, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, +) (result sql.Result, err error) { + return d.doMergeInsert(ctx, link, table, list, option, false) +} + +// doMergeInsert implements MERGE-based insert operations for Oracle database. +// When withUpdate is true, it performs upsert (insert or update). +// When withUpdate is false, it performs insert ignore (insert only when no conflict). +func (d *Driver) doMergeInsert( + ctx context.Context, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, withUpdate bool, +) (result sql.Result, err error) { + // If OnConflict is not specified, automatically get the primary key of the table + conflictKeys := option.OnConflict + if len(conflictKeys) == 0 { + primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table) + if err != nil { + return nil, gerror.WrapCode( + gcode.CodeInternalError, + err, + `failed to get primary keys for table`, + ) + } + foundPrimaryKey := false + for _, primaryKey := range primaryKeys { + for dataKey := range list[0] { + if strings.EqualFold(dataKey, primaryKey) { + foundPrimaryKey = true + break + } + } + if foundPrimaryKey { + break + } + } + if !foundPrimaryKey { + return nil, gerror.NewCodef( + gcode.CodeMissingParameter, + `Replace/Save/InsertIgnore operation requires conflict detection: `+ + `either specify OnConflict() columns or ensure table '%s' has a primary key in the data`, + table, + ) + } + // TODO consider composite primary keys. + conflictKeys = primaryKeys } var ( - one = list[0] - oneLen = len(one) - charL, charR = d.GetChars() - - conflictKeys = option.OnConflict + one = list[0] + oneLen = len(one) + charL, charR = d.GetChars() conflictKeySet = gset.New(false) // queryHolders: Handle data with Holder that need to be upsert @@ -142,9 +208,9 @@ func (d *Driver) doSave(ctx context.Context, insertKeys[index] = keyWithChar insertValues[index] = fmt.Sprintf("T2.%s", keyWithChar) - // filter conflict keys in updateValues. - // And the key is not a soft created field. - if !(conflictKeySet.Contains(key) || d.Core.IsSoftCreatedFieldName(key)) { + // Build updateValues only when withUpdate is true + // Filter conflict keys and soft created fields from updateValues + if withUpdate && !(conflictKeySet.Contains(key) || d.Core.IsSoftCreatedFieldName(key)) { updateValues = append( updateValues, fmt.Sprintf(`T1.%s = T2.%s`, keyWithChar, keyWithChar), @@ -153,8 +219,10 @@ func (d *Driver) doSave(ctx context.Context, index++ } - batchResult := new(gdb.SqlResult) - sqlStr := parseSqlForUpsert(table, queryHolders, insertKeys, insertValues, updateValues, conflictKeys) + var ( + batchResult = new(gdb.SqlResult) + sqlStr = parseSqlForMerge(table, queryHolders, insertKeys, insertValues, updateValues, conflictKeys) + ) r, err := d.DoExec(ctx, link, sqlStr, queryValues...) if err != nil { return r, err @@ -168,40 +236,43 @@ func (d *Driver) doSave(ctx context.Context, return batchResult, nil } -// parseSqlForUpsert -// MERGE INTO {{table}} T1 -// USING ( SELECT {{queryHolders}} FROM DUAL T2 -// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}} AND ...) -// WHEN NOT MATCHED THEN -// INSERT {{insertKeys}} VALUES {{insertValues}} -// WHEN MATCHED THEN -// UPDATE SET {{updateValues}} -func parseSqlForUpsert(table string, +// parseSqlForMerge generates MERGE statement for Oracle database. +// When updateValues is empty, it only inserts (INSERT IGNORE behavior). +// When updateValues is provided, it performs upsert (INSERT or UPDATE). +// Examples: +// - INSERT IGNORE: MERGE INTO table T1 USING (...) T2 ON (...) WHEN NOT MATCHED THEN INSERT(...) VALUES (...) +// - UPSERT: MERGE INTO table T1 USING (...) T2 ON (...) WHEN NOT MATCHED THEN INSERT(...) VALUES (...) WHEN MATCHED THEN UPDATE SET ... +func parseSqlForMerge(table string, queryHolders, insertKeys, insertValues, updateValues, duplicateKey []string, ) (sqlStr string) { var ( queryHolderStr = strings.Join(queryHolders, ",") insertKeyStr = strings.Join(insertKeys, ",") insertValueStr = strings.Join(insertValues, ",") - updateValueStr = strings.Join(updateValues, ",") duplicateKeyStr string - pattern = gstr.Trim(`MERGE INTO %s T1 USING (SELECT %s FROM DUAL) T2 ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s`) ) + // Build ON condition for index, keys := range duplicateKey { if index != 0 { duplicateKeyStr += " AND " } - duplicateTmp := fmt.Sprintf("T1.%s = T2.%s", keys, keys) - duplicateKeyStr += duplicateTmp + duplicateKeyStr += fmt.Sprintf("T1.%s = T2.%s", keys, keys) } - return fmt.Sprintf(pattern, - table, - queryHolderStr, - duplicateKeyStr, - insertKeyStr, - insertValueStr, - updateValueStr, + // Build SQL based on whether UPDATE is needed + pattern := gstr.Trim( + `MERGE INTO %s T1 USING (SELECT %s FROM DUAL) T2 ON (%s) WHEN ` + + `NOT MATCHED THEN INSERT(%s) VALUES (%s)`, ) + if len(updateValues) > 0 { + // Upsert: INSERT or UPDATE + pattern += gstr.Trim(` WHEN MATCHED THEN UPDATE SET %s`) + return fmt.Sprintf( + pattern, table, queryHolderStr, duplicateKeyStr, insertKeyStr, insertValueStr, + strings.Join(updateValues, ","), + ) + } + // Insert Ignore: INSERT only + return fmt.Sprintf(pattern, table, queryHolderStr, duplicateKeyStr, insertKeyStr, insertValueStr) } diff --git a/contrib/drivers/oracle/oracle_result.go b/contrib/drivers/oracle/oracle_result.go new file mode 100644 index 00000000000..a4795530bf2 --- /dev/null +++ b/contrib/drivers/oracle/oracle_result.go @@ -0,0 +1,24 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package oracle + +// Result implements sql.Result interface for Oracle database. +type Result struct { + lastInsertId int64 + rowsAffected int64 + lastInsertIdError error +} + +// LastInsertId returns the last insert id. +func (r *Result) LastInsertId() (int64, error) { + return r.lastInsertId, r.lastInsertIdError +} + +// RowsAffected returns the rows affected. +func (r *Result) RowsAffected() (int64, error) { + return r.rowsAffected, nil +} diff --git a/contrib/drivers/oracle/oracle_table_fields.go b/contrib/drivers/oracle/oracle_table_fields.go index aa20858dcd0..8efb9c1102b 100644 --- a/contrib/drivers/oracle/oracle_table_fields.go +++ b/contrib/drivers/oracle/oracle_table_fields.go @@ -18,13 +18,23 @@ import ( var ( tableFieldsSqlTmp = ` SELECT - COLUMN_NAME AS FIELD, + c.COLUMN_NAME AS FIELD, CASE - WHEN (DATA_TYPE='NUMBER' AND NVL(DATA_SCALE,0)=0) THEN 'INT'||'('||DATA_PRECISION||','||DATA_SCALE||')' - WHEN (DATA_TYPE='NUMBER' AND NVL(DATA_SCALE,0)>0) THEN 'FLOAT'||'('||DATA_PRECISION||','||DATA_SCALE||')' - WHEN DATA_TYPE='FLOAT' THEN DATA_TYPE||'('||DATA_PRECISION||','||DATA_SCALE||')' - ELSE DATA_TYPE||'('||DATA_LENGTH||')' END AS TYPE,NULLABLE -FROM USER_TAB_COLUMNS WHERE TABLE_NAME = '%s' ORDER BY COLUMN_ID + WHEN (c.DATA_TYPE='NUMBER' AND NVL(c.DATA_SCALE,0)=0) THEN 'INT'||'('||c.DATA_PRECISION||','||c.DATA_SCALE||')' + WHEN (c.DATA_TYPE='NUMBER' AND NVL(c.DATA_SCALE,0)>0) THEN 'FLOAT'||'('||c.DATA_PRECISION||','||c.DATA_SCALE||')' + WHEN c.DATA_TYPE='FLOAT' THEN c.DATA_TYPE||'('||c.DATA_PRECISION||','||c.DATA_SCALE||')' + ELSE c.DATA_TYPE||'('||c.DATA_LENGTH||')' END AS TYPE, + c.NULLABLE, + CASE WHEN pk.COLUMN_NAME IS NOT NULL THEN 'PRI' ELSE '' END AS KEY +FROM USER_TAB_COLUMNS c +LEFT JOIN ( + SELECT cols.COLUMN_NAME + FROM USER_CONSTRAINTS cons + JOIN USER_CONS_COLUMNS cols ON cons.CONSTRAINT_NAME = cols.CONSTRAINT_NAME + WHERE cons.TABLE_NAME = '%s' AND cons.CONSTRAINT_TYPE = 'P' +) pk ON c.COLUMN_NAME = pk.COLUMN_NAME +WHERE c.TABLE_NAME = '%s' +ORDER BY c.COLUMN_ID ` ) @@ -44,7 +54,8 @@ func (d *Driver) TableFields(ctx context.Context, table string, schema ...string result gdb.Result link gdb.Link usedSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) - structureSql = fmt.Sprintf(tableFieldsSqlTmp, strings.ToUpper(table)) + upperTable = strings.ToUpper(table) + structureSql = fmt.Sprintf(tableFieldsSqlTmp, upperTable, upperTable) ) if link, err = d.SlaveLink(usedSchema); err != nil { return nil, err @@ -53,6 +64,7 @@ func (d *Driver) TableFields(ctx context.Context, table string, schema ...string if err != nil { return nil, err } + fields = make(map[string]*gdb.TableField) for i, m := range result { isNull := false @@ -65,6 +77,7 @@ func (d *Driver) TableFields(ctx context.Context, table string, schema ...string Name: m["FIELD"].String(), Type: m["TYPE"].String(), Null: isNull, + Key: m["KEY"].String(), } } return fields, nil diff --git a/contrib/drivers/oracle/oracle_z_unit_basic_test.go b/contrib/drivers/oracle/oracle_z_unit_basic_test.go index 7a0af262455..24836455acf 100644 --- a/contrib/drivers/oracle/oracle_z_unit_basic_test.go +++ b/contrib/drivers/oracle/oracle_z_unit_basic_test.go @@ -139,10 +139,10 @@ func Test_Do_Insert(t *testing.T) { "CREATE_TIME": gtime.Now().String(), } _, err := db.Save(ctx, "t_user", data, 10) - gtest.AssertNE(err, nil) + gtest.AssertNil(err) _, err = db.Replace(ctx, "t_user", data, 10) - gtest.AssertNE(err, nil) + gtest.AssertNil(err) }) } @@ -185,6 +185,7 @@ func Test_DB_Insert(t *testing.T) { table := createTable() defer dropTable(table) + // db.SetDebug(true) gtest.C(t, func(t *gtest.T) { _, err := db.Insert(ctx, table, g.Map{ "ID": 1, @@ -233,7 +234,7 @@ func Test_DB_Insert(t *testing.T) { one, err := db.Model(table).Where("ID", 3).One() t.AssertNil(err) - fmt.Println(one) + // fmt.Println(one) t.Assert(one["ID"].Int(), 3) t.Assert(one["PASSPORT"].String(), "user_3") t.Assert(one["PASSWORD"].String(), "25d55ad283aa400af464c76d713c07ad") diff --git a/contrib/drivers/oracle/oracle_z_unit_init_test.go b/contrib/drivers/oracle/oracle_z_unit_init_test.go index 3e549d4b265..74bf9f26e8f 100644 --- a/contrib/drivers/oracle/oracle_z_unit_init_test.go +++ b/contrib/drivers/oracle/oracle_z_unit_init_test.go @@ -113,16 +113,48 @@ func createTable(table ...string) (name string) { dropTable(name) - if _, err := db.Exec(ctx, fmt.Sprintf(` - CREATE TABLE %s ( - ID NUMBER(10) NOT NULL, - PASSPORT VARCHAR(45) NOT NULL, - PASSWORD CHAR(32) NOT NULL, - NICKNAME VARCHAR(45) NOT NULL, - CREATE_TIME varchar(45), - SALARY NUMBER(18,2), - PRIMARY KEY (ID)) - `, name)); err != nil { + // Step 1: Create table + createTableSQL := fmt.Sprintf(` + CREATE TABLE %s ( + ID NUMBER(10) NOT NULL, + PASSPORT VARCHAR(45) NOT NULL, + PASSWORD CHAR(32) NOT NULL, + NICKNAME VARCHAR(45) NOT NULL, + CREATE_TIME VARCHAR(45), + SALARY NUMBER(18,2), + PRIMARY KEY (ID) + )`, name) + + if _, err := db.Exec(ctx, createTableSQL); err != nil { + gtest.Fatal(err) + } + + // Step 2: Create sequence + createSeqSQL := fmt.Sprintf(` + CREATE SEQUENCE %s_ID_SEQ + START WITH 1 + INCREMENT BY 1 + MINVALUE 1 + MAXVALUE 9999999999 + NOCYCLE + NOCACHE`, name) + + if _, err := db.Exec(ctx, createSeqSQL); err != nil { + gtest.Fatal(err) + } + + // Step 3: Create trigger - only set ID from sequence when it's NULL + createTriggerSQL := fmt.Sprintf(` +CREATE OR REPLACE TRIGGER %s_ID_TRG +BEFORE INSERT ON %s +FOR EACH ROW +BEGIN + IF :NEW.ID IS NULL THEN + :NEW.ID := %s_ID_SEQ.NEXTVAL; + END IF; +END;`, name, name, name) + + if _, err := db.Exec(ctx, createTriggerSQL); err != nil { gtest.Fatal(err) } @@ -160,7 +192,15 @@ func dropTable(table string) { if count == 0 { return } + + // Drop table if _, err = db.Exec(ctx, fmt.Sprintf("DROP TABLE %s", table)); err != nil { gtest.Fatal(err) } + + // Drop sequence if exists + seqCount, err := db.GetCount(ctx, "SELECT COUNT(*) FROM USER_SEQUENCES WHERE SEQUENCE_NAME = ?", strings.ToUpper(table+"_ID_SEQ")) + if err == nil && seqCount > 0 { + db.Exec(ctx, fmt.Sprintf("DROP SEQUENCE %s_ID_SEQ", table)) + } } diff --git a/contrib/drivers/oracle/oracle_z_unit_model_test.go b/contrib/drivers/oracle/oracle_z_unit_model_test.go index 26031615e03..185446b2442 100644 --- a/contrib/drivers/oracle/oracle_z_unit_model_test.go +++ b/contrib/drivers/oracle/oracle_z_unit_model_test.go @@ -233,6 +233,67 @@ func Test_Model_Insert(t *testing.T) { }) } +func Test_Model_InsertIgnore(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // db.SetDebug(true) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": fmt.Sprintf(`t%d`, 777), + "password": fmt.Sprintf(`p%d`, 777), + "nickname": fmt.Sprintf(`T%d`, 777), + "create_time": gtime.Now(), + } + _, err := db.Model(table).Data(data).InsertIgnore() + t.AssertNil(err) + + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "user_1") + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "passport": fmt.Sprintf(`t%d`, 777), + "password": fmt.Sprintf(`p%d`, 777), + "nickname": fmt.Sprintf(`T%d`, 777), + "create_time": gtime.Now(), + } + _, err := db.Model(table).Data(data).InsertIgnore() + t.AssertNE(err, nil) + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) +} + +func Test_Model_InsertAndGetId(t *testing.T) { + table := createTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + // "id": 1, + "passport": fmt.Sprintf(`t%d`, 1), + "password": fmt.Sprintf(`p%d`, 1), + "nickname": fmt.Sprintf(`T%d`, 1), + "create_time": gtime.Now(), + } + lastId, err := db.Model(table).Data(data).InsertAndGetId() + t.AssertNil(err) + t.AssertGT(lastId, 0) + }) + +} + // https://github.com/gogf/gf/issues/3286 func Test_Model_Insert_Raw(t *testing.T) { table := createTable() @@ -1179,14 +1240,73 @@ func Test_Model_Replace(t *testing.T) { defer dropTable(table) gtest.C(t, func(t *gtest.T) { - _, err := db.Model(table).Data(g.Map{ + // Insert initial record + result, err := db.Model(table).Data(g.Map{ + "id": 1, + "passport": "t1", + "password": "pass1", + "nickname": "T1", + "create_time": "2018-10-24 10:00:00", + }).Insert() + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + // Replace with new data (should update existing record using MERGE) + result, err = db.Model(table).Data(g.Map{ "id": 1, "passport": "t11", "password": "25d55ad283aa400af464c76d713c07ad", "nickname": "T11", "create_time": "2018-10-24 10:00:00", + }).OnConflict("id").Replace() + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) + + // Verify the data was replaced + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "t11") + t.Assert(one["PASSWORD"].String(), "25d55ad283aa400af464c76d713c07ad") + t.Assert(one["NICKNAME"].String(), "T11") + + // Replace with new ID (insert new record) + result, err = db.Model(table).Data(g.Map{ + "id": 2, + "passport": "t222", + "password": "pass2", + "nickname": "T222", + "create_time": "2018-10-24 11:00:00", + }).OnConflict("id").Replace() + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) + + // Verify new record was inserted + one, err = db.Model(table).Where("id", 2).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "t222") + t.Assert(one["NICKNAME"].String(), "T222") + + // Replace without OnConflict (primary key auto-detection is implemented) + _, err = db.Model(table).Data(g.Map{ + "id": 3, + "passport": "t3", + "password": "pass3", + "nickname": "T3", + "create_time": "2018-10-24 12:00:00", + }).Replace() + t.AssertNil(err) + + _, err = db.Model(table).Data(g.Map{ + // "id": 3, + "passport": "t3", + "password": "pass3", + "nickname": "T3", + "create_time": "2018-10-24 12:00:00", }).Replace() - t.Assert(err, "Replace operation is not supported by oracle driver") + t.AssertNE(err, nil) }) } diff --git a/contrib/drivers/pgsql/pgsql.go b/contrib/drivers/pgsql/pgsql.go index b7a18a810f2..3c3272240cb 100644 --- a/contrib/drivers/pgsql/pgsql.go +++ b/contrib/drivers/pgsql/pgsql.go @@ -5,10 +5,6 @@ // You can obtain one at https://github.com/gogf/gf. // Package pgsql implements gdb.Driver, which supports operations for database PostgreSQL. -// -// Note: -// 1. It does not support Replace features. -// 2. It does not support Insert Ignore features. package pgsql import ( diff --git a/contrib/drivers/pgsql/pgsql_do_insert.go b/contrib/drivers/pgsql/pgsql_do_insert.go index ec24edde302..6bd0ba14284 100644 --- a/contrib/drivers/pgsql/pgsql_do_insert.go +++ b/contrib/drivers/pgsql/pgsql_do_insert.go @@ -9,6 +9,7 @@ package pgsql import ( "context" "database/sql" + "strings" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/errors/gcode" @@ -16,25 +17,68 @@ import ( ) // DoInsert inserts or updates data for given table. -func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption) (result sql.Result, err error) { +// The list parameter must contain at least one record, which was previously validated. +func (d *Driver) DoInsert( + ctx context.Context, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, +) (result sql.Result, err error) { switch option.InsertOption { - case gdb.InsertOptionReplace: - return nil, gerror.NewCode( - gcode.CodeNotSupported, - `Replace operation is not supported by pgsql driver`, - ) + case + gdb.InsertOptionSave, + gdb.InsertOptionReplace: + // PostgreSQL does not support REPLACE INTO syntax, use Save (ON CONFLICT ... DO UPDATE) instead. + // Automatically detect primary keys if OnConflict is not specified. + if len(option.OnConflict) == 0 { + primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table) + if err != nil { + return nil, gerror.WrapCode( + gcode.CodeInternalError, + err, + `failed to get primary keys for Save/Replace operation`, + ) + } + foundPrimaryKey := false + for _, primaryKey := range primaryKeys { + for dataKey := range list[0] { + if strings.EqualFold(dataKey, primaryKey) { + foundPrimaryKey = true + break + } + } + if foundPrimaryKey { + break + } + } + if !foundPrimaryKey { + return nil, gerror.NewCodef( + gcode.CodeMissingParameter, + `Replace/Save operation requires conflict detection: `+ + `either specify OnConflict() columns or ensure table '%s' has a primary key in the data`, + table, + ) + } + // TODO consider composite primary keys. + option.OnConflict = primaryKeys + } + // Treat Replace as Save operation + option.InsertOption = gdb.InsertOptionSave - case gdb.InsertOptionDefault: + // pgsql support InsertIgnore natively, so no need to set primary key in context. + case gdb.InsertOptionIgnore, gdb.InsertOptionDefault: + // Get table fields to retrieve the primary key TableField object (not just the name) + // because DoExec needs the `TableField.Type` to determine if LastInsertId is supported. tableFields, err := d.GetCore().GetDB().TableFields(ctx, table) if err == nil { for _, field := range tableFields { - if field.Key == "pri" { + if strings.EqualFold(field.Key, "pri") { pkField := *field ctx = context.WithValue(ctx, internalPrimaryKeyInCtx, pkField) break } } } + + default: } return d.Core.DoInsert(ctx, link, table, list, option) } diff --git a/contrib/drivers/pgsql/pgsql_table_fields.go b/contrib/drivers/pgsql/pgsql_table_fields.go index 07f3a4e43ae..8573648f6f3 100644 --- a/contrib/drivers/pgsql/pgsql_table_fields.go +++ b/contrib/drivers/pgsql/pgsql_table_fields.go @@ -80,10 +80,22 @@ func (d *Driver) TableFields(ctx context.Context, table string, schema ...string } continue } + + var ( + fieldType string + dataType = m["type"].String() + dataLength = m["length"].Int() + ) + if dataLength > 0 { + fieldType = fmt.Sprintf("%s(%d)", dataType, dataLength) + } else { + fieldType = dataType + } + fields[name] = &gdb.TableField{ Index: index, Name: name, - Type: m["type"].String(), + Type: fieldType, Null: !m["null"].Bool(), Key: m["key"].String(), Default: m["default_value"].Val(), diff --git a/contrib/drivers/pgsql/pgsql_z_unit_db_test.go b/contrib/drivers/pgsql/pgsql_z_unit_db_test.go index 67b9d7978ad..a83cb38bf23 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_db_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_db_test.go @@ -90,7 +90,7 @@ func Test_DB_Save(t *testing.T) { "create_time": gtime.Now().String(), } _, err := db.Save(ctx, "t_user", data, 10) - gtest.AssertNE(err, nil) + gtest.AssertNil(err) }) } @@ -99,6 +99,7 @@ func Test_DB_Replace(t *testing.T) { createTable("t_user") defer dropTable("t_user") + // Insert initial record i := 10 data := g.Map{ "id": i, @@ -107,8 +108,26 @@ func Test_DB_Replace(t *testing.T) { "nickname": fmt.Sprintf(`T%d`, i), "create_time": gtime.Now().String(), } - _, err := db.Replace(ctx, "t_user", data, 10) - gtest.AssertNE(err, nil) + _, err := db.Insert(ctx, "t_user", data) + gtest.AssertNil(err) + + // Replace with new data + data2 := g.Map{ + "id": i, + "passport": fmt.Sprintf(`t%d_new`, i), + "password": fmt.Sprintf(`p%d_new`, i), + "nickname": fmt.Sprintf(`T%d_new`, i), + "create_time": gtime.Now().String(), + } + _, err = db.Replace(ctx, "t_user", data2) + gtest.AssertNil(err) + + // Verify the data was replaced + one, err := db.GetOne(ctx, fmt.Sprintf("SELECT * FROM t_user WHERE id=?"), i) + gtest.AssertNil(err) + gtest.Assert(one["passport"].String(), fmt.Sprintf(`t%d_new`, i)) + gtest.Assert(one["password"].String(), fmt.Sprintf(`p%d_new`, i)) + gtest.Assert(one["nickname"].String(), fmt.Sprintf(`T%d_new`, i)) }) } @@ -304,10 +323,10 @@ func Test_DB_TableFields(t *testing.T) { var expect = map[string][]any{ // []string: Index Type Null Key Default Comment // id is bigserial so the default is a pgsql function - "id": {0, "int8", false, "pri", fmt.Sprintf("nextval('%s_id_seq'::regclass)", table), ""}, - "passport": {1, "varchar", false, "", nil, ""}, - "password": {2, "varchar", false, "", nil, ""}, - "nickname": {3, "varchar", false, "", nil, ""}, + "id": {0, "int8(64)", false, "pri", fmt.Sprintf("nextval('%s_id_seq'::regclass)", table), ""}, + "passport": {1, "varchar(45)", false, "", nil, ""}, + "password": {2, "varchar(32)", false, "", nil, ""}, + "nickname": {3, "varchar(45)", false, "", nil, ""}, "create_time": {4, "timestamp", false, "", nil, ""}, } @@ -410,13 +429,13 @@ func Test_DB_TableFields_DuplicateConstraints(t *testing.T) { t.AssertNE(fields["id"], nil) t.Assert(fields["id"].Key, "pri") t.Assert(fields["id"].Name, "id") - t.Assert(fields["id"].Type, "int8") + t.Assert(fields["id"].Type, "int8(64)") // Verify email field has unique constraint t.AssertNE(fields["email"], nil) t.Assert(fields["email"].Key, "uni") t.Assert(fields["email"].Name, "email") - t.Assert(fields["email"].Type, "varchar") + t.Assert(fields["email"].Type, "varchar(100)") // Verify username field has no constraint t.AssertNE(fields["username"], nil) diff --git a/contrib/drivers/pgsql/pgsql_z_unit_field_test.go b/contrib/drivers/pgsql/pgsql_z_unit_field_test.go index 7c3df4ab038..22a8f25a28b 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_field_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_field_test.go @@ -73,18 +73,18 @@ func Test_TableFields_Types(t *testing.T) { t.AssertNil(err) // Test integer type names - t.Assert(fields["col_int2"].Type, "int2") - t.Assert(fields["col_int4"].Type, "int4") - t.Assert(fields["col_int8"].Type, "int8") + t.Assert(fields["col_int2"].Type, "int2(16)") + t.Assert(fields["col_int4"].Type, "int4(32)") + t.Assert(fields["col_int8"].Type, "int8(64)") // Test float type names - t.Assert(fields["col_float4"].Type, "float4") - t.Assert(fields["col_float8"].Type, "float8") - t.Assert(fields["col_numeric"].Type, "numeric") + t.Assert(fields["col_float4"].Type, "float4(24)") + t.Assert(fields["col_float8"].Type, "float8(53)") + t.Assert(fields["col_numeric"].Type, "numeric(10)") // Test character type names - t.Assert(fields["col_char"].Type, "bpchar") - t.Assert(fields["col_varchar"].Type, "varchar") + t.Assert(fields["col_char"].Type, "bpchar(10)") + t.Assert(fields["col_varchar"].Type, "varchar(100)") t.Assert(fields["col_text"].Type, "text") // Test boolean type name diff --git a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go index 56db22a85b1..4ca72989137 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go @@ -334,14 +334,53 @@ func Test_Model_Replace(t *testing.T) { defer dropTable(table) gtest.C(t, func(t *gtest.T) { - _, err := db.Model(table).Data(g.Map{ + // Insert initial record + result, err := db.Model(table).Data(g.Map{ + "id": 1, + "passport": "t1", + "password": "pass1", + "nickname": "T1", + "create_time": "2018-10-24 10:00:00", + }).Insert() + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + // Replace with new data + result, err = db.Model(table).Data(g.Map{ "id": 1, "passport": "t11", "password": "25d55ad283aa400af464c76d713c07ad", "nickname": "T11", "create_time": "2018-10-24 10:00:00", }).Replace() - t.Assert(err, "Replace operation is not supported by pgsql driver") + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) + + // Verify the data was replaced + one, err := db.Model(table).Where("id", 1).One() + t.AssertNil(err) + t.Assert(one["passport"].String(), "t11") + t.Assert(one["password"].String(), "25d55ad283aa400af464c76d713c07ad") + t.Assert(one["nickname"].String(), "T11") + + // Replace with new ID (insert new record) + result, err = db.Model(table).Data(g.Map{ + "id": 2, + "passport": "t22", + "password": "pass22", + "nickname": "T22", + "create_time": "2018-10-24 11:00:00", + }).Replace() + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) + + // Verify new record was inserted + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, 2) }) } @@ -757,3 +796,69 @@ func Test_ConvertSliceFloat64(t *testing.T) { }) } } + +func Test_Model_InsertIgnore(t *testing.T) { + table := createTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + user := db.Model(table) + result, err := user.Data(g.Map{ + "id": 1, + "uid": 1, + "passport": "t1", + "password": "25d55ad283aa400af464c76d713c07ad", + "nickname": "name_1", + "create_time": gtime.Now().String(), + }).Insert() + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + result, err = db.Model(table).Data(g.Map{ + "id": 1, + "uid": 1, + "passport": "t1", + "password": "25d55ad283aa400af464c76d713c07ad", + "nickname": "name_1", + "create_time": gtime.Now().String(), + }).Insert() + t.AssertNE(err, nil) + + result, err = db.Model(table).Data(g.Map{ + "id": 1, + "uid": 1, + "passport": "t2", + "password": "25d55ad283aa400af464c76d713c07ad", + "nickname": "name_2", + "create_time": gtime.Now().String(), + }).InsertIgnore() + t.AssertNil(err) + + n, _ = result.RowsAffected() + t.Assert(n, 0) + + value, err := db.Model(table).Fields("passport").WherePri(1).Value() + t.AssertNil(err) + t.Assert(value.String(), "t1") + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, 1) + + // pgsql support ignore without primary key + result, err = db.Model(table).Data(g.Map{ + // "id": 1, + "uid": 1, + "passport": "t2", + "password": "25d55ad283aa400af464c76d713c07ad", + "nickname": "name_2", + "create_time": gtime.Now().String(), + }).InsertIgnore() + t.AssertNil(err) + + count, err = db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, 1) + }) +} diff --git a/contrib/drivers/pgsql/pgsql_z_unit_upsert_test.go b/contrib/drivers/pgsql/pgsql_z_unit_upsert_test.go index a93017e97bd..edefab63225 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_upsert_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_upsert_test.go @@ -219,10 +219,10 @@ func Test_FormatUpsert_NoOnConflict(t *testing.T) { }).Insert() t.AssertNil(err) - // Try Save without OnConflict - should fail for pgsql - // PostgreSQL requires OnConflict() for Save() operations, unlike MySQL + // Try Save without OnConflict and without primary key in data - should fail + // because driver cannot auto-detect conflict columns when primary key is missing _, err = db.Model(table).Data(g.Map{ - "id": 1, + // "id": 1, "passport": "no_conflict_user", "password": "newpwd", "nickname": "newnick", diff --git a/contrib/drivers/sqlitecgo/sqlite_format_upsert.go b/contrib/drivers/sqlitecgo/sqlitecgo_format_upsert.go similarity index 100% rename from contrib/drivers/sqlitecgo/sqlite_format_upsert.go rename to contrib/drivers/sqlitecgo/sqlitecgo_format_upsert.go diff --git a/database/gdb/gdb_core_utility.go b/database/gdb/gdb_core_utility.go index 4872f13a14e..b97d7431e85 100644 --- a/database/gdb/gdb_core_utility.go +++ b/database/gdb/gdb_core_utility.go @@ -10,6 +10,7 @@ package gdb import ( "context" "fmt" + "strings" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" @@ -251,3 +252,22 @@ func (c *Core) guessPrimaryTableName(tableStr string) string { } return guessedTableName } + +// GetPrimaryKeys retrieves and returns the primary key field names of the specified table. +// This method extracts primary key information from TableFields. +// The parameter `schema` is optional, if not specified it uses the default schema. +func (c *Core) GetPrimaryKeys(ctx context.Context, table string, schema ...string) ([]string, error) { + tableFields, err := c.db.TableFields(ctx, table, schema...) + if err != nil { + return nil, err + } + + var primaryKeys []string + for _, field := range tableFields { + if strings.EqualFold(field.Key, "pri") { + primaryKeys = append(primaryKeys, field.Name) + } + } + + return primaryKeys, nil +} diff --git a/database/gdb/gdb_driver_wrapper_db.go b/database/gdb/gdb_driver_wrapper_db.go index 7dbc1d0cee5..81c5b729c40 100644 --- a/database/gdb/gdb_driver_wrapper_db.go +++ b/database/gdb/gdb_driver_wrapper_db.go @@ -109,7 +109,17 @@ func (d *DriverWrapperDB) TableFields( // InsertOptionReplace: if there's unique/primary key in the data, it deletes it from table and inserts a new one; // InsertOptionSave: if there's unique/primary key in the data, it updates it or else inserts a new one; // InsertOptionIgnore: if there's unique/primary key in the data, it ignores the inserting; -func (d *DriverWrapperDB) DoInsert(ctx context.Context, link Link, table string, list List, option DoInsertOption) (result sql.Result, err error) { +func (d *DriverWrapperDB) DoInsert( + ctx context.Context, link Link, table string, list List, option DoInsertOption, +) (result sql.Result, err error) { + if len(list) == 0 { + return nil, gerror.NewCodef( + gcode.CodeInvalidRequest, + `data list is empty for %s operation`, + GetInsertOperationByOption(option.InsertOption), + ) + } + // Convert data type before commit it to underlying db driver. for i, item := range list { list[i], err = d.GetCore().ConvertDataForRecord(ctx, item, table)