diff --git a/internal/common/payload.go b/internal/common/payload.go index fc338b3..8f13544 100644 --- a/internal/common/payload.go +++ b/internal/common/payload.go @@ -6,30 +6,35 @@ import ( type WrappedPayload struct { Payload any - Modules any + Modules map[string]Module Sender any Source string End bool } func GetWrappedPayload(ctx context.Context, payload any) WrappedPayload { - templateData := WrappedPayload{ + wrappedPayload := WrappedPayload{ Payload: payload, End: false, } modules := ctx.Value(ModulesContextKey) if modules != nil { - templateData.Modules = modules + moduleMap, ok := modules.(map[string]Module) + if ok { + wrappedPayload.Modules = moduleMap + } else { + wrappedPayload.Modules = make(map[string]Module) + } } sender := ctx.Value(SenderContextKey) if sender != nil { - templateData.Sender = sender + wrappedPayload.Sender = sender } source := ctx.Value(SourceContextKey) if source != nil { - templateData.Source = source.(string) + wrappedPayload.Source = source.(string) } - return templateData + return wrappedPayload } diff --git a/internal/processor/db-query.go b/internal/processor/db-query.go index 08a17cb..4f7f90d 100644 --- a/internal/processor/db-query.go +++ b/internal/processor/db-query.go @@ -18,19 +18,12 @@ type DbQuery struct { } func (dq *DbQuery) Process(ctx context.Context, wrappedPayload common.WrappedPayload) (common.WrappedPayload, error) { - ctxModules := ctx.Value(common.ModulesContextKey) - if ctxModules == nil { + if wrappedPayload.Modules == nil { wrappedPayload.End = true - return wrappedPayload, errors.New("db.query unable to get modules from context") + return wrappedPayload, errors.New("db.query wrapped payload has no modules") } - moduleMap, ok := ctxModules.(map[string]common.Module) - if !ok { - wrappedPayload.End = true - return wrappedPayload, errors.New("db.query modules from context has wrong type") - } - - module, ok := moduleMap[dq.ModuleId] + module, ok := wrappedPayload.Modules[dq.ModuleId] if !ok { wrappedPayload.End = true return wrappedPayload, fmt.Errorf("db.query unable to find module with id: %s", dq.ModuleId) diff --git a/internal/processor/kv-get.go b/internal/processor/kv-get.go index bff2f5d..3ac4d64 100644 --- a/internal/processor/kv-get.go +++ b/internal/processor/kv-get.go @@ -18,19 +18,12 @@ type KVGet struct { } func (kvg *KVGet) Process(ctx context.Context, wrappedPayload common.WrappedPayload) (common.WrappedPayload, error) { - ctxModules := ctx.Value(common.ModulesContextKey) - if ctxModules == nil { + if wrappedPayload.Modules == nil { wrappedPayload.End = true - return wrappedPayload, errors.New("kv.get unable to get modules from context") + return wrappedPayload, errors.New("kv.get wrapped payload has no modules") } - moduleMap, ok := ctxModules.(map[string]common.Module) - if !ok { - wrappedPayload.End = true - return wrappedPayload, errors.New("kv.get modules from context has wrong type") - } - - module, ok := moduleMap[kvg.ModuleId] + module, ok := wrappedPayload.Modules[kvg.ModuleId] if !ok { wrappedPayload.End = true return wrappedPayload, fmt.Errorf("kv.get unable to find module with id: %s", kvg.ModuleId) diff --git a/internal/processor/kv-set.go b/internal/processor/kv-set.go index 56ae79f..43051e3 100644 --- a/internal/processor/kv-set.go +++ b/internal/processor/kv-set.go @@ -20,19 +20,13 @@ type KVSet struct { } func (kvs *KVSet) Process(ctx context.Context, wrappedPayload common.WrappedPayload) (common.WrappedPayload, error) { - ctxModules := ctx.Value(common.ModulesContextKey) - if ctxModules == nil { + + if wrappedPayload.Modules == nil { wrappedPayload.End = true - return wrappedPayload, errors.New("kv.set unable to get modules from context") + return wrappedPayload, errors.New("kv.set wrapped payload has no modules") } - moduleMap, ok := ctxModules.(map[string]common.Module) - if !ok { - wrappedPayload.End = true - return wrappedPayload, errors.New("kv.set modules from context has wrong type") - } - - module, ok := moduleMap[kvs.ModuleId] + module, ok := wrappedPayload.Modules[kvs.ModuleId] if !ok { wrappedPayload.End = true return wrappedPayload, fmt.Errorf("kv.set unable to find module with id: %s", kvs.ModuleId)