diff --git a/aws/s3/s3.go b/aws/s3/s3.go index e87669f..ca10fe0 100644 --- a/aws/s3/s3.go +++ b/aws/s3/s3.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "errors" - "fmt" "io" "os" "sync" @@ -509,6 +508,18 @@ func (s *S3) ListAllObjectsConcurrently(bucket string, prefixes []string) ([]typ // PutStream puts the data stream to key in bucket. func (s *S3) PutStream(bucket, key string, reader io.ReadCloser) error { + return s.putStream(context.TODO(), bucket, key, reader) +} + +// PutStreamWithContext is the same as PutStream but uses +// the provided context. +func (s *S3) PutStreamWithContext(ctx context.Context, bucket, key string, reader io.ReadCloser) error { + return s.putStream(ctx, bucket, key, reader) +} + +// putStream is the common code used internally to upload a data stream to +// an S3 bucket using the client's uploader. +func (s *S3) putStream(ctx context.Context, bucket, key string, reader io.ReadCloser) error { defer reader.Close() if s.uploader == nil { @@ -519,9 +530,9 @@ func (s *S3) PutStream(bucket, key string, reader io.ReadCloser) error { Key: aws.String(key), Body: reader, } - _, err := s.uploader.Upload(context.TODO(), &input) + _, err := s.uploader.Upload(ctx, &input) if err != nil { - return fmt.Errorf("error uploading to s3 for key %s, error: %s", key, err.Error()) + return err } return nil } @@ -530,6 +541,26 @@ func (s *S3) PutStream(bucket, key string, reader io.ReadCloser) error { // File is split up into parts and downloaded concurrently into an os.File, // so is useful for getting large files. Returns number of bytes downloaded. func (s *S3) Download(bucket, key string, f *os.File) (int64, error) { + return s.download(context.TODO(), bucket, key, f) +} + +// DownloadWithContext is the same as Download but uses +// the provided context. +func (s *S3) DownloadWithContext( + ctx context.Context, + bucket, key string, + f *os.File, +) (int64, error) { + return s.download(ctx, bucket, key, f) +} + +// download is the common code used internally to download an S3 object +// using the downloader based on the provided input. +func (s *S3) download( + ctx context.Context, + bucket, key string, + f *os.File, +) (int64, error) { if s.downloader == nil { return 0, errors.New("error downloading from S3, downloader not initialised") } @@ -537,7 +568,7 @@ func (s *S3) Download(bucket, key string, f *os.File) (int64, error) { Bucket: aws.String(bucket), Key: aws.String(key), } - numBytes, err := s.downloader.Download(context.TODO(), f, &input) + numBytes, err := s.downloader.Download(ctx, f, &input) if err != nil { return 0, err } diff --git a/aws/s3/s3_integration_test.go b/aws/s3/s3_integration_test.go index 884d17e..b285404 100644 --- a/aws/s3/s3_integration_test.go +++ b/aws/s3/s3_integration_test.go @@ -2,8 +2,10 @@ package s3 import ( "bytes" + "context" "encoding/json" "fmt" + "io" "os" "os/exec" "path/filepath" @@ -456,6 +458,94 @@ func TestS3GetContentSizeTime(t *testing.T) { assert.Equal(t, meta.lastModified, lastModified) } +func TestS3Download(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + awsCmdPopulateBucket() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating s3 client: %v", err)) + + err = client.AddDownloader() + require.Nil(t, err, fmt.Sprintf("error adding s3 downloader: %v", err)) + + dir := t.TempDir() + dataFile, err := os.CreateTemp(dir, "test") + require.Nil(t, err, fmt.Sprintf("error creating file: %v", err)) + + getSize := func(*os.File) int64 { + info, err := dataFile.Stat() + require.Nil(t, err, fmt.Sprintf("error getting file info: %v", err)) + return info.Size() + } + require.Equal(t, int64(0), getSize(dataFile)) + + // ACTION + bytesDownloaded, err := client.Download(testBucket, testObjectKey, dataFile) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, getSize(dataFile), bytesDownloaded) + assert.Equal(t, getSize(dataFile), int64(len(testObjectData))) + + buf := make([]byte, len(testObjectData)) + _, err = dataFile.Read(buf) + require.Nil(t, err, fmt.Sprintf("error reading file: %v", err)) + assert.Equal(t, testObjectData, string(buf)) +} + +func TestS3DownloadWithContext(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + awsCmdPopulateBucket() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating s3 client: %v", err)) + + err = client.AddDownloader() + require.Nil(t, err, fmt.Sprintf("error adding s3 downloader: %v", err)) + + dir := t.TempDir() + dataFile, err := os.CreateTemp(dir, "test") + require.Nil(t, err, fmt.Sprintf("error creating file: %v", err)) + + getSize := func(*os.File) int64 { + info, err := dataFile.Stat() + require.Nil(t, err, fmt.Sprintf("error getting file info: %v", err)) + return info.Size() + } + require.Equal(t, int64(0), getSize(dataFile)) + + // ACTION + ctx := context.Background() + bytesDownloaded, err := client.DownloadWithContext(ctx, testBucket, testObjectKey, dataFile) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, getSize(dataFile), bytesDownloaded) + assert.Equal(t, getSize(dataFile), int64(len(testObjectData))) + + buf := make([]byte, len(testObjectData)) + _, err = dataFile.Read(buf) + require.Nil(t, err, fmt.Sprintf("error reading file: %v", err)) + assert.Equal(t, testObjectData, string(buf)) + + timeoutCtx, cancel := context.WithTimeout( + context.Background(), 1*time.Nanosecond, // Should timeout + ) + defer cancel() + + // ACTION + _, err = client.DownloadWithContext(timeoutCtx, testBucket, testObjectKey, dataFile) + + // ASSERT + assert.ErrorIs(t, err, context.DeadlineExceeded) +} + func TestS3Put(t *testing.T) { // ARRANGE setup() @@ -496,6 +586,102 @@ func TestS3PutWithMetadata(t *testing.T) { assert.Equal(t, testMetaValue, metaData.meta[testMetaKey]) } +func TestS3PutStream(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + awsCmdPopulateBucket() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating s3 client: %v", err)) + + err = client.AddUploader() + require.Nil(t, err, fmt.Sprintf("error adding s3 uploader: %v", err)) + + dir := t.TempDir() + dataFile, err := os.CreateTemp(dir, "test") + require.Nil(t, err, fmt.Sprintf("error creating file: %v", err)) + + bytesInFile, err := dataFile.WriteString(testObjectData) + require.Nil(t, err, fmt.Sprintf("error writing to file: %v", err)) + require.Equal(t, len(testObjectData), bytesInFile) + + _, err = dataFile.Seek(0, io.SeekStart) + require.Nil(t, err, fmt.Sprintf("error setting file offset back to start: %v", err)) + + // ACTION + err = client.PutStream(testBucket, testObjectKey, dataFile) + + // ASSERT + assert.Nil(t, err) + + // Check the file uploaded is the same + buf := bytes.NewBuffer(nil) + err = client.Get(testBucket, testObjectKey, "", buf) + require.Nil(t, err, fmt.Sprintf("error getting object to check: %v", err)) + + contents := buf.Bytes() + assert.Equal(t, bytesInFile, len(contents)) + assert.Equal(t, testObjectData, string(contents)) +} + +func TestS3PutStreamWithContext(t *testing.T) { + + // ARRANGE + setup() + defer teardown() + + awsCmdPopulateBucket() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating s3 client: %v", err)) + + err = client.AddUploader() + require.Nil(t, err, fmt.Sprintf("error adding s3 uploader: %v", err)) + + dir := t.TempDir() + dataFile, err := os.CreateTemp(dir, "test") + require.Nil(t, err, fmt.Sprintf("error creating file: %v", err)) + + bytesInFile, err := dataFile.WriteString(testObjectData) + require.Nil(t, err, fmt.Sprintf("error writing to file: %v", err)) + require.Equal(t, len(testObjectData), bytesInFile) + + _, err = dataFile.Seek(0, io.SeekStart) + require.Nil(t, err, fmt.Sprintf("error setting file offset back to start: %v", err)) + + // ACTION + ctx := context.Background() + err = client.PutStreamWithContext(ctx, testBucket, testObjectKey, dataFile) + + // ASSERT + assert.Nil(t, err) + + // Check the file uploaded is the same + buf := bytes.NewBuffer(nil) + err = client.Get(testBucket, testObjectKey, "", buf) + require.Nil(t, err, fmt.Sprintf("error getting object to check: %v", err)) + + contents := buf.Bytes() + assert.Equal(t, bytesInFile, len(contents)) + assert.Equal(t, testObjectData, string(contents)) + + // ARRANGE + timeoutCtx, cancel := context.WithTimeout( + context.Background(), 1*time.Nanosecond, // Should timeout + ) + defer cancel() + dataFile, err = os.Open(dataFile.Name()) + require.Nil(t, err, fmt.Sprintf("error opening file: %v", err)) + + // ACTION + err = client.PutStreamWithContext(timeoutCtx, testBucket, testObjectKey, dataFile) + + // ASSERT + assert.ErrorIs(t, err, context.DeadlineExceeded) +} + func TestS3Exists(t *testing.T) { // ARRANGE setup() diff --git a/go.mod b/go.mod index 25ebd39..6a593af 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/gocql/gocql v1.6.0 github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da github.com/lib/pq v1.3.0 - github.com/stretchr/testify v1.6.1 + github.com/stretchr/testify v1.11.1 golang.org/x/text v0.14.0 google.golang.org/protobuf v1.33.0 ) @@ -35,12 +35,12 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.20.2 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.2 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.28.4 // indirect - github.com/davecgh/go-spew v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect - gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index a657393..f6465d4 100644 --- a/go.sum +++ b/go.sum @@ -49,8 +49,9 @@ github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYE github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY= github.com/gocql/gocql v1.6.0 h1:IdFdOTbnpbd0pDhl4REKQDM+Q0SzKXQ1Yh+YZZ8T/qU= github.com/gocql/gocql v1.6.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= @@ -81,8 +82,9 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -134,5 +136,6 @@ gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=