Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1029646: Add WithFileGetStream that supports downloading a file into stream #1192

Merged
merged 10 commits into from
Aug 27, 2024
Merged
34 changes: 24 additions & 10 deletions azure_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error)
UploadFile(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error)
DownloadFile(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error)
DownloadStream(ctx context.Context, o *blob.DownloadStreamOptions) (azblob.DownloadStreamResponse, error)
GetProperties(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
}

Expand Down Expand Up @@ -276,16 +277,29 @@
if meta.mockAzureClient != nil {
blobClient = meta.mockAzureClient
}
f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
if err != nil {
return err
}
defer f.Close()
_, err = blobClient.DownloadFile(
context.Background(), f, &azblob.DownloadFileOptions{
Concurrency: uint16(maxConcurrency)})
if err != nil {
return err
if meta.options.getFileToStream {
blobDownloadResponse, err := blobClient.DownloadStream(context.Background(), &azblob.DownloadStreamOptions{})
if err != nil {
return err
}

Check warning on line 284 in azure_storage_client.go

View check run for this annotation

Codecov / codecov/patch

azure_storage_client.go#L283-L284

Added lines #L283 - L284 were not covered by tests
retryReader := blobDownloadResponse.NewRetryReader(context.Background(), &azblob.RetryReaderOptions{})
defer retryReader.Close()
_, err = meta.dstStream.ReadFrom(retryReader)
if err != nil {
return err
}

Check warning on line 290 in azure_storage_client.go

View check run for this annotation

Codecov / codecov/patch

azure_storage_client.go#L289-L290

Added lines #L289 - L290 were not covered by tests
} else {
f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
if err != nil {
return err
}

Check warning on line 295 in azure_storage_client.go

View check run for this annotation

Codecov / codecov/patch

azure_storage_client.go#L294-L295

Added lines #L294 - L295 were not covered by tests
defer f.Close()
_, err = blobClient.DownloadFile(
context.Background(), f, &azblob.DownloadFileOptions{
Concurrency: uint16(maxConcurrency)})
if err != nil {
return err
}
}
meta.resStatus = downloaded
return nil
Expand Down
13 changes: 9 additions & 4 deletions azure_storage_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ func TestUnitDetectAzureTokenExpireError(t *testing.T) {
}

type azureObjectAPIMock struct {
UploadStreamFunc func(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error)
UploadFileFunc func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error)
DownloadFileFunc func(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error)
GetPropertiesFunc func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
UploadStreamFunc func(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error)
UploadFileFunc func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error)
DownloadFileFunc func(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error)
DownloadStreamFunc func(ctx context.Context, o *blob.DownloadStreamOptions) (azblob.DownloadStreamResponse, error)
GetPropertiesFunc func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
}

