Skip to content

Commit b23839b

Browse files
committed
Enhance DECIMAL/Numeric support: update README, add precision/scale validation, and modify example to handle price as string
1 parent 5cfa141 commit b23839b

3 files changed

Lines changed: 103 additions & 14 deletions

File tree

README.md

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,36 @@ func main() {
9696
| `bool` | BIT |
9797
| `int8`, `int16`, `int32`, `int64` | TINYINT, SMALLINT, INTEGER, BIGINT |
9898
| `float32`, `float64` | REAL, DOUBLE |
99-
| `string` | CHAR, VARCHAR, TEXT |
99+
| `string` | CHAR, VARCHAR, TEXT, DECIMAL, NUMERIC |
100100
| `[]byte` | BINARY, VARBINARY, BLOB |
101101
| `time.Time` | DATE, TIME, TIMESTAMP |
102102

103+
## Decimal Precision
104+
105+
DECIMAL and NUMERIC columns are returned as `string` to preserve full precision (avoiding float64 rounding errors). Use the `DecimalSize()` method on column types to get precision and scale metadata:
106+
107+
```go
108+
rows, _ := db.Query("SELECT price FROM products")
109+
defer rows.Close()
110+
111+
cols, _ := rows.ColumnTypes()
112+
for _, col := range cols {
113+
if prec, scale, ok := col.DecimalSize(); ok {
114+
fmt.Printf("Column %s: DECIMAL(%d,%d)\n", col.Name(), prec, scale)
115+
}
116+
}
117+
```
118+
119+
For arbitrary-precision arithmetic, use a decimal library like [shopspring/decimal](https://github.com/shopspring/decimal):
120+
121+
```go
122+
import "github.com/shopspring/decimal"
123+
124+
var priceStr string
125+
rows.Scan(&priceStr)
126+
price, _ := decimal.NewFromString(priceStr)
127+
```
128+
103129
## Transactions
104130

105131
```go
@@ -155,10 +181,11 @@ go build ./examples/basic/
155181
### What It Tests
156182

157183
- Connection and ping
158-
- Table creation with various data types (INTEGER, VARCHAR, FLOAT, BOOLEAN/BIT, TIMESTAMP, BINARY)
184+
- Table creation with various data types (INTEGER, VARCHAR, FLOAT, BOOLEAN/BIT, TIMESTAMP, BINARY, DECIMAL)
159185
- Prepared statement parameter binding
160186
- Data insertion and retrieval
161187
- Value equality validation (verifies inserted values match retrieved values)
188+
- Decimal precision/scale metadata (validates `ColumnTypePrecisionScale` returns correct values)
162189
- Transaction rollback (inserts row, rolls back, verifies not persisted)
163190
- Transaction commit (inserts row, commits, verifies persisted)
164191
- Table cleanup

examples/basic/main.go

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"log"
1010
"os"
11+
"strconv"
1112
"strings"
1213
"time"
1314

@@ -94,6 +95,7 @@ func getDDLTemplates(dbType DBType, tableName string) DDLTemplates {
9495
active BIT,
9596
created_at DATETIME2,
9697
data VARBINARY(100),
98+
price DECIMAL(10,2),
9799
PRIMARY KEY (id)
98100
)`, tableName),
99101
DropTable: fmt.Sprintf("DROP TABLE %s", tableName),
@@ -109,6 +111,7 @@ func getDDLTemplates(dbType DBType, tableName string) DDLTemplates {
109111
active BOOLEAN,
110112
created_at TIMESTAMP,
111113
data BYTEA,
114+
price DECIMAL(10,2),
112115
PRIMARY KEY (id)
113116
)`, tableName),
114117
DropTable: fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName),
@@ -124,6 +127,7 @@ func getDDLTemplates(dbType DBType, tableName string) DDLTemplates {
124127
active TINYINT(1),
125128
created_at DATETIME(3),
126129
data VARBINARY(100),
130+
price DECIMAL(10,2),
127131
PRIMARY KEY (id)
128132
)`, tableName),
129133
DropTable: fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName),
@@ -139,6 +143,7 @@ func getDDLTemplates(dbType DBType, tableName string) DDLTemplates {
139143
active INTEGER,
140144
created_at TEXT,
141145
data BLOB,
146+
price DECIMAL(10,2),
142147
PRIMARY KEY (id)
143148
)`, tableName),
144149
DropTable: fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName),
@@ -154,6 +159,7 @@ func getDDLTemplates(dbType DBType, tableName string) DDLTemplates {
154159
active NUMBER(1),
155160
created_at TIMESTAMP,
156161
data RAW(100),
162+
price NUMBER(10,2),
157163
PRIMARY KEY (id)
158164
)`, tableName),
159165
DropTable: fmt.Sprintf("DROP TABLE %s", tableName),
@@ -170,6 +176,7 @@ func getDDLTemplates(dbType DBType, tableName string) DDLTemplates {
170176
active SMALLINT,
171177
created_at TIMESTAMP,
172178
data VARBINARY(100),
179+
price DECIMAL(10,2),
173180
PRIMARY KEY (id)
174181
)`, tableName),
175182
DropTable: fmt.Sprintf("DROP TABLE %s", tableName),
@@ -251,21 +258,22 @@ func runTest(db *sql.DB, _ DBType, tableName string, ddl DDLTemplates) error {
251258
active bool
252259
createdAt time.Time
253260
data []byte
261+
price string // DECIMAL as string to preserve precision
254262
}{
255-
{1, "Alice", 123.45, true, time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), []byte{0x01, 0x02, 0x03}},
256-
{2, "Bob", 678.90, false, time.Date(2024, 2, 20, 14, 45, 0, 0, time.UTC), []byte{0x04, 0x05, 0x06}},
257-
{3, "Charlie", 0.0, true, time.Date(2024, 3, 25, 9, 0, 0, 0, time.UTC), nil},
263+
{1, "Alice", 123.45, true, time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), []byte{0x01, 0x02, 0x03}, "12345.67"},
264+
{2, "Bob", 678.90, false, time.Date(2024, 2, 20, 14, 45, 0, 0, time.UTC), []byte{0x04, 0x05, 0x06}, "9999.99"},
265+
{3, "Charlie", 0.0, true, time.Date(2024, 3, 25, 9, 0, 0, 0, time.UTC), nil, "0.01"},
258266
}
259267

260-
insertSQL := fmt.Sprintf("INSERT INTO %s (id, name, value, active, created_at, data) VALUES (?, ?, ?, ?, ?, ?)", tableName)
268+
insertSQL := fmt.Sprintf("INSERT INTO %s (id, name, value, active, created_at, data, price) VALUES (?, ?, ?, ?, ?, ?, ?)", tableName)
261269
stmt, err := db.Prepare(insertSQL)
262270
if err != nil {
263271
return fmt.Errorf("failed to prepare insert: %w", err)
264272
}
265273
defer stmt.Close()
266274

267275
for _, row := range testRows {
268-
result, err := stmt.Exec(row.id, row.name, row.value, row.active, row.createdAt, row.data)
276+
result, err := stmt.Exec(row.id, row.name, row.value, row.active, row.createdAt, row.data, row.price)
269277
if err != nil {
270278
return fmt.Errorf("failed to insert row %d: %w", row.id, err)
271279
}
@@ -275,7 +283,7 @@ func runTest(db *sql.DB, _ DBType, tableName string, ddl DDLTemplates) error {
275283

276284
// Select and validate rows
277285
log.Println("Selecting rows back...")
278-
selectSQL := fmt.Sprintf("SELECT id, name, value, active, created_at, data FROM %s ORDER BY id", tableName)
286+
selectSQL := fmt.Sprintf("SELECT id, name, value, active, created_at, data, price FROM %s ORDER BY id", tableName)
279287
rows, err := db.Query(selectSQL)
280288
if err != nil {
281289
return fmt.Errorf("failed to query rows: %w", err)
@@ -289,6 +297,28 @@ func runTest(db *sql.DB, _ DBType, tableName string, ddl DDLTemplates) error {
289297
}
290298
log.Printf("Columns: %v", columns)
291299

300+
// Validate DECIMAL precision/scale metadata
301+
colTypes, err := rows.ColumnTypes()
302+
if err != nil {
303+
return fmt.Errorf("failed to get column types: %w", err)
304+
}
305+
for _, col := range colTypes {
306+
if col.Name() == "price" {
307+
prec, scale, ok := col.DecimalSize()
308+
if !ok {
309+
return fmt.Errorf("price column: DecimalSize() returned ok=false, expected ok=true")
310+
}
311+
log.Printf(" Column 'price': DECIMAL(%d,%d)", prec, scale)
312+
if prec != 10 {
313+
return fmt.Errorf("price column: expected precision=10, got %d", prec)
314+
}
315+
if scale != 2 {
316+
return fmt.Errorf("price column: expected scale=2, got %d", scale)
317+
}
318+
log.Println(" ✓ DECIMAL precision/scale metadata validated")
319+
}
320+
}
321+
292322
// Fetch and validate rows
293323
rowCount := 0
294324
for rows.Next() {
@@ -298,13 +328,14 @@ func runTest(db *sql.DB, _ DBType, tableName string, ddl DDLTemplates) error {
298328
var active sql.NullBool
299329
var createdAt sql.NullTime
300330
var data []byte
331+
var price sql.NullString // DECIMAL scanned as string to preserve precision
301332

302-
if err := rows.Scan(&id, &name, &value, &active, &createdAt, &data); err != nil {
333+
if err := rows.Scan(&id, &name, &value, &active, &createdAt, &data, &price); err != nil {
303334
return fmt.Errorf("failed to scan row: %w", err)
304335
}
305336

306-
log.Printf(" Row %d: id=%d, name=%v, value=%v, active=%v, created_at=%v, data=%v",
307-
rowCount+1, id, name.String, value.Float64, active.Bool, createdAt.Time, data)
337+
log.Printf(" Row %d: id=%d, name=%v, value=%v, active=%v, created_at=%v, data=%v, price=%v",
338+
rowCount+1, id, name.String, value.Float64, active.Bool, createdAt.Time, data, price.String)
308339

309340
// Validate against expected values
310341
expected := testRows[rowCount]
@@ -343,6 +374,13 @@ func runTest(db *sql.DB, _ DBType, tableName string, ddl DDLTemplates) error {
343374
}
344375
}
345376
}
377+
// Compare DECIMAL price - parse as float for comparison since string format may vary
378+
// (e.g., ".01" vs "0.01" depending on database)
379+
expectedPrice, _ := strconv.ParseFloat(expected.price, 64)
380+
gotPrice, _ := strconv.ParseFloat(price.String, 64)
381+
if expectedPrice != gotPrice {
382+
return fmt.Errorf("row %d: expected price=%q (%v), got %q (%v)", rowCount+1, expected.price, expectedPrice, price.String, gotPrice)
383+
}
346384

347385
log.Printf(" Row %d: ✓ all values match expected", rowCount+1)
348386
rowCount++

rows.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ type Rows struct {
1414
columns []string
1515
colTypes []SQLSMALLINT
1616
colSizes []SQLULEN
17+
decDigits []SQLSMALLINT // decimal digits (scale) for NUMERIC/DECIMAL types
1718
nullable []SQLSMALLINT
1819
closed bool
1920
closeStmt bool // Whether to close the statement when rows are closed
@@ -39,18 +40,20 @@ func newRows(stmt *Stmt, closeStmt bool) (*Rows, error) {
3940
columns := make([]string, numCols)
4041
colTypes := make([]SQLSMALLINT, numCols)
4142
colSizes := make([]SQLULEN, numCols)
43+
decDigits := make([]SQLSMALLINT, numCols)
4244
nullable := make([]SQLSMALLINT, numCols)
4345

4446
colName := make([]byte, 256)
4547
for i := SQLUSMALLINT(1); i <= SQLUSMALLINT(numCols); i++ {
46-
nameLen, dataType, colSize, _, nullableVal, ret := DescribeCol(stmt.stmt, i, colName)
48+
nameLen, dataType, colSize, decDigitsVal, nullableVal, ret := DescribeCol(stmt.stmt, i, colName)
4749
if !IsSuccess(ret) {
4850
return nil, NewError(SQL_HANDLE_STMT, SQLHANDLE(stmt.stmt))
4951
}
5052

5153
columns[i-1] = string(colName[:nameLen])
5254
colTypes[i-1] = dataType
5355
colSizes[i-1] = colSize
56+
decDigits[i-1] = decDigitsVal
5457
nullable[i-1] = nullableVal
5558
}
5659

@@ -59,6 +62,7 @@ func newRows(stmt *Stmt, closeStmt bool) (*Rows, error) {
5962
columns: columns,
6063
colTypes: colTypes,
6164
colSizes: colSizes,
65+
decDigits: decDigits,
6266
nullable: nullable,
6367
closeStmt: closeStmt,
6468
}, nil
@@ -423,8 +427,10 @@ func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
423427
return reflect.TypeOf(int64(0))
424428
case SQL_REAL:
425429
return reflect.TypeOf(float32(0))
426-
case SQL_FLOAT, SQL_DOUBLE, SQL_NUMERIC, SQL_DECIMAL:
430+
case SQL_FLOAT, SQL_DOUBLE:
427431
return reflect.TypeOf(float64(0))
432+
case SQL_NUMERIC, SQL_DECIMAL:
433+
return reflect.TypeOf("") // String preserves decimal precision
428434
case SQL_CHAR, SQL_VARCHAR, SQL_LONGVARCHAR, SQL_WCHAR, SQL_WVARCHAR, SQL_WLONGVARCHAR:
429435
return reflect.TypeOf("")
430436
case SQL_BINARY, SQL_VARBINARY, SQL_LONGVARBINARY:
@@ -523,6 +529,20 @@ func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) {
523529
}
524530
}
525531

532+
// ColumnTypePrecisionScale returns the precision and scale for NUMERIC/DECIMAL types
533+
func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
534+
if index < 0 || index >= len(r.colTypes) {
535+
return 0, 0, false
536+
}
537+
switch r.colTypes[index] {
538+
case SQL_NUMERIC, SQL_DECIMAL:
539+
// colSize = precision (total digits), decDigits = scale (digits after decimal)
540+
return int64(r.colSizes[index]), int64(r.decDigits[index]), true
541+
default:
542+
return 0, 0, false
543+
}
544+
}
545+
526546
// HasNextResultSet checks if there are more result sets
527547
func (r *Rows) HasNextResultSet() bool {
528548
return MoreResults(r.stmt.stmt) == SQL_SUCCESS
@@ -548,24 +568,27 @@ func (r *Rows) NextResultSet() error {
548568
columns := make([]string, numCols)
549569
colTypes := make([]SQLSMALLINT, numCols)
550570
colSizes := make([]SQLULEN, numCols)
571+
decDigits := make([]SQLSMALLINT, numCols)
551572
nullable := make([]SQLSMALLINT, numCols)
552573

553574
colName := make([]byte, 256)
554575
for i := SQLUSMALLINT(1); i <= SQLUSMALLINT(numCols); i++ {
555-
nameLen, dataType, colSize, _, nullableVal, ret := DescribeCol(r.stmt.stmt, i, colName)
576+
nameLen, dataType, colSize, decDigitsVal, nullableVal, ret := DescribeCol(r.stmt.stmt, i, colName)
556577
if !IsSuccess(ret) {
557578
return NewError(SQL_HANDLE_STMT, SQLHANDLE(r.stmt.stmt))
558579
}
559580

560581
columns[i-1] = string(colName[:nameLen])
561582
colTypes[i-1] = dataType
562583
colSizes[i-1] = colSize
584+
decDigits[i-1] = decDigitsVal
563585
nullable[i-1] = nullableVal
564586
}
565587

566588
r.columns = columns
567589
r.colTypes = colTypes
568590
r.colSizes = colSizes
591+
r.decDigits = decDigits
569592
r.nullable = nullable
570593

571594
return nil
@@ -578,5 +601,6 @@ var (
578601
_ driver.RowsColumnTypeDatabaseTypeName = (*Rows)(nil)
579602
_ driver.RowsColumnTypeLength = (*Rows)(nil)
580603
_ driver.RowsColumnTypeNullable = (*Rows)(nil)
604+
_ driver.RowsColumnTypePrecisionScale = (*Rows)(nil)
581605
_ driver.RowsNextResultSet = (*Rows)(nil)
582606
)

0 commit comments

Comments
 (0)