Skip to content

Commit 24222a3

Browse files
authored
Merge pull request #985 from philipglazman/copy-return-command-tag
Return rows copied for COPY command
2 parents bd4ec23 + ad163b1 commit 24222a3

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

copy.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ type copyin struct {
4949
buffer []byte
5050
rowData chan []byte
5151
done chan bool
52+
driver.Result
5253

5354
closed bool
5455

@@ -151,6 +152,8 @@ func (ci *copyin) resploop() {
151152
switch t {
152153
case 'C':
153154
// complete
155+
res, _ := ci.cn.parseComplete(r.string())
156+
ci.setResult(res)
154157
case 'N':
155158
if n := ci.cn.noticeHandler; n != nil {
156159
n(parseError(&r))
@@ -201,6 +204,22 @@ func (ci *copyin) setError(err error) {
201204
ci.Unlock()
202205
}
203206

207+
func (ci *copyin) setResult(result driver.Result) {
208+
ci.Lock()
209+
ci.Result = result
210+
ci.Unlock()
211+
}
212+
213+
func (ci *copyin) getResult() driver.Result {
214+
ci.Lock()
215+
result := ci.Result
216+
if result == nil {
217+
return driver.RowsAffected(0)
218+
}
219+
ci.Unlock()
220+
return result
221+
}
222+
204223
func (ci *copyin) NumInput() int {
205224
return -1
206225
}
@@ -231,7 +250,11 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
231250
}
232251

233252
if len(v) == 0 {
234-
return driver.RowsAffected(0), ci.Close()
253+
if err := ci.Close(); err != nil {
254+
return driver.RowsAffected(0), err
255+
}
256+
257+
return ci.getResult(), nil
235258
}
236259

237260
numValues := len(v)

copy_test.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,20 @@ func TestCopyInMultipleValues(t *testing.T) {
7373
}
7474
}
7575

76-
_, err = stmt.Exec()
76+
result, err := stmt.Exec()
7777
if err != nil {
7878
t.Fatal(err)
7979
}
8080

81+
rowsAffected, err := result.RowsAffected()
82+
if err != nil {
83+
t.Fatal(err)
84+
}
85+
86+
if rowsAffected != 500 {
87+
t.Fatalf("expected 500 rows affected, not %d", rowsAffected)
88+
}
89+
8190
err = stmt.Close()
8291
if err != nil {
8392
t.Fatal(err)

0 commit comments

Comments
 (0)