diff --git a/api.go b/api.go index 7b408eb..8770895 100644 --- a/api.go +++ b/api.go @@ -1,46 +1,145 @@ package showbridge import ( + "context" "embed" _ "embed" "encoding/json" "fmt" "net/http" + "time" + + "github.com/jwetzell/showbridge-go/internal/config" + "github.com/jwetzell/showbridge-go/internal/module" + "github.com/jwetzell/showbridge-go/internal/route" ) +func (r *Router) startAPIServer(config config.ApiConfig) { + r.logger.Debug("starting API server", "port", config.Port) + mux := http.NewServeMux() + mux.HandleFunc("/ws", r.handleWebsocket) + mux.HandleFunc("/health", r.handleHealthHTTP) + mux.HandleFunc("/schema/{schema}", r.handleSchemaHTTP) + mux.HandleFunc("/api/v1/config", r.handleConfigHTTP) + + r.apiServerMu.Lock() + defer r.apiServerMu.Unlock() + r.apiServer = &http.Server{ + Addr: fmt.Sprintf(":%d", config.Port), + Handler: mux, + } + + go func() { + r.apiServer.ListenAndServe() + r.apiServerShutdown() + }() +} + +func (r *Router) stopAPIServer() { + r.logger.Debug("stopping API server") + r.apiServerMu.Lock() + defer r.apiServerMu.Unlock() + if r.apiServer != nil { + apiShutdownCtx, apiShutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + r.apiServerShutdown = apiShutdownCancel + r.apiServer.Shutdown(apiShutdownCtx) + <-apiShutdownCtx.Done() + r.apiServer = nil + } +} + +func (r *Router) handleHealthHTTP(w http.ResponseWriter, req *http.Request) { + switch req.Method { + case http.MethodGet: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusOK) + case http.MethodOptions: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + w.WriteHeader(http.StatusOK) + default: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + func (r *Router) handleConfigHTTP(w http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return + + switch req.Method { + case http.MethodGet: + configJSON, err := json.Marshal(r.runningConfig) + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Content-Type", "application/json") + w.Write(configJSON) + case http.MethodPut: + var newConfig config.Config + err := json.NewDecoder(req.Body).Decode(&newConfig) + if err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + moduleErrors, routeErrors := r.UpdateConfig(newConfig) + if len(moduleErrors) > 0 || len(routeErrors) > 0 { + errorResponse := struct { + ModuleErrors []module.ModuleError `json:"moduleErrors,omitempty"` + RouteErrors []route.RouteError `json:"routeErrors,omitempty"` + }{ + ModuleErrors: moduleErrors, + RouteErrors: routeErrors, + } + errorResponseJSON, err := json.Marshal(errorResponse) + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write(errorResponseJSON) + return + } + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusOK) + case http.MethodOptions: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + w.WriteHeader(http.StatusOK) + default: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusMethodNotAllowed) } - configJSON, err := json.Marshal(r.runningConfig) - if err != nil { - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } - - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Content-Type", "application/json") - w.Write(configJSON) } //go:embed schema var schema embed.FS func (r *Router) handleSchemaHTTP(w http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return + switch req.Method { + case http.MethodGet: + schemaName := req.PathValue("schema") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Content-Type", "application/json") + configSchema, err := schema.ReadFile(fmt.Sprintf("schema/%s.schema.json", schemaName)) + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + w.Write(configSchema) + case http.MethodOptions: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + w.WriteHeader(http.StatusOK) + default: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusMethodNotAllowed) } - - schemaName := req.PathValue("schema") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Content-Type", "application/json") - configSchema, err := schema.ReadFile(fmt.Sprintf("schema/%s.schema.json", schemaName)) - if err != nil { - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } - w.Write(configSchema) } diff --git a/cmd/showbridge/main.go b/cmd/showbridge/main.go index f611590..a2f3860 100644 --- a/cmd/showbridge/main.go +++ b/cmd/showbridge/main.go @@ -13,6 +13,8 @@ import ( "github.com/jwetzell/showbridge-go" "github.com/jwetzell/showbridge-go/internal/config" + "github.com/jwetzell/showbridge-go/internal/module" + "github.com/jwetzell/showbridge-go/internal/route" "github.com/urfave/cli/v3" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" @@ -173,7 +175,19 @@ func run(ctx context.Context, c *cli.Command) error { routerRunner: &sync.WaitGroup{}, } - router, err := showbridgeApp.getNewRouter() + config, err := readConfig(showbridgeApp.configPath) + if err != nil { + return err + } + + router, moduleErrors, routeErrors := showbridge.NewRouter(config) + + showbridgeApp.logConfigErrors(moduleErrors, routeErrors) + + if moduleErrors != nil || routeErrors != nil { + return fmt.Errorf("errors initializing modules or routes") + } + if err != nil { return fmt.Errorf("failed to initialize router: %w", err) } @@ -200,18 +214,15 @@ func (app *showbridgeApp) handleHangup() { select { case <-sigHangup: app.logger.Info("received SIGHUP, reloading configuration") - newRouter, err := app.getNewRouter() + app.routerMutex.Lock() + config, err := readConfig(app.configPath) if err != nil { - app.logger.Error("failed to reload configuration", "error", err) + app.logger.Error("failed to read config file", "error", err) + app.routerMutex.Unlock() continue } - app.routerMutex.Lock() - app.router.Stop() - app.routerRunner.Wait() - app.router = newRouter - app.routerRunner.Go(func() { - app.router.Start(context.Background()) - }) + moduleErrors, routeErrors := app.router.UpdateConfig(config) + app.logConfigErrors(moduleErrors, routeErrors) app.logger.Info("configuration reloaded successfully") app.routerMutex.Unlock() case <-app.ctx.Done(): @@ -220,15 +231,7 @@ func (app *showbridgeApp) handleHangup() { } } -func (app *showbridgeApp) getNewRouter() (*showbridge.Router, error) { - // TODO(jwetzell): what should happen when the config file is unchanged? - config, err := readConfig(app.configPath) - if err != nil { - return nil, err - } - - router, moduleErrors, routeErrors := showbridge.NewRouter(config) - +func (app *showbridgeApp) logConfigErrors(moduleErrors []module.ModuleError, routeErrors []route.RouteError) { for _, moduleError := range moduleErrors { app.logger.Error("problem initializing module", "index", moduleError.Index, "error", moduleError.Error) } @@ -236,12 +239,6 @@ func (app *showbridgeApp) getNewRouter() (*showbridge.Router, error) { for _, routeError := range routeErrors { app.logger.Error("problem initializing route", "index", routeError.Index, "error", routeError.Error) } - - if moduleErrors != nil || routeErrors != nil { - return nil, fmt.Errorf("errors initializing modules or routes") - } - - return router, nil } func newTracerProvider(exp sdktrace.SpanExporter) *sdktrace.TracerProvider { diff --git a/events.go b/events.go index 24a497a..0d60517 100644 --- a/events.go +++ b/events.go @@ -16,16 +16,28 @@ func (e Event) toJSON() ([]byte, error) { return json.Marshal(e) } -func (r *Router) handleEvent(event Event) { +func (r *Router) handleEvent(event Event, sender *websocket.Conn) { switch event.Type { case "ping": - r.broadcastEvent(Event{Type: "pong"}) + r.unicastEvent(Event{Type: "pong"}, sender) default: r.logger.Warn("unknown event type", "eventType", event.Type) } } -func (r *Router) broadcastEvent(event Event) { +func (r *Router) unicastEvent(event Event, conn *websocket.Conn) { + eventJSON, err := event.toJSON() + if err != nil { + r.logger.Error("failed to marshal event to JSON", "error", err) + return + } + err = conn.WriteMessage(websocket.TextMessage, eventJSON) + if err != nil { + r.logger.Error("failed to write message to websocket connection", "error", err) + } +} + +func (r *Router) broadcastEvent(event Event, excluded ...*websocket.Conn) { eventJSON, err := event.toJSON() if err != nil { r.logger.Error("failed to marshal event to JSON", "error", err) @@ -34,6 +46,16 @@ func (r *Router) broadcastEvent(event Event) { r.wsConnsMu.Lock() defer r.wsConnsMu.Unlock() for _, conn := range r.wsConns { + exclude := false + for _, excludedConn := range excluded { + if conn == excludedConn { + exclude = true + break + } + } + if exclude { + continue + } err := conn.WriteMessage(websocket.TextMessage, eventJSON) if err != nil { r.logger.Error("failed to write message to websocket connection", "error", err) diff --git a/router.go b/router.go index a7e247c..1716e86 100644 --- a/router.go +++ b/router.go @@ -3,11 +3,10 @@ package showbridge import ( "context" "errors" - "fmt" "log/slog" "net/http" + "reflect" "sync" - "time" "github.com/gorilla/websocket" "github.com/jwetzell/showbridge-go/internal/common" @@ -22,17 +21,21 @@ import ( ) type Router struct { - contextCancel context.CancelFunc - Context context.Context + contextCancel context.CancelFunc + Context context.Context + // TODO(jwetzell): do these need to be guarded against concurrency? ModuleInstances map[string]module.Module // TODO(jwetzell): change to something easier to lookup - RouteInstances []*route.Route - moduleWait sync.WaitGroup - logger *slog.Logger - runningConfig config.Config - wsConns []*websocket.Conn - wsConnsMu sync.Mutex - apiServer *http.Server + RouteInstances []*route.Route + moduleWait sync.WaitGroup + logger *slog.Logger + runningConfig config.Config + runningConfigMu sync.Mutex + wsConns []*websocket.Conn + wsConnsMu sync.Mutex + apiServer *http.Server + apiServerMu sync.Mutex + apiServerShutdown context.CancelFunc } func (r *Router) addModule(moduleDecl config.ModuleConfig) error { @@ -162,32 +165,11 @@ func (r *Router) Start(ctx context.Context) { routerContext, cancel := context.WithCancel(ctx) r.Context = routerContext r.contextCancel = cancel - contextWithRouter := context.WithValue(routerContext, common.RouterContextKey, r) - - for moduleId := range r.ModuleInstances { - // TODO(jwetzell): handle module run errors - err := r.startModule(contextWithRouter, moduleId) - if err != nil { - r.logger.Error("error starting module", "moduleId", moduleId, "error", err) - } - } - apiShutdownCtx, apiShutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) - - go func() { - r.apiServer = &http.Server{ - Addr: fmt.Sprintf(":%d", r.runningConfig.Api.Port), - } - http.HandleFunc("/ws", r.handleWebsocket) - http.HandleFunc("/api/v1/config", r.handleConfigHTTP) - http.HandleFunc("/api/v1/schema/{schema}", r.handleSchemaHTTP) - r.logger.Debug("starting api server", "port", r.runningConfig.Api.Port) - r.apiServer.ListenAndServe() - apiShutdownCancel() - }() + r.startModules() + r.startAPIServer(r.runningConfig.Api) <-r.Context.Done() r.logger.Debug("shutting down api server") - r.apiServer.Shutdown(apiShutdownCtx) - <-apiShutdownCtx.Done() + r.stopAPIServer() r.logger.Debug("waiting for modules to exit") r.moduleWait.Wait() r.logger.Info("done") @@ -199,6 +181,9 @@ func (r *Router) Stop() { } func (r *Router) HandleInput(ctx context.Context, sourceId string, payload any) (bool, []common.RouteIOError) { + r.runningConfigMu.Lock() + defer r.runningConfigMu.Unlock() + spanCtx, span := otel.Tracer("router").Start(ctx, "input", trace.WithAttributes(attribute.String("source.id", sourceId)), trace.WithNewRoot()) defer span.End() var routeIOErrors []common.RouteIOError @@ -298,6 +283,83 @@ func (r *Router) HandleOutput(ctx context.Context, destinationId string, payload return nil } +func (r *Router) startModules() { + contextWithRouter := context.WithValue(r.Context, common.RouterContextKey, r) + + for moduleId := range r.ModuleInstances { + // TODO(jwetzell): handle module run errors + err := r.startModule(contextWithRouter, moduleId) + if err != nil { + r.logger.Error("error starting module", "moduleId", moduleId, "error", err) + } + } +} + func (r *Router) RunningConfig() config.Config { + r.runningConfigMu.Lock() + defer r.runningConfigMu.Unlock() return r.runningConfig } + +func (r *Router) UpdateConfig(newConfig config.Config) ([]module.ModuleError, []route.RouteError) { + r.runningConfigMu.Lock() + defer r.runningConfigMu.Unlock() + oldConfig := r.runningConfig + r.logger.Debug("received config update", "oldConfig", oldConfig, "newConfig", newConfig) + + if !reflect.DeepEqual(oldConfig.Api, newConfig.Api) { + r.logger.Info("applying new API config") + r.stopAPIServer() + r.startAPIServer(newConfig.Api) + r.runningConfig.Api = newConfig.Api + } + + // TODO(jwetzell): handle config update errors better + for _, moduleInstance := range r.ModuleInstances { + moduleInstance.Stop() + } + r.logger.Debug("waiting for modules to exit") + r.moduleWait.Wait() + + r.ModuleInstances = make(map[string]module.Module) + r.RouteInstances = []*route.Route{} + + var moduleErrors []module.ModuleError + + for moduleIndex, moduleDecl := range newConfig.Modules { + + err := r.addModule(moduleDecl) + if err != nil { + if moduleErrors == nil { + moduleErrors = []module.ModuleError{} + } + moduleErrors = append(moduleErrors, module.ModuleError{ + Index: moduleIndex, + Config: moduleDecl, + Error: err.Error(), + }) + continue + } + + } + + var routeErrors []route.RouteError + for routeIndex, routeDecl := range newConfig.Routes { + err := r.addRoute(routeDecl) + if err != nil { + if routeErrors == nil { + routeErrors = []route.RouteError{} + } + routeErrors = append(routeErrors, route.RouteError{ + Index: routeIndex, + Config: routeDecl, + Error: err.Error(), + }) + continue + } + } + r.runningConfig = newConfig + r.startModules() + + return moduleErrors, routeErrors +} diff --git a/websocket.go b/websocket.go index 215f932..c1520f9 100644 --- a/websocket.go +++ b/websocket.go @@ -24,21 +24,39 @@ func (r *Router) handleWebsocket(w http.ResponseWriter, req *http.Request) { r.wsConnsMu.Lock() r.wsConns = append(r.wsConns, conn) r.wsConnsMu.Unlock() +READ_LOOP: for { - _, message, err := conn.ReadMessage() + messageType, message, err := conn.ReadMessage() if err != nil { - r.logger.Error("websocket read error", "error", err) - break + _, ok := err.(*websocket.CloseError) + if ok { + break READ_LOOP + } } - event := Event{} - err = json.Unmarshal(message, &event) - if err != nil { - r.logger.Error("websocket message unmarshal error", "error", err) + switch messageType { + case websocket.TextMessage, websocket.BinaryMessage: + event := Event{} + err = json.Unmarshal(message, &event) + if err != nil { + r.logger.Error("websocket message unmarshal error", "error", err) + continue + } + r.handleEvent(event, conn) + case websocket.CloseMessage: + break READ_LOOP + case websocket.PingMessage: + err = conn.WriteMessage(websocket.PongMessage, nil) + if err != nil { + r.logger.Error("websocket pong error", "error", err) + } + default: + r.logger.Warn("unsupported websocket message type", "type", messageType) continue } - r.handleEvent(event) + } + //NOTE(jwetzell): remove ws connection r.wsConnsMu.Lock() for i, c := range r.wsConns { if c == conn {