Skip to content

Commit 1dc804a

Browse files
[storage][azblob]DownloadFile: Download large files serially (#21259)
* Move Buffer manager to shared * Serialize downloading to a file * Move DownloadFile to new func * Add special handling for small files * Fix build * Fix build * Lint error * Fix tab spaces * Fix lint again :( * Address comments * Update comment * Doc comment for default concurrency * Fix formatting * Fix formatting * Fix file read performUploadAndDownloadFileTest() method Seek to zero in performUploadAndDownloadFileTest before reading. * Fix testcase TestBasicDoBatchTransfer * Fix testcase
1 parent 9d6efae commit 1dc804a

File tree

10 files changed

+288
-100
lines changed

10 files changed

+288
-100
lines changed

sdk/storage/azblob/blob/client.go

Lines changed: 169 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ package blob
88

99
import (
1010
"context"
11-
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
12-
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
1311
"io"
1412
"os"
1513
"sync"
1614
"time"
1715

1816
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
17+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
1918
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
2019
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
20+
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
2121
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/base"
2222
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/exported"
2323
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/generated"
@@ -324,8 +324,8 @@ func (b *Client) GetSASURL(permissions sas.BlobPermissions, expiry time.Time, o
324324

325325
// Concurrent Download Functions -----------------------------------------------------------------------------------------
326326

327-
// download downloads an Azure blob to a WriterAt in parallel.
328-
func (b *Client) download(ctx context.Context, writer io.WriterAt, o downloadOptions) (int64, error) {
327+
// downloadBuffer downloads an Azure blob to a WriterAt in parallel.
328+
func (b *Client) downloadBuffer(ctx context.Context, writer io.WriterAt, o downloadOptions) (int64, error) {
329329
if o.BlockSize == 0 {
330330
o.BlockSize = DefaultDownloadBlockSize
331331
}
@@ -353,6 +353,7 @@ func (b *Client) download(ctx context.Context, writer io.WriterAt, o downloadOpt
353353
OperationName: "downloadBlobToWriterAt",
354354
TransferSize: count,
355355
ChunkSize: o.BlockSize,
356+
NumChunks: uint16(((count - 1) / o.BlockSize) + 1),
356357
Concurrency: o.Concurrency,
357358
Operation: func(ctx context.Context, chunkStart int64, count int64) error {
358359
downloadBlobOptions := o.getDownloadBlobOptions(HTTPRange{
@@ -391,6 +392,168 @@ func (b *Client) download(ctx context.Context, writer io.WriterAt, o downloadOpt
391392
return count, nil
392393
}
393394

395+
// downloadFile downloads an Azure blob to a Writer. The blocks are downloaded parallely,
396+
// but written to file serially
397+
func (b *Client) downloadFile(ctx context.Context, writer io.Writer, o downloadOptions) (int64, error) {
398+
ctx, cancel := context.WithCancel(ctx)
399+
defer cancel()
400+
if o.BlockSize == 0 {
401+
o.BlockSize = DefaultDownloadBlockSize
402+
}
403+
404+
if o.Concurrency == 0 {
405+
o.Concurrency = DefaultConcurrency
406+
}
407+
408+
count := o.Range.Count
409+
if count == CountToEnd { //Calculate size if not specified
410+
gr, err := b.GetProperties(ctx, o.getBlobPropertiesOptions())
411+
if err != nil {
412+
return 0, err
413+
}
414+
count = *gr.ContentLength - o.Range.Offset
415+
}
416+
417+
if count <= 0 {
418+
// The file is empty, there is nothing to download.
419+
return 0, nil
420+
}
421+
422+
progress := int64(0)
423+
progressLock := &sync.Mutex{}
424+
425+
// helper routine to get body
426+
getBodyForRange := func(ctx context.Context, chunkStart, size int64) (io.ReadCloser, error) {
427+
downloadBlobOptions := o.getDownloadBlobOptions(HTTPRange{
428+
Offset: chunkStart + o.Range.Offset,
429+
Count: size,
430+
}, nil)
431+
dr, err := b.DownloadStream(ctx, downloadBlobOptions)
432+
if err != nil {
433+
return nil, err
434+
}
435+
436+
var body io.ReadCloser = dr.NewRetryReader(ctx, &o.RetryReaderOptionsPerBlock)
437+
if o.Progress != nil {
438+
rangeProgress := int64(0)
439+
body = streaming.NewResponseProgress(
440+
body,
441+
func(bytesTransferred int64) {
442+
diff := bytesTransferred - rangeProgress
443+
rangeProgress = bytesTransferred
444+
progressLock.Lock()
445+
progress += diff
446+
o.Progress(progress)
447+
progressLock.Unlock()
448+
})
449+
}
450+
451+
return body, nil
452+
}
453+
454+
// if file fits in a single buffer, we'll download here.
455+
if count <= o.BlockSize {
456+
body, err := getBodyForRange(ctx, int64(0), count)
457+
if err != nil {
458+
return 0, err
459+
}
460+
defer body.Close()
461+
462+
return io.Copy(writer, body)
463+
}
464+
465+
buffers := shared.NewMMBPool(int(o.Concurrency), o.BlockSize)
466+
defer buffers.Free()
467+
aquireBuffer := func() ([]byte, error) {
468+
select {
469+
case b := <-buffers.Acquire():
470+
// got a buffer
471+
return b, nil
472+
default:
473+
// no buffer available; allocate a new buffer if possible
474+
if _, err := buffers.Grow(); err != nil {
475+
return nil, err
476+
}
477+
478+
// either grab the newly allocated buffer or wait for one to become available
479+
return <-buffers.Acquire(), nil
480+
}
481+
}
482+
483+
numChunks := uint16((count-1)/o.BlockSize) + 1
484+
blocks := make([]chan []byte, numChunks)
485+
for b := range blocks {
486+
blocks[b] = make(chan []byte)
487+
}
488+
489+
/*
490+
* We have created as many channels as the number of chunks we have.
491+
* Each downloaded block will be sent to the channel matching its
492+
* sequece number, i.e. 0th block is sent to 0th channel, 1st block
493+
* to 1st channel and likewise. The blocks are then read and written
494+
* to the file serially by below goroutine. Do note that the blocks
495+
* blocks are still downloaded parallelly from n/w, only serailized
496+
* and written to file here.
497+
*/
498+
writerError := make(chan error)
499+
go func(ch chan error) {
500+
for _, block := range blocks {
501+
select {
502+
case <-ctx.Done():
503+
return
504+
case block := <-block:
505+
_, err := writer.Write(block)
506+
buffers.Release(block)
507+
if err != nil {
508+
ch <- err
509+
return
510+
}
511+
}
512+
}
513+
ch <- nil
514+
}(writerError)
515+
516+
// Prepare and do parallel download.
517+
err := shared.DoBatchTransfer(ctx, &shared.BatchTransferOptions{
518+
OperationName: "downloadBlobToWriterAt",
519+
TransferSize: count,
520+
ChunkSize: o.BlockSize,
521+
NumChunks: numChunks,
522+
Concurrency: o.Concurrency,
523+
Operation: func(ctx context.Context, chunkStart int64, count int64) error {
524+
buff, err := aquireBuffer()
525+
if err != nil {
526+
return err
527+
}
528+
529+
body, err := getBodyForRange(ctx, chunkStart, count)
530+
if err != nil {
531+
buffers.Release(buff)
532+
return nil
533+
}
534+
535+
_, err = io.ReadFull(body, buff[:count])
536+
body.Close()
537+
if err != nil {
538+
return err
539+
}
540+
541+
blockIndex := (chunkStart / o.BlockSize)
542+
blocks[blockIndex] <- buff
543+
return nil
544+
},
545+
})
546+
547+
if err != nil {
548+
return 0, err
549+
}
550+
// error from writer thread.
551+
if err = <-writerError; err != nil {
552+
return 0, err
553+
}
554+
return count, nil
555+
}
556+
394557
// DownloadStream reads a range of bytes from a blob. The response also includes the blob's properties and metadata.
395558
// For more information, see https://docs.microsoft.com/rest/api/storageservices/get-blob.
396559
func (b *Client) DownloadStream(ctx context.Context, o *DownloadStreamOptions) (DownloadStreamResponse, error) {
@@ -419,7 +582,7 @@ func (b *Client) DownloadBuffer(ctx context.Context, buffer []byte, o *DownloadB
419582
if o == nil {
420583
o = &DownloadBufferOptions{}
421584
}
422-
return b.download(ctx, shared.NewBytesWriter(buffer), (downloadOptions)(*o))
585+
return b.downloadBuffer(ctx, shared.NewBytesWriter(buffer), (downloadOptions)(*o))
423586
}
424587

425588
// DownloadFile downloads an Azure blob to a local file.
@@ -458,7 +621,7 @@ func (b *Client) DownloadFile(ctx context.Context, file *os.File, o *DownloadFil
458621
}
459622

460623
if size > 0 {
461-
return b.download(ctx, file, *do)
624+
return b.downloadFile(ctx, file, *do)
462625
} else { // if the blob's size is 0, there is no need in downloading it
463626
return 0, nil
464627
}

sdk/storage/azblob/blob/constants.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package blob
99
import (
1010
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/exported"
1111
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/generated"
12+
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/shared"
1213
)
1314

1415
const (
@@ -18,6 +19,9 @@ const (
1819

1920
// DefaultDownloadBlockSize is default block size
2021
DefaultDownloadBlockSize = int64(4 * 1024 * 1024) // 4MB
22+
23+
// DefaultConcurrency is the default number of blocks downloaded or uploaded in parallel
24+
DefaultConcurrency = shared.DefaultConcurrency
2125
)
2226

2327
// BlobType defines values for BlobType

sdk/storage/azblob/blockblob/chunkwriting.go

Lines changed: 2 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818

1919
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
2020
"github.com/Azure/azure-sdk-for-go/sdk/internal/uuid"
21+
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/shared"
2122
)
2223

2324
// blockWriter provides methods to upload blocks that represent a file to a server and commit them.
@@ -28,27 +29,8 @@ type blockWriter interface {
2829
CommitBlockList(context.Context, []string, *CommitBlockListOptions) (CommitBlockListResponse, error)
2930
}
3031

31-
// bufferManager provides an abstraction for the management of buffers.
32-
// this is mostly for testing purposes, but does allow for different implementations without changing the algorithm.
33-
type bufferManager[T ~[]byte] interface {
34-
// Acquire returns the channel that contains the pool of buffers.
35-
Acquire() <-chan T
36-
37-
// Release releases the buffer back to the pool for reuse/cleanup.
38-
Release(T)
39-
40-
// Grow grows the number of buffers, up to the predefined max.
41-
// It returns the total number of buffers or an error.
42-
// No error is returned if the number of buffers has reached max.
43-
// This is called only from the reading goroutine.
44-
Grow() (int, error)
45-
46-
// Free cleans up all buffers.
47-
Free()
48-
}
49-
5032
// copyFromReader copies a source io.Reader to blob storage using concurrent uploads.
51-
func copyFromReader[T ~[]byte](ctx context.Context, src io.Reader, dst blockWriter, options UploadStreamOptions, getBufferManager func(maxBuffers int, bufferSize int64) bufferManager[T]) (CommitBlockListResponse, error) {
33+
func copyFromReader[T ~[]byte](ctx context.Context, src io.Reader, dst blockWriter, options UploadStreamOptions, getBufferManager func(maxBuffers int, bufferSize int64) shared.BufferManager[T]) (CommitBlockListResponse, error) {
5234
options.setDefaults()
5335

5436
wg := sync.WaitGroup{} // Used to know when all outgoing blocks have finished processing
@@ -265,49 +247,3 @@ func (ubi uuidBlockID) WithBlockNumber(blockNumber uint32) uuidBlockID {
265247
func (ubi uuidBlockID) ToBase64() string {
266248
return blockID(ubi).ToBase64()
267249
}
268-
269-
// mmbPool implements the bufferManager interface.
270-
// it uses anonymous memory mapped files for buffers.
271-
// don't use this type directly, use newMMBPool() instead.
272-
type mmbPool struct {
273-
buffers chan mmb
274-
count int
275-
max int
276-
size int64
277-
}
278-
279-
func newMMBPool(maxBuffers int, bufferSize int64) bufferManager[mmb] {
280-
return &mmbPool{
281-
buffers: make(chan mmb, maxBuffers),
282-
max: maxBuffers,
283-
size: bufferSize,
284-
}
285-
}
286-
287-
func (pool *mmbPool) Acquire() <-chan mmb {
288-
return pool.buffers
289-
}
290-
291-
func (pool *mmbPool) Grow() (int, error) {
292-
if pool.count < pool.max {
293-
buffer, err := newMMB(pool.size)
294-
if err != nil {
295-
return 0, err
296-
}
297-
pool.buffers <- buffer
298-
pool.count++
299-
}
300-
return pool.count, nil
301-
}
302-
303-
func (pool *mmbPool) Release(buffer mmb) {
304-
pool.buffers <- buffer
305-
}
306-
307-
func (pool *mmbPool) Free() {
308-
for i := 0; i < pool.count; i++ {
309-
buffer := <-pool.buffers
310-
buffer.delete()
311-
}
312-
pool.count = 0
313-
}

sdk/storage/azblob/blockblob/chunkwriting_test.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"testing"
1717
"time"
1818

19+
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/shared"
1920
"github.com/stretchr/testify/assert"
2021
"github.com/stretchr/testify/require"
2122
)
@@ -115,19 +116,19 @@ func calcMD5(data []byte) string {
115116

116117
// used to track proper acquisition and closing of buffers
117118
type bufMgrTracker struct {
118-
inner bufferManager[mmb]
119+
inner shared.BufferManager[shared.Mmb]
119120

120121
Count int // total count of allocated buffers
121122
Freed bool // buffers were freed
122123
}
123124

124125
func newBufMgrTracker(maxBuffers int, bufferSize int64) *bufMgrTracker {
125126
return &bufMgrTracker{
126-
inner: newMMBPool(maxBuffers, bufferSize),
127+
inner: shared.NewMMBPool(maxBuffers, bufferSize),
127128
}
128129
}
129130

130-
func (pool *bufMgrTracker) Acquire() <-chan mmb {
131+
func (pool *bufMgrTracker) Acquire() <-chan shared.Mmb {
131132
return pool.inner.Acquire()
132133
}
133134

@@ -140,7 +141,7 @@ func (pool *bufMgrTracker) Grow() (int, error) {
140141
return n, nil
141142
}
142143

143-
func (pool *bufMgrTracker) Release(buffer mmb) {
144+
func (pool *bufMgrTracker) Release(buffer shared.Mmb) {
144145
pool.inner.Release(buffer)
145146
}
146147

@@ -161,7 +162,7 @@ func TestSlowDestCopyFrom(t *testing.T) {
161162

162163
errs := make(chan error, 1)
163164
go func() {
164-
_, err := copyFromReader(context.Background(), bytes.NewReader(bigSrc), fakeBB, UploadStreamOptions{}, func(maxBuffers int, bufferSize int64) bufferManager[mmb] {
165+
_, err := copyFromReader(context.Background(), bytes.NewReader(bigSrc), fakeBB, UploadStreamOptions{}, func(maxBuffers int, bufferSize int64) shared.BufferManager[shared.Mmb] {
165166
tracker = newBufMgrTracker(maxBuffers, bufferSize)
166167
return tracker
167168
})
@@ -270,7 +271,7 @@ func TestCopyFromReader(t *testing.T) {
270271

271272
var tracker *bufMgrTracker
272273

273-
_, err := copyFromReader(test.ctx, bytes.NewReader(from), fakeBB, test.o, func(maxBuffers int, bufferSize int64) bufferManager[mmb] {
274+
_, err := copyFromReader(test.ctx, bytes.NewReader(from), fakeBB, test.o, func(maxBuffers int, bufferSize int64) shared.BufferManager[shared.Mmb] {
274275
tracker = newBufMgrTracker(maxBuffers, bufferSize)
275276
return tracker
276277
})
@@ -322,7 +323,7 @@ func TestCopyFromReaderReadError(t *testing.T) {
322323
reader: bytes.NewReader(make([]byte, 5*_1MiB)),
323324
failOn: 2,
324325
}
325-
_, err := copyFromReader(context.Background(), &rf, fakeBB, UploadStreamOptions{}, func(maxBuffers int, bufferSize int64) bufferManager[mmb] {
326+
_, err := copyFromReader(context.Background(), &rf, fakeBB, UploadStreamOptions{}, func(maxBuffers int, bufferSize int64) shared.BufferManager[shared.Mmb] {
326327
tracker = newBufMgrTracker(maxBuffers, bufferSize)
327328
return tracker
328329
})

0 commit comments

Comments
 (0)