diff --git a/build/build.go b/build/build.go index 194b1897a81..20c54ea8498 100644 --- a/build/build.go +++ b/build/build.go @@ -135,6 +135,218 @@ func filterAvailableNodes(nodes []builder.Node) ([]builder.Node, error) { return nil, err } +type driverPair struct { + driverIndex int + platforms []specs.Platform + so *client.SolveOpt + bopts gateway.BuildOpts +} + +func driverIndexes(m map[string][]driverPair) []int { + out := make([]int, 0, len(m)) + visited := map[int]struct{}{} + for _, dp := range m { + for _, d := range dp { + if _, ok := visited[d.driverIndex]; ok { + continue + } + visited[d.driverIndex] = struct{}{} + out = append(out, d.driverIndex) + } + } + return out +} + +func allIndexes(l int) []int { + out := make([]int, 0, l) + for i := 0; i < l; i++ { + out = append(out, i) + } + return out +} + +func ensureBooted(ctx context.Context, nodes []builder.Node, idxs []int, pw progress.Writer) ([]*client.Client, error) { + clients := make([]*client.Client, len(nodes)) + + baseCtx := ctx + eg, ctx := errgroup.WithContext(ctx) + + for _, i := range idxs { + func(i int) { + eg.Go(func() error { + c, err := driver.Boot(ctx, baseCtx, nodes[i].Driver, pw) + if err != nil { + return err + } + clients[i] = c + return nil + }) + }(i) + } + + if err := eg.Wait(); err != nil { + return nil, err + } + + return clients, nil +} + +func splitToDriverPairs(availablePlatforms map[string]int, opt map[string]Options) map[string][]driverPair { + m := map[string][]driverPair{} + for k, opt := range opt { + mm := map[int][]specs.Platform{} + for _, p := range opt.Platforms { + k := platforms.Format(p) + idx := availablePlatforms[k] // default 0 + pp := mm[idx] + pp = append(pp, p) + mm[idx] = pp + } + // if no platform is specified, use first driver + if len(mm) == 0 { + mm[0] = nil + } + dps := make([]driverPair, 0, 2) + for idx, pp := range mm { + dps = append(dps, driverPair{driverIndex: idx, platforms: pp}) + } + m[k] = dps + } + return m +} + +func resolveDrivers(ctx context.Context, nodes []builder.Node, opt map[string]Options, pw progress.Writer) (map[string][]driverPair, []*client.Client, error) { + dps, clients, err := resolveDriversBase(ctx, nodes, opt, pw) + if err != nil { + return nil, nil, err + } + + bopts := make([]gateway.BuildOpts, len(clients)) + + span, ctx := tracing.StartSpan(ctx, "load buildkit capabilities", trace.WithSpanKind(trace.SpanKindInternal)) + + eg, ctx := errgroup.WithContext(ctx) + for i, c := range clients { + if c == nil { + continue + } + + func(i int, c *client.Client) { + eg.Go(func() error { + clients[i].Build(ctx, client.SolveOpt{ + Internal: true, + }, "buildx", func(ctx context.Context, c gateway.Client) (*gateway.Result, error) { + bopts[i] = c.BuildOpts() + return nil, nil + }, nil) + return nil + }) + }(i, c) + } + + err = eg.Wait() + tracing.FinishWithError(span, err) + if err != nil { + return nil, nil, err + } + for key := range dps { + for i, dp := range dps[key] { + dps[key][i].bopts = bopts[dp.driverIndex] + } + } + + return dps, clients, nil +} + +func resolveDriversBase(ctx context.Context, nodes []builder.Node, opt map[string]Options, pw progress.Writer) (map[string][]driverPair, []*client.Client, error) { + availablePlatforms := map[string]int{} + for i, node := range nodes { + for _, p := range node.Platforms { + availablePlatforms[platforms.Format(p)] = i + } + } + + undetectedPlatform := false + allPlatforms := map[string]struct{}{} + for _, opt := range opt { + for _, p := range opt.Platforms { + k := platforms.Format(p) + allPlatforms[k] = struct{}{} + if _, ok := availablePlatforms[k]; !ok { + undetectedPlatform = true + } + } + } + + // fast path + if len(nodes) == 1 || len(allPlatforms) == 0 { + m := map[string][]driverPair{} + for k, opt := range opt { + m[k] = []driverPair{{driverIndex: 0, platforms: opt.Platforms}} + } + clients, err := ensureBooted(ctx, nodes, driverIndexes(m), pw) + if err != nil { + return nil, nil, err + } + return m, clients, nil + } + + // map based on existing platforms + if !undetectedPlatform { + m := splitToDriverPairs(availablePlatforms, opt) + clients, err := ensureBooted(ctx, nodes, driverIndexes(m), pw) + if err != nil { + return nil, nil, err + } + return m, clients, nil + } + + // boot all drivers in k + clients, err := ensureBooted(ctx, nodes, allIndexes(len(nodes)), pw) + if err != nil { + return nil, nil, err + } + + eg, ctx := errgroup.WithContext(ctx) + workers := make([][]*client.WorkerInfo, len(clients)) + + for i, c := range clients { + if c == nil { + continue + } + func(i int) { + eg.Go(func() error { + ww, err := clients[i].ListWorkers(ctx) + if err != nil { + return errors.Wrap(err, "listing workers") + } + workers[i] = ww + return nil + }) + + }(i) + } + + if err := eg.Wait(); err != nil { + return nil, nil, err + } + + for i, ww := range workers { + for _, w := range ww { + for _, p := range w.Platforms { + p = platforms.Normalize(p) + ps := platforms.Format(p) + + if _, ok := availablePlatforms[ps]; !ok { + availablePlatforms[ps] = i + } + } + } + } + + return splitToDriverPairs(availablePlatforms, opt), clients, nil +} + func toRepoOnly(in string) (string, error) { m := map[string]struct{}{} p := strings.Split(in, ",") @@ -505,14 +717,10 @@ func BuildWithResultHandler(ctx context.Context, nodes []builder.Node, opt map[s } } - drivers, err := resolveDrivers(ctx, nodes, opt, w) + m, clients, err := resolveDrivers(ctx, nodes, opt, w) if err != nil { return nil, err } - driversSolveOpts := make(map[string][]*client.SolveOpt, len(drivers)) - for k, dps := range drivers { - driversSolveOpts[k] = make([]*client.SolveOpt, len(dps)) - } defers := make([]func(), 0, 2) defer func() { @@ -526,33 +734,30 @@ func BuildWithResultHandler(ctx context.Context, nodes []builder.Node, opt map[s eg, ctx := errgroup.WithContext(ctx) for k, opt := range opt { - multiDriver := len(drivers[k]) > 1 + multiDriver := len(m[k]) > 1 hasMobyDriver := false gitattrs, err := getGitAttributes(ctx, opt.Inputs.ContextPath, opt.Inputs.DockerfilePath) if err != nil { logrus.WithError(err).Warn("current commit information was not captured by the build") } - for i, np := range drivers[k] { - if np.Node().Driver.IsMobyDriver() { + for i, np := range m[k] { + node := nodes[np.driverIndex] + if node.Driver.IsMobyDriver() { hasMobyDriver = true } opt.Platforms = np.platforms - gatewayOpts, err := np.BuildOpts(ctx) + so, release, err := toSolveOpt(ctx, node, multiDriver, opt, np.bopts, configDir, w, docker) if err != nil { return nil, err } - so, release, err := toSolveOpt(ctx, np.Node(), multiDriver, opt, gatewayOpts, configDir, w, docker) - if err != nil { - return nil, err - } - if err := saveLocalState(so, k, opt, np.Node(), configDir); err != nil { + if err := saveLocalState(so, k, opt, node, configDir); err != nil { return nil, err } for k, v := range gitattrs { so.FrontendAttrs[k] = v } defers = append(defers, release) - driversSolveOpts[k][i] = so + m[k][i].so = so } for _, at := range opt.Session { if s, ok := at.(interface { @@ -566,8 +771,8 @@ func BuildWithResultHandler(ctx context.Context, nodes []builder.Node, opt map[s // validate for multi-node push if hasMobyDriver && multiDriver { - for _, so := range driversSolveOpts[k] { - for _, e := range so.Exports { + for _, dp := range m[k] { + for _, e := range dp.so.Exports { if e.Type == "moby" { if ok, _ := strconv.ParseBool(e.Attrs["push"]); ok { return nil, errors.Errorf("multi-node push can't currently be performed with the docker driver, please switch to a different driver") @@ -580,13 +785,12 @@ func BuildWithResultHandler(ctx context.Context, nodes []builder.Node, opt map[s // validate that all links between targets use same drivers for name := range opt { - dps := drivers[name] - for i, dp := range dps { - so := driversSolveOpts[name][i] - for k, v := range so.FrontendAttrs { + dps := m[name] + for _, dp := range dps { + for k, v := range dp.so.FrontendAttrs { if strings.HasPrefix(k, "context:") && strings.HasPrefix(v, "target:") { k2 := strings.TrimPrefix(v, "target:") - dps2, ok := drivers[k2] + dps2, ok := m[k2] if !ok { return nil, errors.Errorf("failed to find target %s for context %s", k2, strings.TrimPrefix(k, "context:")) // should be validated before already } @@ -610,13 +814,13 @@ func BuildWithResultHandler(ctx context.Context, nodes []builder.Node, opt map[s results := waitmap.New() multiTarget := len(opt) > 1 - childTargets := calculateChildTargets(drivers, driversSolveOpts, opt) + childTargets := calculateChildTargets(m, opt) for k, opt := range opt { err := func(k string) error { opt := opt - dps := drivers[k] - multiDriver := len(drivers[k]) > 1 + dps := m[k] + multiDriver := len(m[k]) > 1 var span trace.Span ctx := ctx @@ -632,9 +836,8 @@ func BuildWithResultHandler(ctx context.Context, nodes []builder.Node, opt map[s var insecurePush bool for i, dp := range dps { - i, dp := i, dp - node := dp.Node() - so := driversSolveOpts[k][i] + i, dp, so := i, dp, *dp.so + node := nodes[dp.driverIndex] if multiDriver { for i, e := range so.Exports { switch e.Type { @@ -665,14 +868,11 @@ func BuildWithResultHandler(ctx context.Context, nodes []builder.Node, opt map[s pw := progress.WithPrefix(w, k, multiTarget) - c, err := dp.Client(ctx) - if err != nil { - return err - } + c := clients[dp.driverIndex] eg2.Go(func() error { pw = progress.ResetTime(pw) - if err := waitContextDeps(ctx, dp.driverIndex, results, so); err != nil { + if err := waitContextDeps(ctx, dp.driverIndex, results, &so); err != nil { return err } @@ -770,10 +970,10 @@ func BuildWithResultHandler(ctx context.Context, nodes []builder.Node, opt map[s var rr *client.SolveResponse if resultHandleFunc != nil { var resultHandle *ResultHandle - resultHandle, rr, err = NewResultHandle(ctx, cc, *so, "buildx", buildFunc, ch) + resultHandle, rr, err = NewResultHandle(ctx, cc, so, "buildx", buildFunc, ch) resultHandleFunc(dp.driverIndex, resultHandle) } else { - rr, err = c.Build(ctx, *so, "buildx", buildFunc, ch) + rr, err = c.Build(ctx, so, "buildx", buildFunc, ch) } if desktop.BuildBackendEnabled() && node.Driver.HistoryAPISupported(ctx) { buildRef := fmt.Sprintf("%s/%s/%s", node.Builder, node.Name, so.Ref) @@ -797,7 +997,7 @@ func BuildWithResultHandler(ctx context.Context, nodes []builder.Node, opt map[s rr.ExporterResponse[k] = string(v) } - node := dp.Node().Driver + node := nodes[dp.driverIndex].Driver if node.IsMobyDriver() { for _, e := range so.Exports { if e.Type == "moby" && e.Attrs["push"] != "" { @@ -889,7 +1089,7 @@ func BuildWithResultHandler(ctx context.Context, nodes []builder.Node, opt map[s if len(descs) > 0 { var imageopt imagetools.Opt for _, dp := range dps { - imageopt = dp.Node().ImageOpt + imageopt = nodes[dp.driverIndex].ImageOpt break } names := strings.Split(pushNames, ",") @@ -1303,16 +1503,16 @@ func resultKey(index int, name string) string { } // calculateChildTargets returns all the targets that depend on current target for reverse index -func calculateChildTargets(drivers map[string][]*resolvedNode, driversSolveOpts map[string][]*client.SolveOpt, opt map[string]Options) map[string][]string { +func calculateChildTargets(drivers map[string][]driverPair, opt map[string]Options) map[string][]string { out := make(map[string][]string) - for name := range opt { - dps := drivers[name] - for i, dp := range dps { - so := driversSolveOpts[name][i] + for src := range opt { + dps := drivers[src] + for _, dp := range dps { + so := *dp.so for k, v := range so.FrontendAttrs { if strings.HasPrefix(k, "context:") && strings.HasPrefix(v, "target:") { target := resultKey(dp.driverIndex, strings.TrimPrefix(v, "target:")) - out[target] = append(out[target], resultKey(dp.driverIndex, name)) + out[target] = append(out[target], resultKey(dp.driverIndex, src)) } } } diff --git a/build/driver.go b/build/driver.go deleted file mode 100644 index 9e89b4407ee..00000000000 --- a/build/driver.go +++ /dev/null @@ -1,305 +0,0 @@ -package build - -import ( - "context" - "fmt" - - "github.com/containerd/containerd/platforms" - "github.com/docker/buildx/builder" - "github.com/docker/buildx/driver" - "github.com/docker/buildx/util/progress" - "github.com/moby/buildkit/client" - gateway "github.com/moby/buildkit/frontend/gateway/client" - "github.com/moby/buildkit/util/flightcontrol" - "github.com/moby/buildkit/util/tracing" - specs "github.com/opencontainers/image-spec/specs-go/v1" - "github.com/pkg/errors" - "go.opentelemetry.io/otel/trace" - "golang.org/x/sync/errgroup" -) - -type resolvedNode struct { - resolver *nodeResolver - driverIndex int - platforms []specs.Platform -} - -func (dp resolvedNode) Node() builder.Node { - return dp.resolver.nodes[dp.driverIndex] -} - -func (dp resolvedNode) Client(ctx context.Context) (*client.Client, error) { - clients, err := dp.resolver.boot(ctx, []int{dp.driverIndex}, nil) - if err != nil { - return nil, err - } - return clients[0], nil -} - -func (dp resolvedNode) BuildOpts(ctx context.Context) (gateway.BuildOpts, error) { - opts, err := dp.resolver.opts(ctx, []int{dp.driverIndex}, nil) - if err != nil { - return gateway.BuildOpts{}, err - } - return opts[0], nil -} - -type matchMaker func(specs.Platform) platforms.MatchComparer - -type nodeResolver struct { - nodes []builder.Node - clients flightcontrol.Group[*client.Client] - opt flightcontrol.Group[gateway.BuildOpts] -} - -func resolveDrivers(ctx context.Context, nodes []builder.Node, opt map[string]Options, pw progress.Writer) (map[string][]*resolvedNode, error) { - driverRes := newDriverResolver(nodes) - drivers, err := driverRes.Resolve(ctx, opt, pw) - if err != nil { - return nil, err - } - return drivers, err -} - -func newDriverResolver(nodes []builder.Node) *nodeResolver { - r := &nodeResolver{ - nodes: nodes, - } - return r -} - -func (r *nodeResolver) Resolve(ctx context.Context, opt map[string]Options, pw progress.Writer) (map[string][]*resolvedNode, error) { - if len(r.nodes) == 0 { - return nil, nil - } - - nodes := map[string][]*resolvedNode{} - for k, opt := range opt { - node, perfect, err := r.resolve(ctx, opt.Platforms, pw, platforms.OnlyStrict, nil) - if err != nil { - return nil, err - } - if !perfect { - break - } - nodes[k] = node - } - if len(nodes) != len(opt) { - // if we didn't get a perfect match, we need to boot all drivers - allIndexes := make([]int, len(r.nodes)) - for i := range allIndexes { - allIndexes[i] = i - } - - clients, err := r.boot(ctx, allIndexes, pw) - if err != nil { - return nil, err - } - eg, egCtx := errgroup.WithContext(ctx) - workers := make([][]specs.Platform, len(clients)) - for i, c := range clients { - i, c := i, c - if c == nil { - continue - } - eg.Go(func() error { - ww, err := c.ListWorkers(egCtx) - if err != nil { - return errors.Wrap(err, "listing workers") - } - - ps := make(map[string]specs.Platform, len(ww)) - for _, w := range ww { - for _, p := range w.Platforms { - pk := platforms.Format(platforms.Normalize(p)) - ps[pk] = p - } - } - for _, p := range ps { - workers[i] = append(workers[i], p) - } - return nil - }) - } - if err := eg.Wait(); err != nil { - return nil, err - } - - // then we can attempt to match against all the available platforms - // (this time we don't care about imperfect matches) - nodes = map[string][]*resolvedNode{} - for k, opt := range opt { - node, _, err := r.resolve(ctx, opt.Platforms, pw, platforms.Only, func(idx int, n builder.Node) []specs.Platform { - return workers[idx] - }) - if err != nil { - return nil, err - } - nodes[k] = node - } - } - - idxs := make([]int, 0, len(r.nodes)) - for _, nodes := range nodes { - for _, node := range nodes { - idxs = append(idxs, node.driverIndex) - } - } - - // preload capabilities - span, ctx := tracing.StartSpan(ctx, "load buildkit capabilities", trace.WithSpanKind(trace.SpanKindInternal)) - _, err := r.opts(ctx, idxs, pw) - tracing.FinishWithError(span, err) - if err != nil { - return nil, err - } - - return nodes, nil -} - -func (r *nodeResolver) resolve(ctx context.Context, ps []specs.Platform, pw progress.Writer, matcher matchMaker, additional func(idx int, n builder.Node) []specs.Platform) ([]*resolvedNode, bool, error) { - if len(r.nodes) == 0 { - return nil, true, nil - } - - if len(ps) == 0 { - ps = []specs.Platform{platforms.DefaultSpec()} - } - - perfect := true - nodeIdxs := make([]int, 0) - for _, p := range ps { - idx := r.get(p, matcher, additional) - if idx == -1 { - idx = 0 - perfect = false - } - nodeIdxs = append(nodeIdxs, idx) - } - - var nodes []*resolvedNode - for i, idx := range nodeIdxs { - nodes = append(nodes, &resolvedNode{ - resolver: r, - driverIndex: idx, - platforms: []specs.Platform{ps[i]}, - }) - } - nodes = recombineNodes(nodes) - if _, err := r.boot(ctx, nodeIdxs, pw); err != nil { - return nil, false, err - } - return nodes, perfect, nil -} - -func (r *nodeResolver) get(p specs.Platform, matcher matchMaker, additionalPlatforms func(int, builder.Node) []specs.Platform) int { - best := -1 - bestPlatform := specs.Platform{} - for i, node := range r.nodes { - platforms := node.Platforms - if additionalPlatforms != nil { - platforms = append([]specs.Platform{}, platforms...) - platforms = append(platforms, additionalPlatforms(i, node)...) - } - for _, p2 := range platforms { - m := matcher(p2) - if !m.Match(p) { - continue - } - - if best == -1 { - best = i - bestPlatform = p2 - continue - } - if matcher(p2).Less(p, bestPlatform) { - best = i - bestPlatform = p2 - } - } - } - return best -} - -func (r *nodeResolver) boot(ctx context.Context, idxs []int, pw progress.Writer) ([]*client.Client, error) { - clients := make([]*client.Client, len(idxs)) - - baseCtx := ctx - eg, ctx := errgroup.WithContext(ctx) - - for i, idx := range idxs { - i, idx := i, idx - eg.Go(func() error { - c, err := r.clients.Do(ctx, fmt.Sprint(idx), func(ctx context.Context) (*client.Client, error) { - if r.nodes[idx].Driver == nil { - return nil, nil - } - return driver.Boot(ctx, baseCtx, r.nodes[idx].Driver, pw) - }) - if err != nil { - return err - } - clients[i] = c - return nil - }) - } - if err := eg.Wait(); err != nil { - return nil, err - } - - return clients, nil -} - -func (r *nodeResolver) opts(ctx context.Context, idxs []int, pw progress.Writer) ([]gateway.BuildOpts, error) { - clients, err := r.boot(ctx, idxs, pw) - if err != nil { - return nil, err - } - - bopts := make([]gateway.BuildOpts, len(clients)) - eg, ctx := errgroup.WithContext(ctx) - for i, idxs := range idxs { - i, idx := i, idxs - c := clients[i] - if c == nil { - continue - } - eg.Go(func() error { - opt, err := r.opt.Do(ctx, fmt.Sprint(idx), func(ctx context.Context) (gateway.BuildOpts, error) { - opt := gateway.BuildOpts{} - _, err := c.Build(ctx, client.SolveOpt{ - Internal: true, - }, "buildx", func(ctx context.Context, c gateway.Client) (*gateway.Result, error) { - opt = c.BuildOpts() - return nil, nil - }, nil) - return opt, err - }) - if err != nil { - return err - } - bopts[i] = opt - return nil - }) - } - if err := eg.Wait(); err != nil { - return nil, err - } - return bopts, nil -} - -// recombineDriverPairs recombines resolved nodes that are on the same driver -// back together into a single node. -func recombineNodes(nodes []*resolvedNode) []*resolvedNode { - result := make([]*resolvedNode, 0, len(nodes)) - lookup := map[int]int{} - for _, node := range nodes { - if idx, ok := lookup[node.driverIndex]; ok { - result[idx].platforms = append(result[idx].platforms, node.platforms...) - } else { - lookup[node.driverIndex] = len(result) - result = append(result, node) - } - } - return result -} diff --git a/build/driver_test.go b/build/driver_test.go deleted file mode 100644 index 068ec7eec37..00000000000 --- a/build/driver_test.go +++ /dev/null @@ -1,313 +0,0 @@ -package build - -import ( - "context" - "sort" - "testing" - - "github.com/containerd/containerd/platforms" - "github.com/docker/buildx/builder" - specs "github.com/opencontainers/image-spec/specs-go/v1" - "github.com/stretchr/testify/require" -) - -func TestFindDriverSanity(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.DefaultSpec()}, - }) - - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{platforms.DefaultSpec()}, nil, platforms.OnlyStrict, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, 0, res[0].driverIndex) - require.Equal(t, "aaa", res[0].Node().Builder) -} - -func TestFindDriverEmpty(t *testing.T) { - r := makeTestResolver(nil) - - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{platforms.DefaultSpec()}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Nil(t, res) -} - -func TestFindDriverWeirdName(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.MustParse("linux/amd64")}, - "bbb": {platforms.MustParse("linux/foobar")}, - }) - - // find first platform - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/foobar")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, 1, res[0].driverIndex) - require.Equal(t, "bbb", res[0].Node().Builder) -} - -func TestFindDriverUnknown(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.MustParse("linux/amd64")}, - }) - - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/riscv64")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.False(t, perfect) - require.Len(t, res, 1) - require.Equal(t, 0, res[0].driverIndex) - require.Equal(t, "aaa", res[0].Node().Builder) -} - -func TestSelectNodeSinglePlatform(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.MustParse("linux/amd64")}, - "bbb": {platforms.MustParse("linux/riscv64")}, - }) - - // find first platform - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/amd64")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, 0, res[0].driverIndex) - require.Equal(t, "aaa", res[0].Node().Builder) - - // find second platform - res, perfect, err = r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/riscv64")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, 1, res[0].driverIndex) - require.Equal(t, "bbb", res[0].Node().Builder) - - // find an unknown platform, should match the first driver - res, perfect, err = r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/s390x")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.False(t, perfect) - require.Len(t, res, 1) - require.Equal(t, 0, res[0].driverIndex) - require.Equal(t, "aaa", res[0].Node().Builder) -} - -func TestSelectNodeMultiPlatform(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.MustParse("linux/amd64"), platforms.MustParse("linux/arm64")}, - "bbb": {platforms.MustParse("linux/riscv64")}, - }) - - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/amd64")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, 0, res[0].driverIndex) - require.Equal(t, "aaa", res[0].Node().Builder) - - res, perfect, err = r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/arm64")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, 0, res[0].driverIndex) - require.Equal(t, "aaa", res[0].Node().Builder) - - res, perfect, err = r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/riscv64")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, 1, res[0].driverIndex) - require.Equal(t, "bbb", res[0].Node().Builder) -} - -func TestSelectNodeNonStrict(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.MustParse("linux/amd64")}, - "bbb": {platforms.MustParse("linux/arm64")}, - }) - - // arm64 should match itself - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/arm64")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, "bbb", res[0].Node().Builder) - - // arm64 may support arm/v8 - res, perfect, err = r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/arm/v8")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, "bbb", res[0].Node().Builder) - - // arm64 may support arm/v7 - res, perfect, err = r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/arm/v7")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, "bbb", res[0].Node().Builder) -} - -func TestSelectNodeNonStrictARM(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.MustParse("linux/amd64")}, - "bbb": {platforms.MustParse("linux/arm64")}, - "ccc": {platforms.MustParse("linux/arm/v8")}, - }) - - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/arm/v8")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, "ccc", res[0].Node().Builder) - - res, perfect, err = r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/arm/v7")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, "ccc", res[0].Node().Builder) -} - -func TestSelectNodeNonStrictLower(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.MustParse("linux/amd64")}, - "bbb": {platforms.MustParse("linux/arm/v7")}, - }) - - // v8 can't be built on v7 (so we should select the default)... - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/arm/v8")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.False(t, perfect) - require.Len(t, res, 1) - require.Equal(t, "aaa", res[0].Node().Builder) - - // ...but v6 can be built on v8 - res, perfect, err = r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/arm/v6")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, "bbb", res[0].Node().Builder) -} - -func TestSelectNodePreferStart(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.MustParse("linux/amd64")}, - "bbb": {platforms.MustParse("linux/riscv64")}, - "ccc": {platforms.MustParse("linux/riscv64")}, - }) - - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/riscv64")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, "bbb", res[0].Node().Builder) -} - -func TestSelectNodePreferExact(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.MustParse("linux/arm/v8")}, - "bbb": {platforms.MustParse("linux/arm/v7")}, - }) - - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/arm/v7")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, "bbb", res[0].Node().Builder) -} - -func TestSelectNodeCurrentPlatform(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.MustParse("linux/foobar")}, - "bbb": {platforms.DefaultSpec()}, - }) - - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, "bbb", res[0].Node().Builder) -} - -func TestSelectNodeAdditionalPlatforms(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.MustParse("linux/amd64")}, - "bbb": {platforms.MustParse("linux/arm/v8")}, - }) - - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/arm/v7")}, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, "bbb", res[0].Node().Builder) - - res, perfect, err = r.resolve(context.TODO(), []specs.Platform{platforms.MustParse("linux/arm/v7")}, nil, platforms.Only, func(idx int, n builder.Node) []specs.Platform { - if n.Builder == "aaa" { - return []specs.Platform{platforms.MustParse("linux/arm/v7")} - } - return nil - }) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, "aaa", res[0].Node().Builder) -} - -func TestSplitNodeMultiPlatform(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.MustParse("linux/amd64"), platforms.MustParse("linux/arm64")}, - "bbb": {platforms.MustParse("linux/riscv64")}, - }) - - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{ - platforms.MustParse("linux/amd64"), - platforms.MustParse("linux/arm64"), - }, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 1) - require.Equal(t, "aaa", res[0].Node().Builder) - - res, perfect, err = r.resolve(context.TODO(), []specs.Platform{ - platforms.MustParse("linux/amd64"), - platforms.MustParse("linux/riscv64"), - }, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 2) - require.Equal(t, "aaa", res[0].Node().Builder) - require.Equal(t, "bbb", res[1].Node().Builder) -} - -func TestSplitNodeMultiPlatformNoUnify(t *testing.T) { - r := makeTestResolver(map[string][]specs.Platform{ - "aaa": {platforms.MustParse("linux/amd64")}, - "bbb": {platforms.MustParse("linux/amd64"), platforms.MustParse("linux/riscv64")}, - }) - - // the "best" choice would be the node with both platforms, but we're using - // a naive algorithm that doesn't try to unify the platforms - res, perfect, err := r.resolve(context.TODO(), []specs.Platform{ - platforms.MustParse("linux/amd64"), - platforms.MustParse("linux/riscv64"), - }, nil, platforms.Only, nil) - require.NoError(t, err) - require.True(t, perfect) - require.Len(t, res, 2) - require.Equal(t, "aaa", res[0].Node().Builder) - require.Equal(t, "bbb", res[1].Node().Builder) -} - -func makeTestResolver(nodes map[string][]specs.Platform) *nodeResolver { - var ns []builder.Node - for name, platforms := range nodes { - ns = append(ns, builder.Node{ - Builder: name, - Platforms: platforms, - }) - } - sort.Slice(ns, func(i, j int) bool { - return ns[i].Builder < ns[j].Builder - }) - return newDriverResolver(ns) -}