diff --git a/router.go b/router.go index 8a03d3b..2789411 100644 --- a/router.go +++ b/router.go @@ -127,6 +127,7 @@ func (r *Router) Run(ctx context.Context) { r.moduleWait.Go(func() { err := moduleInstance.Run(contextWithRouter) if err != nil { + // TODO(jwetzell): handle module run errors better r.logger.Error("error encountered running module", "error", err) } }) diff --git a/router_test.go b/router_test.go index b10aeda..9af83b4 100644 --- a/router_test.go +++ b/router_test.go @@ -19,7 +19,7 @@ var ( tracer = otel.Tracer("showbridge.test") ) -type MockModule struct { +type MockCounterModule struct { config config.ModuleConfig ctx context.Context outputCount int @@ -27,31 +27,36 @@ type MockModule struct { logger *slog.Logger } -func (m *MockModule) Id() string { - return m.config.Id +func (mcm *MockCounterModule) Id() string { + return mcm.config.Id } -func (m *MockModule) Output(context.Context, any) error { - m.outputCount += 1 +func (mcm *MockCounterModule) Output(context.Context, any) error { + mcm.outputCount += 1 return nil } -func (m *MockModule) Run(ctx context.Context) error { - m.ctx = ctx - <-m.ctx.Done() +func (mcm *MockCounterModule) Run(ctx context.Context) error { + router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO) + + if !ok { + return fmt.Errorf("mock.counter could not get router from context") + } + mcm.router = router + mcm.ctx = ctx + <-mcm.ctx.Done() return nil } -func (m *MockModule) Type() string { - return m.config.Type +func (mcm *MockCounterModule) Type() string { + return mcm.config.Type } func init() { module.RegisterModule(module.ModuleRegistration{ Type: "mock.counter", New: func(config config.ModuleConfig) (module.Module, error) { - - return &MockModule{config: config, logger: slog.Default()}, nil + return &MockCounterModule{config: config, logger: slog.Default()}, nil }, }) } @@ -162,6 +167,7 @@ func TestRouterInputSingleRoute(t *testing.T) { routerRunner.Go(func() { router.Run(t.Context()) + fmt.Println("router stopped") }) time.Sleep(time.Second * 1) @@ -183,7 +189,7 @@ func TestRouterInputSingleRoute(t *testing.T) { for _, moduleInstance := range router.ModuleInstances { if moduleInstance.Id() == "mock" { - mockModuleInstance, ok := moduleInstance.(*MockModule) + mockModuleInstance, ok := moduleInstance.(*MockCounterModule) if !ok { t.Fatalf("couldn't get mock module") } @@ -253,7 +259,7 @@ func TestRouterInputMultipleRoutes(t *testing.T) { for _, moduleInstance := range router.ModuleInstances { if moduleInstance.Id() == "mock" { - mockModuleInstance, ok := moduleInstance.(*MockModule) + mockModuleInstance, ok := moduleInstance.(*MockCounterModule) if !ok { t.Fatalf("couldn't get mock module") } @@ -338,7 +344,7 @@ func TestRouterInputMultipleModules(t *testing.T) { for _, moduleInstance := range router.ModuleInstances { if moduleInstance.Id() == "mock1" { - mockModuleInstance, ok := moduleInstance.(*MockModule) + mockModuleInstance, ok := moduleInstance.(*MockCounterModule) if !ok { t.Fatalf("couldn't get mock module") } @@ -349,7 +355,7 @@ func TestRouterInputMultipleModules(t *testing.T) { break } if moduleInstance.Id() == "mock2" { - mockModuleInstance, ok := moduleInstance.(*MockModule) + mockModuleInstance, ok := moduleInstance.(*MockCounterModule) if !ok { t.Fatalf("couldn't get mock module") }