stuff values into context

This commit is contained in:
Joel Wetzell
2025-12-27 22:59:30 -06:00
parent 3458b52206
commit 8ffc7d02a5
21 changed files with 167 additions and 53 deletions

View File

@@ -22,7 +22,13 @@ type HTTPClient struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "http.client", Type: "http.client",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("http.client unable to get router from context")
}
return &HTTPClient{config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &HTTPClient{config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
@@ -48,7 +54,7 @@ func (hc *HTTPClient) Run() error {
return nil return nil
} }
func (hc *HTTPClient) Output(payload any) error { func (hc *HTTPClient) Output(ctx context.Context, payload any) error {
payloadRequest, ok := payload.(*http.Request) payloadRequest, ok := payload.(*http.Request)

View File

@@ -28,7 +28,7 @@ type ResponseData struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "http.server", Type: "http.server",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
port, ok := params["port"] port, ok := params["port"]
if !ok { if !ok {
@@ -41,6 +41,12 @@ func init() {
return nil, errors.New("http.server port must be uint16") return nil, errors.New("http.server port must be uint16")
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("http.server unable to get router from context")
}
return &HTTPServer{Port: uint16(portNum), config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &HTTPServer{Port: uint16(portNum), config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -105,6 +111,6 @@ func (hs *HTTPServer) Run() error {
return nil return nil
} }
func (hs *HTTPServer) Output(payload any) error { func (hs *HTTPServer) Output(ctx context.Context, payload any) error {
return errors.New("http.server output is not implemented") return errors.New("http.server output is not implemented")
} }

View File

@@ -26,7 +26,7 @@ type MIDIInput struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "midi.input", Type: "midi.input",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
port, ok := params["port"] port, ok := params["port"]
@@ -40,6 +40,12 @@ func init() {
return nil, errors.New("midi.input port must be a string") return nil, errors.New("midi.input port must be a string")
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("midi.input unable to get router from context")
}
return &MIDIInput{config: config, Port: portString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &MIDIInput{config: config, Port: portString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -78,6 +84,6 @@ func (mi *MIDIInput) Run() error {
return nil return nil
} }
func (mi *MIDIInput) Output(payload any) error { func (mi *MIDIInput) Output(ctx context.Context, payload any) error {
return errors.New("midi.input output is not implemented") return errors.New("midi.input output is not implemented")
} }

View File

@@ -26,7 +26,7 @@ type MIDIOutput struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "midi.output", Type: "midi.output",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
port, ok := params["port"] port, ok := params["port"]
@@ -41,6 +41,12 @@ func init() {
return nil, errors.New("midi.output port must be a string") return nil, errors.New("midi.output port must be a string")
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("midi.output unable to get router from context")
}
return &MIDIOutput{config: config, Port: portString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &MIDIOutput{config: config, Port: portString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -75,7 +81,7 @@ func (mo *MIDIOutput) Run() error {
return nil return nil
} }
func (mo *MIDIOutput) Output(payload any) error { func (mo *MIDIOutput) Output(ctx context.Context, payload any) error {
if mo.SendFunc == nil { if mo.SendFunc == nil {
return errors.New("midi.output output is not setup") return errors.New("midi.output output is not setup")
} }

View File

@@ -7,7 +7,6 @@ import (
"sync" "sync"
"github.com/jwetzell/showbridge-go/internal/config" "github.com/jwetzell/showbridge-go/internal/config"
"github.com/jwetzell/showbridge-go/internal/route"
) )
type ModuleError struct { type ModuleError struct {
@@ -20,12 +19,12 @@ type Module interface {
Id() string Id() string
Type() string Type() string
Run() error Run() error
Output(any) error Output(context.Context, any) error
} }
type ModuleRegistration struct { type ModuleRegistration struct {
Type string `json:"type"` Type string `json:"type"`
New func(context.Context, config.ModuleConfig, route.RouteIO) (Module, error) New func(context.Context, config.ModuleConfig) (Module, error)
} }
func RegisterModule(mod ModuleRegistration) { func RegisterModule(mod ModuleRegistration) {

View File

@@ -24,7 +24,7 @@ type MQTTClient struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "mqtt.client", Type: "mqtt.client",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
broker, ok := params["broker"] broker, ok := params["broker"]
@@ -62,6 +62,12 @@ func init() {
return nil, errors.New("mqtt.client clientId must be string") return nil, errors.New("mqtt.client clientId must be string")
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("mqtt.client unable to get router from context")
}
return &MQTTClient{config: config, Broker: brokerString, Topic: topicString, ClientID: clientIdString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &MQTTClient{config: config, Broker: brokerString, Topic: topicString, ClientID: clientIdString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -104,7 +110,7 @@ func (mc *MQTTClient) Run() error {
return nil return nil
} }
func (mc *MQTTClient) Output(payload any) error { func (mc *MQTTClient) Output(ctx context.Context, payload any) error {
payloadMessage, ok := payload.(mqtt.Message) payloadMessage, ok := payload.(mqtt.Message)
if !ok { if !ok {

View File

@@ -24,7 +24,7 @@ type NATSClient struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "nats.client", Type: "nats.client",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
url, ok := params["url"] url, ok := params["url"]
@@ -50,6 +50,12 @@ func init() {
return nil, errors.New("nats.client subject must be string") return nil, errors.New("nats.client subject must be string")
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("nats.client unable to get router from context")
}
return &NATSClient{config: config, URL: urlString, Subject: subjectString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &NATSClient{config: config, URL: urlString, Subject: subjectString, ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -92,7 +98,7 @@ func (nc *NATSClient) Run() error {
return nil return nil
} }
func (nc *NATSClient) Output(payload any) error { func (nc *NATSClient) Output(ctx context.Context, payload any) error {
payloadMessage, ok := payload.(processor.NATSMessage) payloadMessage, ok := payload.(processor.NATSMessage)

View File

@@ -2,6 +2,7 @@ package module
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"net" "net"
@@ -24,8 +25,12 @@ type PSNClient struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "psn.client", Type: "psn.client",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("psn.client unable to get router from context")
}
return &PSNClient{config: config, decoder: psn.NewDecoder(), ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &PSNClient{config: config, decoder: psn.NewDecoder(), ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -92,6 +97,6 @@ func (pc *PSNClient) Run() error {
} }
} }
func (pc *PSNClient) Output(payload any) error { func (pc *PSNClient) Output(ctx context.Context, payload any) error {
return fmt.Errorf("psn.client output is not implemented") return fmt.Errorf("psn.client output is not implemented")
} }

View File

@@ -29,7 +29,7 @@ type SerialClient struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "serial.client", Type: "serial.client",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
port, ok := params["port"] port, ok := params["port"]
@@ -76,6 +76,12 @@ func init() {
BaudRate: int(baudRateNum), BaudRate: int(baudRateNum),
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("serial.client unable to get router from context")
}
return &SerialClient{config: config, Port: portString, Framer: framer, Mode: &mode, ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &SerialClient{config: config, Port: portString, Framer: framer, Mode: &mode, ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -162,7 +168,7 @@ func (sc *SerialClient) Run() error {
} }
} }
func (sc *SerialClient) Output(payload any) error { func (sc *SerialClient) Output(ctx context.Context, payload any) error {
payloadBytes, ok := payload.([]byte) payloadBytes, ok := payload.([]byte)

View File

@@ -35,7 +35,7 @@ type SIPCallMessage struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "sip.call.server", Type: "sip.call.server",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
portNum := 5060 portNum := 5060
@@ -87,6 +87,12 @@ func init() {
} }
userAgentString = specificTransportString userAgentString = specificTransportString
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("sip.call.server unable to get router from context")
}
return &SIPCallServer{config: config, ctx: ctx, router: router, IP: ipString, Port: int(portNum), Transport: transportString, UserAgent: userAgentString, logger: CreateLogger(config)}, nil return &SIPCallServer{config: config, ctx: ctx, router: router, IP: ipString, Port: int(portNum), Transport: transportString, UserAgent: userAgentString, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -143,7 +149,7 @@ func (scs *SIPCallServer) HandleCall(inDialog *diago.DialogServerSession) {
<-inDialog.Context().Done() <-inDialog.Context().Done()
} }
func (scs *SIPCallServer) Output(payload any) error { func (scs *SIPCallServer) Output(ctx context.Context, payload any) error {
payloadMsg, ok := payload.(string) payloadMsg, ok := payload.(string)
if !ok { if !ok {

View File

@@ -35,7 +35,7 @@ type SIPDTMFMessage struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "sip.dtmf.server", Type: "sip.dtmf.server",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
portNum := 5060 portNum := 5060
@@ -91,6 +91,11 @@ func init() {
if !strings.ContainsRune("0123456789*#ABCD", rune(separatorString[0])) { if !strings.ContainsRune("0123456789*#ABCD", rune(separatorString[0])) {
return nil, errors.New("sip.dtmf.server separator must be a valid DTMF character") return nil, errors.New("sip.dtmf.server separator must be a valid DTMF character")
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("sip.dtmf.server unable to get router from context")
}
return &SIPDTMFServer{config: config, ctx: ctx, router: router, IP: ipString, Port: int(portNum), Transport: transportString, Separator: separatorString, logger: CreateLogger(config)}, nil return &SIPDTMFServer{config: config, ctx: ctx, router: router, IP: ipString, Port: int(portNum), Transport: transportString, Separator: separatorString, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -159,6 +164,6 @@ func (sds *SIPDTMFServer) HandleCall(inDialog *diago.DialogServerSession) error
}, 5*time.Second) }, 5*time.Second)
} }
func (sds *SIPDTMFServer) Output(payload any) error { func (sds *SIPDTMFServer) Output(ctx context.Context, payload any) error {
return errors.New("sip.dtmf.server output is not implemented") return errors.New("sip.dtmf.server output is not implemented")
} }

View File

@@ -26,7 +26,7 @@ type TCPClient struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "net.tcp.client", Type: "net.tcp.client",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
host, ok := params["host"] host, ok := params["host"]
@@ -75,6 +75,12 @@ func init() {
return nil, fmt.Errorf("net.tcp.client unknown framing method: %s", framingMethod) return nil, fmt.Errorf("net.tcp.client unknown framing method: %s", framingMethod)
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("net.tcp.client unable to get router from context")
}
return &TCPClient{framer: framer, Addr: addr, config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &TCPClient{framer: framer, Addr: addr, config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -155,7 +161,7 @@ func (tc *TCPClient) SetupConn() error {
return err return err
} }
func (tc *TCPClient) Output(payload any) error { func (tc *TCPClient) Output(ctx context.Context, payload any) error {
// NOTE(jwetzell): not sure how this would occur but // NOTE(jwetzell): not sure how this would occur but
if tc.conn == nil { if tc.conn == nil {
err := tc.SetupConn() err := tc.SetupConn()

View File

@@ -32,7 +32,7 @@ type TCPServer struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "net.tcp.server", Type: "net.tcp.server",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
port, ok := params["port"] port, ok := params["port"]
if !ok { if !ok {
@@ -82,6 +82,12 @@ func init() {
return nil, err return nil, err
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("net.tcp.server unable to get router from context")
}
return &TCPServer{Framer: framer, Addr: addr, config: config, quit: make(chan interface{}), ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &TCPServer{Framer: framer, Addr: addr, config: config, quit: make(chan interface{}), ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -197,7 +203,7 @@ AcceptLoop:
return nil return nil
} }
func (ts *TCPServer) Output(payload any) error { func (ts *TCPServer) Output(ctx context.Context, payload any) error {
payloadBytes, ok := payload.([]byte) payloadBytes, ok := payload.([]byte)
if !ok { if !ok {

View File

@@ -22,7 +22,7 @@ type TimeInterval struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "time.interval", Type: "time.interval",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
duration, ok := params["duration"] duration, ok := params["duration"]
@@ -36,6 +36,12 @@ func init() {
return nil, errors.New("time.interval duration must be number") return nil, errors.New("time.interval duration must be number")
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("time.interval unable to get router from context")
}
return &TimeInterval{Duration: uint32(durationNum), config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &TimeInterval{Duration: uint32(durationNum), config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -68,7 +74,7 @@ func (i *TimeInterval) Run() error {
} }
func (i *TimeInterval) Output(payload any) error { func (i *TimeInterval) Output(ctx context.Context, payload any) error {
i.ticker.Reset(time.Millisecond * time.Duration(i.Duration)) i.ticker.Reset(time.Millisecond * time.Duration(i.Duration))
return nil return nil
} }

View File

@@ -22,7 +22,7 @@ type TimeTimer struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "time.timer", Type: "time.timer",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
duration, ok := params["duration"] duration, ok := params["duration"]
@@ -36,6 +36,12 @@ func init() {
return nil, errors.New("time.timer duration must be a number") return nil, errors.New("time.timer duration must be a number")
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("net.tcp.client unable to get router from context")
}
return &TimeTimer{Duration: uint32(durationNum), config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &TimeTimer{Duration: uint32(durationNum), config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -66,7 +72,7 @@ func (t *TimeTimer) Run() error {
} }
} }
func (t *TimeTimer) Output(payload any) error { func (t *TimeTimer) Output(ctx context.Context, payload any) error {
t.timer.Reset(time.Millisecond * time.Duration(t.Duration)) t.timer.Reset(time.Millisecond * time.Duration(t.Duration))
return nil return nil
} }

View File

@@ -24,7 +24,7 @@ type UDPClient struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "net.udp.client", Type: "net.udp.client",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
host, ok := params["host"] host, ok := params["host"]
@@ -54,6 +54,12 @@ func init() {
return nil, err return nil, err
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("net.udp.client unable to get router from context")
}
return &UDPClient{Addr: addr, config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &UDPClient{Addr: addr, config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -88,7 +94,7 @@ func (uc *UDPClient) Run() error {
return nil return nil
} }
func (uc *UDPClient) Output(payload any) error { func (uc *UDPClient) Output(ctx context.Context, payload any) error {
payloadBytes, ok := payload.([]byte) payloadBytes, ok := payload.([]byte)
if !ok { if !ok {

View File

@@ -24,7 +24,7 @@ type UDPMulticast struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "net.udp.multicast", Type: "net.udp.multicast",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
ip, ok := params["ip"] ip, ok := params["ip"]
@@ -53,6 +53,12 @@ func init() {
if err != nil { if err != nil {
return nil, err return nil, err
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("net.udp.multicast unable to get router from context")
}
return &UDPMulticast{config: config, Addr: addr, ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &UDPMulticast{config: config, Addr: addr, ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -108,7 +114,7 @@ func (um *UDPMulticast) Run() error {
} }
} }
func (um *UDPMulticast) Output(payload any) error { func (um *UDPMulticast) Output(ctx context.Context, payload any) error {
payloadBytes, ok := payload.([]byte) payloadBytes, ok := payload.([]byte)
if !ok { if !ok {

View File

@@ -25,7 +25,7 @@ type UDPServer struct {
func init() { func init() {
RegisterModule(ModuleRegistration{ RegisterModule(ModuleRegistration{
Type: "net.udp.server", Type: "net.udp.server",
New: func(ctx context.Context, config config.ModuleConfig, router route.RouteIO) (Module, error) { New: func(ctx context.Context, config config.ModuleConfig) (Module, error) {
params := config.Params params := config.Params
port, ok := params["port"] port, ok := params["port"]
if !ok { if !ok {
@@ -68,6 +68,11 @@ func init() {
bufferSizeNum = int(bufferSizeFloat) bufferSizeNum = int(bufferSizeFloat)
} }
router, ok := ctx.Value(route.RouterContextKey).(route.RouteIO)
if !ok {
return nil, errors.New("net.udp.server unable to get router from context")
}
return &UDPServer{Addr: addr, BufferSize: bufferSizeNum, config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil return &UDPServer{Addr: addr, BufferSize: bufferSizeNum, config: config, ctx: ctx, router: router, logger: CreateLogger(config)}, nil
}, },
}) })
@@ -119,6 +124,6 @@ func (us *UDPServer) Run() error {
} }
func (us *UDPServer) Output(payload any) error { func (us *UDPServer) Output(ctx context.Context, payload any) error {
return errors.New("net.udp.server output is not implemented") return errors.New("net.udp.server output is not implemented")
} }

View File

@@ -2,12 +2,18 @@ package route
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/jwetzell/showbridge-go/internal/config" "github.com/jwetzell/showbridge-go/internal/config"
"github.com/jwetzell/showbridge-go/internal/processor" "github.com/jwetzell/showbridge-go/internal/processor"
) )
type routeContextKey string
var RouterContextKey routeContextKey = routeContextKey("router")
var SourceContextKey routeContextKey = routeContextKey("source")
type RouteError struct { type RouteError struct {
Index int Index int
Config config.RouteConfig Config config.RouteConfig
@@ -21,13 +27,13 @@ type RouteIOError struct {
type RouteIO interface { type RouteIO interface {
HandleInput(sourceId string, payload any) []RouteIOError HandleInput(sourceId string, payload any) []RouteIOError
HandleOutput(sourceId string, destinationId string, payload any) error HandleOutput(ctx context.Context, destinationId string, payload any) error
} }
type Route interface { type Route interface {
Input() string Input() string
Output() string Output() string
HandleInput(ctx context.Context, sourceId string, payload any, router RouteIO) error HandleInput(ctx context.Context, payload any) error
} }
type ProcessorRoute struct { type ProcessorRoute struct {
@@ -65,17 +71,24 @@ func (r *ProcessorRoute) Output() string {
return r.output return r.output
} }
func (r *ProcessorRoute) HandleInput(ctx context.Context, sourceId string, payload any, router RouteIO) error { func (r *ProcessorRoute) HandleInput(ctx context.Context, payload any) error {
var err error router, ok := ctx.Value(RouterContextKey).(RouteIO)
if !ok {
return errors.New("unable to get router from context")
}
for _, processor := range r.processors { for _, processor := range r.processors {
payload, err = processor.Process(ctx, payload) processedPayload, err := processor.Process(ctx, payload)
if err != nil { if err != nil {
return err return err
} }
//NOTE(jwetzell) nil payload will result in the route being "terminated" //NOTE(jwetzell) nil payload will result in the route being "terminated"
if payload == nil { if processedPayload == nil {
return nil return nil
} }
payload = processedPayload
} }
return router.HandleOutput(sourceId, r.output, payload)
return router.HandleOutput(ctx, r.output, payload)
} }

View File

@@ -1,6 +1,7 @@
package route_test package route_test
import ( import (
"context"
"testing" "testing"
"github.com/jwetzell/showbridge-go/internal/config" "github.com/jwetzell/showbridge-go/internal/config"
@@ -51,7 +52,7 @@ func TestGoodRouteHandleInput(t *testing.T) {
} }
inputData := "test input data" inputData := "test input data"
err = testRoute.HandleInput(t.Context(), "input", inputData, &MockRouter{}) err = testRoute.HandleInput(context.WithValue(t.Context(), route.RouterContextKey, &MockRouter{}), inputData)
if err != nil { if err != nil {
t.Fatalf("route HandleOutput returned error: %v", err) t.Fatalf("route HandleOutput returned error: %v", err)
} }
@@ -72,7 +73,7 @@ func TestRouteHandleInputWithProcessorError(t *testing.T) {
} }
inputData := "test input data" inputData := "test input data"
err = testRoute.HandleInput(t.Context(), "input", inputData, &MockRouter{}) err = testRoute.HandleInput(context.WithValue(t.Context(), route.RouterContextKey, &MockRouter{}), inputData)
if err == nil { if err == nil {
t.Fatalf("route HandleOutput did not return error for bad processor") t.Fatalf("route HandleOutput did not return error for bad processor")
} }
@@ -91,7 +92,7 @@ func TestRouteHandleNilPayload(t *testing.T) {
return return
} }
err = testRoute.HandleInput(t.Context(), "input", nil, &MockRouter{}) err = testRoute.HandleInput(context.WithValue(t.Context(), route.RouterContextKey, &MockRouter{}), nil)
if err != nil { if err != nil {
t.Fatalf("route HandleOutput returned error for nil payload: %v", err) t.Fatalf("route HandleOutput returned error for nil payload: %v", err)
} }
@@ -111,7 +112,7 @@ func TestRouteHandleNilPayloadFromProcessor(t *testing.T) {
t.Fatalf("route failed to create: %v", err) t.Fatalf("route failed to create: %v", err)
} }
err = testRoute.HandleInput(t.Context(), "input", nil, &MockRouter{}) err = testRoute.HandleInput(context.WithValue(t.Context(), route.RouterContextKey, &MockRouter{}), nil)
if err != nil { if err != nil {
t.Fatalf("route HandleOutput returned error for nil payload: %v", err) t.Fatalf("route HandleOutput returned error for nil payload: %v", err)
} }

View File

@@ -24,14 +24,16 @@ type Router struct {
func NewRouter(ctx context.Context, config config.Config) (*Router, []module.ModuleError, []route.RouteError) { func NewRouter(ctx context.Context, config config.Config) (*Router, []module.ModuleError, []route.RouteError) {
routerContext, cancel := context.WithCancel(ctx) routerContext, cancel := context.WithCancel(ctx)
router := Router{ router := Router{
Context: routerContext,
contextCancel: cancel, contextCancel: cancel,
ModuleInstances: []module.Module{}, ModuleInstances: []module.Module{},
RouteInstances: []route.Route{}, RouteInstances: []route.Route{},
logger: slog.Default().With("component", "router"), logger: slog.Default().With("component", "router"),
} }
router.Context = context.WithValue(routerContext, route.RouterContextKey, &router)
router.logger.Debug("creating") router.logger.Debug("creating")
var moduleErrors []module.ModuleError var moduleErrors []module.ModuleError
@@ -68,7 +70,7 @@ func NewRouter(ctx context.Context, config config.Config) (*Router, []module.Mod
} }
if !moduleInstanceExists { if !moduleInstanceExists {
moduleInstance, err := moduleInfo.New(router.Context, moduleDecl, &router) moduleInstance, err := moduleInfo.New(router.Context, moduleDecl)
if err != nil { if err != nil {
if moduleErrors == nil { if moduleErrors == nil {
moduleErrors = []module.ModuleError{} moduleErrors = []module.ModuleError{}
@@ -131,7 +133,7 @@ func (r *Router) HandleInput(sourceId string, payload any) []route.RouteIOError
var routingErrors []route.RouteIOError var routingErrors []route.RouteIOError
for routeIndex, routeInstance := range r.RouteInstances { for routeIndex, routeInstance := range r.RouteInstances {
if routeInstance.Input() == sourceId { if routeInstance.Input() == sourceId {
err := routeInstance.HandleInput(r.Context, sourceId, payload, r) err := routeInstance.HandleInput(context.WithValue(r.Context, route.SourceContextKey, sourceId), payload)
if err != nil { if err != nil {
if routingErrors == nil { if routingErrors == nil {
routingErrors = []route.RouteIOError{} routingErrors = []route.RouteIOError{}
@@ -147,10 +149,10 @@ func (r *Router) HandleInput(sourceId string, payload any) []route.RouteIOError
return routingErrors return routingErrors
} }
func (r *Router) HandleOutput(sourceId string, destinationId string, payload any) error { func (r *Router) HandleOutput(ctx context.Context, destinationId string, payload any) error {
for _, moduleInstance := range r.ModuleInstances { for _, moduleInstance := range r.ModuleInstances {
if moduleInstance.Id() == destinationId { if moduleInstance.Id() == destinationId {
return moduleInstance.Output(payload) return moduleInstance.Output(ctx, payload)
} }
} }
return fmt.Errorf("router could not find module instance for destination %s", destinationId) return fmt.Errorf("router could not find module instance for destination %s", destinationId)