Skip to content

Commit

Permalink
fix minor bugs, add close handler
Browse files Browse the repository at this point in the history
This change fixes a few minor bugs, they are:
- Properly wait for WebSocket goroutine to close when a client is
  closed, this was preventing a test from passing in managed-service
  due to the stray goroutine.
- Use the correct timeout for keep alive timeouts. This was causing
  clients initialized without a timeout to constantly timeout
  even though heartbeats were working, due to default timeout not
  being used correctly.

This change also adds a close handler to the server. This allows the
server to get statistics about the duration and bytes transferred
of a connection. This can then be exposed as metrics from the caller.
  • Loading branch information
1lann committed Nov 10, 2023
1 parent 9751ff3 commit 8b3e927
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 6 deletions.
3 changes: 3 additions & 0 deletions buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ type Buffer struct {
hasActiveReader bool
flowID uuid.UUID
logger *zap.SugaredLogger

bytesWritten int64
}

// NewBuffer constructs a new buffer with the given size and logger. See Buffer for more information.
Expand Down Expand Up @@ -204,6 +206,7 @@ func (b *Buffer) Write(p []byte) (n int, err error) {
}

b.buffer = append(b.buffer, p[:availableCapacity]...)
b.bytesWritten += availableCapacity

b.cond.Broadcast()

Expand Down
30 changes: 28 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (c *Client) dialWebsocket(ctx context.Context) (net.Conn, error) {
}

