Skip to content

Commit

Permalink
feat: make invalidations execution inside 1 query concurrent.
Browse files Browse the repository at this point in the history
It is a low-hanging fruit that can boost performance for mutations
that have multiple invalidation.
  • Loading branch information
Stumble committed Jul 11, 2024
1 parent 3991503 commit c62907b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 25 deletions.
1 change: 1 addition & 0 deletions internal/codegen/golang/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ func (i *importer) queryImports(filename string) fileImports {
std["encoding/json"] = struct{}{}
std["crypto/sha256"] = struct{}{}
std["encoding/hex"] = struct{}{}
std["sync"] = struct{}{}
}

sqlpkg := parseDriver(i.Settings.Go.SqlPackage)
Expand Down
76 changes: 51 additions & 25 deletions internal/codegen/golang/templates/wpgx/queryCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,12 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}} {{.Invalida
{{ if .Option.Invalidates -}}
// invalidate
_ = q.db.PostExec(func() error {
var anyErr error
anyErr := make(chan error, {{len .Invalidates}})
var wg sync.WaitGroup
wg.Add({{len .Invalidates}})
{{ range .Invalidates -}}
{
go func() {
defer wg.Done()
{{ if not .NoArg -}}
if {{.ArgName}} != nil {
{{ end -}}
Expand All @@ -89,14 +92,16 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}} {{.Invalida
if err != nil {
log.Ctx(ctx).Error().Err(err).Msgf(
"Failed to invalidate: %s", key)
anyErr = err
anyErr <- err
}
{{ if not .NoArg -}}
}
{{ end -}}
}
}()
{{ end -}}
return anyErr
wg.Wait()
close(anyErr)
return <-anyErr
})
{{- end }}
return {{.Ret.Name}}, err
Expand Down Expand Up @@ -166,9 +171,12 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}} {{.Invalida
{{ if .Option.Invalidates -}}
// invalidate
_ = q.db.PostExec(func() error {
var anyErr error
anyErr := make(chan error, {{len .Invalidates}})
var wg sync.WaitGroup
wg.Add({{len .Invalidates}})
{{ range .Invalidates -}}
{
go func() {
defer wg.Done()
{{ if not .NoArg -}}
if {{.ArgName}} != nil {
{{ end -}}
Expand All @@ -177,14 +185,16 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}} {{.Invalida
if err != nil {
log.Ctx(ctx).Error().Err(err).Msgf(
"Failed to invalidate: %s", key)
anyErr = err
anyErr <- err
}
{{ if not .NoArg -}}
}
{{ end -}}
}
}()
{{ end -}}
return anyErr
wg.Wait()
close(anyErr)
return <-anyErr
})
{{- end }}
return items, err
Expand All @@ -206,9 +216,12 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}} {{.Invalida
{{ if .Option.Invalidates -}}
// invalidate
_ = q.db.PostExec(func() error {
var anyErr error
anyErr := make(chan error, {{len .Invalidates}})
var wg sync.WaitGroup
wg.Add({{len .Invalidates}})
{{ range .Invalidates -}}
{
go func() {
defer wg.Done()
{{ if not .NoArg -}}
if {{.ArgName}} != nil {
{{ end -}}
Expand All @@ -217,14 +230,16 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}} {{.Invalida
if err != nil {
log.Ctx(ctx).Error().Err(err).Msgf(
"Failed to invalidate: %s", key)
anyErr = err
anyErr <- err
}
{{ if not .NoArg -}}
}
{{ end -}}
}
}()
{{ end -}}
return anyErr
wg.Wait()
close(anyErr)
return <-anyErr
})
{{- end }}
return nil
Expand All @@ -246,9 +261,12 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}} {{.Invalida
{{ if .Option.Invalidates -}}
// invalidate
_ = q.db.PostExec(func() error {
var anyErr error
anyErr := make(chan error, {{len .Invalidates}})
var wg sync.WaitGroup
wg.Add({{len .Invalidates}})
{{ range .Invalidates -}}
{
go func() {
defer wg.Done()
{{ if not .NoArg -}}
if {{.ArgName}} != nil {
{{ end -}}
Expand All @@ -257,14 +275,16 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}} {{.Invalida
if err != nil {
log.Ctx(ctx).Error().Err(err).Msgf(
"Failed to invalidate: %s", key)
anyErr = err
anyErr <- err
}
{{ if not .NoArg -}}
}
{{ end -}}
}
}()
{{ end -}}
return anyErr
wg.Wait()
close(anyErr)
return <-anyErr
})
{{- end }}
return result.RowsAffected(), nil
Expand All @@ -286,9 +306,12 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}} {{.Invalida
{{ if .Option.Invalidates -}}
// invalidate
_ = q.db.PostExec(func() error {
var anyErr error
anyErr := make(chan error, {{len .Invalidates}})
var wg sync.WaitGroup
wg.Add({{len .Invalidates}})
{{ range .Invalidates -}}
{
go func() {
defer wg.Done()
{{ if not .NoArg -}}
if {{.ArgName}} != nil {
{{ end -}}
Expand All @@ -297,14 +320,16 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}} {{.Invalida
if err != nil {
log.Ctx(ctx).Error().Err(err).Msgf(
"Failed to invalidate: %s", key)
anyErr = err
anyErr <- err
}
{{ if not .NoArg -}}
}
{{ end -}}
}
}()
{{ end -}}
return anyErr
wg.Wait()
close(anyErr)
return <-anyErr
})
{{- end }}
return rv, nil
Expand Down Expand Up @@ -383,5 +408,6 @@ var _ = time.Now()
var _ = json.RawMessage{}
var _ = sha256.Sum256(nil)
var _ = hex.EncodeToString(nil)
var _ = sync.WaitGroup{}

{{end}}

0 comments on commit c62907b

Please sign in to comment.