diff --git a/tcp-server.go b/tcp-server.go index 78591d4..89e0a4a 100644 --- a/tcp-server.go +++ b/tcp-server.go @@ -1,10 +1,13 @@ package showbridge import ( + "errors" "fmt" "log/slog" "net" + "slices" "sync" + "syscall" "time" "github.com/jwetzell/showbridge-go/internal/framing" @@ -18,6 +21,8 @@ type TCPServer struct { router *Router quit chan interface{} wg sync.WaitGroup + connections []net.Conn + connectionsMu sync.RWMutex } func init() { @@ -78,7 +83,10 @@ func (ts *TCPServer) RegisterRouter(router *Router) { } func (ts *TCPServer) handleClient(client net.Conn) { - slog.Debug("connection accepted", "id", ts.config.Id, "remoteAddr", client.RemoteAddr().String()) + ts.connectionsMu.Lock() + ts.connections = append(ts.connections, client) + ts.connectionsMu.Unlock() + slog.Debug("net.tcp.server connection accepted", "id", ts.config.Id, "remoteAddr", client.RemoteAddr().String()) defer client.Close() var framer framing.Framer @@ -104,12 +112,34 @@ ClientRead: byteCount, err := client.Read(buffer) if err != nil { - //NOTE(jwetzell) we hit deadline - if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() { - continue ClientRead + if opErr, ok := err.(*net.OpError); ok { + //NOTE(jwetzell) we hit deadline + if opErr.Timeout() { + continue ClientRead + } + if errors.Is(opErr, syscall.ECONNRESET) { + ts.connectionsMu.Lock() + for i := 0; i < len(ts.connections); i++ { + if ts.connections[i] == client { + ts.connections = slices.Delete(ts.connections, i, i+1) + break + } + } + slog.Debug("net.tcp.server connection reset", "id", ts.config.Id, "remoteAddr", client.RemoteAddr().String()) + ts.connectionsMu.Unlock() + } } + if err.Error() == "EOF" { - slog.Debug("connection closed", "id", ts.config.Id, "remoteAddr", client.RemoteAddr().String()) + ts.connectionsMu.Lock() + for i := 0; i < len(ts.connections); i++ { + if ts.connections[i] == client { + ts.connections = slices.Delete(ts.connections, i, i+1) + break + } + } + slog.Debug("net.tcp.server stream ended", "id", ts.config.Id, "remoteAddr", client.RemoteAddr().String()) + ts.connectionsMu.Unlock() } return } @@ -168,5 +198,24 @@ AcceptLoop: } func (ts *TCPServer) Output(payload any) error { - return fmt.Errorf("net.tcp.server output is not implemented") + payloadBytes, ok := payload.([]byte) + + if !ok { + return fmt.Errorf("net.tcp.server is only able to output bytes") + } + ts.connectionsMu.Lock() + errorString := "" + + for _, connection := range ts.connections { + _, err := connection.Write(payloadBytes) + if err != nil { + errorString += fmt.Sprintf("%s\n", err.Error()) + } + } + ts.connectionsMu.Unlock() + + if errorString == "" { + return nil + } + return fmt.Errorf("%s", errorString) }