Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions aws/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"sync"
Expand Down Expand Up @@ -509,6 +508,18 @@ func (s *S3) ListAllObjectsConcurrently(bucket string, prefixes []string) ([]typ

// PutStream puts the data stream to key in bucket.
func (s *S3) PutStream(bucket, key string, reader io.ReadCloser) error {
return s.putStream(context.TODO(), bucket, key, reader)
}

// PutStreamWithContext is the same as PutStream but uses
// the provided context.
func (s *S3) PutStreamWithContext(ctx context.Context, bucket, key string, reader io.ReadCloser) error {
return s.putStream(ctx, bucket, key, reader)
}

// putStream is the common code used internally to upload a data stream to
// an S3 bucket using the client's uploader.
func (s *S3) putStream(ctx context.Context, bucket, key string, reader io.ReadCloser) error {
defer reader.Close()

if s.uploader == nil {
Expand All @@ -519,9 +530,9 @@ func (s *S3) PutStream(bucket, key string, reader io.ReadCloser) error {
Key: aws.String(key),
Body: reader,
}
_, err := s.uploader.Upload(context.TODO(), &input)
_, err := s.uploader.Upload(ctx, &input)
if err != nil {
return fmt.Errorf("error uploading to s3 for key %s, error: %s", key, err.Error())
return err
}
return nil
}
Expand All @@ -530,14 +541,34 @@ func (s *S3) PutStream(bucket, key string, reader io.ReadCloser) error {
// File is split up into parts and downloaded concurrently into an os.File,
// so is useful for getting large files. Returns number of bytes downloaded.
func (s *S3) Download(bucket, key string, f *os.File) (int64, error) {
return s.download(context.TODO(), bucket, key, f)
}

// DownloadWithContext is the same as Download but uses
// the provided context.
func (s *S3) DownloadWithContext(
ctx context.Context,
bucket, key string,
f *os.File,
) (int64, error) {
return s.download(ctx, bucket, key, f)
}

// download is the common code used internally to download an S3 object
// using the downloader based on the provided input.
func (s *S3) download(
ctx context.Context,
bucket, key string,
f *os.File,
) (int64, error) {
if s.downloader == nil {
return 0, errors.New("error downloading from S3, downloader not initialised")
}
input := s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
}
numBytes, err := s.downloader.Download(context.TODO(), f, &input)
numBytes, err := s.downloader.Download(ctx, f, &input)
if err != nil {
return 0, err
}
Expand Down
186 changes: 186 additions & 0 deletions aws/s3/s3_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package s3

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
Expand Down Expand Up @@ -456,6 +458,94 @@ func TestS3GetContentSizeTime(t *testing.T) {
assert.Equal(t, meta.lastModified, lastModified)
}

func TestS3Download(t *testing.T) {
// ARRANGE
setup()
defer teardown()

awsCmdPopulateBucket()

client, err := New()
require.Nil(t, err, fmt.Sprintf("error creating s3 client: %v", err))

err = client.AddDownloader()
require.Nil(t, err, fmt.Sprintf("error adding s3 downloader: %v", err))

dir := t.TempDir()
dataFile, err := os.CreateTemp(dir, "test")
require.Nil(t, err, fmt.Sprintf("error creating file: %v", err))

getSize := func(*os.File) int64 {
info, err := dataFile.Stat()
require.Nil(t, err, fmt.Sprintf("error getting file info: %v", err))
return info.Size()
}
require.Equal(t, int64(0), getSize(dataFile))

// ACTION
bytesDownloaded, err := client.Download(testBucket, testObjectKey, dataFile)

// ASSERT
assert.Nil(t, err)
assert.Equal(t, getSize(dataFile), bytesDownloaded)
assert.Equal(t, getSize(dataFile), int64(len(testObjectData)))

buf := make([]byte, len(testObjectData))
_, err = dataFile.Read(buf)
require.Nil(t, err, fmt.Sprintf("error reading file: %v", err))
assert.Equal(t, testObjectData, string(buf))
}

func TestS3DownloadWithContext(t *testing.T) {
// ARRANGE
setup()
defer teardown()

awsCmdPopulateBucket()

client, err := New()
require.Nil(t, err, fmt.Sprintf("error creating s3 client: %v", err))

err = client.AddDownloader()
require.Nil(t, err, fmt.Sprintf("error adding s3 downloader: %v", err))

dir := t.TempDir()
dataFile, err := os.CreateTemp(dir, "test")
require.Nil(t, err, fmt.Sprintf("error creating file: %v", err))

getSize := func(*os.File) int64 {
info, err := dataFile.Stat()
require.Nil(t, err, fmt.Sprintf("error getting file info: %v", err))
return info.Size()
}
require.Equal(t, int64(0), getSize(dataFile))

// ACTION
ctx := context.Background()
bytesDownloaded, err := client.DownloadWithContext(ctx, testBucket, testObjectKey, dataFile)

// ASSERT
assert.Nil(t, err)
assert.Equal(t, getSize(dataFile), bytesDownloaded)
assert.Equal(t, getSize(dataFile), int64(len(testObjectData)))

buf := make([]byte, len(testObjectData))
_, err = dataFile.Read(buf)
require.Nil(t, err, fmt.Sprintf("error reading file: %v", err))
assert.Equal(t, testObjectData, string(buf))

timeoutCtx, cancel := context.WithTimeout(
context.Background(), 1*time.Nanosecond, // Should timeout
)
defer cancel()

// ACTION
_, err = client.DownloadWithContext(timeoutCtx, testBucket, testObjectKey, dataFile)

// ASSERT
assert.ErrorIs(t, err, context.DeadlineExceeded)
}

func TestS3Put(t *testing.T) {
// ARRANGE
setup()
Expand Down Expand Up @@ -496,6 +586,102 @@ func TestS3PutWithMetadata(t *testing.T) {
assert.Equal(t, testMetaValue, metaData.meta[testMetaKey])
}

func TestS3PutStream(t *testing.T) {
// ARRANGE
setup()
defer teardown()

awsCmdPopulateBucket()

client, err := New()
require.Nil(t, err, fmt.Sprintf("error creating s3 client: %v", err))

err = client.AddUploader()
require.Nil(t, err, fmt.Sprintf("error adding s3 uploader: %v", err))

dir := t.TempDir()
dataFile, err := os.CreateTemp(dir, "test")
require.Nil(t, err, fmt.Sprintf("error creating file: %v", err))

bytesInFile, err := dataFile.WriteString(testObjectData)
require.Nil(t, err, fmt.Sprintf("error writing to file: %v", err))
require.Equal(t, len(testObjectData), bytesInFile)

_, err = dataFile.Seek(0, io.SeekStart)
require.Nil(t, err, fmt.Sprintf("error setting file offset back to start: %v", err))

// ACTION
err = client.PutStream(testBucket, testObjectKey, dataFile)

// ASSERT
assert.Nil(t, err)

// Check the file uploaded is the same
buf := bytes.NewBuffer(nil)
err = client.Get(testBucket, testObjectKey, "", buf)
require.Nil(t, err, fmt.Sprintf("error getting object to check: %v", err))

contents := buf.Bytes()
assert.Equal(t, bytesInFile, len(contents))
assert.Equal(t, testObjectData, string(contents))
}

func TestS3PutStreamWithContext(t *testing.T) {

// ARRANGE
setup()
defer teardown()

awsCmdPopulateBucket()

client, err := New()
require.Nil(t, err, fmt.Sprintf("error creating s3 client: %v", err))

err = client.AddUploader()
require.Nil(t, err, fmt.Sprintf("error adding s3 uploader: %v", err))

dir := t.TempDir()
dataFile, err := os.CreateTemp(dir, "test")
require.Nil(t, err, fmt.Sprintf("error creating file: %v", err))

bytesInFile, err := dataFile.WriteString(testObjectData)
require.Nil(t, err, fmt.Sprintf("error writing to file: %v", err))
require.Equal(t, len(testObjectData), bytesInFile)

_, err = dataFile.Seek(0, io.SeekStart)
require.Nil(t, err, fmt.Sprintf("error setting file offset back to start: %v", err))

// ACTION
ctx := context.Background()
err = client.PutStreamWithContext(ctx, testBucket, testObjectKey, dataFile)

// ASSERT
assert.Nil(t, err)

// Check the file uploaded is the same
buf := bytes.NewBuffer(nil)
err = client.Get(testBucket, testObjectKey, "", buf)
require.Nil(t, err, fmt.Sprintf("error getting object to check: %v", err))

contents := buf.Bytes()
assert.Equal(t, bytesInFile, len(contents))
assert.Equal(t, testObjectData, string(contents))

// ARRANGE
timeoutCtx, cancel := context.WithTimeout(
context.Background(), 1*time.Nanosecond, // Should timeout
)
defer cancel()
dataFile, err = os.Open(dataFile.Name())
require.Nil(t, err, fmt.Sprintf("error opening file: %v", err))

// ACTION
err = client.PutStreamWithContext(timeoutCtx, testBucket, testObjectKey, dataFile)

// ASSERT
assert.ErrorIs(t, err, context.DeadlineExceeded)
}

func TestS3Exists(t *testing.T) {
// ARRANGE
setup()
Expand Down
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ require (
github.com/gocql/gocql v1.6.0
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da
github.com/lib/pq v1.3.0
github.com/stretchr/testify v1.6.1
github.com/stretchr/testify v1.11.1
golang.org/x/text v0.14.0
google.golang.org/protobuf v1.33.0
)
Expand All @@ -35,12 +35,12 @@ require (
github.com/aws/aws-sdk-go-v2/service/sso v1.20.2 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.2 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.28.4 // indirect
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
9 changes: 6 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYE
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY=
github.com/gocql/gocql v1.6.0 h1:IdFdOTbnpbd0pDhl4REKQDM+Q0SzKXQ1Yh+YZZ8T/qU=
github.com/gocql/gocql v1.6.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8=
Expand Down Expand Up @@ -81,8 +82,9 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
Expand Down Expand Up @@ -134,5 +136,6 @@ gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
Loading