From 51e656313ce09162e02d4bb2651e1a0564847135 Mon Sep 17 00:00:00 2001 From: Joel Wetzell Date: Sat, 22 Nov 2025 11:31:09 -0600 Subject: [PATCH] do some decent context reworking --- cmd/cli/main.go | 6 +++++- interval.go | 7 +++---- module.go | 3 +-- router.go | 20 ++++++++++++++++++-- tcp-client.go | 25 +++++++++++++++++++------ tcp-server.go | 14 ++++++++++---- timer.go | 2 +- udp-client.go | 7 +++---- udp-server.go | 7 ++++--- 9 files changed, 64 insertions(+), 27 deletions(-) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index b933bac..c4b5b25 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -4,7 +4,9 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "os" + "os/signal" "github.com/jwetzell/showbridge-go" "github.com/urfave/cli/v3" @@ -42,7 +44,9 @@ func main() { }, } - err := cmd.Run(context.Background(), os.Args) + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Interrupt) + defer cancel() + err := cmd.Run(ctx, os.Args) if err != nil { panic(err) diff --git a/interval.go b/interval.go index ba06d70..369b9f8 100644 --- a/interval.go +++ b/interval.go @@ -1,7 +1,6 @@ package showbridge import ( - "context" "fmt" "log/slog" "time" @@ -44,17 +43,17 @@ func (i *Interval) Type() string { } func (i *Interval) RegisterRouter(router *Router) { - slog.Debug("registering router", "id", i.config.Id) i.router = router } -func (i *Interval) Run(ctx context.Context) error { +func (i *Interval) Run() error { ticker := time.NewTicker(time.Millisecond * time.Duration(i.Duration)) defer ticker.Stop() for { select { - case <-ctx.Done(): + case <-i.router.Context.Done(): ticker.Stop() + slog.Debug("router context done in module", "id", i.config.Id) return nil case t := <-ticker.C: if i.router != nil { diff --git a/module.go b/module.go index c8d4d52..0b09aa8 100644 --- a/module.go +++ b/module.go @@ -1,7 +1,6 @@ package showbridge import ( - "context" "fmt" "sync" ) @@ -16,7 +15,7 @@ type Module interface { Id() string Type() string RegisterRouter(*Router) - Run(context.Context) error + Run() error Output(any) error } diff --git a/router.go b/router.go index f82ad66..47f52b5 100644 --- a/router.go +++ b/router.go @@ -5,12 +5,15 @@ import ( "fmt" "log/slog" "os" + "sync" ) type Router struct { + contextCancel context.CancelFunc Context context.Context ModuleInstances []Module RouteInstances []*Route + moduleWait sync.WaitGroup } func NewRouter(ctx context.Context, config Config) (*Router, []ModuleError, []RouteError) { @@ -23,8 +26,10 @@ func NewRouter(ctx context.Context, config Config) (*Router, []ModuleError, []Ro slog.Debug("creating router") + routerContext, cancel := context.WithCancel(ctx) router := Router{ - Context: ctx, + Context: routerContext, + contextCancel: cancel, ModuleInstances: []Module{}, RouteInstances: []*Route{}, } @@ -108,9 +113,20 @@ func NewRouter(ctx context.Context, config Config) (*Router, []ModuleError, []Ro func (r *Router) Run() { for _, moduleInstance := range r.ModuleInstances { - go moduleInstance.Run(r.Context) + moduleInstance.RegisterRouter(r) + r.moduleWait.Add(1) + go func() { + moduleInstance.Run() + r.moduleWait.Done() + }() } <-r.Context.Done() + r.moduleWait.Wait() + slog.Info("router context done") +} + +func (r *Router) Stop() { + r.contextCancel() } func (r *Router) HandleInput(sourceId string, payload any) { diff --git a/tcp-client.go b/tcp-client.go index 38e95c3..df6c0b1 100644 --- a/tcp-client.go +++ b/tcp-client.go @@ -1,7 +1,6 @@ package showbridge import ( - "context" "fmt" "log/slog" "net" @@ -87,18 +86,31 @@ func (tc *TCPClient) Type() string { } func (tc *TCPClient) RegisterRouter(router *Router) { - slog.Debug("registering router", "id", tc.config.Id) tc.router = router } -func (tc *TCPClient) Run(ctx context.Context) error { +func (tc *TCPClient) Run() error { addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", tc.Host, tc.Port)) if err != nil { return err } + + // TODO(jwetzell): shutdown with router.Context properly + go func() { + <-tc.router.Context.Done() + slog.Debug("router context done in module", "id", tc.config.Id) + if tc.conn != nil { + tc.conn.Close() + } + }() + for { client, err := net.DialTCP("tcp", nil, addr) if err != nil { + if tc.router.Context.Err() != nil { + slog.Debug("router context done in module", "id", tc.config.Id) + return nil + } slog.Error(err.Error()) time.Sleep(time.Second * 2) continue @@ -108,19 +120,20 @@ func (tc *TCPClient) Run(ctx context.Context) error { buffer := make([]byte, 1024) select { - case <-ctx.Done(): + case <-tc.router.Context.Done(): + slog.Debug("router context done in module", "id", tc.config.Id) return nil default: READ: for { select { - case <-ctx.Done(): + case <-tc.router.Context.Done(): + slog.Debug("router context done in module", "id", tc.config.Id) return nil default: byteCount, err := client.Read(buffer) if err != nil { - slog.Debug("connection closed") tc.framer.Clear() break READ } diff --git a/tcp-server.go b/tcp-server.go index 9ae933f..8c0833d 100644 --- a/tcp-server.go +++ b/tcp-server.go @@ -57,7 +57,6 @@ func (ts *TCPServer) Type() string { } func (ts *TCPServer) RegisterRouter(router *Router) { - slog.Debug("registering router", "id", ts.config.Id) ts.router = router } @@ -108,22 +107,29 @@ func (ts *TCPServer) HandleClient(ctx context.Context, client net.Conn) { } } -func (ts TCPServer) Run(ctx context.Context) error { +func (ts TCPServer) Run() error { listener, err := net.Listen("tcp", fmt.Sprintf(":%d", ts.Port)) if err != nil { return err } + // TODO(jwetzell): shutdown with router.Context properly + go func() { + <-ts.router.Context.Done() + slog.Debug("router context done in module", "id", ts.config.Id) + listener.Close() + }() + for { select { - case <-ctx.Done(): + case <-ts.router.Context.Done(): return nil default: client, err := listener.Accept() if err != nil { return err } - go ts.HandleClient(ctx, client) + go ts.HandleClient(ts.router.Context, client) } } } diff --git a/timer.go b/timer.go index b59b9d8..f858a47 100644 --- a/timer.go +++ b/timer.go @@ -45,7 +45,6 @@ func (t *Timer) Type() string { } func (t *Timer) RegisterRouter(router *Router) { - slog.Debug("registering router", "id", t.config.Id) t.router = router } @@ -56,6 +55,7 @@ func (t *Timer) Run(ctx context.Context) error { select { case <-ctx.Done(): t.timer.Stop() + slog.Debug("router context done in module", "id", t.config.Id) return nil case time := <-t.timer.C: if t.router != nil { diff --git a/udp-client.go b/udp-client.go index dfc2ab5..ba45c36 100644 --- a/udp-client.go +++ b/udp-client.go @@ -1,7 +1,6 @@ package showbridge import ( - "context" "fmt" "log/slog" "net" @@ -57,11 +56,10 @@ func (uc *UDPClient) Type() string { } func (uc *UDPClient) RegisterRouter(router *Router) { - slog.Debug("registering router", "id", uc.config.Id) uc.router = router } -func (uc *UDPClient) Run(ctx context.Context) error { +func (uc *UDPClient) Run() error { addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", uc.Host, uc.Port)) if err != nil { return err @@ -72,7 +70,8 @@ func (uc *UDPClient) Run(ctx context.Context) error { } uc.conn = client - <-ctx.Done() + <-uc.router.Context.Done() + slog.Debug("router context done in module", "id", uc.config.Id) return nil } diff --git a/udp-server.go b/udp-server.go index 5631652..7ee1cb9 100644 --- a/udp-server.go +++ b/udp-server.go @@ -1,7 +1,6 @@ package showbridge import ( - "context" "fmt" "log" "log/slog" @@ -47,7 +46,7 @@ func (us *UDPServer) RegisterRouter(router *Router) { us.router = router } -func (us *UDPServer) Run(ctx context.Context) error { +func (us *UDPServer) Run() error { addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", us.Port)) if err != nil { @@ -64,7 +63,9 @@ func (us *UDPServer) Run(ctx context.Context) error { buffer := make([]byte, 1024) for { select { - case <-ctx.Done(): + case <-us.router.Context.Done(): + // TODO(jwetzell): cleanup? + slog.Debug("router context done in module", "id", us.config.Id) return nil default: numBytes, _, err := listener.ReadFromUDP(buffer)