mirror of
https://github.com/jwetzell/showbridge-go.git
synced 2026-04-26 21:05:30 +00:00
load router from context in mock module
This commit is contained in:
@@ -19,7 +19,7 @@ var (
|
||||
tracer = otel.Tracer("showbridge.test")
|
||||
)
|
||||
|
||||
type MockModule struct {
|
||||
type MockCounterModule struct {
|
||||
config config.ModuleConfig
|
||||
ctx context.Context
|
||||
outputCount int
|
||||
@@ -27,31 +27,36 @@ type MockModule struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (m *MockModule) Id() string {
|
||||
return m.config.Id
|
||||
func (mcm *MockCounterModule) Id() string {
|
||||
return mcm.config.Id
|
||||
}
|
||||
|
||||
func (m *MockModule) Output(context.Context, any) error {
|
||||
m.outputCount += 1
|
||||
func (mcm *MockCounterModule) Output(context.Context, any) error {
|
||||
mcm.outputCount += 1
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockModule) Run(ctx context.Context) error {
|
||||
m.ctx = ctx
|
||||
<-m.ctx.Done()
|
||||
func (mcm *MockCounterModule) Run(ctx context.Context) error {
|
||||
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("mock.counter could not get router from context")
|
||||
}
|
||||
mcm.router = router
|
||||
mcm.ctx = ctx
|
||||
<-mcm.ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockModule) Type() string {
|
||||
return m.config.Type
|
||||
func (mcm *MockCounterModule) Type() string {
|
||||
return mcm.config.Type
|
||||
}
|
||||
|
||||
func init() {
|
||||
module.RegisterModule(module.ModuleRegistration{
|
||||
Type: "mock.counter",
|
||||
New: func(config config.ModuleConfig) (module.Module, error) {
|
||||
|
||||
return &MockModule{config: config, logger: slog.Default()}, nil
|
||||
return &MockCounterModule{config: config, logger: slog.Default()}, nil
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -162,6 +167,7 @@ func TestRouterInputSingleRoute(t *testing.T) {
|
||||
|
||||
routerRunner.Go(func() {
|
||||
router.Run(t.Context())
|
||||
fmt.Println("router stopped")
|
||||
})
|
||||
|
||||
time.Sleep(time.Second * 1)
|
||||
@@ -183,7 +189,7 @@ func TestRouterInputSingleRoute(t *testing.T) {
|
||||
|
||||
for _, moduleInstance := range router.ModuleInstances {
|
||||
if moduleInstance.Id() == "mock" {
|
||||
mockModuleInstance, ok := moduleInstance.(*MockModule)
|
||||
mockModuleInstance, ok := moduleInstance.(*MockCounterModule)
|
||||
if !ok {
|
||||
t.Fatalf("couldn't get mock module")
|
||||
}
|
||||
@@ -253,7 +259,7 @@ func TestRouterInputMultipleRoutes(t *testing.T) {
|
||||
|
||||
for _, moduleInstance := range router.ModuleInstances {
|
||||
if moduleInstance.Id() == "mock" {
|
||||
mockModuleInstance, ok := moduleInstance.(*MockModule)
|
||||
mockModuleInstance, ok := moduleInstance.(*MockCounterModule)
|
||||
if !ok {
|
||||
t.Fatalf("couldn't get mock module")
|
||||
}
|
||||
@@ -338,7 +344,7 @@ func TestRouterInputMultipleModules(t *testing.T) {
|
||||
|
||||
for _, moduleInstance := range router.ModuleInstances {
|
||||
if moduleInstance.Id() == "mock1" {
|
||||
mockModuleInstance, ok := moduleInstance.(*MockModule)
|
||||
mockModuleInstance, ok := moduleInstance.(*MockCounterModule)
|
||||
if !ok {
|
||||
t.Fatalf("couldn't get mock module")
|
||||
}
|
||||
@@ -349,7 +355,7 @@ func TestRouterInputMultipleModules(t *testing.T) {
|
||||
break
|
||||
}
|
||||
if moduleInstance.Id() == "mock2" {
|
||||
mockModuleInstance, ok := moduleInstance.(*MockModule)
|
||||
mockModuleInstance, ok := moduleInstance.(*MockCounterModule)
|
||||
if !ok {
|
||||
t.Fatalf("couldn't get mock module")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user