Skip to content

Commit

Permalink
change cache key to be either domain name or instance name
Browse files Browse the repository at this point in the history
  • Loading branch information
hessjcg committed Sep 19, 2024
1 parent 068a5f5 commit 55ada68
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 119 deletions.
89 changes: 37 additions & 52 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,19 @@ type connectionInfoCache interface {
io.Closer
}

type cacheKey struct {
domainName string
project string
region string
name string
}

// A Dialer is used to create connections to Cloud SQL instances.
//
// Use NewDialer to initialize a Dialer.
type Dialer struct {
lock sync.RWMutex
cache map[instance.ConnName]*monitoredCache
cache map[cacheKey]*monitoredCache
keyGenerator *keyGenerator
refreshTimeout time.Duration
// closed reports if the dialer has been closed.
Expand Down Expand Up @@ -258,7 +265,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {

d := &Dialer{
closed: make(chan struct{}),
cache: make(map[instance.ConnName]*monitoredCache),
cache: make(map[cacheKey]*monitoredCache),
lazyRefresh: cfg.lazyRefresh,
keyGenerator: g,
refreshTimeout: cfg.refreshTimeout,
Expand Down Expand Up @@ -295,6 +302,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
endDial(err)
}()
cn, err := d.resolver.Resolve(ctx, icn)
//TODO does this get logged?
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -389,7 +397,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn

// If this connection was opened using a Domain Name, then store it for later
// in case it needs to be forcibly closed.
if cn.DomainName() != "" {
if cn.HasDomainName() {
c.mu.Lock()
c.openConns = append(c.openConns, iConn)
c.mu.Unlock()
Expand All @@ -412,7 +420,7 @@ func (d *Dialer) removeCached(
d.lock.Lock()
defer d.lock.Unlock()
c.Close()
delete(d.cache, i)
delete(d.cache, createKey(i))
}

// validClientCert checks that the ephemeral client certificate retrieved from
Expand Down Expand Up @@ -564,21 +572,34 @@ func (d *Dialer) Close() error {
return nil
}

// createKey creates a key for the cache from an instance.ConnName.
// An instance.ConnName uniquely identifies a connection using
// project:region:instance + domainName. However, in the dialer cache,
// we want to to identify entries either by project:region:instance, or
// by domainName, but not the combination of the two.
func createKey(cn instance.ConnName) cacheKey {
if cn.HasDomainName() {
return cacheKey{domainName: cn.DomainName()}
}
return cacheKey{
name: cn.Name(),
project: cn.Project(),
region: cn.Region(),
}
}

// connectionInfoCache is a helper function for returning the appropriate
// connection info Cache in a threadsafe way. It will create a new cache,
// modify the existing one, or leave it unchanged as needed.
func (d *Dialer) connectionInfoCache(
ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
) (*monitoredCache, error) {
k := createKey(cn)

d.lock.RLock()
c, ok := d.cache[cn]
c, ok := d.cache[k]
d.lock.RUnlock()

// recheck the domain name, this may close the cache.
if ok {
c.checkDomainName(ctx)
}

if ok && !c.isClosed() {
c.UpdateRefresh(useIAMAuthN)
return c, nil
Expand All @@ -588,7 +609,7 @@ func (d *Dialer) connectionInfoCache(
defer d.lock.Unlock()

// Recheck to ensure instance wasn't created or changed between locks
c, ok = d.cache[cn]
c, ok = d.cache[k]

// c exists and is not closed
if ok && !c.isClosed() {
Expand All @@ -598,16 +619,7 @@ func (d *Dialer) connectionInfoCache(

// c exists and is closed, remove it from the cache
if ok {
// remove it.
_ = c.Close()
delete(d.cache, cn)
}

// c does not exist, check for matching domain and close it
oldCn, old, ok := d.findByDn(cn)
if ok {
_ = old.Close()
delete(d.cache, oldCn)
delete(d.cache, k)
}

// Create a new instance of monitoredCache
Expand All @@ -616,7 +628,7 @@ func (d *Dialer) connectionInfoCache(
useIAMAuthNDial = *useIAMAuthN
}
d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String())
k, err := d.keyGenerator.rsaKey()
rsaKey, err := d.keyGenerator.rsaKey()
if err != nil {
return nil, err
}
Expand All @@ -625,48 +637,21 @@ func (d *Dialer) connectionInfoCache(
cache = cloudsql.NewLazyRefreshCache(
cn,
d.logger,
d.sqladmin, k,
d.sqladmin, rsaKey,
d.refreshTimeout, d.iamTokenSource,
d.dialerID, useIAMAuthNDial,
)
} else {
cache = cloudsql.NewRefreshAheadCache(
cn,
d.logger,
d.sqladmin, k,
d.sqladmin, rsaKey,
d.refreshTimeout, d.iamTokenSource,
d.dialerID, useIAMAuthNDial,
)
}
c = newMonitoredCache(ctx, cache, cn, d.failoverPeriod, d.resolver, d.logger)
d.cache[cn] = c
d.cache[k] = c

return c, nil
}

// getOrAdd returns the cache entry, creating it if necessary. This will also
// take care to remove entries with the same domain name.
//
// cn - the connection name to getOrAdd
//
// returns:
//
// monitoredCache - the cached entry
// bool ok - the instance exists
// instance.ConnName - the key to the old entry with the same domain name
//
// This method does not manage locks.
func (d *Dialer) findByDn(cn instance.ConnName) (instance.ConnName, *monitoredCache, bool) {

// Try to get an instance with the same domain name but different instance
// Remove this instance from the cache, it will be replaced.
if cn.HasDomainName() {
for oldCn, oc := range d.cache {
if oldCn.DomainName() == cn.DomainName() && oldCn != cn {
return oldCn, oc, true
}
}
}

return instance.ConnName{}, nil, false
}
77 changes: 10 additions & 67 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ func TestEngineVersionRemovesInvalidInstancesFromCache(t *testing.T) {
spy := &spyConnectionInfoCache{
connectInfoCalls: []connectionInfoResp{tc.resp},
}
d.cache[inst] = newMonitoredCache(nil, spy, inst, 0, nil, nil)
d.cache[createKey(inst)] = newMonitoredCache(nil, spy, inst, 0, nil, nil)

_, err = d.EngineVersion(context.Background(), tc.icn)
if err == nil {
Expand All @@ -491,7 +491,7 @@ func TestEngineVersionRemovesInvalidInstancesFromCache(t *testing.T) {

// Now verify that bad connection name has been deleted from map.
d.lock.RLock()
_, ok := d.cache[inst]
_, ok := d.cache[createKey(inst)]
d.lock.RUnlock()
if ok {
t.Fatal("connection info was not removed from cache")
Expand Down Expand Up @@ -625,7 +625,7 @@ func TestWarmupRemovesInvalidInstancesFromCache(t *testing.T) {
spy := &spyConnectionInfoCache{
connectInfoCalls: []connectionInfoResp{tc.resp},
}
d.cache[inst] = newMonitoredCache(nil, spy, inst, 0, nil, nil)
d.cache[createKey(inst)] = newMonitoredCache(nil, spy, inst, 0, nil, nil)

err = d.Warmup(context.Background(), tc.icn, tc.opts...)
if err == nil {
Expand All @@ -639,7 +639,7 @@ func TestWarmupRemovesInvalidInstancesFromCache(t *testing.T) {

// Now verify that bad connection name has been deleted from map.
d.lock.RLock()
_, ok := d.cache[inst]
_, ok := d.cache[createKey(inst)]
d.lock.RUnlock()
if ok {
t.Fatal("connection info was not removed from cache")
Expand Down Expand Up @@ -799,7 +799,7 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {
spy := &spyConnectionInfoCache{
connectInfoCalls: []connectionInfoResp{tc.resp},
}
d.cache[inst] = newMonitoredCache(nil, spy, inst, 0, nil, nil)
d.cache[createKey(inst)] = newMonitoredCache(nil, spy, inst, 0, nil, nil)

_, err = d.Dial(context.Background(), tc.icn, tc.opts...)
if err == nil {
Expand All @@ -813,7 +813,7 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {

// Now verify that bad connection name has been deleted from map.
d.lock.RLock()
_, ok := d.cache[inst]
_, ok := d.cache[createKey(inst)]
d.lock.RUnlock()
if ok {
t.Fatal("connection info was not removed from cache")
Expand Down Expand Up @@ -849,7 +849,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) {
},
},
}
d.cache[cn] = newMonitoredCache(nil, spy, cn, 0, nil, nil)
d.cache[createKey(cn)] = newMonitoredCache(nil, spy, cn, 0, nil, nil)

_, err = d.Dial(context.Background(), icn)
if !errors.Is(err, sentinel) {
Expand All @@ -869,7 +869,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) {

// Now verify that bad connection name has been deleted from map.
d.lock.RLock()
_, ok := d.cache[cn]
_, ok := d.cache[createKey(cn)]
d.lock.RUnlock()
if ok {
t.Fatal("bad instance was not removed from the cache")
Expand Down Expand Up @@ -1010,7 +1010,7 @@ func TestDialerInitializesLazyCache(t *testing.T) {
t.Fatal(err)
}

c, ok := d.cache[cn]
c, ok := d.cache[createKey(cn)]
if !ok {
t.Fatal("cache was not populated")
}
Expand Down Expand Up @@ -1103,63 +1103,6 @@ func (r *changingResolver) Resolve(_ context.Context, name string) (instance.Con
return instance.ConnName{}, fmt.Errorf("no resolution for %q", name)
}

func TestDialerUpdatesOnDialAfterDnsChange(t *testing.T) {
// At first, the resolver will resolve
// update.example.com to "my-instance"
// Then, the resolver will resolve the same domain name to
// "my-instance2".
// This shows that on every call to Dial(), the dialer will resolve the
// SRV record and connect to the correct instance.
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
)
inst2 := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance2",
)
r := &changingResolver{
stage: new(int32),
}

d := setupDialer(t, setupConfig{
skipServer: true,
reqs: []*mock.Request{
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
mock.InstanceGetSuccess(inst2, 1),
mock.CreateEphemeralSuccess(inst2, 1),
},
dialerOptions: []Option{
WithResolver(r),
WithTokenSource(mock.EmptyTokenSource{}),
},
})

// Start the proxy for instance 1
stop1 := mock.StartServerProxy(t, inst)
t.Cleanup(func() {
stop1()
})

testSuccessfulDial(
context.Background(), t, d,
"update.example.com",
)
stop1()

atomic.StoreInt32(r.stage, 1)

// Start the proxy for instance 2
stop2 := mock.StartServerProxy(t, inst2)
t.Cleanup(func() {
stop2()
})

testSucessfulDialWithInstanceName(
context.Background(), t, d,
"update.example.com", "my-instance2",
)
}

func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
// At first, the resolver will resolve
// update.example.com to "my-instance"
Expand Down Expand Up @@ -1207,7 +1150,7 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {

time.Sleep(1 * time.Second)
instCn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com")
c, _ := d.cache[instCn]
c, _ := d.cache[createKey(instCn)]
if !c.isClosed() {
t.Fatal("Expected monitoredCache to be closed after domain name changed. monitoredCache was not closed.")
}
Expand Down

0 comments on commit 55ada68

Please sign in to comment.