diff --git a/internal/pkg/config/config.go b/internal/pkg/config/config.go index cbc4b97..bd07201 100644 --- a/internal/pkg/config/config.go +++ b/internal/pkg/config/config.go @@ -28,7 +28,7 @@ func (s String) Get(r *record.Record) (string, error) { return resolveMacro, nil } - return evaluateContext(resolveMacro, r) + return evaluateMeta(resolveMacro, r) } @@ -63,7 +63,7 @@ func Load(configFile string, obj interface{}) error { "env": getEnvironmentVariable, // returns environment variable "macro": setMacroPlaceholder, // set placeholder string for macro replacement "secret": getSecret, // we use this template function to inject secrets from parameter store - "context": setContextPlaceholder, // set placeholder string for context replacement + "context": SetMetaPlaceholder, // set placeholder string for context replacement. Maintaining "context" as a template function name for backward compatibility // indent: add `n` spaces after every newline in the value (useful when // injecting multiline values into YAML block scalars) "indent": func(n int, v string) string { diff --git a/internal/pkg/config/context.go b/internal/pkg/config/context.go index e8a61b1..3957d81 100644 --- a/internal/pkg/config/context.go +++ b/internal/pkg/config/context.go @@ -19,11 +19,11 @@ var ( ) // return placeholder string for context key -func setContextPlaceholder(key string) (string, error) { +func SetMetaPlaceholder(key string) (string, error) { return fmt.Sprintf(contextPlaceholderString, key), nil } -func evaluateContext(data string, record *record.Record) (string, error) { +func evaluateMeta(data string, record *record.Record) (string, error) { // Find all context template patterns matches := contextTemplateRegex.FindAllStringSubmatch(data, -1) @@ -45,7 +45,7 @@ func evaluateContext(data string, record *record.Record) (string, error) { key := match[1] // Get the context value - value, ok := record.GetContextValue(key) + value, ok := record.GetMetaValue(key) if !ok { missingKeys = append(missingKeys, key) continue diff --git a/internal/pkg/pipeline/pipeline.go b/internal/pkg/pipeline/pipeline.go index 8540082..4665c06 100644 --- a/internal/pkg/pipeline/pipeline.go +++ b/internal/pkg/pipeline/pipeline.go @@ -1,6 +1,7 @@ package pipeline import ( + "context" "fmt" "sync" @@ -22,6 +23,8 @@ type Pipeline struct { wg *sync.WaitGroup locker *sync.Mutex errors map[string]error + ctx context.Context + cancel context.CancelFunc } func (p *Pipeline) Init() error { @@ -61,6 +64,9 @@ func (p *Pipeline) Run() error { p.ChannelSize = defaultChannelSize } + // Create pipeline-level context for cancellation + p.ctx, p.cancel = context.WithCancel(context.Background()) + // sync if p.DAG == nil { // data streams @@ -237,10 +243,10 @@ func (p *Pipeline) runTaskConcurrently(t task.Task, input <-chan *record.Record, taskWg.Add(concurrency) for i := 0; i < concurrency; i++ { - go func(task task.Task, in <-chan *record.Record, out chan<- *record.Record) { + go func(ctx context.Context, task task.Task, in <-chan *record.Record, out chan<- *record.Record) { defer taskWg.Done() - if err := task.Run(in, out); err != nil { + if err := task.Run(ctx, in, out); err != nil { fmt.Printf("error in %s: %s\n", task.GetName(), err) if task.GetFailOnError() { p.locker.Lock() @@ -248,7 +254,7 @@ func (p *Pipeline) runTaskConcurrently(t task.Task, input <-chan *record.Record, p.locker.Unlock() } } - }(t, input, output) + }(p.ctx, t, input, output) } go func(wg *sync.WaitGroup, out chan<- *record.Record) { diff --git a/internal/pkg/pipeline/record/context.go b/internal/pkg/pipeline/record/context.go deleted file mode 100644 index f1f01bf..0000000 --- a/internal/pkg/pipeline/record/context.go +++ /dev/null @@ -1,27 +0,0 @@ -package record - -import ( - "context" -) - -type contextKey string - -func (r *Record) SetContextValue(key string, value string) { - if r.Context == nil { - r.Context = context.Background() - } - r.Context = context.WithValue(r.Context, contextKey(key), string(value)) -} - -func (r *Record) GetContextValue(key string) (string, bool) { - - if ctx := r.Context; ctx != nil { - if v := ctx.Value(contextKey(key)); v != nil { - vString, ok := v.(string) - return vString, ok - } - } - - return ``, false - -} diff --git a/internal/pkg/pipeline/record/meta.go b/internal/pkg/pipeline/record/meta.go new file mode 100644 index 0000000..d8906a2 --- /dev/null +++ b/internal/pkg/pipeline/record/meta.go @@ -0,0 +1,16 @@ +package record + +func (r *Record) SetMetaValue(key string, value string) { + if r.Meta == nil { + r.Meta = make(map[string]string) + } + r.Meta[key] = value +} + +func (r *Record) GetMetaValue(key string) (string, bool) { + if r.Meta == nil { + return "", false + } + v, ok := r.Meta[key] + return v, ok +} diff --git a/internal/pkg/pipeline/record/record.go b/internal/pkg/pipeline/record/record.go index 0c158c4..58dc451 100644 --- a/internal/pkg/pipeline/record/record.go +++ b/internal/pkg/pipeline/record/record.go @@ -1,15 +1,14 @@ package record import ( - "context" "encoding/json" ) type Record struct { - ID int `yaml:"id,omitempty" json:"id,omitempty"` - Origin string `yaml:"origin,omitempty" json:"origin,omitempty"` - Data []byte `yaml:"data,omitempty" json:"data,omitempty"` - Context context.Context `yaml:"-" json:"-"` + ID int `yaml:"id,omitempty" json:"id,omitempty"` + Origin string `yaml:"origin,omitempty" json:"origin,omitempty"` + Data []byte `yaml:"data,omitempty" json:"data,omitempty"` + Meta map[string]string `yaml:"meta,omitempty" json:"meta,omitempty"` } func (m Record) MarshalJSON() ([]byte, error) { diff --git a/internal/pkg/pipeline/task/archive/archive.go b/internal/pkg/pipeline/task/archive/archive.go index f093e58..ba87ea1 100644 --- a/internal/pkg/pipeline/task/archive/archive.go +++ b/internal/pkg/pipeline/task/archive/archive.go @@ -1,6 +1,7 @@ package archive import ( + "context" "fmt" "github.com/patterninc/caterpillar/internal/pkg/pipeline/record" @@ -83,7 +84,7 @@ func (c *core) UnmarshalYAML(unmarshal func(interface{}) error) error { return nil } -func (c *core) Run(input <-chan *record.Record, output chan<- *record.Record) (err error) { +func (c *core) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) (err error) { if input == nil { return task.ErrNilInput diff --git a/internal/pkg/pipeline/task/archive/tar.go b/internal/pkg/pipeline/task/archive/tar.go index 2e9a194..55b721c 100644 --- a/internal/pkg/pipeline/task/archive/tar.go +++ b/internal/pkg/pipeline/task/archive/tar.go @@ -48,8 +48,8 @@ func (t *tarArchive) Read() { if _, err := io.ReadFull(r, buf); err != nil && err != io.EOF { log.Fatal(err) } - rc.SetContextValue(string(task.CtxKeyArchiveFileNameWrite), filepath.Base(header.Name)) - t.SendData(rc.Context, buf, t.OutputChan) + rc.SetMetaValue(task.MetaKeyArchiveFileNameWrite, filepath.Base(header.Name)) + t.SendData(rc.Meta, buf, t.OutputChan) } } @@ -73,7 +73,7 @@ func (t *tarArchive) Write() { continue } - filePath, found := rec.GetContextValue(string(task.CtxKeyFileNameWrite)) + filePath, found := rec.GetMetaValue(task.MetaKeyFileNameWrite) if !found { log.Fatal("filepath not set in context") } @@ -97,12 +97,12 @@ func (t *tarArchive) Write() { log.Fatal(err) } - rc.Context = rec.Context + rc.Meta = rec.Meta } if err := tw.Close(); err != nil { log.Fatal(err) } - t.SendData(rc.Context, buf.Bytes(), t.OutputChan) + t.SendData(rc.Meta, buf.Bytes(), t.OutputChan) } diff --git a/internal/pkg/pipeline/task/archive/zip.go b/internal/pkg/pipeline/task/archive/zip.go index 53b10bd..8c1b3a9 100644 --- a/internal/pkg/pipeline/task/archive/zip.go +++ b/internal/pkg/pipeline/task/archive/zip.go @@ -39,7 +39,7 @@ func (z *zipArchive) Read() { // check the file type is regular file if f.FileInfo().Mode().IsRegular() { - rc.SetContextValue(string(task.CtxKeyArchiveFileNameWrite), filepath.Base(f.Name)) + rc.SetMetaValue(task.MetaKeyArchiveFileNameWrite, filepath.Base(f.Name)) fs, err := f.Open() if err != nil { @@ -55,7 +55,7 @@ func (z *zipArchive) Read() { fs.Close() - z.SendData(rc.Context, buf, z.OutputChan) + z.SendData(rc.Meta, buf, z.OutputChan) } } } @@ -73,7 +73,7 @@ func (z *zipArchive) Write() { break } - filePath, found := rec.GetContextValue(string(task.CtxKeyFileNameWrite)) + filePath, found := rec.GetMetaValue(task.MetaKeyFileNameWrite) if !found { log.Fatal("filepath not set in context") } @@ -93,7 +93,7 @@ func (z *zipArchive) Write() { log.Fatal(err) } - rc.Context = rec.Context + rc.Meta = rec.Meta } if err := zipWriter.Close(); err != nil { @@ -101,6 +101,6 @@ func (z *zipArchive) Write() { } // Send the complete ZIP archive - z.SendData(rc.Context, zipBuf.Bytes(), z.OutputChan) + z.SendData(rc.Meta, zipBuf.Bytes(), z.OutputChan) } diff --git a/internal/pkg/pipeline/task/aws/parameter_store/parameter_store.go b/internal/pkg/pipeline/task/aws/parameter_store/parameter_store.go index 365998b..9b79d8a 100644 --- a/internal/pkg/pipeline/task/aws/parameter_store/parameter_store.go +++ b/internal/pkg/pipeline/task/aws/parameter_store/parameter_store.go @@ -53,7 +53,7 @@ func (p *parameterStore) Init() error { return nil } -func (p *parameterStore) Run(input <-chan *record.Record, output chan<- *record.Record) (err error) { +func (p *parameterStore) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) (err error) { for { r, ok := p.GetRecord(input) diff --git a/internal/pkg/pipeline/task/compress/compress.go b/internal/pkg/pipeline/task/compress/compress.go index 3577af5..fc197de 100644 --- a/internal/pkg/pipeline/task/compress/compress.go +++ b/internal/pkg/pipeline/task/compress/compress.go @@ -2,6 +2,7 @@ package compress import ( "bytes" + "context" "fmt" "io" @@ -47,7 +48,7 @@ func (c *core) UnmarshalYAML(unmarshal func(interface{}) error) error { return nil } -func (c *core) Run(input <-chan *record.Record, output chan<- *record.Record) (err error) { +func (c *core) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) (err error) { if input == nil { return task.ErrNilInput @@ -83,7 +84,7 @@ func (c *core) Run(input <-chan *record.Record, output chan<- *record.Record) (e } if output != nil { - c.SendData(r.Context, transformedData, output) + c.SendData(r.Meta, transformedData, output) } } diff --git a/internal/pkg/pipeline/task/converter/converter.go b/internal/pkg/pipeline/task/converter/converter.go index 8a774bc..0a80bdc 100644 --- a/internal/pkg/pipeline/task/converter/converter.go +++ b/internal/pkg/pipeline/task/converter/converter.go @@ -1,6 +1,7 @@ package converter import ( + "context" "fmt" "github.com/patterninc/caterpillar/internal/pkg/pipeline/record" @@ -66,7 +67,7 @@ func (c *core) UnmarshalYAML(unmarshal func(interface{}) error) error { } -func (c *core) Run(input <-chan *record.Record, output chan<- *record.Record) error { +func (c *core) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error { for { r, ok := c.GetRecord(input) @@ -83,10 +84,10 @@ func (c *core) Run(input <-chan *record.Record, output chan<- *record.Record) er if out.Data != nil { // Add metadata to context for k, v := range out.Metadata { - r.SetContextValue(k, v) + r.SetMetaValue(k, v) } - c.SendData(r.Context, out.Data, output) + c.SendData(r.Meta, out.Data, output) } } } diff --git a/internal/pkg/pipeline/task/delay/delay.go b/internal/pkg/pipeline/task/delay/delay.go index 0bf7927..5beb9b6 100644 --- a/internal/pkg/pipeline/task/delay/delay.go +++ b/internal/pkg/pipeline/task/delay/delay.go @@ -1,6 +1,7 @@ package delay import ( + "context" "time" "github.com/patterninc/caterpillar/internal/pkg/duration" @@ -23,7 +24,7 @@ func New() (task.Task, error) { }, nil } -func (d *delay) Run(input <-chan *record.Record, output chan<- *record.Record) error { +func (d *delay) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error { for { r, ok := d.GetRecord(input) diff --git a/internal/pkg/pipeline/task/echo/echo.go b/internal/pkg/pipeline/task/echo/echo.go index 2370ffe..2f520da 100644 --- a/internal/pkg/pipeline/task/echo/echo.go +++ b/internal/pkg/pipeline/task/echo/echo.go @@ -1,6 +1,7 @@ package echo import ( + "context" "fmt" "time" @@ -21,7 +22,7 @@ func New() (task.Task, error) { return &echo{}, nil } -func (e *echo) Run(input <-chan *record.Record, output chan<- *record.Record) (err error) { +func (e *echo) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) (err error) { for { r, ok := e.GetRecord(input) diff --git a/internal/pkg/pipeline/task/file/file.go b/internal/pkg/pipeline/task/file/file.go index 722e315..dc5b64f 100644 --- a/internal/pkg/pipeline/task/file/file.go +++ b/internal/pkg/pipeline/task/file/file.go @@ -57,7 +57,7 @@ func New() (task.Task, error) { }, nil } -func (f *file) Run(input <-chan *record.Record, output chan<- *record.Record) error { +func (f *file) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error { // let's check if we read file or we write file... if input != nil && output != nil { @@ -134,12 +134,12 @@ func (f *file) readFile(output chan<- *record.Record) error { return err } - // Create a default record with context - rc := &record.Record{Context: ctx} - rc.SetContextValue(string(task.CtxKeyFileNameWrite), filepath.Base(path)) + // create a default record with file name in context + rc := &record.Record{} + rc.SetMetaValue(task.MetaKeyFileNameWrite, filepath.Base(path)) // let's write content to output channel - f.SendData(rc.Context, content, output) + f.SendData(rc.Meta, content, output) } @@ -174,7 +174,7 @@ func (f *file) writeFile(input <-chan *record.Record) error { var fs file fs = *f - filePath, found := rc.GetContextValue(string(task.CtxKeyArchiveFileNameWrite)) + filePath, found := rc.GetMetaValue(task.MetaKeyArchiveFileNameWrite) if found { if filePath == "" { log.Fatal("required file path") diff --git a/internal/pkg/pipeline/task/flatten/flatten.go b/internal/pkg/pipeline/task/flatten/flatten.go index c36ab46..8817601 100644 --- a/internal/pkg/pipeline/task/flatten/flatten.go +++ b/internal/pkg/pipeline/task/flatten/flatten.go @@ -1,6 +1,7 @@ package flatten import ( + "context" "encoding/json" "github.com/patterninc/caterpillar/internal/pkg/pipeline/record" @@ -16,7 +17,7 @@ func New() (task.Task, error) { return &flatten{}, nil } -func (f *flatten) Run(input <-chan *record.Record, output chan<- *record.Record) error { +func (f *flatten) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error { for { r, ok := f.GetRecord(input) @@ -41,7 +42,7 @@ func (f *flatten) Run(input <-chan *record.Record, output chan<- *record.Record) return err } - f.SendData(r.Context, flatJson, output) + f.SendData(r.Meta, flatJson, output) } return nil diff --git a/internal/pkg/pipeline/task/heimdall/heimdall.go b/internal/pkg/pipeline/task/heimdall/heimdall.go index 1f54852..9e2b846 100644 --- a/internal/pkg/pipeline/task/heimdall/heimdall.go +++ b/internal/pkg/pipeline/task/heimdall/heimdall.go @@ -1,6 +1,7 @@ package heimdall import ( + "context" "encoding/json" "fmt" "net/http" @@ -53,7 +54,7 @@ func New() (task.Task, error) { return h, nil } -func (h *heimdall) Run(input <-chan *record.Record, output chan<- *record.Record) (err error) { +func (h *heimdall) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) (err error) { // If input is provided, override the job request context if input != nil { diff --git a/internal/pkg/pipeline/task/heimdall/result.go b/internal/pkg/pipeline/task/heimdall/result.go index 52484bd..bb5bb23 100644 --- a/internal/pkg/pipeline/task/heimdall/result.go +++ b/internal/pkg/pipeline/task/heimdall/result.go @@ -54,7 +54,7 @@ func (h *heimdall) sendToOutput(result *result, output chan<- *record.Record) er } for _, item := range items { - h.SendData(ctx, item, output) + h.SendData(nil, item, output) } return nil diff --git a/internal/pkg/pipeline/task/http/http.go b/internal/pkg/pipeline/task/http/http.go index a923c72..9bdbfc4 100644 --- a/internal/pkg/pipeline/task/http/http.go +++ b/internal/pkg/pipeline/task/http/http.go @@ -124,7 +124,7 @@ func (h *httpCore) newFromInput(data []byte) (*httpCore, error) { } -func (h *httpCore) Run(input <-chan *record.Record, output chan<- *record.Record) (err error) { +func (h *httpCore) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) (err error) { // if we have input, treat each value as a URL and try to get data from it... if input != nil { @@ -160,9 +160,9 @@ func (h *httpCore) processItem(rc *record.Record, output chan<- *record.Record) return nil } - // create a default record context if none provided + // create a default record if none provided if rc == nil { - rc = &record.Record{Context: context.Background()} + rc = &record.Record{} } // TODO: perhaps expose the starting page number as a parameter for the task @@ -181,10 +181,10 @@ func (h *httpCore) processItem(rc *record.Record, output chan<- *record.Record) // Header names are preserved as-is and stored with the http-header- prefix for headerName, headerValues := range result.Headers { contextKey := fmt.Sprintf(headerContextPrefix, headerName) - rc.SetContextValue(contextKey, strings.Join(headerValues, "; ")) + rc.SetMetaValue(contextKey, strings.Join(headerValues, "; ")) } - h.SendData(rc.Context, []byte(result.Data), output) + h.SendData(rc.Meta, []byte(result.Data), output) } // if we do not have a way to define the next page, we bail... diff --git a/internal/pkg/pipeline/task/http/server/handler.go b/internal/pkg/pipeline/task/http/server/handler.go index f3f5cfb..4eec4ad 100644 --- a/internal/pkg/pipeline/task/http/server/handler.go +++ b/internal/pkg/pipeline/task/http/server/handler.go @@ -66,7 +66,7 @@ func (s *server) createPathHandler(pathConfig pathConfig, output chan<- *record. } // create a new record with the request information - s.SendData(ctx, jsonData, output) + s.SendData(nil, jsonData, output) w.Header().Set(contentTypeKey, contentTypeJson) w.WriteHeader(http.StatusOK) diff --git a/internal/pkg/pipeline/task/http/server/server.go b/internal/pkg/pipeline/task/http/server/server.go index 8e15442..38a6d60 100644 --- a/internal/pkg/pipeline/task/http/server/server.go +++ b/internal/pkg/pipeline/task/http/server/server.go @@ -58,7 +58,7 @@ func (s *server) GetTaskConcurrency() int { return 1 } -func (s *server) Run(input <-chan *record.Record, output chan<- *record.Record) error { +func (s *server) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error { // input channel must be nil if input != nil { diff --git a/internal/pkg/pipeline/task/join/join.go b/internal/pkg/pipeline/task/join/join.go index 89e3700..994f2e2 100644 --- a/internal/pkg/pipeline/task/join/join.go +++ b/internal/pkg/pipeline/task/join/join.go @@ -40,7 +40,7 @@ func New() (task.Task, error) { }, nil } -func (j *join) Run(input <-chan *record.Record, output chan<- *record.Record) error { +func (j *join) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error { if input == nil || output == nil { return ErrIncorrectInputOutput @@ -103,6 +103,6 @@ func (j *join) sendJoinedRecords(output chan<- *record.Record) { joinedData.Write(r.Data) } - j.SendData(ctx, []byte(joinedData.String()), output) + j.SendData(nil, []byte(joinedData.String()), output) } diff --git a/internal/pkg/pipeline/task/jq/jq.go b/internal/pkg/pipeline/task/jq/jq.go index bb3c9d1..d06f2c9 100644 --- a/internal/pkg/pipeline/task/jq/jq.go +++ b/internal/pkg/pipeline/task/jq/jq.go @@ -1,6 +1,7 @@ package jq import ( + "context" "encoding/json" "fmt" @@ -20,7 +21,7 @@ func New() (task.Task, error) { return &jq{}, nil } -func (j *jq) Run(input <-chan *record.Record, output chan<- *record.Record) (err error) { +func (j *jq) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) (err error) { if input != nil && output != nil { for { @@ -46,24 +47,24 @@ func (j *jq) Run(input <-chan *record.Record, output chan<- *record.Record) (err if splitItems, ok := items.([]any); j.Explode && ok { for _, splitItem := range splitItems { if j.AsRaw { - j.SendData(r.Context, fmt.Appendf(nil, "%v", splitItem), output) + j.SendData(r.Meta, fmt.Appendf(nil, "%v", splitItem), output) } else { jsonItem, err := json.Marshal(splitItem) if err != nil { return err } - j.SendData(r.Context, jsonItem, output) + j.SendData(r.Meta, jsonItem, output) } } } else { if j.AsRaw { - j.SendData(r.Context, fmt.Appendf(nil, "%v", items), output) + j.SendData(r.Meta, fmt.Appendf(nil, "%v", items), output) } else { jsonItem, err := json.Marshal(items) if err != nil { return err } - j.SendData(r.Context, jsonItem, output) + j.SendData(r.Meta, jsonItem, output) } } } diff --git a/internal/pkg/pipeline/task/kafka/kafka.go b/internal/pkg/pipeline/task/kafka/kafka.go index 2686398..2940349 100644 --- a/internal/pkg/pipeline/task/kafka/kafka.go +++ b/internal/pkg/pipeline/task/kafka/kafka.go @@ -105,7 +105,7 @@ func (k *kafka) Init() error { return nil } -func (k *kafka) Run(input <-chan *record.Record, output chan<- *record.Record) error { +func (k *kafka) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error { if input != nil && output != nil { return task.ErrPresentInputOutput } @@ -189,7 +189,7 @@ func (k *kafka) read(output chan<- *record.Record) error { k.emptyReadRetries = 0 // process the message - k.SendData(k.ctx, m.Value, output) + k.SendData(nil, m.Value, output) if k.GroupID == "" { // if not using consumer group, no need to commit messages diff --git a/internal/pkg/pipeline/task/replace/replace.go b/internal/pkg/pipeline/task/replace/replace.go index e7d724a..bd2dabb 100644 --- a/internal/pkg/pipeline/task/replace/replace.go +++ b/internal/pkg/pipeline/task/replace/replace.go @@ -1,6 +1,7 @@ package replace import ( + "context" "regexp" "github.com/patterninc/caterpillar/internal/pkg/pipeline/record" @@ -17,7 +18,7 @@ func New() (task.Task, error) { return &replace{}, nil } -func (r *replace) Run(input <-chan *record.Record, output chan<- *record.Record) (err error) { +func (r *replace) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) (err error) { rx, err := regexp.Compile(r.Expression) if err != nil { @@ -30,7 +31,7 @@ func (r *replace) Run(input <-chan *record.Record, output chan<- *record.Record) if !ok { break } - r.SendData(record.Context, []byte(rx.ReplaceAllString(string(record.Data), r.Replacement)), output) + r.SendData(record.Meta, []byte(rx.ReplaceAllString(string(record.Data), r.Replacement)), output) } } diff --git a/internal/pkg/pipeline/task/sample/sample.go b/internal/pkg/pipeline/task/sample/sample.go index 80d6fd2..4b37e17 100644 --- a/internal/pkg/pipeline/task/sample/sample.go +++ b/internal/pkg/pipeline/task/sample/sample.go @@ -1,6 +1,7 @@ package sample import ( + "context" "fmt" "github.com/patterninc/caterpillar/internal/pkg/pipeline/record" @@ -48,7 +49,7 @@ func New() (task.Task, error) { }, nil } -func (s *sample) Run(input <-chan *record.Record, output chan<- *record.Record) error { +func (s *sample) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error { // if this task is first or last in the pipeline, let's bail... if input == nil || output == nil { diff --git a/internal/pkg/pipeline/task/sns/sns.go b/internal/pkg/pipeline/task/sns/sns.go index cec5c03..675407d 100644 --- a/internal/pkg/pipeline/task/sns/sns.go +++ b/internal/pkg/pipeline/task/sns/sns.go @@ -64,7 +64,7 @@ func (s *snsTask) Init() error { return nil } -func (s *snsTask) Run(input <-chan *record.Record, output chan<- *record.Record) error { +func (s *snsTask) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error { if input == nil { return task.ErrNilInput } @@ -101,7 +101,7 @@ func (s *snsTask) Run(input <-chan *record.Record, output chan<- *record.Record) } } - _, err := s.client.Publish(r.Context, publishInput) + _, err := s.client.Publish(ctx, publishInput) if err != nil { return fmt.Errorf("failed to publish to SNS topic %s: %w", s.TopicArn, err) } diff --git a/internal/pkg/pipeline/task/split/split.go b/internal/pkg/pipeline/task/split/split.go index 51029a7..d29dd87 100644 --- a/internal/pkg/pipeline/task/split/split.go +++ b/internal/pkg/pipeline/task/split/split.go @@ -1,6 +1,7 @@ package split import ( + "context" "strings" "github.com/patterninc/caterpillar/internal/pkg/pipeline/record" @@ -21,7 +22,7 @@ func New() (task.Task, error) { Delimiter: defaultDelimiter, }, nil } -func (s *split) Run(input <-chan *record.Record, output chan<- *record.Record) error { +func (s *split) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error { for { r, ok := s.GetRecord(input) @@ -30,7 +31,7 @@ func (s *split) Run(input <-chan *record.Record, output chan<- *record.Record) e } lines := strings.Split(strings.TrimSuffix(string(r.Data), s.Delimiter), s.Delimiter) for _, line := range lines { - s.SendData(r.Context, []byte(line), output) + s.SendData(r.Meta, []byte(line), output) } } diff --git a/internal/pkg/pipeline/task/sqs/sqs.go b/internal/pkg/pipeline/task/sqs/sqs.go index 1dcf05b..b754ba3 100644 --- a/internal/pkg/pipeline/task/sqs/sqs.go +++ b/internal/pkg/pipeline/task/sqs/sqs.go @@ -86,7 +86,7 @@ func (s *sqs) extractRegionFromQueueURL() string { return defaultRegion } -func (s *sqs) Run(input <-chan *record.Record, output chan<- *record.Record) error { +func (s *sqs) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error { // Client is already initialized in RunPreHook - just use it if input != nil { @@ -159,7 +159,7 @@ func (s *sqs) getMessages(ctx context.Context, output chan<- *record.Record, rec // create new record and send it downstream if output != nil { - s.SendData(ctx, []byte(*m.Body), output) + s.SendData(nil, []byte(*m.Body), output) } // send receipt to receipts channel for deletion diff --git a/internal/pkg/pipeline/task/task.go b/internal/pkg/pipeline/task/task.go index 749c36b..781ef07 100644 --- a/internal/pkg/pipeline/task/task.go +++ b/internal/pkg/pipeline/task/task.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "maps" "sync" "github.com/patterninc/caterpillar/internal/pkg/jq" @@ -12,6 +13,9 @@ import ( const ( ErrUnsupportedFieldValue = `invalid value for field %s: %s` + + MetaKeyFileNameWrite = "CATERPILLAR_FILE_NAME_WRITE" + MetaKeyArchiveFileNameWrite = "CATERPILLAR_ARCHIVE_FILE_NAME_WRITE" ) var ( @@ -21,15 +25,8 @@ var ( ErrPresentInputOutput = fmt.Errorf(`either input or output must be set, not both`) ) -type contextKeyFile string - -const ( - CtxKeyFileNameWrite contextKeyFile = "CATERPILLAR_FILE_NAME_WRITE" - CtxKeyArchiveFileNameWrite contextKeyFile = "CATERPILLAR_ARCHIVE_FILE_NAME_WRITE" -) - type Task interface { - Run(<-chan *record.Record, chan<- *record.Record) error + Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error GetName() string GetFailOnError() bool GetTaskConcurrency() int @@ -79,7 +76,7 @@ func (b *Base) GetRecord(input <-chan *record.Record) (*record.Record, bool) { } -func (b *Base) Run(input <-chan *record.Record, output chan<- *record.Record) error { +func (b *Base) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error { for r := range input { b.SendRecord(r, output) @@ -89,7 +86,7 @@ func (b *Base) Run(input <-chan *record.Record, output chan<- *record.Record) er } -func (b *Base) SendData(ctx context.Context, data []byte, output chan<- *record.Record) /* we should return error here */ { +func (b *Base) SendData(meta map[string]string, data []byte, output chan<- *record.Record) /* we should return error here */ { b.Lock() defer b.Unlock() @@ -97,10 +94,15 @@ func (b *Base) SendData(ctx context.Context, data []byte, output chan<- *record. b.recordIndex++ record := &record.Record{ - ID: b.recordIndex, - Origin: b.Name, - Data: data, - Context: ctx, + ID: b.recordIndex, + Origin: b.Name, + Data: data, + } + + // Copy meta map if provided + if meta != nil { + record.Meta = make(map[string]string, len(meta)) + maps.Copy(record.Meta, meta) } b.SendRecord(record, output) @@ -139,7 +141,7 @@ func (b *Base) SendRecord(r *record.Record, output chan<- *record.Record) /* we fmt.Println(`ERROR (result):`, err) return } - r.SetContextValue(name, string(contextValueJson)) + r.SetMetaValue(name, string(contextValueJson)) } } diff --git a/internal/pkg/pipeline/task/xpath/xpath.go b/internal/pkg/pipeline/task/xpath/xpath.go index 29b6f2f..5fd990c 100644 --- a/internal/pkg/pipeline/task/xpath/xpath.go +++ b/internal/pkg/pipeline/task/xpath/xpath.go @@ -2,6 +2,7 @@ package xpath import ( "bytes" + "context" "encoding/json" "fmt" "strings" @@ -27,7 +28,7 @@ func New() (task.Task, error) { return &xpath{IgnoreMissing: true}, nil } -func (x *xpath) Run(input <-chan *record.Record, output chan<- *record.Record) error { +func (x *xpath) Run(ctx context.Context, input <-chan *record.Record, output chan<- *record.Record) error { for { r, ok := x.GetRecord(input) @@ -60,8 +61,8 @@ func (x *xpath) Run(input <-chan *record.Record, output chan<- *record.Record) e if len(data) != 0 { index := fmt.Sprintf("%d", i+1) - r.SetContextValue(nodeIndexKey, index) - x.SendData(r.Context, data, output) + r.SetMetaValue(nodeIndexKey, index) + x.SendData(r.Meta, data, output) } } }