func (c *azureObjectAPIMock) UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) {
Expand All @@ -131,6 +132,10 @@ func (c *azureObjectAPIMock) DownloadFile(ctx context.Context, file *os.File, o
return c.DownloadFileFunc(ctx, file, o)
}

func (c *azureObjectAPIMock) DownloadStream(ctx context.Context, o *blob.DownloadStreamOptions) (azblob.DownloadStreamResponse, error) {
return c.DownloadStreamFunc(ctx, o)
}

func TestUploadFileWithAzureUploadFailedError(t *testing.T) {
info := execResponseStageInfo{
Location: "azblob/storage/users/456/",
Expand Down
28 changes: 24 additions & 4 deletions connection_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -88,10 +89,11 @@
isInternal bool) (
*execResponse, error) {
sfa := snowflakeFileTransferAgent{
sc: sc,
data: &data.Data,
command: query,
options: new(SnowflakeFileTransferOptions),
sc: sc,
data: &data.Data,
command: query,
options: new(SnowflakeFileTransferOptions),
streamBuffer: new(bytes.Buffer),
}
if fs := getFileStream(ctx); fs != nil {
sfa.sourceStream = fs
Expand All @@ -112,6 +114,11 @@
if err != nil {
return nil, err
}
if sfa.options.getFileToStream {
if err := writeFileStream(ctx, sfa.streamBuffer); err != nil {
return nil, err
}

Check warning on line 120 in connection_util.go

View check run for this annotation

Codecov / codecov/patch

connection_util.go#L119-L120

Added lines #L119 - L120 were not covered by tests
}
return data, nil
}

Expand All @@ -138,6 +145,19 @@
return o
}

func writeFileStream(ctx context.Context, streamBuf *bytes.Buffer) error {
s := ctx.Value(fileGetStream)
w, ok := s.(io.Writer)
if !ok {
return errors.New("expected an io.Writer")
}

Check warning on line 153 in connection_util.go

View check run for this annotation

Codecov / codecov/patch

connection_util.go#L152-L153

Added lines #L152 - L153 were not covered by tests
_, err := streamBuf.WriteTo(w)
if err != nil {
return err
}

Check warning on line 157 in connection_util.go

View check run for this annotation

Codecov / codecov/patch

connection_util.go#L156-L157

Added lines #L156 - L157 were not covered by tests
return nil
}

func (sc *snowflakeConn) populateSessionParameters(parameters []nameValueParameter) {
// other session parameters (not all)
logger.WithContext(sc.ctx).Infof("params: %#v", parameters)
Expand Down
13 changes: 13 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,19 @@ an absolute path rather than a relative path. For example:

db.Query("GET @~ file:///tmp/my_data_file auto_compress=false overwrite=false")

To download a file into an in-memory stream (rather than a file) use code similar to the code below.

var streamBuf bytes.Buffer
ctx := WithFileTransferOptions(context.Background(), &SnowflakeFileTransferOptions{getFileToStream: true})
ctx = WithFileGetStream(ctx, &streamBuf)

sql := "get @~/data1.txt.gz file:///tmp/testData"
dbt.mustExecContext(ctx, sql)
// streamBuf is now filled with the stream. Use bytes.NewReader(streamBuf.Bytes()) to read uncompressed stream or
// use gzip.NewReader(&streamBuf) for to read compressed stream.

Note: GET statements are not supported for multi-statement queries.

Specifying temporary directory for encryption and compression:

Putting and getting requires compression and/or encryption, which is done in the OS temporary directory.
Expand Down
72 changes: 51 additions & 21 deletions encrypt_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,46 +190,51 @@
return meta, tmpOutputFile.Name(), nil
}

func decryptFile(
func decryptFileKey(
metadata *encryptMetadata,
sfe *snowflakeFileEncryption,
filename string,
chunkSize int,
tmpDir string) (
string, error) {
if chunkSize == 0 {
chunkSize = aes.BlockSize * 4 * 1024
}
sfe *snowflakeFileEncryption) ([]byte, []byte, error) {
decodedKey, err := base64.StdEncoding.DecodeString(sfe.QueryStageMasterKey)
if err != nil {
return "", err
return nil, nil, err

Check warning on line 198 in encrypt_util.go

View check run for this annotation

Codecov / codecov/patch

encrypt_util.go#L198

Added line #L198 was not covered by tests
}
keyBytes, err := base64.StdEncoding.DecodeString(metadata.key) // encrypted file key
if err != nil {
return "", err
return nil, nil, err

Check warning on line 202 in encrypt_util.go

View check run for this annotation

Codecov / codecov/patch

encrypt_util.go#L202

Added line #L202 was not covered by tests
}
ivBytes, err := base64.StdEncoding.DecodeString(metadata.iv)
if err != nil {
return "", err
return nil, nil, err

Check warning on line 206 in encrypt_util.go

View check run for this annotation

Codecov / codecov/patch

encrypt_util.go#L206

Added line #L206 was not covered by tests
}

// decrypt file key
decryptedKey := make([]byte, len(keyBytes))
if err = decryptECB(decryptedKey, keyBytes, decodedKey); err != nil {
return "", err
return nil, nil, err

Check warning on line 212 in encrypt_util.go

View check run for this annotation

Codecov / codecov/patch

encrypt_util.go#L212

Added line #L212 was not covered by tests
}
decryptedKey, err = paddingTrim(decryptedKey)
if err != nil {
return "", err
return nil, nil, err

Check warning on line 216 in encrypt_util.go

View check run for this annotation

Codecov / codecov/patch

encrypt_util.go#L216

Added line #L216 was not covered by tests
}

// decrypt file
return decryptedKey, ivBytes, err
}

func initCBC(decryptedKey []byte, ivBytes []byte) (cipher.BlockMode, error) {
block, err := aes.NewCipher(decryptedKey)
if err != nil {
return "", err
return nil, err

Check warning on line 225 in encrypt_util.go

View check run for this annotation

Codecov / codecov/patch

encrypt_util.go#L225

Added line #L225 was not covered by tests
}
mode := cipher.NewCBCDecrypter(block, ivBytes)

return mode, err
}

func decryptFile(
metadata *encryptMetadata,
sfe *snowflakeFileEncryption,
filename string,
chunkSize int,
tmpDir string) (string, error) {
tmpOutputFile, err := os.CreateTemp(tmpDir, baseName(filename)+"#")
if err != nil {
return "", err
Expand All @@ -240,11 +245,37 @@
return "", err
}
defer infile.Close()
totalFileSize, err := decryptStream(metadata, sfe, chunkSize, infile, tmpOutputFile)
if err != nil {
return "", err
}

Check warning on line 251 in encrypt_util.go

View check run for this annotation

Codecov / codecov/patch

encrypt_util.go#L250-L251

Added lines #L250 - L251 were not covered by tests
tmpOutputFile.Truncate(int64(totalFileSize))
return tmpOutputFile.Name(), nil
}

func decryptStream(
metadata *encryptMetadata,
sfe *snowflakeFileEncryption,
chunkSize int,
src io.Reader,
out io.Writer) (int, error) {
if chunkSize == 0 {
chunkSize = aes.BlockSize * 4 * 1024
}
decryptedKey, ivBytes, err := decryptFileKey(metadata, sfe)
if err != nil {
return 0, err
}

Check warning on line 268 in encrypt_util.go

View check run for this annotation

Codecov / codecov/patch

encrypt_util.go#L267-L268

Added lines #L267 - L268 were not covered by tests
mode, err := initCBC(decryptedKey, ivBytes)
if err != nil {
return 0, err
}

Check warning on line 272 in encrypt_util.go

View check run for this annotation

Codecov / codecov/patch

encrypt_util.go#L271-L272

Added lines #L271 - L272 were not covered by tests

var totalFileSize int
var prevChunk []byte
for {
chunk := make([]byte, chunkSize)
n, err := infile.Read(chunk)
n, err := src.Read(chunk)
if n == 0 || err != nil {
break
} else if n%aes.BlockSize != 0 {
Expand All @@ -255,17 +286,16 @@
totalFileSize += n
chunk = chunk[:n]
mode.CryptBlocks(chunk, chunk)
tmpOutputFile.Write(chunk)
out.Write(chunk)
prevChunk = chunk
}
if err != nil {
return "", err
return 0, err

Check warning on line 293 in encrypt_util.go

View check run for this annotation

Codecov / codecov/patch

encrypt_util.go#L293

Added line #L293 was not covered by tests
}
if prevChunk != nil {
totalFileSize -= paddingOffset(prevChunk)
}
tmpOutputFile.Truncate(int64(totalFileSize))
return tmpOutputFile.Name(), nil
return totalFileSize, err
}

type materialDescriptor struct {
Expand Down
5 changes: 5 additions & 0 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ type SnowflakeFileTransferOptions struct {
/* streaming PUT */
compressSourceFromStream bool

/* streaming GET */
getFileToStream bool

/* PUT */
putCallback *snowflakeProgressPercentage
putAzureCallback *snowflakeProgressPercentage
Expand Down Expand Up @@ -124,6 +127,7 @@ type snowflakeFileTransferAgent struct {
useAccelerateEndpoint bool
presignedURLs []string
options *SnowflakeFileTransferOptions
streamBuffer *bytes.Buffer
}

func (sfa *snowflakeFileTransferAgent) execute() error {
Expand Down Expand Up @@ -411,6 +415,7 @@ func (sfa *snowflakeFileTransferAgent) initFileMetadata() error {
name: baseName(fileName),
srcFileName: fileName,
dstFileName: dstFileName,
dstStream: new(bytes.Buffer),
stageLocationType: sfa.stageLocationType,
stageInfo: sfa.stageInfo,
localLocation: sfa.localLocation,
Expand Down
3 changes: 3 additions & 0 deletions file_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ type fileMetadata struct {
srcStream *bytes.Buffer
realSrcStream *bytes.Buffer

/* streaming GET */
dstStream *bytes.Buffer

/* GCS */
presignedURL *url.URL
gcsFileHeaderDigest string
Expand Down
31 changes: 18 additions & 13 deletions gcs_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,24 @@
return meta.lastError
}

f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
if err != nil {
return err
}
defer f.Close()
if _, err = io.Copy(f, resp.Body); err != nil {
return err
if meta.options.getFileToStream {
if _, err := io.Copy(meta.dstStream, resp.Body); err != nil {
return err
}

Check warning on line 328 in gcs_storage_client.go

View check run for this annotation

Codecov / codecov/patch

gcs_storage_client.go#L327-L328

Added lines #L327 - L328 were not covered by tests
} else {
f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
if err != nil {
return err
}

Check warning on line 333 in gcs_storage_client.go

View check run for this annotation

Codecov / codecov/patch

gcs_storage_client.go#L332-L333

Added lines #L332 - L333 were not covered by tests
defer f.Close()
if _, err = io.Copy(f, resp.Body); err != nil {
return err
}

Check warning on line 337 in gcs_storage_client.go

View check run for this annotation

Codecov / codecov/patch

gcs_storage_client.go#L336-L337

Added lines #L336 - L337 were not covered by tests
fi, err := os.Stat(fullDstFileName)
if err != nil {
return err
}

Check warning on line 341 in gcs_storage_client.go

View check run for this annotation

Codecov / codecov/patch

gcs_storage_client.go#L340-L341

Added lines #L340 - L341 were not covered by tests
meta.srcFileSize = fi.Size()
}

var encryptMeta encryptMetadata
Expand All @@ -348,12 +359,6 @@
}
}
}

fi, err := os.Stat(fullDstFileName)
if err != nil {
return err
}
meta.srcFileSize = fi.Size()
meta.resStatus = downloaded
meta.gcsFileHeaderDigest = resp.Header.Get(gcsMetadataSfcDigest)
meta.gcsFileHeaderContentLength = resp.ContentLength
Expand Down
Loading
Loading