load router from context in mock module

This commit is contained in:
Joel Wetzell
2026-02-05 20:20:00 -06:00
parent b095419b6e
commit 8f5091cf9b
2 changed files with 23 additions and 16 deletions

View File

@@ -19,7 +19,7 @@ var (
tracer = otel.Tracer("showbridge.test")
)
type MockModule struct {
type MockCounterModule struct {
config config.ModuleConfig
ctx context.Context
outputCount int
@@ -27,31 +27,36 @@ type MockModule struct {
logger *slog.Logger
}
func (m *MockModule) Id() string {
return m.config.Id
func (mcm *MockCounterModule) Id() string {
return mcm.config.Id
}
func (m *MockModule) Output(context.Context, any) error {
m.outputCount += 1
func (mcm *MockCounterModule) Output(context.Context, any) error {
mcm.outputCount += 1
return nil
}
func (m *MockModule) Run(ctx context.Context) error {
m.ctx = ctx
<-m.ctx.Done()
func (mcm *MockCounterModule) Run(ctx context.Context) error {
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return fmt.Errorf("mock.counter could not get router from context")
}
mcm.router = router
mcm.ctx = ctx
<-mcm.ctx.Done()
return nil
}
func (m *MockModule) Type() string {
return m.config.Type
func (mcm *MockCounterModule) Type() string {
return mcm.config.Type
}
func init() {
module.RegisterModule(module.ModuleRegistration{
Type: "mock.counter",
New: func(config config.ModuleConfig) (module.Module, error) {
return &MockModule{config: config, logger: slog.Default()}, nil
return &MockCounterModule{config: config, logger: slog.Default()}, nil
},
})
}
@@ -162,6 +167,7 @@ func TestRouterInputSingleRoute(t *testing.T) {
routerRunner.Go(func() {
router.Run(t.Context())
fmt.Println("router stopped")
})
time.Sleep(time.Second * 1)
@@ -183,7 +189,7 @@ func TestRouterInputSingleRoute(t *testing.T) {
for _, moduleInstance := range router.ModuleInstances {
if moduleInstance.Id() == "mock" {
mockModuleInstance, ok := moduleInstance.(*MockModule)
mockModuleInstance, ok := moduleInstance.(*MockCounterModule)
if !ok {
t.Fatalf("couldn't get mock module")
}
@@ -253,7 +259,7 @@ func TestRouterInputMultipleRoutes(t *testing.T) {
for _, moduleInstance := range router.ModuleInstances {
if moduleInstance.Id() == "mock" {
mockModuleInstance, ok := moduleInstance.(*MockModule)
mockModuleInstance, ok := moduleInstance.(*MockCounterModule)
if !ok {
t.Fatalf("couldn't get mock module")
}
@@ -338,7 +344,7 @@ func TestRouterInputMultipleModules(t *testing.T) {
for _, moduleInstance := range router.ModuleInstances {
if moduleInstance.Id() == "mock1" {
mockModuleInstance, ok := moduleInstance.(*MockModule)
mockModuleInstance, ok := moduleInstance.(*MockCounterModule)
if !ok {
t.Fatalf("couldn't get mock module")
}
@@ -349,7 +355,7 @@ func TestRouterInputMultipleModules(t *testing.T) {
break
}
if moduleInstance.Id() == "mock2" {
mockModuleInstance, ok := moduleInstance.(*MockModule)
mockModuleInstance, ok := moduleInstance.(*MockCounterModule)
if !ok {
t.Fatalf("couldn't get mock module")
}