From 0c777a88744a5ab36a5d729a8b4ec72650ecad08 Mon Sep 17 00:00:00 2001 From: Joel Wetzell Date: Fri, 16 Jan 2026 18:50:40 -0600 Subject: [PATCH] move router context to run methods --- cmd/showbridge/main.go | 4 ++-- internal/module/http-client.go | 19 ++++++++------- internal/module/http-server.go | 20 +++++++++------- internal/module/midi-input.go | 19 ++++++++------- internal/module/midi-output.go | 19 ++++++++------- internal/module/module.go | 4 ++-- internal/module/mqtt-client.go | 20 +++++++++------- internal/module/nats-client.go | 21 ++++++++++------- internal/module/psn-client.go | 17 +++++++------ internal/module/serial-client.go | 20 +++++++++------- internal/module/sip-call-server.go | 19 ++++++++------- internal/module/sip-dtmf-server.go | 19 ++++++++------- internal/module/tcp-client.go | 21 +++++++++-------- internal/module/tcp-server.go | 22 +++++++++-------- internal/module/time-interval.go | 22 +++++++++-------- internal/module/time-timer.go | 21 ++++++++++------- internal/module/udp-client.go | 21 +++++++++-------- internal/module/udp-multicast.go | 20 +++++++++------- internal/module/udp-server.go | 20 +++++++++------- router.go | 20 ++++++++-------- router_test.go | 38 +++++++++++++++--------------- 21 files changed, 220 insertions(+), 186 deletions(-) diff --git a/cmd/showbridge/main.go b/cmd/showbridge/main.go index 84bdebe..2c726ed 100644 --- a/cmd/showbridge/main.go +++ b/cmd/showbridge/main.go @@ -137,7 +137,7 @@ func run(ctx context.Context, c *cli.Command) error { commandLogger := slog.Default().With("component", "cmd") - router, moduleErrors, routeErrors := showbridge.NewRouter(context.Background(), config) + router, moduleErrors, routeErrors := showbridge.NewRouter(config) for _, moduleError := range moduleErrors { commandLogger.Error("problem initializing module", "index", moduleError.Index, "error", moduleError.Error) @@ -150,7 +150,7 @@ func run(ctx context.Context, c *cli.Command) error { routerRunner := sync.WaitGroup{} routerRunner.Go(func() { - router.Run() + router.Run(context.Background()) }) <-ctx.Done() diff --git a/internal/module/http-client.go b/internal/module/http-client.go index 3df0071..e2ba7b9 100644 --- a/internal/module/http-client.go +++ b/internal/module/http-client.go @@ -22,15 +22,9 @@ type HTTPClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "http.client", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("http.client unable to get router from context") - } - - return &HTTPClient{config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &HTTPClient{config: config, logger: CreateLogger(config)}, nil }, }) } @@ -43,7 +37,14 @@ func (hc *HTTPClient) Type() string { return hc.config.Type } -func (hc *HTTPClient) Run() error { +func (hc *HTTPClient) Run(ctx context.Context) error { + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("http.client unable to get router from context") + } + hc.router = router + hc.ctx = ctx hc.client = &http.Client{ Timeout: 10 * time.Second, diff --git a/internal/module/http-server.go b/internal/module/http-server.go index 01ac186..b90a544 100644 --- a/internal/module/http-server.go +++ b/internal/module/http-server.go @@ -54,7 +54,7 @@ func (hsrw *HTTPServerResponseWriter) Write(data []byte) (int, error) { func init() { RegisterModule(ModuleRegistration{ Type: "http.server", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params port, ok := params["port"] if !ok { @@ -67,13 +67,7 @@ func init() { return nil, errors.New("http.server port must be uint16") } - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("http.server unable to get router from context") - } - - return &HTTPServer{Port: uint16(portNum), config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &HTTPServer{Port: uint16(portNum), config: config, logger: CreateLogger(config)}, nil }, }) } @@ -157,7 +151,15 @@ func (hs *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (hs *HTTPServer) Run() error { +func (hs *HTTPServer) Run(ctx context.Context) error { + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("http.server unable to get router from context") + } + hs.router = router + hs.ctx = ctx + httpServer := &http.Server{ Addr: fmt.Sprintf(":%d", hs.Port), Handler: hs, diff --git a/internal/module/midi-input.go b/internal/module/midi-input.go index ca0f384..f952dc5 100644 --- a/internal/module/midi-input.go +++ b/internal/module/midi-input.go @@ -26,7 +26,7 @@ type MIDIInput struct { func init() { RegisterModule(ModuleRegistration{ Type: "midi.input", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params port, ok := params["port"] @@ -40,13 +40,7 @@ func init() { return nil, errors.New("midi.input port must be a string") } - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("midi.input unable to get router from context") - } - - return &MIDIInput{config: config, Port: portString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &MIDIInput{config: config, Port: portString, logger: CreateLogger(config)}, nil }, }) } @@ -59,8 +53,15 @@ func (mi *MIDIInput) Type() string { return mi.config.Type } -func (mi *MIDIInput) Run() error { +func (mi *MIDIInput) Run(ctx context.Context) error { defer midi.CloseDriver() + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("midi.input unable to get router from context") + } + mi.router = router + mi.ctx = ctx in, err := midi.FindInPort(mi.Port) if err != nil { diff --git a/internal/module/midi-output.go b/internal/module/midi-output.go index d53ab37..cb1207e 100644 --- a/internal/module/midi-output.go +++ b/internal/module/midi-output.go @@ -26,7 +26,7 @@ type MIDIOutput struct { func init() { RegisterModule(ModuleRegistration{ Type: "midi.output", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params port, ok := params["port"] @@ -41,13 +41,7 @@ func init() { return nil, errors.New("midi.output port must be a string") } - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("midi.output unable to get router from context") - } - - return &MIDIOutput{config: config, Port: portString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &MIDIOutput{config: config, Port: portString, logger: CreateLogger(config)}, nil }, }) } @@ -60,8 +54,15 @@ func (mo *MIDIOutput) Type() string { return mo.config.Type } -func (mo *MIDIOutput) Run() error { +func (mo *MIDIOutput) Run(ctx context.Context) error { defer midi.CloseDriver() + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("midi.output unable to get router from context") + } + mo.router = router + mo.ctx = ctx out, err := midi.FindOutPort(mo.Port) diff --git a/internal/module/module.go b/internal/module/module.go index 8633601..002b34a 100644 --- a/internal/module/module.go +++ b/internal/module/module.go @@ -18,13 +18,13 @@ type ModuleError struct { type Module interface { Id() string Type() string - Run() error + Run(context.Context) error Output(context.Context, any) error } type ModuleRegistration struct { Type string `json:"type"` - New func(context.Context, config.ModuleConfig) (Module, error) + New func(config.ModuleConfig) (Module, error) } func RegisterModule(mod ModuleRegistration) { diff --git a/internal/module/mqtt-client.go b/internal/module/mqtt-client.go index b6225ec..65953eb 100644 --- a/internal/module/mqtt-client.go +++ b/internal/module/mqtt-client.go @@ -24,7 +24,7 @@ type MQTTClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "mqtt.client", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params broker, ok := params["broker"] @@ -62,13 +62,7 @@ func init() { return nil, errors.New("mqtt.client clientId must be string") } - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("mqtt.client unable to get router from context") - } - - return &MQTTClient{config: config, Broker: brokerString, Topic: topicString, ClientID: clientIdString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &MQTTClient{config: config, Broker: brokerString, Topic: topicString, ClientID: clientIdString, logger: CreateLogger(config)}, nil }, }) } @@ -81,7 +75,15 @@ func (mc *MQTTClient) Type() string { return mc.config.Type } -func (mc *MQTTClient) Run() error { +func (mc *MQTTClient) Run(ctx context.Context) error { + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("mqtt.client unable to get router from context") + } + mc.router = router + mc.ctx = ctx + opts := mqtt.NewClientOptions() opts.AddBroker(mc.Broker) opts.SetClientID(mc.ClientID) diff --git a/internal/module/nats-client.go b/internal/module/nats-client.go index bd5ad5c..b9ee35a 100644 --- a/internal/module/nats-client.go +++ b/internal/module/nats-client.go @@ -24,7 +24,7 @@ type NATSClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "nats.client", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params url, ok := params["url"] @@ -50,13 +50,7 @@ func init() { return nil, errors.New("nats.client subject must be string") } - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("nats.client unable to get router from context") - } - - return &NATSClient{config: config, URL: urlString, Subject: subjectString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &NATSClient{config: config, URL: urlString, Subject: subjectString, logger: CreateLogger(config)}, nil }, }) } @@ -69,7 +63,16 @@ func (nc *NATSClient) Type() string { return nc.config.Type } -func (nc *NATSClient) Run() error { +func (nc *NATSClient) Run(ctx context.Context) error { + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("nats.client unable to get router from context") + } + + nc.router = router + nc.ctx = ctx + client, err := nats.Connect(nc.URL, nats.RetryOnFailedConnect(true)) if err != nil { diff --git a/internal/module/psn-client.go b/internal/module/psn-client.go index 0a8beb5..d63733c 100644 --- a/internal/module/psn-client.go +++ b/internal/module/psn-client.go @@ -25,13 +25,9 @@ type PSNClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "psn.client", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + New: func(config config.ModuleConfig) (Module, error) { - if !ok { - return nil, errors.New("psn.client unable to get router from context") - } - return &PSNClient{config: config, decoder: psn.NewDecoder(), ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &PSNClient{config: config, decoder: psn.NewDecoder(), logger: CreateLogger(config)}, nil }, }) } @@ -44,7 +40,14 @@ func (pc *PSNClient) Type() string { return pc.config.Type } -func (pc *PSNClient) Run() error { +func (pc *PSNClient) Run(ctx context.Context) error { + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("psn.client unable to get router from context") + } + pc.router = router + pc.ctx = ctx addr, err := net.ResolveUDPAddr("udp", "236.10.10.10:56565") if err != nil { diff --git a/internal/module/serial-client.go b/internal/module/serial-client.go index a7f049c..83a828d 100644 --- a/internal/module/serial-client.go +++ b/internal/module/serial-client.go @@ -29,7 +29,7 @@ type SerialClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "serial.client", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params port, ok := params["port"] @@ -76,13 +76,7 @@ func init() { BaudRate: int(baudRateNum), } - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("serial.client unable to get router from context") - } - - return &SerialClient{config: config, Port: portString, Framer: framer, Mode: &mode, ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &SerialClient{config: config, Port: portString, Framer: framer, Mode: &mode, logger: CreateLogger(config)}, nil }, }) } @@ -107,7 +101,15 @@ func (sc *SerialClient) SetupPort() error { return nil } -func (sc *SerialClient) Run() error { +func (sc *SerialClient) Run(ctx context.Context) error { + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("serial.client unable to get router from context") + } + + sc.router = router + sc.ctx = ctx // TODO(jwetzell): shutdown with router.Context properly go func() { diff --git a/internal/module/sip-call-server.go b/internal/module/sip-call-server.go index af6dd4c..9786a12 100644 --- a/internal/module/sip-call-server.go +++ b/internal/module/sip-call-server.go @@ -45,7 +45,7 @@ type sipCallContextKey string func init() { RegisterModule(ModuleRegistration{ Type: "sip.call.server", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params portNum := 5060 @@ -98,12 +98,7 @@ func init() { userAgentString = specificTransportString } - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("sip.call.server unable to get router from context") - } - return &SIPCallServer{config: config, ctx: ctx, router: router, IP: ipString, Port: int(portNum), Transport: transportString, UserAgent: userAgentString, logger: CreateLogger(config)}, nil + return &SIPCallServer{config: config, IP: ipString, Port: int(portNum), Transport: transportString, UserAgent: userAgentString, logger: CreateLogger(config)}, nil }, }) } @@ -116,7 +111,15 @@ func (scs *SIPCallServer) Type() string { return scs.config.Type } -func (scs *SIPCallServer) Run() error { +func (scs *SIPCallServer) Run(ctx context.Context) error { + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("sip.call.server unable to get router from context") + } + scs.router = router + scs.ctx = ctx + diagoLogger := slog.New(slog.NewJSONHandler(io.Discard, nil)) ua, _ := sipgo.NewUA( diff --git a/internal/module/sip-dtmf-server.go b/internal/module/sip-dtmf-server.go index 3758332..4504e78 100644 --- a/internal/module/sip-dtmf-server.go +++ b/internal/module/sip-dtmf-server.go @@ -44,7 +44,7 @@ type SIPDTMFCall struct { func init() { RegisterModule(ModuleRegistration{ Type: "sip.dtmf.server", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params portNum := 5060 @@ -100,12 +100,7 @@ func init() { if !strings.ContainsRune("0123456789*#ABCD", rune(separatorString[0])) { return nil, errors.New("sip.dtmf.server separator must be a valid DTMF character") } - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("sip.dtmf.server unable to get router from context") - } - return &SIPDTMFServer{config: config, ctx: ctx, router: router, IP: ipString, Port: int(portNum), Transport: transportString, Separator: separatorString, logger: CreateLogger(config)}, nil + return &SIPDTMFServer{config: config, IP: ipString, Port: int(portNum), Transport: transportString, Separator: separatorString, logger: CreateLogger(config)}, nil }, }) } @@ -118,7 +113,15 @@ func (sds *SIPDTMFServer) Type() string { return sds.config.Type } -func (sds *SIPDTMFServer) Run() error { +func (sds *SIPDTMFServer) Run(ctx context.Context) error { + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("sip.dtmf.server unable to get router from context") + } + sds.router = router + sds.ctx = ctx + diagoLogger := slog.New(slog.NewJSONHandler(io.Discard, nil)) ua, _ := sipgo.NewUA( diff --git a/internal/module/tcp-client.go b/internal/module/tcp-client.go index c12a5e1..e762474 100644 --- a/internal/module/tcp-client.go +++ b/internal/module/tcp-client.go @@ -26,7 +26,7 @@ type TCPClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "net.tcp.client", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params host, ok := params["host"] @@ -74,14 +74,7 @@ func init() { if framer == nil { return nil, fmt.Errorf("net.tcp.client unknown framing method: %s", framingMethod) } - - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("net.tcp.client unable to get router from context") - } - - return &TCPClient{framer: framer, Addr: addr, config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &TCPClient{framer: framer, Addr: addr, config: config, logger: CreateLogger(config)}, nil }, }) } @@ -94,7 +87,15 @@ func (tc *TCPClient) Type() string { return tc.config.Type } -func (tc *TCPClient) Run() error { +func (tc *TCPClient) Run(ctx context.Context) error { + + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("net.tcp.client unable to get router from context") + } + tc.router = router + tc.ctx = ctx // TODO(jwetzell): shutdown with router.Context properly go func() { diff --git a/internal/module/tcp-server.go b/internal/module/tcp-server.go index 3419475..43ef7a7 100644 --- a/internal/module/tcp-server.go +++ b/internal/module/tcp-server.go @@ -32,7 +32,7 @@ type TCPServer struct { func init() { RegisterModule(ModuleRegistration{ Type: "net.tcp.server", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params port, ok := params["port"] if !ok { @@ -81,14 +81,7 @@ func init() { if err != nil { return nil, err } - - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("net.tcp.server unable to get router from context") - } - - return &TCPServer{Framer: framer, Addr: addr, config: config, quit: make(chan interface{}), ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &TCPServer{Framer: framer, Addr: addr, config: config, quit: make(chan interface{}), logger: CreateLogger(config)}, nil }, }) } @@ -166,7 +159,16 @@ ClientRead: } } -func (ts *TCPServer) Run() error { +func (ts *TCPServer) Run(ctx context.Context) error { + + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("net.tcp.server unable to get router from context") + } + ts.router = router + ts.ctx = ctx + listener, err := net.ListenTCP("tcp", ts.Addr) if err != nil { return err diff --git a/internal/module/time-interval.go b/internal/module/time-interval.go index 11178f7..af5dbfe 100644 --- a/internal/module/time-interval.go +++ b/internal/module/time-interval.go @@ -22,7 +22,7 @@ type TimeInterval struct { func init() { RegisterModule(ModuleRegistration{ Type: "time.interval", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params duration, ok := params["duration"] @@ -35,14 +35,7 @@ func init() { if !ok { return nil, errors.New("time.interval duration must be number") } - - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("time.interval unable to get router from context") - } - - return &TimeInterval{Duration: uint32(durationNum), config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &TimeInterval{Duration: uint32(durationNum), config: config, logger: CreateLogger(config)}, nil }, }) } @@ -55,7 +48,16 @@ func (i *TimeInterval) Type() string { return i.config.Type } -func (i *TimeInterval) Run() error { +func (i *TimeInterval) Run(ctx context.Context) error { + + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("time.interval unable to get router from context") + } + i.router = router + i.ctx = ctx + ticker := time.NewTicker(time.Millisecond * time.Duration(i.Duration)) i.ticker = ticker defer ticker.Stop() diff --git a/internal/module/time-timer.go b/internal/module/time-timer.go index ea2c2e4..06828f6 100644 --- a/internal/module/time-timer.go +++ b/internal/module/time-timer.go @@ -22,7 +22,7 @@ type TimeTimer struct { func init() { RegisterModule(ModuleRegistration{ Type: "time.timer", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params duration, ok := params["duration"] @@ -36,13 +36,7 @@ func init() { return nil, errors.New("time.timer duration must be a number") } - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("net.tcp.client unable to get router from context") - } - - return &TimeTimer{Duration: uint32(durationNum), config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &TimeTimer{Duration: uint32(durationNum), config: config, logger: CreateLogger(config)}, nil }, }) } @@ -55,7 +49,16 @@ func (t *TimeTimer) Type() string { return t.config.Type } -func (t *TimeTimer) Run() error { +func (t *TimeTimer) Run(ctx context.Context) error { + + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("net.tcp.client unable to get router from context") + } + t.router = router + t.ctx = ctx + t.timer = time.NewTimer(time.Millisecond * time.Duration(t.Duration)) defer t.timer.Stop() for { diff --git a/internal/module/udp-client.go b/internal/module/udp-client.go index 341e2a8..b3b2511 100644 --- a/internal/module/udp-client.go +++ b/internal/module/udp-client.go @@ -24,7 +24,7 @@ type UDPClient struct { func init() { RegisterModule(ModuleRegistration{ Type: "net.udp.client", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params host, ok := params["host"] @@ -53,14 +53,7 @@ func init() { if err != nil { return nil, err } - - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("net.udp.client unable to get router from context") - } - - return &UDPClient{Addr: addr, config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &UDPClient{Addr: addr, config: config, logger: CreateLogger(config)}, nil }, }) } @@ -79,7 +72,15 @@ func (uc *UDPClient) SetupConn() error { return err } -func (uc *UDPClient) Run() error { +func (uc *UDPClient) Run(ctx context.Context) error { + + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("net.udp.client unable to get router from context") + } + uc.router = router + uc.ctx = ctx err := uc.SetupConn() if err != nil { diff --git a/internal/module/udp-multicast.go b/internal/module/udp-multicast.go index 3840056..92c6d8b 100644 --- a/internal/module/udp-multicast.go +++ b/internal/module/udp-multicast.go @@ -24,7 +24,7 @@ type UDPMulticast struct { func init() { RegisterModule(ModuleRegistration{ Type: "net.udp.multicast", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params ip, ok := params["ip"] @@ -53,13 +53,7 @@ func init() { if err != nil { return nil, err } - - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("net.udp.multicast unable to get router from context") - } - return &UDPMulticast{config: config, Addr: addr, ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &UDPMulticast{config: config, Addr: addr, logger: CreateLogger(config)}, nil }, }) } @@ -72,7 +66,15 @@ func (um *UDPMulticast) Type() string { return um.config.Type } -func (um *UDPMulticast) Run() error { +func (um *UDPMulticast) Run(ctx context.Context) error { + + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("net.udp.multicast unable to get router from context") + } + um.router = router + um.ctx = ctx client, err := net.ListenMulticastUDP("udp", nil, um.Addr) if err != nil { diff --git a/internal/module/udp-server.go b/internal/module/udp-server.go index f551605..c749153 100644 --- a/internal/module/udp-server.go +++ b/internal/module/udp-server.go @@ -25,7 +25,7 @@ type UDPServer struct { func init() { RegisterModule(ModuleRegistration{ Type: "net.udp.server", - New: func(ctx context.Context, config config.ModuleConfig) (Module, error) { + New: func(config config.ModuleConfig) (Module, error) { params := config.Params port, ok := params["port"] if !ok { @@ -67,13 +67,7 @@ func init() { } bufferSizeNum = int(bufferSizeFloat) } - - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("net.udp.server unable to get router from context") - } - return &UDPServer{Addr: addr, BufferSize: bufferSizeNum, config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil + return &UDPServer{Addr: addr, BufferSize: bufferSizeNum, config: config, logger: CreateLogger(config)}, nil }, }) } @@ -86,7 +80,15 @@ func (us *UDPServer) Type() string { return us.config.Id } -func (us *UDPServer) Run() error { +func (us *UDPServer) Run(ctx context.Context) error { + + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return errors.New("net.udp.server unable to get router from context") + } + us.router = router + us.ctx = ctx listener, err := net.ListenUDP("udp", us.Addr) if err != nil { diff --git a/router.go b/router.go index 574cd7a..4dbd5ea 100644 --- a/router.go +++ b/router.go @@ -20,19 +20,14 @@ type Router struct { logger *slog.Logger } -func NewRouter(ctx context.Context, config config.Config) (*Router, []module.ModuleError, []route.RouteError) { - - routerContext, cancel := context.WithCancel(ctx) +func NewRouter(config config.Config) (*Router, []module.ModuleError, []route.RouteError) { router := Router{ - contextCancel: cancel, ModuleInstances: []module.Module{}, RouteInstances: []route.Route{}, logger: slog.Default().With("component", "router"), } - router.Context = context.WithValue(routerContext, route.RouterContextKey, &router) - router.logger.Debug("creating") var moduleErrors []module.ModuleError @@ -69,7 +64,7 @@ func NewRouter(ctx context.Context, config config.Config) (*Router, []module.Mod } if !moduleInstanceExists { - moduleInstance, err := moduleInfo.New(router.Context, moduleDecl) + moduleInstance, err := moduleInfo.New(moduleDecl) if err != nil { if moduleErrors == nil { moduleErrors = []module.ModuleError{} @@ -107,11 +102,16 @@ func NewRouter(ctx context.Context, config config.Config) (*Router, []module.Mod return &router, moduleErrors, routeErrors } -func (r *Router) Run() { +func (r *Router) Run(ctx context.Context) { r.logger.Info("running") + routerContext, cancel := context.WithCancel(ctx) + r.Context = routerContext + r.contextCancel = cancel + contextWithRouter := context.WithValue(routerContext, route.RouterContextKey, r) + for _, moduleInstance := range r.ModuleInstances { r.moduleWait.Go(func() { - err := moduleInstance.Run() + err := moduleInstance.Run(contextWithRouter) if err != nil { r.logger.Error("error encountered running module", "error", err) } @@ -124,7 +124,7 @@ func (r *Router) Run() { } func (r *Router) Stop() { - r.logger.Debug("stopping") + r.logger.Info("stopping") r.contextCancel() } diff --git a/router_test.go b/router_test.go index ecca63d..7125c6b 100644 --- a/router_test.go +++ b/router_test.go @@ -2,11 +2,11 @@ package showbridge_test import ( "context" - "errors" "fmt" "log/slog" "sync" "testing" + "time" "github.com/jwetzell/showbridge-go" "github.com/jwetzell/showbridge-go/internal/config" @@ -31,7 +31,8 @@ func (m *MockModule) Output(context.Context, any) error { return nil } -func (m *MockModule) Run() error { +func (m *MockModule) Run(ctx context.Context) error { + m.ctx = ctx <-m.ctx.Done() return nil } @@ -43,15 +44,9 @@ func (m *MockModule) Type() string { func init() { module.RegisterModule(module.ModuleRegistration{ Type: "mock.counter", - New: func(ctx context.Context, config config.ModuleConfig) (module.Module, error) { + New: func(config config.ModuleConfig) (module.Module, error) { - router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) - - if !ok { - return nil, errors.New("mock.counter unable to get router from context") - } - - return &MockModule{config: config, ctx: ctx, router: router, logger: slog.Default()}, nil + return &MockModule{config: config, logger: slog.Default()}, nil }, }) } @@ -66,7 +61,7 @@ func TestNewRouter(t *testing.T) { }, } - _, moduleErrors, routeErrors := showbridge.NewRouter(t.Context(), routerConfig) + _, moduleErrors, routeErrors := showbridge.NewRouter(routerConfig) if moduleErrors != nil { t.Fatalf("router should not have returned any module errors: %v", moduleErrors) @@ -87,7 +82,7 @@ func TestNewRouterUnknownModuleType(t *testing.T) { }, } - _, moduleErrors, _ := showbridge.NewRouter(t.Context(), routerConfig) + _, moduleErrors, _ := showbridge.NewRouter(routerConfig) if moduleErrors == nil { t.Fatalf("router should have returned 'unknown module' module errors") @@ -108,7 +103,7 @@ func TestNewRouterDuplicateModuleId(t *testing.T) { }, } - _, moduleErrors, _ := showbridge.NewRouter(t.Context(), routerConfig) + _, moduleErrors, _ := showbridge.NewRouter(routerConfig) if moduleErrors == nil { t.Fatalf("router should have returned 'duplicate id' module error") @@ -131,7 +126,7 @@ func TestRouterInputSingleRoute(t *testing.T) { }, } - router, moduleErrors, routeErrors := showbridge.NewRouter(t.Context(), routerConfig) + router, moduleErrors, routeErrors := showbridge.NewRouter(routerConfig) if moduleErrors != nil { t.Fatalf("router should not have returned any module errors: %v", moduleErrors) @@ -144,9 +139,11 @@ func TestRouterInputSingleRoute(t *testing.T) { routerRunner := sync.WaitGroup{} routerRunner.Go(func() { - router.Run() + router.Run(t.Context()) }) + time.Sleep(time.Second * 1) + defer router.Stop() mockModuleInputCount := 3 @@ -200,7 +197,7 @@ func TestRouterInputMultipleRoutes(t *testing.T) { }, } - router, moduleErrors, routeErrors := showbridge.NewRouter(t.Context(), routerConfig) + router, moduleErrors, routeErrors := showbridge.NewRouter(routerConfig) if moduleErrors != nil { t.Fatalf("router should not have returned any module errors: %v", moduleErrors) @@ -213,8 +210,9 @@ func TestRouterInputMultipleRoutes(t *testing.T) { routerRunner := sync.WaitGroup{} routerRunner.Go(func() { - router.Run() + router.Run(t.Context()) }) + time.Sleep(time.Second * 1) defer router.Stop() @@ -270,7 +268,7 @@ func TestRouterInputMultipleModules(t *testing.T) { }, } - router, moduleErrors, routeErrors := showbridge.NewRouter(t.Context(), routerConfig) + router, moduleErrors, routeErrors := showbridge.NewRouter(routerConfig) if moduleErrors != nil { t.Fatalf("router should not have returned any module errors: %v", moduleErrors) @@ -283,9 +281,11 @@ func TestRouterInputMultipleModules(t *testing.T) { routerRunner := sync.WaitGroup{} routerRunner.Go(func() { - router.Run() + router.Run(t.Context()) }) + time.Sleep(time.Second * 1) + defer router.Stop() mock1ModuleInputCount := 3