From d1c13e039a29ccbc085e2d3ca8451f83825e8d32 Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Tue, 9 Apr 2024 09:53:09 -0600 Subject: [PATCH] fix: return a friendly error if the dialer is closed (#766) This is a port of https://github.com/GoogleCloudPlatform/alloydb-go-connector/pull/538 --- dialer.go | 23 +++++++++++++++++++++-- dialer_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/dialer.go b/dialer.go index 1d7fd909..7f7f0747 100644 --- a/dialer.go +++ b/dialer.go @@ -53,6 +53,9 @@ const ( ) var ( + // ErrDialerClosed is used when a caller invokes Dial after closing the + // Dialer. + ErrDialerClosed = errors.New("cloudsqlconn: dialer is closed") // versionString indicates the version of this library. //go:embed version.txt versionString string @@ -91,8 +94,11 @@ type Dialer struct { instances map[instance.ConnName]connectionInfoCache key *rsa.PrivateKey refreshTimeout time.Duration - sqladmin *sqladmin.Service - logger debug.Logger + // closed reports if the dialer has been closed. + closed chan struct{} + + sqladmin *sqladmin.Service + logger debug.Logger // defaultDialConfig holds the constructor level DialOptions, so that it // can be copied and mutated by the Dial function. @@ -210,6 +216,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { return nil, err } d := &Dialer{ + closed: make(chan struct{}), instances: make(map[instance.ConnName]connectionInfoCache), key: cfg.rsaKey, refreshTimeout: cfg.refreshTimeout, @@ -227,6 +234,11 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { // icn argument must be the instance's connection name, which is in the format // "project-name:region:instance-name". func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn net.Conn, err error) { + select { + case <-d.closed: + return nil, ErrDialerClosed + default: + } startTime := time.Now() var endDial trace.EndSpanFunc ctx, endDial = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn.Dial", @@ -420,6 +432,13 @@ func (i *instrumentedConn) Close() error { // needed to connect. Additional dial operations may succeed until the information // expires. func (d *Dialer) Close() error { + // Check if Close has already been called. + select { + case <-d.closed: + return nil + default: + } + close(d.closed) d.lock.Lock() defer d.lock.Unlock() for _, i := range d.instances { diff --git a/dialer_test.go b/dialer_test.go index 8a3d7960..c894eac3 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -795,3 +795,27 @@ func TestDialerSupportsOneOffDialFunction(t *testing.T) { t.Fatal("one-off dial func was not called") } } + +func TestDialerCloseReportsFriendlyError(t *testing.T) { + d, err := NewDialer( + context.Background(), + WithTokenSource(mock.EmptyTokenSource{}), + ) + if err != nil { + t.Fatal(err) + } + _ = d.Close() + + _, err = d.Dial(context.Background(), "p:r:i") + if !errors.Is(err, ErrDialerClosed) { + t.Fatalf("want = %v, got = %v", ErrDialerClosed, err) + } + + // Ensure multiple calls to close don't panic + _ = d.Close() + + _, err = d.Dial(context.Background(), "p:r:i") + if !errors.Is(err, ErrDialerClosed) { + t.Fatalf("want = %v, got = %v", ErrDialerClosed, err) + } +}