From 8ffc7d02a59848fb1a2b1c16622ccd3c042a3977 Mon Sep 17 00:00:00 2001 From: Joel Wetzell Date: Sat, 27 Dec 2025 22:59:30 -0600 Subject: [PATCH] stuff values into context --- internal/module/http-client.go | 10 ++++++++-- internal/module/http-server.go | 10 ++++++++-- internal/module/midi-input.go | 10 ++++++++-- internal/module/midi-ouptut.go | 10 ++++++++-- internal/module/module.go | 5 ++--- internal/module/mqtt-client.go | 10 ++++++++-- internal/module/nats-client.go | 10 ++++++++-- internal/module/psn-client.go | 9 +++++++-- internal/module/serial-client.go | 10 ++++++++-- internal/module/sip-call-server.go | 10 ++++++++-- internal/module/sip-dtmf-server.go | 9 +++++++-- internal/module/tcp-client.go | 10 ++++++++-- internal/module/tcp-server.go | 10 ++++++++-- internal/module/time-interval.go | 10 ++++++++-- internal/module/time-timer.go | 10 ++++++++-- internal/module/udp-client.go | 10 ++++++++-- internal/module/udp-multicast.go | 10 ++++++++-- internal/module/udp-server.go | 9 +++++++-- internal/route/route.go | 27 ++++++++++++++++++++------- internal/route/route_test.go | 9 +++++---- router.go | 12 +++++++----- 21 files changed, 167 insertions(+), 53 deletions(-) diff --git a/internal/module/http-client.go b/internal/module/http-client.go index c0522b9..5e1e9e8 100644 --- a/internal/module/http-client.go +++ b/internal/module/http-client.go @@ -22,7 +22,13 @@ type HTTPClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "http.client", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("http.client unable to get router from context") + } return &HTTPClient{config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, @@ -48,7 +54,7 @@ func (hc *HTTPClient) Run() error { return nil } -func (hc *HTTPClient) Output(payload any) error { +func (hc *HTTPClient) Output(ctx context.Context, payload any) error { payloadRequest, ok := payload.(*http.Request) diff --git a/internal/module/http-server.go b/internal/module/http-server.go index ebeae93..e6d2914 100644 --- a/internal/module/http-server.go +++ b/internal/module/http-server.go @@ -28,7 +28,7 @@ type ResponseData struct { func init() { RegisterModule(ModuleRegistration{ Type: "http.server", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params port, ok := params["port"] if !ok { @@ -41,6 +41,12 @@ func init() { return nil, errors.New("http.server port must be uint16") } + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("http.server unable to get router from context") + } + return &HTTPServer{Port: uint16(portNum), config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -105,6 +111,6 @@ func (hs *HTTPServer) Run() error { return nil } -func (hs *HTTPServer) Output(payload any) error { +func (hs *HTTPServer) Output(ctx context.Context, payload any) error { return errors.New("http.server output is not implemented") } diff --git a/internal/module/midi-input.go b/internal/module/midi-input.go index 93fb2bf..82c07b6 100644 --- a/internal/module/midi-input.go +++ b/internal/module/midi-input.go @@ -26,7 +26,7 @@ type MIDIInput struct { func init() { RegisterModule(ModuleRegistration{ Type: "midi.input", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params port, ok := params["port"] @@ -40,6 +40,12 @@ func init() { return nil, errors.New("midi.input port must be a string") } + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("midi.input unable to get router from context") + } + return &MIDIInput{config: config, Port: portString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -78,6 +84,6 @@ func (mi *MIDIInput) Run() error { return nil } -func (mi *MIDIInput) Output(payload any) error { +func (mi *MIDIInput) Output(ctx context.Context, payload any) error { return errors.New("midi.input output is not implemented") } diff --git a/internal/module/midi-ouptut.go b/internal/module/midi-ouptut.go index 128b136..d53ab37 100644 --- a/internal/module/midi-ouptut.go +++ b/internal/module/midi-ouptut.go @@ -26,7 +26,7 @@ type MIDIOutput struct { func init() { RegisterModule(ModuleRegistration{ Type: "midi.output", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params port, ok := params["port"] @@ -41,6 +41,12 @@ func init() { return nil, errors.New("midi.output port must be a string") } + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("midi.output unable to get router from context") + } + return &MIDIOutput{config: config, Port: portString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -75,7 +81,7 @@ func (mo *MIDIOutput) Run() error { return nil } -func (mo *MIDIOutput) Output(payload any) error { +func (mo *MIDIOutput) Output(ctx context.Context, payload any) error { if mo.SendFunc == nil { return errors.New("midi.output output is not setup") } diff --git a/internal/module/module.go b/internal/module/module.go index 5da1d00..8633601 100644 --- a/internal/module/module.go +++ b/internal/module/module.go @@ -7,7 +7,6 @@ import ( "sync" "github.com/jwetzell/showbridge-go/internal/config" - "github.com/jwetzell/showbridge-go/internal/route" ) type ModuleError struct { @@ -20,12 +19,12 @@ type Module interface { Id() string Type() string Run() error - Output(any) error + Output(context.Context, any) error } type ModuleRegistration struct { Type string `json:"type"` - New func(context.Context, config.ModuleConfig, route.RouteIO) (Module, error) + New func(context.Context, config.ModuleConfig) (Module, error) } func RegisterModule(mod ModuleRegistration) { diff --git a/internal/module/mqtt-client.go b/internal/module/mqtt-client.go index 268d327..d169ad8 100644 --- a/internal/module/mqtt-client.go +++ b/internal/module/mqtt-client.go @@ -24,7 +24,7 @@ type MQTTClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "mqtt.client", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params broker, ok := params["broker"] @@ -62,6 +62,12 @@ func init() { return nil, errors.New("mqtt.client clientId must be string") } + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("mqtt.client unable to get router from context") + } + return &MQTTClient{config: config, Broker: brokerString, Topic: topicString, ClientID: clientIdString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -104,7 +110,7 @@ func (mc *MQTTClient) Run() error { return nil } -func (mc *MQTTClient) Output(payload any) error { +func (mc *MQTTClient) Output(ctx context.Context, payload any) error { payloadMessage, ok := payload.(mqtt.Message) if !ok { diff --git a/internal/module/nats-client.go b/internal/module/nats-client.go index 76b4e20..ad13e47 100644 --- a/internal/module/nats-client.go +++ b/internal/module/nats-client.go @@ -24,7 +24,7 @@ type NATSClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "nats.client", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params url, ok := params["url"] @@ -50,6 +50,12 @@ func init() { return nil, errors.New("nats.client subject must be string") } + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("nats.client unable to get router from context") + } + return &NATSClient{config: config, URL: urlString, Subject: subjectString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -92,7 +98,7 @@ func (nc *NATSClient) Run() error { return nil } -func (nc *NATSClient) Output(payload any) error { +func (nc *NATSClient) Output(ctx context.Context, payload any) error { payloadMessage, ok := payload.(processor.NATSMessage) diff --git a/internal/module/psn-client.go b/internal/module/psn-client.go index cb6b7ad..dec52ff 100644 --- a/internal/module/psn-client.go +++ b/internal/module/psn-client.go @@ -2,6 +2,7 @@ package module import ( "context" + "errors" "fmt" "log/slog" "net" @@ -24,8 +25,12 @@ type PSNClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "psn.client", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + if !ok { + return nil, errors.New("psn.client unable to get router from context") + } return &PSNClient{config: config, decoder: psn.NewDecoder(), ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -92,6 +97,6 @@ func (pc *PSNClient) Run() error { } } -func (pc *PSNClient) Output(payload any) error { +func (pc *PSNClient) Output(ctx context.Context, payload any) error { return fmt.Errorf("psn.client output is not implemented") } diff --git a/internal/module/serial-client.go b/internal/module/serial-client.go index 5084b4c..5ba972a 100644 --- a/internal/module/serial-client.go +++ b/internal/module/serial-client.go @@ -29,7 +29,7 @@ type SerialClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "serial.client", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params port, ok := params["port"] @@ -76,6 +76,12 @@ func init() { BaudRate: int(baudRateNum), } + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("serial.client unable to get router from context") + } + return &SerialClient{config: config, Port: portString, Framer: framer, Mode: &mode, ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -162,7 +168,7 @@ func (sc *SerialClient) Run() error { } } -func (sc *SerialClient) Output(payload any) error { +func (sc *SerialClient) Output(ctx context.Context, payload any) error { payloadBytes, ok := payload.([]byte) diff --git a/internal/module/sip-call-server.go b/internal/module/sip-call-server.go index f6fd385..53b7750 100644 --- a/internal/module/sip-call-server.go +++ b/internal/module/sip-call-server.go @@ -35,7 +35,7 @@ type SIPCallMessage struct { func init() { RegisterModule(ModuleRegistration{ Type: "sip.call.server", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params portNum := 5060 @@ -87,6 +87,12 @@ func init() { } userAgentString = specificTransportString } + + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("sip.call.server unable to get router from context") + } return &SIPCallServer{config: config, ctx: ctx, router: router, IP: ipString, Port: int(portNum), Transport: transportString, UserAgent: userAgentString, logger: CreateLogger(config)}, nil }, }) @@ -143,7 +149,7 @@ func (scs *SIPCallServer) HandleCall(inDialog *diago.DialogServerSession) { <-inDialog.Context().Done() } -func (scs *SIPCallServer) Output(payload any) error { +func (scs *SIPCallServer) Output(ctx context.Context, payload any) error { payloadMsg, ok := payload.(string) if !ok { diff --git a/internal/module/sip-dtmf-server.go b/internal/module/sip-dtmf-server.go index 76fe382..c8020f3 100644 --- a/internal/module/sip-dtmf-server.go +++ b/internal/module/sip-dtmf-server.go @@ -35,7 +35,7 @@ type SIPDTMFMessage struct { func init() { RegisterModule(ModuleRegistration{ Type: "sip.dtmf.server", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params portNum := 5060 @@ -91,6 +91,11 @@ func init() { if !strings.ContainsRune("0123456789*#ABCD", rune(separatorString[0])) { return nil, errors.New("sip.dtmf.server separator must be a valid DTMF character") } + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("sip.dtmf.server unable to get router from context") + } return &SIPDTMFServer{config: config, ctx: ctx, router: router, IP: ipString, Port: int(portNum), Transport: transportString, Separator: separatorString, logger: CreateLogger(config)}, nil }, }) @@ -159,6 +164,6 @@ func (sds *SIPDTMFServer) HandleCall(inDialog *diago.DialogServerSession) error }, 5*time.Second) } -func (sds *SIPDTMFServer) Output(payload any) error { +func (sds *SIPDTMFServer) Output(ctx context.Context, payload any) error { return errors.New("sip.dtmf.server output is not implemented") } diff --git a/internal/module/tcp-client.go b/internal/module/tcp-client.go index d45e443..e0942eb 100644 --- a/internal/module/tcp-client.go +++ b/internal/module/tcp-client.go @@ -26,7 +26,7 @@ type TCPClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "net.tcp.client", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params host, ok := params["host"] @@ -75,6 +75,12 @@ func init() { return nil, fmt.Errorf("net.tcp.client unknown framing method: %s", framingMethod) } + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("net.tcp.client unable to get router from context") + } + return &TCPClient{framer: framer, Addr: addr, config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -155,7 +161,7 @@ func (tc *TCPClient) SetupConn() error { return err } -func (tc *TCPClient) Output(payload any) error { +func (tc *TCPClient) Output(ctx context.Context, payload any) error { // NOTE(jwetzell): not sure how this would occur but if tc.conn == nil { err := tc.SetupConn() diff --git a/internal/module/tcp-server.go b/internal/module/tcp-server.go index b22934e..db791d6 100644 --- a/internal/module/tcp-server.go +++ b/internal/module/tcp-server.go @@ -32,7 +32,7 @@ type TCPServer struct { func init() { RegisterModule(ModuleRegistration{ Type: "net.tcp.server", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params port, ok := params["port"] if !ok { @@ -82,6 +82,12 @@ func init() { return nil, err } + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("net.tcp.server unable to get router from context") + } + return &TCPServer{Framer: framer, Addr: addr, config: config, quit: make(chan interface{}), ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -197,7 +203,7 @@ AcceptLoop: return nil } -func (ts *TCPServer) Output(payload any) error { +func (ts *TCPServer) Output(ctx context.Context, payload any) error { payloadBytes, ok := payload.([]byte) if !ok { diff --git a/internal/module/time-interval.go b/internal/module/time-interval.go index f46621b..7b982b7 100644 --- a/internal/module/time-interval.go +++ b/internal/module/time-interval.go @@ -22,7 +22,7 @@ type TimeInterval struct { func init() { RegisterModule(ModuleRegistration{ Type: "time.interval", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params duration, ok := params["duration"] @@ -36,6 +36,12 @@ func init() { return nil, errors.New("time.interval duration must be number") } + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("time.interval unable to get router from context") + } + return &TimeInterval{Duration: uint32(durationNum), config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -68,7 +74,7 @@ func (i *TimeInterval) Run() error { } -func (i *TimeInterval) Output(payload any) error { +func (i *TimeInterval) Output(ctx context.Context, payload any) error { i.ticker.Reset(time.Millisecond * time.Duration(i.Duration)) return nil } diff --git a/internal/module/time-timer.go b/internal/module/time-timer.go index 398ccb6..087898e 100644 --- a/internal/module/time-timer.go +++ b/internal/module/time-timer.go @@ -22,7 +22,7 @@ type TimeTimer struct { func init() { RegisterModule(ModuleRegistration{ Type: "time.timer", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params duration, ok := params["duration"] @@ -36,6 +36,12 @@ func init() { return nil, errors.New("time.timer duration must be a number") } + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("net.tcp.client unable to get router from context") + } + return &TimeTimer{Duration: uint32(durationNum), config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -66,7 +72,7 @@ func (t *TimeTimer) Run() error { } } -func (t *TimeTimer) Output(payload any) error { +func (t *TimeTimer) Output(ctx context.Context, payload any) error { t.timer.Reset(time.Millisecond * time.Duration(t.Duration)) return nil } diff --git a/internal/module/udp-client.go b/internal/module/udp-client.go index b821b88..341e2a8 100644 --- a/internal/module/udp-client.go +++ b/internal/module/udp-client.go @@ -24,7 +24,7 @@ type UDPClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "net.udp.client", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params host, ok := params["host"] @@ -54,6 +54,12 @@ func init() { return nil, err } + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("net.udp.client unable to get router from context") + } + return &UDPClient{Addr: addr, config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -88,7 +94,7 @@ func (uc *UDPClient) Run() error { return nil } -func (uc *UDPClient) Output(payload any) error { +func (uc *UDPClient) Output(ctx context.Context, payload any) error { payloadBytes, ok := payload.([]byte) if !ok { diff --git a/internal/module/udp-multicast.go b/internal/module/udp-multicast.go index b5131c5..c5bc834 100644 --- a/internal/module/udp-multicast.go +++ b/internal/module/udp-multicast.go @@ -24,7 +24,7 @@ type UDPMulticast struct { func init() { RegisterModule(ModuleRegistration{ Type: "net.udp.multicast", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params ip, ok := params["ip"] @@ -53,6 +53,12 @@ func init() { if err != nil { return nil, err } + + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("net.udp.multicast unable to get router from context") + } return &UDPMulticast{config: config, Addr: addr, ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -108,7 +114,7 @@ func (um *UDPMulticast) Run() error { } } -func (um *UDPMulticast) Output(payload any) error { +func (um *UDPMulticast) Output(ctx context.Context, payload any) error { payloadBytes, ok := payload.([]byte) if !ok { diff --git a/internal/module/udp-server.go b/internal/module/udp-server.go index 6c32061..0597bd9 100644 --- a/internal/module/udp-server.go +++ b/internal/module/udp-server.go @@ -25,7 +25,7 @@ type UDPServer struct { func init() { RegisterModule(ModuleRegistration{ Type: "net.udp.server", - New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { + New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { params := config.Params port, ok := params["port"] if !ok { @@ -68,6 +68,11 @@ func init() { bufferSizeNum = int(bufferSizeFloat) } + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return nil, errors.New("net.udp.server unable to get router from context") + } return &UDPServer{Addr: addr, BufferSize: bufferSizeNum, config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil }, }) @@ -119,6 +124,6 @@ func (us *UDPServer) Run() error { } -func (us *UDPServer) Output(payload any) error { +func (us *UDPServer) Output(ctx context.Context, payload any) error { return errors.New("net.udp.server output is not implemented") } diff --git a/internal/route/route.go b/internal/route/route.go index aebe108..e92b0a2 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -2,12 +2,18 @@ package route import ( "context" + "errors" "fmt" "github.com/jwetzell/showbridge-go/internal/config" "github.com/jwetzell/showbridge-go/internal/processor" ) +type routeContextKey string + +var RouterContextKey routeContextKey = routeContextKey("router") +var SourceContextKey routeContextKey = routeContextKey("source") + type RouteError struct { Index int Config config.RouteConfig @@ -21,13 +27,13 @@ type RouteIOError struct { type RouteIO interface { HandleInput(sourceId string, payload any) []RouteIOError - HandleOutput(sourceId string, destinationId string, payload any) error + HandleOutput(ctx context.Context, destinationId string, payload any) error } type Route interface { Input() string Output() string - HandleInput(ctx context.Context, sourceId string, payload any, router RouteIO) error + HandleInput(ctx context.Context, payload any) error } type ProcessorRoute struct { @@ -65,17 +71,24 @@ func (r *ProcessorRoute) Output() string { return r.output } -func (r *ProcessorRoute) HandleInput(ctx context.Context, sourceId string, payload any, router RouteIO) error { - var err error +func (r *ProcessorRoute) HandleInput(ctx context.Context, payload any) error { + router, ok := ctx.Value(RouterContextKey).(RouteIO) + + if !ok { + return errors.New("unable to get router from context") + } + for _, processor := range r.processors { - payload, err = processor.Process(ctx, payload) + processedPayload, err := processor.Process(ctx, payload) if err != nil { return err } //NOTE(jwetzell) nil payload will result in the route being "terminated" - if payload == nil { + if processedPayload == nil { return nil } + payload = processedPayload } - return router.HandleOutput(sourceId, r.output, payload) + + return router.HandleOutput(ctx, r.output, payload) } diff --git a/internal/route/route_test.go b/internal/route/route_test.go index ec91439..a010fe8 100644 --- a/internal/route/route_test.go +++ b/internal/route/route_test.go @@ -1,6 +1,7 @@ package route_test import ( + "context" "testing" "github.com/jwetzell/showbridge-go/internal/config" @@ -51,7 +52,7 @@ func TestGoodRouteHandleInput(t *testing.T) { } inputData := "test input data" - err = testRoute.HandleInput(t.Context(), "input", inputData, &MockRouter{}) + err = testRoute.HandleInput(context.WithValue(t.Context(), route.RouterContextKey, &MockRouter{}), inputData) if err != nil { t.Fatalf("route HandleOutput returned error: %v", err) } @@ -72,7 +73,7 @@ func TestRouteHandleInputWithProcessorError(t *testing.T) { } inputData := "test input data" - err = testRoute.HandleInput(t.Context(), "input", inputData, &MockRouter{}) + err = testRoute.HandleInput(context.WithValue(t.Context(), route.RouterContextKey, &MockRouter{}), inputData) if err == nil { t.Fatalf("route HandleOutput did not return error for bad processor") } @@ -91,7 +92,7 @@ func TestRouteHandleNilPayload(t *testing.T) { return } - err = testRoute.HandleInput(t.Context(), "input", nil, &MockRouter{}) + err = testRoute.HandleInput(context.WithValue(t.Context(), route.RouterContextKey, &MockRouter{}), nil) if err != nil { t.Fatalf("route HandleOutput returned error for nil payload: %v", err) } @@ -111,7 +112,7 @@ func TestRouteHandleNilPayloadFromProcessor(t *testing.T) { t.Fatalf("route failed to create: %v", err) } - err = testRoute.HandleInput(t.Context(), "input", nil, &MockRouter{}) + err = testRoute.HandleInput(context.WithValue(t.Context(), route.RouterContextKey, &MockRouter{}), nil) if err != nil { t.Fatalf("route HandleOutput returned error for nil payload: %v", err) } diff --git a/router.go b/router.go index aae986c..889b979 100644 --- a/router.go +++ b/router.go @@ -24,14 +24,16 @@ type Router struct { func NewRouter(ctx context.Context, config config.Config) (*Router, []module.ModuleError, []route.RouteError) { routerContext, cancel := context.WithCancel(ctx) + router := Router{ - Context: routerContext, contextCancel: cancel, ModuleInstances: []module.Module{}, RouteInstances: []route.Route{}, logger: slog.Default().With("component", "router"), } + router.Context = context.WithValue(routerContext, route.RouterContextKey, &router) + router.logger.Debug("creating") var moduleErrors []module.ModuleError @@ -68,7 +70,7 @@ func NewRouter(ctx context.Context, config config.Config) (*Router, []module.Mod } if !moduleInstanceExists { - moduleInstance, err := moduleInfo.New(router.Context, moduleDecl, &router) + moduleInstance, err := moduleInfo.New(router.Context, moduleDecl) if err != nil { if moduleErrors == nil { moduleErrors = []module.ModuleError{} @@ -131,7 +133,7 @@ func (r *Router) HandleInput(sourceId string, payload any) []route.RouteIOError var routingErrors []route.RouteIOError for routeIndex, routeInstance := range r.RouteInstances { if routeInstance.Input() == sourceId { - err := routeInstance.HandleInput(r.Context, sourceId, payload, r) + err := routeInstance.HandleInput(context.WithValue(r.Context, route.SourceContextKey, sourceId), payload) if err != nil { if routingErrors == nil { routingErrors = []route.RouteIOError{} @@ -147,10 +149,10 @@ func (r *Router) HandleInput(sourceId string, payload any) []route.RouteIOError return routingErrors } -func (r *Router) HandleOutput(sourceId string, destinationId string, payload any) error { +func (r *Router) HandleOutput(ctx context.Context, destinationId string, payload any) error { for _, moduleInstance := range r.ModuleInstances { if moduleInstance.Id() == destinationId { - return moduleInstance.Output(payload) + return moduleInstance.Output(ctx, payload) } } return fmt.Errorf("router could not find module instance for destination %s", destinationId)