lastResp := lastRespPtr.Load()
if time.Since(*lastResp) > 2*c.KeepAlive {
if time.Since(*lastResp) > 2*keepAlive {
logger.Warnf("keep alive timeout, closing websocket connection")
return
}
Expand All @@ -141,6 +141,26 @@ func (c *Client) dialWebsocket(ctx context.Context) (net.Conn, error) {

var errWebsocketDial = errors.New("websocket dial error")

// dialedConn is a wrapper around net.Conn that correctly performs Close signaling for the websocket manager
// in the Dial method.
type dialedConn struct {
net.Conn
closed chan struct{}
}

// Close implements Close in net.Conn.
func (d *dialedConn) Close() error {
err := d.Conn.Close()
if err != nil {
// We return early if the connection fails to close. Although this should never happen, this might leak a
// goroutine, but it's better than something going wrong and this goroutine stalling.
return err
}

<-d.closed
return nil
}

// Dial forms a tunnel to the backend TCP endpoint and returns a net.Conn.
//
// Callers are responsible for closing the returned connection.
Expand All @@ -149,8 +169,14 @@ func (c *Client) Dial(ctx context.Context) (net.Conn, error) {

firstAttempt := true
stream := NewStream(maxBufferSize, minBufferBehindSize, c.Logger)
conn := &dialedConn{
Conn: stream,
closed: make(chan struct{}),
}

go func() {
defer close(conn.closed)

var streamID uuid.UUID

for !stream.IsClosed() {
Expand Down Expand Up @@ -239,5 +265,5 @@ func (c *Client) Dial(ctx context.Context) (net.Conn, error) {
return nil, err
}

return stream, nil
return conn, nil
}
14 changes: 14 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ type Server struct {
isClosed bool
done chan struct{}
janitorDone chan struct{}

onClose func(streamID uuid.UUID, startTime time.Time, bytesRead, bytesWritten int64)
}

// ErrorCode represents an error code in a handshake.
Expand Down Expand Up @@ -122,6 +124,12 @@ func NewServer(dst string, timeout time.Duration, logger *zap.SugaredLogger) *Se
return s
}

// OnStreamClose sets a callback handler for when a stream closes. The callback should never block as it is not
// called in a separate goroutine.
func (s *Server) OnStreamClose(f func(streamID uuid.UUID, startTime time.Time, bytesRead, bytesWritten int64)) {
s.onClose = f
}

// Close shuts down the server, closing existing streams and rejects new connections.
func (s *Server) Close() {
s.mu.Lock()
Expand Down Expand Up @@ -269,6 +277,12 @@ func (s *Server) handleHandshake(remoteAddr string, downstream net.Conn, receive
zap.String("remote_ip", remoteAddr),
))

if s.onClose != nil {
stream.OnClose(func(startTime time.Time, bytesRead, bytesWritten int64) {
s.onClose(id, startTime, bytesRead, bytesWritten)
})
}

s.streams[id] = stream
s.mu.Unlock()

Expand Down
16 changes: 14 additions & 2 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,18 @@ type Stream struct {
cond *sync.Cond
closeOnce sync.Once

logger *zap.SugaredLogger
err error
logger *zap.SugaredLogger
err error
startTime time.Time

onClose func(startTime time.Time, bytesRead, bytesWritten int64)
}

// NewStream creates a new stream.
func NewStream(maxTotal, minBehind int64, logger *zap.SugaredLogger) *Stream {
streamsCount.Inc()
return &Stream{
startTime: time.Now(),
writeBuffer: NewBuffer(maxTotal, minBehind, logger),
wg: new(sync.WaitGroup),
wgMu: new(sync.Mutex),
Expand All @@ -57,6 +61,10 @@ func NewStream(maxTotal, minBehind int64, logger *zap.SugaredLogger) *Stream {
}
}

func (s *Stream) OnClose(f func(startTime time.Time, bytesRead, bytesWritten int64)) {

Check failure on line 64 in stream.go

View workflow job for this annotation

GitHub Actions / Code Quality

exported method Stream.OnClose should have comment or be unexported
s.onClose = f
}

// Flow is an active "instance" of a stream, which represents an unreliable connection such as a WebSocket
// connection.
type Flow struct {
Expand Down Expand Up @@ -162,6 +170,10 @@ func (s *Stream) closeInternal() {
s.cond.Broadcast()
s.writeBuffer.Close()
streamsCount.Dec()

if s.onClose != nil {
s.onClose(s.startTime, s.bytesRead, s.writeBuffer.bytesWritten)
}
})
}

Expand Down
30 changes: 28 additions & 2 deletions tunnel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// limitations under the Licen
package httptun

import (
Expand All @@ -28,6 +27,7 @@ import (
"time"

"github.com/cockroachdb/errors"
"github.com/gofrs/uuid"
"github.com/gorilla/websocket"
"go.uber.org/zap/zaptest"
"golang.org/x/net/nettest"
Expand All @@ -42,8 +42,22 @@ func TestTunnel(t *testing.T) {
// Start the upstream TCP server.
ln := listen(t)

type closeEvent struct {
bytesRead int64
bytesWritten int64
}

closeEvents := make(chan closeEvent, 10)

// Start the tunnel server, which uses ln as its upstream.
srv := NewServer(ln.Addr(), time.Millisecond*200, zaptest.NewLogger(t).Sugar().Named("server"))
srv.OnStreamClose(func(streamID uuid.UUID, startTime time.Time, bytesRead, bytesWritten int64) {
closeEvents <- closeEvent{
bytesRead: bytesRead,
bytesWritten: bytesWritten,
}
})

defer srv.Close()
// Make a test HTTP server with the standard library.
httpServer := httptest.NewServer(srv)
Expand Down Expand Up @@ -86,12 +100,24 @@ func TestTunnel(t *testing.T) {
assertWrite(t, srcTwo, []byte("ping on new conn"))
assertRead(t, []byte("ping on new conn"), dstTwo)

expectCloseEvent := func(bytesRead, bytesWritten int64) {
event := <-closeEvents
if event.bytesRead != bytesRead {
t.Fatalf("want bytes read %d, got %d", bytesRead, event.bytesRead)
}
if event.bytesWritten != bytesWritten {
t.Fatalf("want bytes written %d, got %d", bytesWritten, event.bytesWritten)
}
}

// Close the connection from the upstream side to test the downstream side is closed.
dstTwo.Close()
assertClosed(t, srcTwo)
expectCloseEvent(16, 0)
// Close the connection from the downstream side to test the upstream side is closed.
srcOne.Close()
assertClosed(t, dstOne)
expectCloseEvent(14, 4)

assertEqual(t, 0, ln.UnhandledConns())

Expand Down

0 comments on commit 8b3e927

Please sign in to comment.