diff --git a/api.go b/api.go new file mode 100644 index 0000000..a90a520 --- /dev/null +++ b/api.go @@ -0,0 +1,146 @@ +package showbridge + +import ( + "context" + "embed" + _ "embed" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/jwetzell/showbridge-go/internal/config" + "github.com/jwetzell/showbridge-go/internal/module" + "github.com/jwetzell/showbridge-go/internal/route" +) + +func (r *Router) startAPIServer(config config.ApiConfig) { + r.logger.Debug("starting API server", "port", config.Port) + mux := http.NewServeMux() + mux.HandleFunc("/ws", r.handleWebsocket) + mux.HandleFunc("/health", r.handleHealthHTTP) + mux.HandleFunc("/schema/{schema}", r.handleSchemaHTTP) + mux.HandleFunc("/api/v1/config", r.handleConfigHTTP) + + r.apiServerMu.Lock() + defer r.apiServerMu.Unlock() + r.apiServer = &http.Server{ + Addr: fmt.Sprintf(":%d", config.Port), + Handler: mux, + } + + go func() { + r.apiServer.ListenAndServe() + r.apiServerShutdown() + }() +} + +func (r *Router) stopAPIServer() { + r.logger.Debug("stopping API server") + r.apiServerMu.Lock() + defer r.apiServerMu.Unlock() + if r.apiServer != nil { + apiShutdownCtx, apiShutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + r.apiServerShutdown = apiShutdownCancel + r.apiServer.Shutdown(apiShutdownCtx) + <-apiShutdownCtx.Done() + r.apiServer = nil + } +} + +func (r *Router) handleHealthHTTP(w http.ResponseWriter, req *http.Request) { + switch req.Method { + case http.MethodGet: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusOK) + case http.MethodOptions: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + w.WriteHeader(http.StatusOK) + default: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (r *Router) handleConfigHTTP(w http.ResponseWriter, req *http.Request) { + + switch req.Method { + case http.MethodGet: + configJSON, err := json.Marshal(r.runningConfig) + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Content-Type", "application/json") + w.Write(configJSON) + case http.MethodPut: + var newConfig config.Config + err := json.NewDecoder(req.Body).Decode(&newConfig) + if err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + moduleErrors, routeErrors := r.UpdateConfig(newConfig) + if len(moduleErrors) > 0 || len(routeErrors) > 0 { + errorResponse := struct { + ModuleErrors []module.ModuleError `json:"moduleErrors,omitempty"` + RouteErrors []route.RouteError `json:"routeErrors,omitempty"` + }{ + ModuleErrors: moduleErrors, + RouteErrors: routeErrors, + } + errorResponseJSON, err := json.Marshal(errorResponse) + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write(errorResponseJSON) + return + } + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusOK) + r.ConfigChange <- newConfig + case http.MethodOptions: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + w.WriteHeader(http.StatusOK) + default: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusMethodNotAllowed) + } + +} + +//go:embed schema +var schema embed.FS + +func (r *Router) handleSchemaHTTP(w http.ResponseWriter, req *http.Request) { + switch req.Method { + case http.MethodGet: + schemaName := req.PathValue("schema") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Content-Type", "application/json") + configSchema, err := schema.ReadFile(fmt.Sprintf("schema/%s.schema.json", schemaName)) + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + w.Write(configSchema) + case http.MethodOptions: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + w.WriteHeader(http.StatusOK) + default: + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusMethodNotAllowed) + } +} diff --git a/cmd/showbridge/main.go b/cmd/showbridge/main.go index f611590..0091644 100644 --- a/cmd/showbridge/main.go +++ b/cmd/showbridge/main.go @@ -13,6 +13,8 @@ import ( "github.com/jwetzell/showbridge-go" "github.com/jwetzell/showbridge-go/internal/config" + "github.com/jwetzell/showbridge-go/internal/module" + "github.com/jwetzell/showbridge-go/internal/route" "github.com/urfave/cli/v3" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" @@ -109,6 +111,20 @@ func readConfig(configPath string) (config.Config, error) { return cfg, nil } +func writeConfig(configPath string, newConfig config.Config) error { + configBytes, err := yaml.Marshal(newConfig) + if err != nil { + return err + } + + err = os.WriteFile(configPath, configBytes, 0644) + if err != nil { + return err + } + + return nil +} + func run(ctx context.Context, c *cli.Command) error { configPath := c.String("config") if configPath == "" { @@ -173,7 +189,19 @@ func run(ctx context.Context, c *cli.Command) error { routerRunner: &sync.WaitGroup{}, } - router, err := showbridgeApp.getNewRouter() + config, err := readConfig(showbridgeApp.configPath) + if err != nil { + return err + } + + router, moduleErrors, routeErrors := showbridge.NewRouter(config) + + showbridgeApp.logConfigErrors(moduleErrors, routeErrors) + + if moduleErrors != nil || routeErrors != nil { + return fmt.Errorf("errors initializing modules or routes") + } + if err != nil { return fmt.Errorf("failed to initialize router: %w", err) } @@ -185,7 +213,7 @@ func run(ctx context.Context, c *cli.Command) error { }) showbridgeApp.routerMutex.Unlock() - go showbridgeApp.handleHangup() + go showbridgeApp.handleChannels() <-showbridgeApp.ctx.Done() showbridgeApp.logger.Debug("shutting down router") @@ -195,40 +223,37 @@ func run(ctx context.Context, c *cli.Command) error { return nil } -func (app *showbridgeApp) handleHangup() { +func (app *showbridgeApp) handleChannels() { for { select { case <-sigHangup: app.logger.Info("received SIGHUP, reloading configuration") - newRouter, err := app.getNewRouter() + app.routerMutex.Lock() + config, err := readConfig(app.configPath) if err != nil { - app.logger.Error("failed to reload configuration", "error", err) + app.logger.Error("failed to read config file", "error", err) + app.routerMutex.Unlock() continue } - app.routerMutex.Lock() - app.router.Stop() - app.routerRunner.Wait() - app.router = newRouter - app.routerRunner.Go(func() { - app.router.Start(context.Background()) - }) + moduleErrors, routeErrors := app.router.UpdateConfig(config) + app.logConfigErrors(moduleErrors, routeErrors) app.logger.Info("configuration reloaded successfully") app.routerMutex.Unlock() + case config := <-app.router.ConfigChange: + app.logger.Info("router config changed updating config file") + err := writeConfig(app.configPath, config) + if err != nil { + app.logger.Error("failed to write config file", "error", err) + continue + } + app.logger.Info("config file updated successfully") case <-app.ctx.Done(): return } } } -func (app *showbridgeApp) getNewRouter() (*showbridge.Router, error) { - // TODO(jwetzell): what should happen when the config file is unchanged? - config, err := readConfig(app.configPath) - if err != nil { - return nil, err - } - - router, moduleErrors, routeErrors := showbridge.NewRouter(config) - +func (app *showbridgeApp) logConfigErrors(moduleErrors []module.ModuleError, routeErrors []route.RouteError) { for _, moduleError := range moduleErrors { app.logger.Error("problem initializing module", "index", moduleError.Index, "error", moduleError.Error) } @@ -236,12 +261,6 @@ func (app *showbridgeApp) getNewRouter() (*showbridge.Router, error) { for _, routeError := range routeErrors { app.logger.Error("problem initializing route", "index", routeError.Index, "error", routeError.Error) } - - if moduleErrors != nil || routeErrors != nil { - return nil, fmt.Errorf("errors initializing modules or routes") - } - - return router, nil } func newTracerProvider(exp sdktrace.SpanExporter) *sdktrace.TracerProvider { diff --git a/config.yaml b/config.yaml index 52856e0..71a0138 100644 --- a/config.yaml +++ b/config.yaml @@ -1,3 +1,5 @@ +api: + port: 8080 modules: - id: http type: http.server diff --git a/events.go b/events.go new file mode 100644 index 0000000..0d60517 --- /dev/null +++ b/events.go @@ -0,0 +1,64 @@ +package showbridge + +import ( + "encoding/json" + + "github.com/gorilla/websocket" +) + +type Event struct { + Type string `json:"type"` + Data any `json:"data,omitempty"` + Error string `json:"error,omitempty"` +} + +func (e Event) toJSON() ([]byte, error) { + return json.Marshal(e) +} + +func (r *Router) handleEvent(event Event, sender *websocket.Conn) { + switch event.Type { + case "ping": + r.unicastEvent(Event{Type: "pong"}, sender) + default: + r.logger.Warn("unknown event type", "eventType", event.Type) + } +} + +func (r *Router) unicastEvent(event Event, conn *websocket.Conn) { + eventJSON, err := event.toJSON() + if err != nil { + r.logger.Error("failed to marshal event to JSON", "error", err) + return + } + err = conn.WriteMessage(websocket.TextMessage, eventJSON) + if err != nil { + r.logger.Error("failed to write message to websocket connection", "error", err) + } +} + +func (r *Router) broadcastEvent(event Event, excluded ...*websocket.Conn) { + eventJSON, err := event.toJSON() + if err != nil { + r.logger.Error("failed to marshal event to JSON", "error", err) + return + } + r.wsConnsMu.Lock() + defer r.wsConnsMu.Unlock() + for _, conn := range r.wsConns { + exclude := false + for _, excludedConn := range excluded { + if conn == excludedConn { + exclude = true + break + } + } + if exclude { + continue + } + err := conn.WriteMessage(websocket.TextMessage, eventJSON) + if err != nil { + r.logger.Error("failed to write message to websocket connection", "error", err) + } + } +} diff --git a/go.mod b/go.mod index d785765..370fb76 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/emiago/sipgo v1.2.1 github.com/expr-lang/expr v1.17.8 github.com/extism/go-sdk v1.7.1 + github.com/gorilla/websocket v1.5.3 github.com/jwetzell/artnet-go v0.2.1 github.com/jwetzell/free-d-go v0.1.0 github.com/jwetzell/osc-go v0.2.0 @@ -42,7 +43,6 @@ require ( github.com/gobwas/ws v1.4.0 // indirect github.com/google/go-tpm v0.9.8 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/gorilla/websocket v1.5.3 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect github.com/ianlancetaylor/demangle v0.0.0-20240805132620-81f5be970eca // indirect github.com/icholy/digest v1.1.0 // indirect diff --git a/internal/common/routing.go b/internal/common/routing.go index 70bd17e..555350e 100644 --- a/internal/common/routing.go +++ b/internal/common/routing.go @@ -8,8 +8,8 @@ type RouteIO interface { } type RouteIOError struct { - Index int - OutputError error - ProcessError error - InputError error + Index int `json:"index"` + OutputError error `json:"outputError"` + ProcessError error `json:"processError"` + InputError error `json:"inputError"` } diff --git a/internal/config/config.go b/internal/config/config.go index fc1b5eb..6878187 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,14 +1,18 @@ package config type Config struct { + Api ApiConfig `json:"api"` Modules []ModuleConfig `json:"modules"` Routes []RouteConfig `json:"routes"` } +type ApiConfig struct { + Port int `json:"port"` +} type ModuleConfig struct { Id string `json:"id"` Type string `json:"type"` - Params Params `json:"params"` + Params Params `json:"params,omitempty"` } type RouteConfig struct { @@ -18,5 +22,5 @@ type RouteConfig struct { type ProcessorConfig struct { Type string `json:"type"` - Params Params `json:"params"` + Params Params `json:"params,omitempty"` } diff --git a/internal/module/module.go b/internal/module/module.go index 4a82f59..730c84f 100644 --- a/internal/module/module.go +++ b/internal/module/module.go @@ -10,9 +10,9 @@ import ( ) type ModuleError struct { - Index int - Config config.ModuleConfig - Error error + Index int `json:"index"` + Config config.ModuleConfig `json:"config"` + Error string `json:"error"` } type Module interface { diff --git a/internal/route/route.go b/internal/route/route.go index 351448d..cf6be43 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -13,9 +13,9 @@ import ( ) type RouteError struct { - Index int - Config config.RouteConfig - Error error + Index int `json:"index"` + Config config.RouteConfig `json:"config"` + Error string `json:"error"` } type Route struct { input string diff --git a/router.go b/router.go index c4d28f6..5be194a 100644 --- a/router.go +++ b/router.go @@ -4,8 +4,11 @@ import ( "context" "errors" "log/slog" + "net/http" + "reflect" "sync" + "github.com/gorilla/websocket" "github.com/jwetzell/showbridge-go/internal/common" "github.com/jwetzell/showbridge-go/internal/config" "github.com/jwetzell/showbridge-go/internal/module" @@ -18,14 +21,22 @@ import ( ) type Router struct { - contextCancel context.CancelFunc - Context context.Context + contextCancel context.CancelFunc + Context context.Context + // TODO(jwetzell): do these need to be guarded against concurrency? ModuleInstances map[string]module.Module // TODO(jwetzell): change to something easier to lookup - RouteInstances []*route.Route - moduleWait sync.WaitGroup - logger *slog.Logger - runningConfig config.Config + RouteInstances []*route.Route + ConfigChange chan config.Config + moduleWait sync.WaitGroup + logger *slog.Logger + runningConfig config.Config + runningConfigMu sync.Mutex + wsConns []*websocket.Conn + wsConnsMu sync.Mutex + apiServer *http.Server + apiServerMu sync.Mutex + apiServerShutdown context.CancelFunc } func (r *Router) addModule(moduleDecl config.ModuleConfig) error { @@ -102,19 +113,20 @@ func (r *Router) getModule(moduleId string) module.Module { return moduleInstance } -func NewRouter(config config.Config) (*Router, []module.ModuleError, []route.RouteError) { +func NewRouter(routerConfig config.Config) (*Router, []module.ModuleError, []route.RouteError) { router := Router{ ModuleInstances: make(map[string]module.Module), RouteInstances: []*route.Route{}, + ConfigChange: make(chan config.Config, 1), logger: slog.Default().With("component", "router"), - runningConfig: config, + runningConfig: routerConfig, } router.logger.Debug("creating") var moduleErrors []module.ModuleError - for moduleIndex, moduleDecl := range config.Modules { + for moduleIndex, moduleDecl := range routerConfig.Modules { err := router.addModule(moduleDecl) if err != nil { @@ -124,7 +136,7 @@ func NewRouter(config config.Config) (*Router, []module.ModuleError, []route.Rou moduleErrors = append(moduleErrors, module.ModuleError{ Index: moduleIndex, Config: moduleDecl, - Error: err, + Error: err.Error(), }) continue } @@ -132,7 +144,7 @@ func NewRouter(config config.Config) (*Router, []module.ModuleError, []route.Rou } var routeErrors []route.RouteError - for routeIndex, routeDecl := range config.Routes { + for routeIndex, routeDecl := range routerConfig.Routes { err := router.addRoute(routeDecl) if err != nil { if routeErrors == nil { @@ -141,7 +153,7 @@ func NewRouter(config config.Config) (*Router, []module.ModuleError, []route.Rou routeErrors = append(routeErrors, route.RouteError{ Index: routeIndex, Config: routeDecl, - Error: err, + Error: err.Error(), }) continue } @@ -155,16 +167,11 @@ func (r *Router) Start(ctx context.Context) { routerContext, cancel := context.WithCancel(ctx) r.Context = routerContext r.contextCancel = cancel - contextWithRouter := context.WithValue(routerContext, common.RouterContextKey, r) - - for moduleId := range r.ModuleInstances { - // TODO(jwetzell): handle module run errors - err := r.startModule(contextWithRouter, moduleId) - if err != nil { - r.logger.Error("error starting module", "moduleId", moduleId, "error", err) - } - } + r.startModules() + r.startAPIServer(r.runningConfig.Api) <-r.Context.Done() + r.logger.Debug("shutting down api server") + r.stopAPIServer() r.logger.Debug("waiting for modules to exit") r.moduleWait.Wait() r.logger.Info("done") @@ -176,11 +183,21 @@ func (r *Router) Stop() { } func (r *Router) HandleInput(ctx context.Context, sourceId string, payload any) (bool, []common.RouteIOError) { + r.runningConfigMu.Lock() + defer r.runningConfigMu.Unlock() + spanCtx, span := otel.Tracer("router").Start(ctx, "input", trace.WithAttributes(attribute.String("source.id", sourceId)), trace.WithNewRoot()) defer span.End() var routeIOErrors []common.RouteIOError routeFound := false + r.broadcastEvent(Event{ + Type: "input", + Data: map[string]any{ + "source": sourceId, + }, + }) + var routeWaitGroup sync.WaitGroup for routeIndex, routeInstance := range r.RouteInstances { @@ -207,8 +224,21 @@ func (r *Router) HandleInput(ctx context.Context, sourceId string, payload any) Index: routeIndex, ProcessError: err, }) + r.broadcastEvent(Event{ + Type: "route", + Data: map[string]any{ + "index": routeIndex, + }, + Error: err.Error(), + }) return } + r.broadcastEvent(Event{ + Type: "route", + Data: map[string]any{ + "index": routeIndex, + }, + }) routeSpan.End() }) } @@ -220,7 +250,12 @@ func (r *Router) HandleInput(ctx context.Context, sourceId string, payload any) func (r *Router) HandleOutput(ctx context.Context, destinationId string, payload any) error { spanCtx, span := otel.Tracer("router").Start(ctx, "output", trace.WithAttributes(attribute.String("destination.id", destinationId))) defer span.End() - + outputEvent := Event{ + Type: "output", + Data: map[string]any{ + "destination": destinationId, + }, + } destinationModule := r.getModule(destinationId) if destinationModule == nil { @@ -228,6 +263,8 @@ func (r *Router) HandleOutput(ctx context.Context, destinationId string, payload span.SetStatus(codes.Error, err.Error()) span.RecordError(err) r.logger.Error("no module found for destination id", "destinationId", destinationId) + outputEvent.Error = err.Error() + r.broadcastEvent(outputEvent) return err } @@ -238,14 +275,93 @@ func (r *Router) HandleOutput(ctx context.Context, destinationId string, payload moduleOutputSpan.SetStatus(codes.Error, err.Error()) moduleOutputSpan.RecordError(err) r.logger.ErrorContext(moduleOutputCtx, "module output encountered error", "module", destinationModule.Id(), "error", err) + outputEvent.Error = err.Error() + r.broadcastEvent(outputEvent) return err } else { moduleOutputSpan.SetStatus(codes.Ok, "module output successful") } - + r.broadcastEvent(outputEvent) return nil } +func (r *Router) startModules() { + contextWithRouter := context.WithValue(r.Context, common.RouterContextKey, r) + + for moduleId := range r.ModuleInstances { + // TODO(jwetzell): handle module run errors + err := r.startModule(contextWithRouter, moduleId) + if err != nil { + r.logger.Error("error starting module", "moduleId", moduleId, "error", err) + } + } +} + func (r *Router) RunningConfig() config.Config { + r.runningConfigMu.Lock() + defer r.runningConfigMu.Unlock() return r.runningConfig } + +func (r *Router) UpdateConfig(newConfig config.Config) ([]module.ModuleError, []route.RouteError) { + r.runningConfigMu.Lock() + defer r.runningConfigMu.Unlock() + oldConfig := r.runningConfig + r.logger.Debug("received config update", "oldConfig", oldConfig, "newConfig", newConfig) + + if !reflect.DeepEqual(oldConfig.Api, newConfig.Api) { + r.logger.Info("applying new API config") + r.stopAPIServer() + r.startAPIServer(newConfig.Api) + r.runningConfig.Api = newConfig.Api + } + + // TODO(jwetzell): handle config update errors better + for _, moduleInstance := range r.ModuleInstances { + moduleInstance.Stop() + } + r.logger.Debug("waiting for modules to exit") + r.moduleWait.Wait() + + r.ModuleInstances = make(map[string]module.Module) + r.RouteInstances = []*route.Route{} + + var moduleErrors []module.ModuleError + + for moduleIndex, moduleDecl := range newConfig.Modules { + + err := r.addModule(moduleDecl) + if err != nil { + if moduleErrors == nil { + moduleErrors = []module.ModuleError{} + } + moduleErrors = append(moduleErrors, module.ModuleError{ + Index: moduleIndex, + Config: moduleDecl, + Error: err.Error(), + }) + continue + } + + } + + var routeErrors []route.RouteError + for routeIndex, routeDecl := range newConfig.Routes { + err := r.addRoute(routeDecl) + if err != nil { + if routeErrors == nil { + routeErrors = []route.RouteError{} + } + routeErrors = append(routeErrors, route.RouteError{ + Index: routeIndex, + Config: routeDecl, + Error: err.Error(), + }) + continue + } + } + r.runningConfig = newConfig + r.startModules() + + return moduleErrors, routeErrors +} diff --git a/router_test.go b/router_test.go index 441e03c..2359245 100644 --- a/router_test.go +++ b/router_test.go @@ -149,8 +149,8 @@ func TestNewRouterDuplicateModuleId(t *testing.T) { t.Fatalf("router should have returned exactly 1 module error, got: %d", len(moduleErrors)) } - if moduleErrors[0].Error.Error() != "module id already exists" { - t.Fatalf("module error did not match expected, got: %s", moduleErrors[0].Error.Error()) + if moduleErrors[0].Error != "module id already exists" { + t.Fatalf("module error did not match expected, got: %s", moduleErrors[0].Error) } } @@ -184,8 +184,8 @@ func TestNewRouterRouteWithUnknwonProcessor(t *testing.T) { t.Fatalf("router should have returned exactly 1 route error, got: %d", len(routeErrors)) } - if routeErrors[0].Error.Error() != "problem loading processor registration for processor type: asdfasdf" { - t.Fatalf("route error did not match expected, got: %s", routeErrors[0].Error.Error()) + if routeErrors[0].Error != "problem loading processor registration for processor type: asdfasdf" { + t.Fatalf("route error did not match expected, got: %s", routeErrors[0].Error) } } diff --git a/schema/config.schema.json b/schema/config.schema.json index bfc34c3..7046004 100644 --- a/schema/config.schema.json +++ b/schema/config.schema.json @@ -5,6 +5,16 @@ "description": "showbridge configuration", "type": "object", "properties": { + "api": { + "type": "object", + "properties": { + "port": { + "type": "integer", + "description": "Port for the API server to listen on" + } + }, + "required": ["port"] + }, "modules": { "$ref": "https://showbridge.io/modules.schema.json" }, diff --git a/websocket.go b/websocket.go new file mode 100644 index 0000000..c1520f9 --- /dev/null +++ b/websocket.go @@ -0,0 +1,68 @@ +package showbridge + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/websocket" +) + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, +} + +func (r *Router) handleWebsocket(w http.ResponseWriter, req *http.Request) { + conn, err := upgrader.Upgrade(w, req, nil) + if err != nil { + r.logger.Error("websocket upgrade error", "error", err) + return + } + defer conn.Close() + + r.wsConnsMu.Lock() + r.wsConns = append(r.wsConns, conn) + r.wsConnsMu.Unlock() +READ_LOOP: + for { + messageType, message, err := conn.ReadMessage() + if err != nil { + _, ok := err.(*websocket.CloseError) + if ok { + break READ_LOOP + } + } + + switch messageType { + case websocket.TextMessage, websocket.BinaryMessage: + event := Event{} + err = json.Unmarshal(message, &event) + if err != nil { + r.logger.Error("websocket message unmarshal error", "error", err) + continue + } + r.handleEvent(event, conn) + case websocket.CloseMessage: + break READ_LOOP + case websocket.PingMessage: + err = conn.WriteMessage(websocket.PongMessage, nil) + if err != nil { + r.logger.Error("websocket pong error", "error", err) + } + default: + r.logger.Warn("unsupported websocket message type", "type", messageType) + continue + } + + } + //NOTE(jwetzell): remove ws connection + r.wsConnsMu.Lock() + for i, c := range r.wsConns { + if c == conn { + r.wsConns = append(r.wsConns[:i], r.wsConns[i+1:]...) + break + } + } + r.wsConnsMu.Unlock() +}