diff --git a/Makefile b/Makefile index a64ac393ca..efef26f7b9 100644 --- a/Makefile +++ b/Makefile @@ -106,7 +106,7 @@ help: ## Display this help. manifests: controller-gen ## Generate WebhookConfiguration, ClusterRole and CustomResourceDefinition objects. $(CONTROLLER_GEN) \ rbac:roleName=manager-role output:rbac:artifacts:config=config/components/rbac\ - crd output:crd:artifacts:config=config/components/crd/bases\ + crd:generateEmbeddedObjectMeta=true output:crd:artifacts:config=config/components/crd/bases\ webhook output:webhook:artifacts:config=config/components/webhook\ paths="./..." diff --git a/charts/kueue/templates/crd/kueue.x-k8s.io_workloads.yaml b/charts/kueue/templates/crd/kueue.x-k8s.io_workloads.yaml index 9ba326d0f0..4ea9d162a6 100644 --- a/charts/kueue/templates/crd/kueue.x-k8s.io_workloads.yaml +++ b/charts/kueue/templates/crd/kueue.x-k8s.io_workloads.yaml @@ -98,6 +98,23 @@ spec: properties: metadata: description: 'Standard object''s metadata. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#metadata' + properties: + annotations: + additionalProperties: + type: string + type: object + finalizers: + items: + type: string + type: array + labels: + additionalProperties: + type: string + type: object + name: + type: string + namespace: + type: string type: object spec: description: 'Specification of the desired behavior of the @@ -6880,6 +6897,23 @@ spec: that will be copied into the PVC when creating it. No other fields are allowed and will be rejected during validation. + properties: + annotations: + additionalProperties: + type: string + type: object + finalizers: + items: + type: string + type: array + labels: + additionalProperties: + type: string + type: object + name: + type: string + namespace: + type: string type: object spec: description: The specification for the diff --git a/config/components/crd/bases/kueue.x-k8s.io_workloads.yaml b/config/components/crd/bases/kueue.x-k8s.io_workloads.yaml index 0ede9e24d0..2dda7e5730 100644 --- a/config/components/crd/bases/kueue.x-k8s.io_workloads.yaml +++ b/config/components/crd/bases/kueue.x-k8s.io_workloads.yaml @@ -85,6 +85,23 @@ spec: properties: metadata: description: 'Standard object''s metadata. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#metadata' + properties: + annotations: + additionalProperties: + type: string + type: object + finalizers: + items: + type: string + type: array + labels: + additionalProperties: + type: string + type: object + name: + type: string + namespace: + type: string type: object spec: description: 'Specification of the desired behavior of the @@ -6867,6 +6884,23 @@ spec: that will be copied into the PVC when creating it. No other fields are allowed and will be rejected during validation. + properties: + annotations: + additionalProperties: + type: string + type: object + finalizers: + items: + type: string + type: array + labels: + additionalProperties: + type: string + type: object + name: + type: string + namespace: + type: string type: object spec: description: The specification for the diff --git a/pkg/controller/jobframework/podsetinfo.go b/pkg/controller/jobframework/podsetinfo.go new file mode 100644 index 0000000000..3caa39136a --- /dev/null +++ b/pkg/controller/jobframework/podsetinfo.go @@ -0,0 +1,92 @@ +/* +Copyright 2023 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package jobframework + +import ( + "maps" + "slices" + + corev1 "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + utilmaps "sigs.k8s.io/kueue/pkg/util/maps" +) + +type PodSetInfo struct { + Name string + Count int32 + Annotations map[string]string + Labels map[string]string + NodeSelector map[string]string + Tolerations []corev1.Toleration +} + +func (podSetInfo *PodSetInfo) Merge(o PodSetInfo) error { + if err := utilmaps.HaveConflict(podSetInfo.Annotations, o.Annotations); err != nil { + return BadPodSetsUpdateError("annotations", err) + } + if err := utilmaps.HaveConflict(podSetInfo.Labels, o.Labels); err != nil { + return BadPodSetsUpdateError("labels", err) + } + if err := utilmaps.HaveConflict(podSetInfo.NodeSelector, o.NodeSelector); err != nil { + return BadPodSetsUpdateError("nodeSelector", err) + } + podSetInfo.Annotations = utilmaps.MergeKeepFirst(podSetInfo.Annotations, o.Annotations) + podSetInfo.Labels = utilmaps.MergeKeepFirst(podSetInfo.Labels, o.Labels) + podSetInfo.NodeSelector = utilmaps.MergeKeepFirst(podSetInfo.NodeSelector, o.NodeSelector) + podSetInfo.Tolerations = append(podSetInfo.Tolerations, o.Tolerations...) + return nil +} + +// Merge updates or appends the replica metadata & spec fields based on PodSetInfo. +// If returns error if there is a conflict. +func Merge(meta *metav1.ObjectMeta, spec *v1.PodSpec, info PodSetInfo) error { + if err := info.Merge(PodSetInfo{ + Annotations: meta.Annotations, + Labels: meta.Labels, + NodeSelector: spec.NodeSelector, + Tolerations: spec.Tolerations, + }); err != nil { + return err + } + meta.Annotations = info.Annotations + meta.Labels = info.Labels + spec.NodeSelector = info.NodeSelector + spec.Tolerations = info.Tolerations + return nil +} + +// Restore sets replica metadata and spec fields based on PodSetInfo. +// It returns true if there is any change. +func Restore(meta *metav1.ObjectMeta, spec *v1.PodSpec, info PodSetInfo) bool { + changed := false + if !maps.Equal(meta.Annotations, info.Annotations) { + meta.Annotations = maps.Clone(info.Annotations) + changed = true + } + if !maps.Equal(meta.Labels, info.Labels) { + meta.Labels = maps.Clone(info.Labels) + changed = true + } + if !maps.Equal(spec.NodeSelector, info.NodeSelector) { + spec.NodeSelector = maps.Clone(info.NodeSelector) + changed = true + } + if !slices.Equal(spec.Tolerations, info.Tolerations) { + spec.Tolerations = slices.Clone(info.Tolerations) + changed = true + } + return changed +} diff --git a/pkg/controller/jobframework/reconciler.go b/pkg/controller/jobframework/reconciler.go index 05107ed2c5..a772dadee7 100644 --- a/pkg/controller/jobframework/reconciler.go +++ b/pkg/controller/jobframework/reconciler.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "maps" + "slices" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -40,7 +41,7 @@ import ( "sigs.k8s.io/kueue/pkg/util/equality" "sigs.k8s.io/kueue/pkg/util/kubeversion" utilpriority "sigs.k8s.io/kueue/pkg/util/priority" - "sigs.k8s.io/kueue/pkg/util/slices" + utilslices "sigs.k8s.io/kueue/pkg/util/slices" "sigs.k8s.io/kueue/pkg/workload" ) @@ -55,6 +56,7 @@ var ( ErrNoMatchingWorkloads = errors.New("no matching workloads") ErrExtraWorkloads = errors.New("extra workloads") ErrInvalidPodsetInfo = errors.New("invalid podset infos") + ErrInvalidPodSetUpdate = errors.New("invalid admission check PodSetUpdate") ) // JobReconciler reconciles a GenericJob object @@ -362,7 +364,7 @@ func (r *JobReconciler) ReconcileGenericJob(ctx context.Context, req ctrl.Reques } func isPermanent(e error) bool { - return errors.Is(e, ErrInvalidPodsetInfo) + return errors.Is(e, ErrInvalidPodsetInfo) || errors.Is(e, ErrInvalidPodSetUpdate) } // IsParentJobManaged checks whether the parent job is managed by kueue. @@ -513,7 +515,7 @@ func (r *JobReconciler) equivalentToWorkload(job GenericJob, object client.Objec // startJob will unsuspend the job, and also inject the node affinity. func (r *JobReconciler) startJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload) error { - info, err := r.getPodSetsInfoFromAdmission(ctx, wl) + info, err := r.getPodSetsInfoFromStatus(ctx, wl) if err != nil { return err } @@ -643,26 +645,23 @@ func extractPriorityFromPodSets(podSets []kueue.PodSet) string { return "" } -type PodSetInfo struct { - Name string `json:"name"` - NodeSelector map[string]string `json:"nodeSelector"` - Count int32 `json:"count"` -} - -// getPodSetsInfoFromAdmission will extract podSetsInfo and podSets count from admitted workloads. -func (r *JobReconciler) getPodSetsInfoFromAdmission(ctx context.Context, w *kueue.Workload) ([]PodSetInfo, error) { +// getPodSetsInfoFromStatus extracts podSetInfos from workload status, based on +// admission, and admission checks. +func (r *JobReconciler) getPodSetsInfoFromStatus(ctx context.Context, w *kueue.Workload) ([]PodSetInfo, error) { if len(w.Status.Admission.PodSetAssignments) == 0 { return nil, nil } - nodeSelectors := make([]PodSetInfo, len(w.Status.Admission.PodSetAssignments)) + podSetInfos := make([]PodSetInfo, len(w.Status.Admission.PodSetAssignments)) for i, podSetFlavor := range w.Status.Admission.PodSetAssignments { processedFlvs := sets.NewString() - nodeSelector := PodSetInfo{ + podSetInfo := PodSetInfo{ Name: podSetFlavor.Name, NodeSelector: make(map[string]string), Count: ptr.Deref(podSetFlavor.Count, w.Spec.PodSets[i].Count), + Labels: make(map[string]string), + Annotations: make(map[string]string), } for _, flvRef := range podSetFlavor.Flavors { flvName := string(flvRef) @@ -675,14 +674,27 @@ func (r *JobReconciler) getPodSetsInfoFromAdmission(ctx context.Context, w *kueu return nil, err } for k, v := range flv.Spec.NodeLabels { - nodeSelector.NodeSelector[k] = v + podSetInfo.NodeSelector[k] = v } processedFlvs.Insert(flvName) } - - nodeSelectors[i] = nodeSelector + for _, admissionCheck := range w.Status.AdmissionChecks { + for _, podSetUpdate := range admissionCheck.PodSetUpdates { + if podSetUpdate.Name == podSetInfo.Name { + if err := podSetInfo.Merge(PodSetInfo{ + Labels: podSetUpdate.Labels, + Annotations: podSetUpdate.Annotations, + Tolerations: podSetUpdate.Tolerations, + NodeSelector: podSetUpdate.NodeSelector, + }); err != nil { + return nil, fmt.Errorf("in admission check %q: %w", admissionCheck.Name, err) + } + } + } + } + podSetInfos[i] = podSetInfo } - return nodeSelectors, nil + return podSetInfos, nil } func (r *JobReconciler) handleJobWithNoWorkload(ctx context.Context, job GenericJob, object client.Object) error { @@ -734,11 +746,14 @@ func getPodSetsInfoFromWorkload(wl *kueue.Workload) []PodSetInfo { return nil } - return slices.Map(wl.Spec.PodSets, func(ps *kueue.PodSet) PodSetInfo { + return utilslices.Map(wl.Spec.PodSets, func(ps *kueue.PodSet) PodSetInfo { return PodSetInfo{ Name: ps.Name, - NodeSelector: maps.Clone(ps.Template.Spec.NodeSelector), Count: ps.Count, + Annotations: maps.Clone(ps.Template.Annotations), + Labels: maps.Clone(ps.Template.Labels), + NodeSelector: maps.Clone(ps.Template.Spec.NodeSelector), + Tolerations: slices.Clone(ps.Template.Spec.Tolerations), } }) } @@ -790,3 +805,7 @@ func resetMinCounts(in []kueue.PodSet) []kueue.PodSet { func BadPodSetsInfoLenError(want, got int) error { return fmt.Errorf("%w: expecting %d podset, got %d", ErrInvalidPodsetInfo, got, want) } + +func BadPodSetsUpdateError(update string, err error) error { + return fmt.Errorf("%w: conflict for %v: %v", ErrInvalidPodSetUpdate, update, err) +} diff --git a/pkg/controller/jobs/job/job_controller.go b/pkg/controller/jobs/job/job_controller.go index 388d9f8a56..bdc77b2bfb 100644 --- a/pkg/controller/jobs/job/job_controller.go +++ b/pkg/controller/jobs/job/job_controller.go @@ -19,12 +19,10 @@ package job import ( "context" "fmt" - "maps" "strconv" batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" @@ -41,7 +39,6 @@ import ( kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" "sigs.k8s.io/kueue/pkg/controller/jobframework" - utilmaps "sigs.k8s.io/kueue/pkg/util/maps" ) var ( @@ -225,7 +222,6 @@ func (j *Job) RunWithPodSetsInfo(podSetsInfo []jobframework.PodSetInfo) error { } info := podSetsInfo[0] - j.Spec.Template.Spec.NodeSelector = utilmaps.MergeKeepFirst(info.NodeSelector, j.Spec.Template.Spec.NodeSelector) if j.minPodsCount() != nil { j.Spec.Parallelism = ptr.To(info.Count) @@ -233,7 +229,7 @@ func (j *Job) RunWithPodSetsInfo(podSetsInfo []jobframework.PodSetInfo) error { j.Spec.Completions = j.Spec.Parallelism } } - return nil + return jobframework.Merge(&j.Spec.Template.ObjectMeta, &j.Spec.Template.Spec, info) } func (j *Job) RestorePodSetsInfo(podSetsInfo []jobframework.PodSetInfo) bool { @@ -250,12 +246,8 @@ func (j *Job) RestorePodSetsInfo(podSetsInfo []jobframework.PodSetInfo) bool { j.Spec.Completions = j.Spec.Parallelism } } - - if equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, podSetsInfo[0].NodeSelector) { - return changed - } - j.Spec.Template.Spec.NodeSelector = maps.Clone(podSetsInfo[0].NodeSelector) - return true + changed = jobframework.Restore(&j.Spec.Template.ObjectMeta, &j.Spec.Template.Spec, podSetsInfo[0]) || changed + return changed } func (j *Job) Finished() (metav1.Condition, bool) { diff --git a/pkg/controller/jobs/job/job_controller_test.go b/pkg/controller/jobs/job/job_controller_test.go index a490c3700c..240a41b3ac 100644 --- a/pkg/controller/jobs/job/job_controller_test.go +++ b/pkg/controller/jobs/job/job_controller_test.go @@ -210,9 +210,10 @@ func TestPodSetsInfo(t *testing.T) { }, }, }, + wantRunError: jobframework.ErrInvalidPodSetUpdate, wantUnsuspended: utiltestingjob.MakeJob("job", "ns"). Parallelism(1). - NodeSelector("orig-key", "new-val"). + NodeSelector("orig-key", "orig-val"). Suspend(false). Obj(), restoreInfo: []jobframework.PodSetInfo{ @@ -341,6 +342,7 @@ var ( "ObjectMeta.Name", "ObjectMeta.ResourceVersion", ), cmpopts.IgnoreFields(metav1.Condition{}, "LastTransitionTime"), + cmpopts.IgnoreFields(kueue.AdmissionCheckState{}, "LastTransitionTime"), } ) @@ -349,6 +351,7 @@ func TestReconciler(t *testing.T) { baseJobWrapper := utiltestingjob.MakeJob("job", "ns"). Suspend(true). + Queue("foo"). Parallelism(10). Request(corev1.ResourceCPU, "1"). Image("", nil) @@ -368,6 +371,438 @@ func TestReconciler(t *testing.T) { wantWorkloads []kueue.Workload wantErr error }{ + "when workload is admitted the PodSetUpdates are propagated to job": { + job: *baseJobWrapper.Clone(). + Obj(), + wantJob: *baseJobWrapper.Clone(). + Suspend(false). + PodLabel("ac-key", "ac-value"). + Obj(), + workloads: []kueue.Workload{ + *utiltesting.MakeWorkload("a", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota(utiltesting.MakeAdmission("cq").AssignmentPodCount(10).Obj()). + Admitted(true). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Labels: map[string]string{ + "ac-key": "ac-value", + }, + }, + }, + }). + Obj(), + }, + wantWorkloads: []kueue.Workload{ + *utiltesting.MakeWorkload("a", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota(utiltesting.MakeAdmission("cq").AssignmentPodCount(10).Obj()). + Admitted(true). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Labels: map[string]string{ + "ac-key": "ac-value", + }, + }, + }, + }). + Obj(), + }, + }, + "when workload is admitted and PodSetUpdates conflict between admission checks on labels, the workload is finished with failure": { + job: *baseJobWrapper.Clone(). + Obj(), + wantJob: *baseJobWrapper.Clone(). + Suspend(true). + Obj(), + workloads: []kueue.Workload{ + *utiltesting.MakeWorkload("a", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota(utiltesting.MakeAdmission("cq").AssignmentPodCount(10).Obj()). + Admitted(true). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check1", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Labels: map[string]string{ + "ac-key": "ac-value1", + }, + }, + }, + }). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check2", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Labels: map[string]string{ + "ac-key": "ac-value2", + }, + }, + }, + }). + Obj(), + }, + wantWorkloads: []kueue.Workload{ + *utiltesting.MakeWorkload("a", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota(utiltesting.MakeAdmission("cq").AssignmentPodCount(10).Obj()). + Admitted(true). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check1", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Labels: map[string]string{ + "ac-key": "ac-value1", + }, + }, + }, + }). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check2", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Labels: map[string]string{ + "ac-key": "ac-value2", + }, + }, + }, + }). + Condition(metav1.Condition{ + Type: kueue.WorkloadFinished, + Status: metav1.ConditionTrue, + Reason: "FailedToStart", + Message: `in admission check "check2": invalid admission check PodSetUpdate: conflict for labels: conflict for key=ac-key, value1=ac-value1, value2=ac-value2`, + }). + Obj(), + }, + }, + "when workload is admitted and PodSetUpdates conflict between admission checks on annotations, the workload is finished with failure": { + job: *baseJobWrapper.Clone(). + Obj(), + wantJob: *baseJobWrapper.Clone(). + Suspend(true). + Obj(), + workloads: []kueue.Workload{ + *utiltesting.MakeWorkload("a", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota(utiltesting.MakeAdmission("cq").AssignmentPodCount(10).Obj()). + Admitted(true). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check1", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Annotations: map[string]string{ + "ac-key": "ac-value1", + }, + }, + }, + }). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check2", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Annotations: map[string]string{ + "ac-key": "ac-value2", + }, + }, + }, + }). + Obj(), + }, + wantWorkloads: []kueue.Workload{ + *utiltesting.MakeWorkload("a", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota(utiltesting.MakeAdmission("cq").AssignmentPodCount(10).Obj()). + Admitted(true). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check1", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Annotations: map[string]string{ + "ac-key": "ac-value1", + }, + }, + }, + }). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check2", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Annotations: map[string]string{ + "ac-key": "ac-value2", + }, + }, + }, + }). + Condition(metav1.Condition{ + Type: kueue.WorkloadFinished, + Status: metav1.ConditionTrue, + Reason: "FailedToStart", + Message: `in admission check "check2": invalid admission check PodSetUpdate: conflict for annotations: conflict for key=ac-key, value1=ac-value1, value2=ac-value2`, + }). + Obj(), + }, + }, + "when workload is admitted and PodSetUpdates conflict between admission checks on nodeSelector, the workload is finished with failure": { + job: *baseJobWrapper.Clone(). + Obj(), + wantJob: *baseJobWrapper.Clone(). + Suspend(true). + Obj(), + workloads: []kueue.Workload{ + *utiltesting.MakeWorkload("a", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota(utiltesting.MakeAdmission("cq").AssignmentPodCount(10).Obj()). + Admitted(true). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check1", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + NodeSelector: map[string]string{ + "ac-key": "ac-value1", + }, + }, + }, + }). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check2", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + NodeSelector: map[string]string{ + "ac-key": "ac-value2", + }, + }, + }, + }). + Obj(), + }, + wantWorkloads: []kueue.Workload{ + *utiltesting.MakeWorkload("a", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota(utiltesting.MakeAdmission("cq").AssignmentPodCount(10).Obj()). + Admitted(true). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check1", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + NodeSelector: map[string]string{ + "ac-key": "ac-value1", + }, + }, + }, + }). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check2", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + NodeSelector: map[string]string{ + "ac-key": "ac-value2", + }, + }, + }, + }). + Condition(metav1.Condition{ + Type: kueue.WorkloadFinished, + Status: metav1.ConditionTrue, + Reason: "FailedToStart", + Message: `in admission check "check2": invalid admission check PodSetUpdate: conflict for nodeSelector: conflict for key=ac-key, value1=ac-value1, value2=ac-value2`, + }). + Obj(), + }, + }, + "when workload is admitted and PodSetUpdates conflict between admission check nodeSelector and current node selector, the workload is finished with failure": { + job: *baseJobWrapper.Clone(). + NodeSelector("provisioning", "spot"). + Obj(), + wantJob: *baseJobWrapper.Clone(). + Suspend(true). + NodeSelector("provisioning", "spot"). + Obj(), + workloads: []kueue.Workload{ + *utiltesting.MakeWorkload("a", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota(utiltesting.MakeAdmission("cq").AssignmentPodCount(10).Obj()). + Admitted(true). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + NodeSelector: map[string]string{ + "provisioning": "on-demand", + }, + }, + }, + }). + Obj(), + }, + wantWorkloads: []kueue.Workload{ + *utiltesting.MakeWorkload("a", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota(utiltesting.MakeAdmission("cq").AssignmentPodCount(10).Obj()). + Admitted(true). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + NodeSelector: map[string]string{ + "provisioning": "on-demand", + }, + }, + }, + }). + Condition(metav1.Condition{ + Type: kueue.WorkloadFinished, + Status: metav1.ConditionTrue, + Reason: "FailedToStart", + Message: `invalid admission check PodSetUpdate: conflict for nodeSelector: conflict for key=provisioning, value1=on-demand, value2=spot`, + }). + Obj(), + }, + }, + "when workload is admitted the PodSetUpdates values matching for key": { + job: *baseJobWrapper.Clone(). + Obj(), + wantJob: *baseJobWrapper.Clone(). + Suspend(false). + PodAnnotation("annotation-key1", "common-value"). + PodAnnotation("annotation-key2", "only-in-check1"). + PodLabel("label-key1", "common-value"). + NodeSelector("node-selector-key1", "common-value"). + NodeSelector("node-selector-key2", "only-in-check2"). + Obj(), + workloads: []kueue.Workload{ + *utiltesting.MakeWorkload("a", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota(utiltesting.MakeAdmission("cq").AssignmentPodCount(10).Obj()). + Admitted(true). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check1", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Labels: map[string]string{ + "label-key1": "common-value", + }, + Annotations: map[string]string{ + "annotation-key1": "common-value", + "annotation-key2": "only-in-check1", + }, + NodeSelector: map[string]string{ + "node-selector-key1": "common-value", + }, + }, + }, + }). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check2", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Labels: map[string]string{ + "label-key1": "common-value", + }, + Annotations: map[string]string{ + "annotation-key1": "common-value", + }, + NodeSelector: map[string]string{ + "node-selector-key1": "common-value", + "node-selector-key2": "only-in-check2", + }, + }, + }, + }). + Obj(), + }, + wantWorkloads: []kueue.Workload{ + *utiltesting.MakeWorkload("a", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota(utiltesting.MakeAdmission("cq").AssignmentPodCount(10).Obj()). + Admitted(true). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check1", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Labels: map[string]string{ + "label-key1": "common-value", + }, + Annotations: map[string]string{ + "annotation-key1": "common-value", + "annotation-key2": "only-in-check1", + }, + NodeSelector: map[string]string{ + "node-selector-key1": "common-value", + }, + }, + }, + }). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: "check2", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Labels: map[string]string{ + "label-key1": "common-value", + }, + Annotations: map[string]string{ + "annotation-key1": "common-value", + }, + NodeSelector: map[string]string{ + "node-selector-key1": "common-value", + "node-selector-key2": "only-in-check2", + }, + }, + }, + }). + Obj(), + }, + }, "suspended job with matching admitted workload is unsuspended": { reconcilerOptions: []jobframework.Option{ jobframework.WithManageJobsWithoutQueueName(true), @@ -510,12 +945,10 @@ func TestReconciler(t *testing.T) { }, }, "the workload is not created when queue name is not set": { - job: *baseJobWrapper. - Clone(). + job: *utiltestingjob.MakeJob("job", "ns"). Suspend(false). Obj(), - wantJob: *baseJobWrapper. - Clone(). + wantJob: *utiltestingjob.MakeJob("job", "ns"). Suspend(false). Obj(), }, @@ -664,7 +1097,6 @@ func TestReconciler(t *testing.T) { }, "when workload is evicted, suspend, reset startTime and restore node affinity": { job: *baseJobWrapper.Clone(). - Queue("foo"). Suspend(false). StartTime(time.Now()). NodeSelector("provisioning", "spot"). @@ -683,7 +1115,6 @@ func TestReconciler(t *testing.T) { Obj(), }, wantJob: *baseJobWrapper.Clone(). - Queue("foo"). Suspend(true). Active(10). Obj(), @@ -701,7 +1132,6 @@ func TestReconciler(t *testing.T) { }, "when workload is evicted but suspended, reset startTime and restore node affinity": { job: *baseJobWrapper.Clone(). - Queue("foo"). Suspend(true). StartTime(time.Now()). NodeSelector("provisioning", "spot"). @@ -719,7 +1149,6 @@ func TestReconciler(t *testing.T) { Obj(), }, wantJob: *baseJobWrapper.Clone(). - Queue("foo"). Suspend(true). Active(10). Obj(), @@ -737,7 +1166,6 @@ func TestReconciler(t *testing.T) { }, "when workload is evicted, suspended and startTime is reset, restore node affinity": { job: *baseJobWrapper.Clone(). - Queue("foo"). Suspend(true). NodeSelector("provisioning", "spot"). Active(10). @@ -754,7 +1182,6 @@ func TestReconciler(t *testing.T) { Obj(), }, wantJob: *baseJobWrapper.Clone(). - Queue("foo"). Suspend(true). Active(10). Obj(), @@ -772,7 +1199,6 @@ func TestReconciler(t *testing.T) { }, "when job completes, workload is marked as finished": { job: *baseJobWrapper.Clone(). - Queue("foo"). Condition(batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue}). Obj(), workloads: []kueue.Workload{ @@ -784,7 +1210,6 @@ func TestReconciler(t *testing.T) { Obj(), }, wantJob: *baseJobWrapper.Clone(). - Queue("foo"). Condition(batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue}). Obj(), wantWorkloads: []kueue.Workload{ @@ -803,7 +1228,7 @@ func TestReconciler(t *testing.T) { }, }, "when the workload is finished, its finalizer is removed": { - job: *baseJobWrapper.Clone().Queue("foo").Obj(), + job: *baseJobWrapper.Clone().Obj(), workloads: []kueue.Workload{ *utiltesting.MakeWorkload("a", "ns").Finalizers(kueue.ResourceInUseFinalizerName). Finalizers(kueue.ResourceInUseFinalizerName). @@ -814,7 +1239,7 @@ func TestReconciler(t *testing.T) { }). Obj(), }, - wantJob: *baseJobWrapper.Clone().Queue("foo").Obj(), + wantJob: *baseJobWrapper.Clone().Obj(), wantWorkloads: []kueue.Workload{ *utiltesting.MakeWorkload("a", "ns"). PodSets(*utiltesting.MakePodSet("main", 10).Request(corev1.ResourceCPU, "1").Obj()). diff --git a/pkg/controller/jobs/jobset/jobset_controller.go b/pkg/controller/jobs/jobset/jobset_controller.go index b81cc58c46..cd242a800d 100644 --- a/pkg/controller/jobs/jobset/jobset_controller.go +++ b/pkg/controller/jobs/jobset/jobset_controller.go @@ -18,10 +18,8 @@ package jobset import ( "context" - "maps" "strings" - "k8s.io/apimachinery/pkg/api/equality" apimeta "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -33,7 +31,6 @@ import ( kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" "sigs.k8s.io/kueue/pkg/controller/jobframework" - utilmaps "sigs.k8s.io/kueue/pkg/util/maps" "sigs.k8s.io/kueue/pkg/util/slices" ) @@ -124,8 +121,11 @@ func (j *JobSet) RunWithPodSetsInfo(podSetInfos []jobframework.PodSetInfo) error // If there are Jobs already created by the JobSet, their node selectors will be updated by the JobSet controller // before unsuspending the individual Jobs. for index := range j.Spec.ReplicatedJobs { - templateSpec := &j.Spec.ReplicatedJobs[index].Template.Spec.Template.Spec - templateSpec.NodeSelector = utilmaps.MergeKeepFirst(podSetInfos[index].NodeSelector, templateSpec.NodeSelector) + template := &j.Spec.ReplicatedJobs[index].Template.Spec.Template + info := podSetInfos[index] + if err := jobframework.Merge(&template.ObjectMeta, &template.Spec, info); err != nil { + return nil + } } return nil } @@ -136,11 +136,9 @@ func (j *JobSet) RestorePodSetsInfo(podSetInfos []jobframework.PodSetInfo) bool } changed := false for index := range j.Spec.ReplicatedJobs { - if equality.Semantic.DeepEqual(j.Spec.ReplicatedJobs[index].Template.Spec.Template.Spec.NodeSelector, podSetInfos[index].NodeSelector) { - continue - } - changed = true - j.Spec.ReplicatedJobs[index].Template.Spec.Template.Spec.NodeSelector = maps.Clone(podSetInfos[index].NodeSelector) + replica := &j.Spec.ReplicatedJobs[index].Template.Spec.Template + info := podSetInfos[index] + changed = jobframework.Restore(&replica.ObjectMeta, &replica.Spec, info) || changed } return changed } diff --git a/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_controller.go b/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_controller.go index d10fbea5be..0871ce3f7a 100644 --- a/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_controller.go +++ b/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_controller.go @@ -17,12 +17,10 @@ limitations under the License. package kubeflowjob import ( - "maps" "strings" kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/utils/ptr" @@ -30,7 +28,6 @@ import ( kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" "sigs.k8s.io/kueue/pkg/controller/jobframework" - utilmaps "sigs.k8s.io/kueue/pkg/util/maps" ) type KubeflowJob struct { @@ -64,8 +61,11 @@ func (j *KubeflowJob) RunWithPodSetsInfo(podSetInfos []jobframework.PodSetInfo) for index := range podSetInfos { replicaType := orderedReplicaTypes[index] info := podSetInfos[index] - replicaSpec := &j.KFJobControl.ReplicaSpecs()[replicaType].Template.Spec - replicaSpec.NodeSelector = utilmaps.MergeKeepFirst(info.NodeSelector, replicaSpec.NodeSelector) + replica := &j.KFJobControl.ReplicaSpecs()[replicaType].Template + if err := jobframework.Merge(&replica.ObjectMeta, &replica.Spec, info); err != nil { + return err + } + } return nil } @@ -75,11 +75,8 @@ func (j *KubeflowJob) RestorePodSetsInfo(podSetInfos []jobframework.PodSetInfo) changed := false for index, info := range podSetInfos { replicaType := orderedReplicaTypes[index] - replicaSpec := &j.KFJobControl.ReplicaSpecs()[replicaType].Template.Spec - if !equality.Semantic.DeepEqual(replicaSpec.NodeSelector, info.NodeSelector) { - changed = true - replicaSpec.NodeSelector = maps.Clone(info.NodeSelector) - } + replica := &j.KFJobControl.ReplicaSpecs()[replicaType].Template + changed = jobframework.Restore(&replica.ObjectMeta, &replica.Spec, info) || changed } return changed } diff --git a/pkg/controller/jobs/mpijob/mpijob_controller.go b/pkg/controller/jobs/mpijob/mpijob_controller.go index 40d87ba005..9d80567bfe 100644 --- a/pkg/controller/jobs/mpijob/mpijob_controller.go +++ b/pkg/controller/jobs/mpijob/mpijob_controller.go @@ -18,12 +18,10 @@ package mpijob import ( "context" - "maps" "strings" kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1" corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" @@ -33,7 +31,6 @@ import ( kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" "sigs.k8s.io/kueue/pkg/controller/jobframework" - utilmaps "sigs.k8s.io/kueue/pkg/util/maps" ) var ( @@ -129,8 +126,10 @@ func (j *MPIJob) RunWithPodSetsInfo(podSetInfos []jobframework.PodSetInfo) error for index := range podSetInfos { replicaType := orderedReplicaTypes[index] info := podSetInfos[index] - replicaSpec := &j.Spec.MPIReplicaSpecs[replicaType].Template.Spec - replicaSpec.NodeSelector = utilmaps.MergeKeepFirst(info.NodeSelector, replicaSpec.NodeSelector) + replica := &j.Spec.MPIReplicaSpecs[replicaType].Template + if err := jobframework.Merge(&replica.ObjectMeta, &replica.Spec, info); err != nil { + return err + } } return nil } @@ -140,11 +139,8 @@ func (j *MPIJob) RestorePodSetsInfo(podSetInfos []jobframework.PodSetInfo) bool changed := false for index, info := range podSetInfos { replicaType := orderedReplicaTypes[index] - replicaSpec := &j.Spec.MPIReplicaSpecs[replicaType].Template.Spec - if !equality.Semantic.DeepEqual(replicaSpec.NodeSelector, info.NodeSelector) { - changed = true - replicaSpec.NodeSelector = maps.Clone(info.NodeSelector) - } + replica := &j.Spec.MPIReplicaSpecs[replicaType].Template + changed = jobframework.Restore(&replica.ObjectMeta, &replica.Spec, info) || changed } return changed } diff --git a/pkg/controller/jobs/pod/pod_controller.go b/pkg/controller/jobs/pod/pod_controller.go index c27521ce2d..7031afefa8 100644 --- a/pkg/controller/jobs/pod/pod_controller.go +++ b/pkg/controller/jobs/pod/pod_controller.go @@ -34,7 +34,6 @@ import ( kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" "sigs.k8s.io/kueue/pkg/constants" "sigs.k8s.io/kueue/pkg/controller/jobframework" - "sigs.k8s.io/kueue/pkg/util/maps" ) const ( @@ -114,11 +113,7 @@ func (p *Pod) RunWithPodSetsInfo(podSetsInfo []jobframework.PodSetInfo) error { if idx != gateNotFound { p.Spec.SchedulingGates = append(p.Spec.SchedulingGates[:idx], p.Spec.SchedulingGates[idx+1:]...) } - - p.Spec.NodeSelector = maps.MergeKeepFirst(podSetsInfo[0].NodeSelector, p.Spec.NodeSelector) - - return nil - + return jobframework.Merge(&p.ObjectMeta, &p.Spec, podSetsInfo[0]) } // RestorePodSetsInfo will restore the original node affinity and podSet counts of the job. diff --git a/pkg/controller/jobs/rayjob/rayjob_controller.go b/pkg/controller/jobs/rayjob/rayjob_controller.go index 5e7551fad1..a1266ffee7 100644 --- a/pkg/controller/jobs/rayjob/rayjob_controller.go +++ b/pkg/controller/jobs/rayjob/rayjob_controller.go @@ -18,11 +18,9 @@ package rayjob import ( "context" - "maps" "strings" rayjobapi "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1" - "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" utilruntime "k8s.io/apimachinery/pkg/util/runtime" @@ -30,7 +28,6 @@ import ( kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" "sigs.k8s.io/kueue/pkg/controller/jobframework" - utilmaps "sigs.k8s.io/kueue/pkg/util/maps" ) var ( @@ -123,13 +120,19 @@ func (j *RayJob) RunWithPodSetsInfo(podSetInfos []jobframework.PodSetInfo) error j.Spec.Suspend = false // head - headPodSpec := &j.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec - headPodSpec.NodeSelector = utilmaps.MergeKeepFirst(podSetInfos[0].NodeSelector, headPodSpec.NodeSelector) + headPod := &j.Spec.RayClusterSpec.HeadGroupSpec.Template + info := podSetInfos[0] + if err := jobframework.Merge(&headPod.ObjectMeta, &headPod.Spec, info); err != nil { + return err + } // workers for index := range j.Spec.RayClusterSpec.WorkerGroupSpecs { - workerPodSpec := &j.Spec.RayClusterSpec.WorkerGroupSpecs[index].Template.Spec - workerPodSpec.NodeSelector = utilmaps.MergeKeepFirst(podSetInfos[index+1].NodeSelector, workerPodSpec.NodeSelector) + workerPod := &j.Spec.RayClusterSpec.WorkerGroupSpecs[index].Template + info := podSetInfos[index+1] + if err := jobframework.Merge(&workerPod.ObjectMeta, &workerPod.Spec, info); err != nil { + return err + } } return nil } @@ -141,19 +144,14 @@ func (j *RayJob) RestorePodSetsInfo(podSetInfos []jobframework.PodSetInfo) bool changed := false // head - headPodSpec := &j.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec - if !equality.Semantic.DeepEqual(headPodSpec.NodeSelector, podSetInfos[0].NodeSelector) { - headPodSpec.NodeSelector = maps.Clone(podSetInfos[0].NodeSelector) - changed = true - } + headPod := &j.Spec.RayClusterSpec.HeadGroupSpec.Template + changed = jobframework.Restore(&headPod.ObjectMeta, &headPod.Spec, podSetInfos[0]) || changed // workers for index := range j.Spec.RayClusterSpec.WorkerGroupSpecs { - workerPodSpec := &j.Spec.RayClusterSpec.WorkerGroupSpecs[index].Template.Spec - if !equality.Semantic.DeepEqual(workerPodSpec.NodeSelector, podSetInfos[index+1].NodeSelector) { - workerPodSpec.NodeSelector = maps.Clone(podSetInfos[index+1].NodeSelector) - changed = true - } + workerPod := &j.Spec.RayClusterSpec.WorkerGroupSpecs[index].Template + info := podSetInfos[index+1] + changed = jobframework.Restore(&workerPod.ObjectMeta, &workerPod.Spec, info) || changed } return changed } diff --git a/pkg/controller/jobs/rayjob/rayjob_controller_test.go b/pkg/controller/jobs/rayjob/rayjob_controller_test.go index 352320c071..9ff7421854 100644 --- a/pkg/controller/jobs/rayjob/rayjob_controller_test.go +++ b/pkg/controller/jobs/rayjob/rayjob_controller_test.go @@ -169,7 +169,7 @@ func TestNodeSelectors(t *testing.T) { }, { NodeSelector: map[string]string{ - "key-wg1": "updated-value-wg1", + "key-wg1": "value-wg1", }, }, { @@ -210,7 +210,7 @@ func TestNodeSelectors(t *testing.T) { Template: corev1.PodTemplateSpec{ Spec: corev1.PodSpec{ NodeSelector: map[string]string{ - "key-wg1": "updated-value-wg1", + "key-wg1": "value-wg1", }, }, }, diff --git a/pkg/util/maps/maps.go b/pkg/util/maps/maps.go index cebcde4ed2..3d0e89e2e1 100644 --- a/pkg/util/maps/maps.go +++ b/pkg/util/maps/maps.go @@ -19,6 +19,7 @@ limitations under the License. package maps import ( + "fmt" "maps" ) @@ -61,6 +62,16 @@ func MergeKeepFirst[K comparable, V any, S ~map[K]V](a, b S) S { return Merge(a, b, func(v, _ V) V { return v }) } +// HaveConflict checks if a and b have the same key, but different value +func HaveConflict[K comparable, V comparable, S ~map[K]V](a, b S) error { + for k, av := range a { + if bv, found := b[k]; found && av != bv { + return fmt.Errorf("conflict for key=%v, value1=%v, value2=%v", k, av, bv) + } + } + return nil +} + // Contains returns true if a contains all the keys in b with the same value func Contains[K, V comparable, A ~map[K]V, B ~map[K]V](a A, b B) bool { for k, bv := range b { diff --git a/pkg/util/testing/wrappers.go b/pkg/util/testing/wrappers.go index 3afc2cc570..0df8e79350 100644 --- a/pkg/util/testing/wrappers.go +++ b/pkg/util/testing/wrappers.go @@ -176,6 +176,11 @@ func (w *WorkloadWrapper) Condition(condition metav1.Condition) *WorkloadWrapper return w } +func (w *WorkloadWrapper) AdmissionCheck(ac kueue.AdmissionCheckState) *WorkloadWrapper { + w.Status.AdmissionChecks = append(w.Status.AdmissionChecks, ac) + return w +} + func (w *WorkloadWrapper) SetOrReplaceCondition(condition metav1.Condition) *WorkloadWrapper { existingCondition := apimeta.FindStatusCondition(w.Status.Conditions, condition.Type) if existingCondition != nil { diff --git a/pkg/util/testingjobs/job/wrappers.go b/pkg/util/testingjobs/job/wrappers.go index 9a3832e383..d277486fdf 100644 --- a/pkg/util/testingjobs/job/wrappers.go +++ b/pkg/util/testingjobs/job/wrappers.go @@ -107,26 +107,26 @@ func (j *JobWrapper) PriorityClass(pc string) *JobWrapper { // WorkloadPriorityClass updates job workloadpriorityclass. func (j *JobWrapper) WorkloadPriorityClass(wpc string) *JobWrapper { - if j.Labels == nil { - j.Labels = make(map[string]string) - } - j.Labels[constants.WorkloadPriorityClassLabel] = wpc - return j + return j.Label(constants.WorkloadPriorityClassLabel, wpc) } // Queue updates the queue name of the job func (j *JobWrapper) Queue(queue string) *JobWrapper { + return j.Label(constants.QueueLabel, queue) +} + +// Annotation sets the annotation key and value +func (j *JobWrapper) Label(key, value string) *JobWrapper { if j.Labels == nil { j.Labels = make(map[string]string) } - j.Labels[constants.QueueLabel] = queue + j.Labels[key] = value return j } // QueueNameAnnotation updates the queue name of the job by annotation (deprecated) func (j *JobWrapper) QueueNameAnnotation(queue string) *JobWrapper { - j.Annotations[constants.QueueAnnotation] = queue - return j + return j.SetAnnotation(constants.QueueAnnotation, queue) } // ParentWorkload sets the parent-workload annotation @@ -152,6 +152,24 @@ func (j *JobWrapper) NodeSelector(k, v string) *JobWrapper { return j } +// PodAnnotation sets annotation at the pod template level +func (j *JobWrapper) PodAnnotation(k, v string) *JobWrapper { + if j.Spec.Template.Annotations == nil { + j.Spec.Template.Annotations = make(map[string]string) + } + j.Spec.Template.Annotations[k] = v + return j +} + +// PodLabel sets label at the pod template level +func (j *JobWrapper) PodLabel(k, v string) *JobWrapper { + if j.Spec.Template.Labels == nil { + j.Spec.Template.Labels = make(map[string]string) + } + j.Spec.Template.Labels[k] = v + return j +} + // Request adds a resource request to the default container. func (j *JobWrapper) Request(r corev1.ResourceName, v string) *JobWrapper { j.Spec.Template.Spec.Containers[0].Resources.Requests[r] = resource.MustParse(v) diff --git a/pkg/util/testingjobs/mpijob/wrappers_mpijob.go b/pkg/util/testingjobs/mpijob/wrappers_mpijob.go index 179d34e583..0f1e99d3df 100644 --- a/pkg/util/testingjobs/mpijob/wrappers_mpijob.go +++ b/pkg/util/testingjobs/mpijob/wrappers_mpijob.go @@ -138,3 +138,21 @@ func (j *MPIJobWrapper) UID(uid string) *MPIJobWrapper { j.ObjectMeta.UID = types.UID(uid) return j } + +// PodAnnotation sets annotation at the pod template level +func (j *MPIJobWrapper) PodAnnotation(replicaType kubeflow.MPIReplicaType, k, v string) *MPIJobWrapper { + if j.Spec.MPIReplicaSpecs[replicaType].Template.Annotations == nil { + j.Spec.MPIReplicaSpecs[replicaType].Template.Annotations = make(map[string]string) + } + j.Spec.MPIReplicaSpecs[replicaType].Template.Annotations[k] = v + return j +} + +// PodLabel sets label at the pod template level +func (j *MPIJobWrapper) PodLabel(replicaType kubeflow.MPIReplicaType, k, v string) *MPIJobWrapper { + if j.Spec.MPIReplicaSpecs[replicaType].Template.Labels == nil { + j.Spec.MPIReplicaSpecs[replicaType].Template.Labels = make(map[string]string) + } + j.Spec.MPIReplicaSpecs[replicaType].Template.Labels[k] = v + return j +} diff --git a/pkg/util/testingjobs/pytorchjob/wrappers_pytorchjob.go b/pkg/util/testingjobs/pytorchjob/wrappers_pytorchjob.go index 270d42b53a..ab2cfefe23 100644 --- a/pkg/util/testingjobs/pytorchjob/wrappers_pytorchjob.go +++ b/pkg/util/testingjobs/pytorchjob/wrappers_pytorchjob.go @@ -137,3 +137,21 @@ func (j *PyTorchJobWrapper) UID(uid string) *PyTorchJobWrapper { j.ObjectMeta.UID = types.UID(uid) return j } + +// PodAnnotation sets annotation at the pod template level +func (j *PyTorchJobWrapper) PodAnnotation(replicaType kftraining.ReplicaType, k, v string) *PyTorchJobWrapper { + if j.Spec.PyTorchReplicaSpecs[replicaType].Template.Annotations == nil { + j.Spec.PyTorchReplicaSpecs[replicaType].Template.Annotations = make(map[string]string) + } + j.Spec.PyTorchReplicaSpecs[replicaType].Template.Annotations[k] = v + return j +} + +// PodLabel sets label at the pod template level +func (j *PyTorchJobWrapper) PodLabel(replicaType kftraining.ReplicaType, k, v string) *PyTorchJobWrapper { + if j.Spec.PyTorchReplicaSpecs[replicaType].Template.Labels == nil { + j.Spec.PyTorchReplicaSpecs[replicaType].Template.Labels = make(map[string]string) + } + j.Spec.PyTorchReplicaSpecs[replicaType].Template.Labels[k] = v + return j +} diff --git a/test/integration/controller/jobs/job/job_controller_test.go b/test/integration/controller/jobs/job/job_controller_test.go index 6e6b146478..d0f6e95cdd 100644 --- a/test/integration/controller/jobs/job/job_controller_test.go +++ b/test/integration/controller/jobs/job/job_controller_test.go @@ -438,6 +438,231 @@ var _ = ginkgo.Describe("Job controller", ginkgo.Ordered, ginkgo.ContinueOnFailu util.ExpectWorkloadsToBePending(ctx, k8sClient, wl) }) }) + + ginkgo.When("the queue has admission checks", func() { + var ( + clusterQueueAc *kueue.ClusterQueue + localQueue *kueue.LocalQueue + testFlavor *kueue.ResourceFlavor + jobLookupKey *types.NamespacedName + wlLookupKey *types.NamespacedName + admissionCheck *kueue.AdmissionCheck + ) + + ginkgo.BeforeEach(func() { + admissionCheck = testing.MakeAdmissionCheck("check").Obj() + gomega.Expect(k8sClient.Create(ctx, admissionCheck)).To(gomega.Succeed()) + util.SetAdmissionCheckActive(ctx, k8sClient, admissionCheck, metav1.ConditionTrue) + clusterQueueAc = testing.MakeClusterQueue("prod-cq-with-checks"). + ResourceGroup( + *testing.MakeFlavorQuotas("test-flavor").Resource(corev1.ResourceCPU, "5").Obj(), + ).AdmissionChecks("check").Obj() + gomega.Expect(k8sClient.Create(ctx, clusterQueueAc)).Should(gomega.Succeed()) + localQueue = testing.MakeLocalQueue("queue", ns.Name).ClusterQueue(clusterQueueAc.Name).Obj() + gomega.Expect(k8sClient.Create(ctx, localQueue)).To(gomega.Succeed()) + testFlavor = testing.MakeResourceFlavor("test-flavor").Label(instanceKey, "test-flavor").Obj() + gomega.Expect(k8sClient.Create(ctx, testFlavor)).Should(gomega.Succeed()) + + jobLookupKey = &types.NamespacedName{Name: jobName, Namespace: ns.Name} + wlLookupKey = &types.NamespacedName{Name: workloadjob.GetWorkloadNameForJob(jobName), Namespace: ns.Name} + }) + + ginkgo.AfterEach(func() { + gomega.Expect(util.DeleteAdmissionCheck(ctx, k8sClient, admissionCheck)).To(gomega.Succeed()) + util.ExpectResourceFlavorToBeDeleted(ctx, k8sClient, testFlavor, true) + gomega.Expect(util.DeleteNamespace(ctx, k8sClient, ns)).To(gomega.Succeed()) + util.ExpectClusterQueueToBeDeleted(ctx, k8sClient, clusterQueueAc, true) + }) + + ginkgo.It("labels and annotations should be propagated from admission check to job", func() { + createdJob := &batchv1.Job{} + createdWorkload := &kueue.Workload{} + + ginkgo.By("creating the job with pod labels & annotations", func() { + job := testingjob.MakeJob(jobName, ns.Name). + Queue(localQueue.Name). + Request(corev1.ResourceCPU, "5"). + PodAnnotation("old-ann-key", "old-ann-value"). + PodLabel("old-label-key", "old-label-value"). + Obj() + gomega.Expect(k8sClient.Create(ctx, job)).Should(gomega.Succeed()) + }) + + ginkgo.By("fetch the job and verify it is suspended as the checks are not ready", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(true))) + }) + + ginkgo.By("fetch the created workload", func() { + gomega.Eventually(func() error { + return k8sClient.Get(ctx, *wlLookupKey, createdWorkload) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("add labels & annotations to the workload admission check in PodSetUpdates", func() { + gomega.Eventually(func() error { + var newWL kueue.Workload + gomega.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(createdWorkload), &newWL)).To(gomega.Succeed()) + workload.SetAdmissionCheckState(&newWL.Status.AdmissionChecks, kueue.AdmissionCheckState{ + Name: "check", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Labels: map[string]string{ + "label1": "label-value1", + }, + Annotations: map[string]string{ + "ann1": "ann-value1", + }, + NodeSelector: map[string]string{ + "selector1": "selector-value1", + }, + Tolerations: []corev1.Toleration{ + { + Key: "selector1", + Value: "selector-value1", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, + }, + }) + return k8sClient.Status().Update(ctx, &newWL) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("admit the workload", func() { + admission := testing.MakeAdmission(clusterQueueAc.Name). + Assignment(corev1.ResourceCPU, "test-flavor", "1"). + AssignmentPodCount(createdWorkload.Spec.PodSets[0].Count). + Obj() + gomega.Expect(k8sClient.Get(ctx, *wlLookupKey, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(util.SetQuotaReservation(ctx, k8sClient, createdWorkload, admission)).Should(gomega.Succeed()) + util.SyncAdmittedConditionForWorkloads(ctx, k8sClient, createdWorkload) + }) + + ginkgo.By("await for the job to be admitted", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(false))) + }) + + ginkgo.By("verify the PodSetUpdates are propagated to the running job", func() { + gomega.Expect(createdJob.Spec.Template.Annotations).Should(gomega.HaveKeyWithValue("ann1", "ann-value1")) + gomega.Expect(createdJob.Spec.Template.Annotations).Should(gomega.HaveKeyWithValue("old-ann-key", "old-ann-value")) + gomega.Expect(createdJob.Spec.Template.Labels).Should(gomega.HaveKeyWithValue("label1", "label-value1")) + gomega.Expect(createdJob.Spec.Template.Labels).Should(gomega.HaveKeyWithValue("old-label-key", "old-label-value")) + gomega.Expect(createdJob.Spec.Template.Spec.NodeSelector).Should(gomega.HaveKeyWithValue(instanceKey, "test-flavor")) + gomega.Expect(createdJob.Spec.Template.Spec.NodeSelector).Should(gomega.HaveKeyWithValue("selector1", "selector-value1")) + gomega.Expect(createdJob.Spec.Template.Spec.Tolerations).Should(gomega.BeComparableTo( + []corev1.Toleration{ + { + Key: "selector1", + Value: "selector-value1", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + )) + }) + + ginkgo.By("delete the localQueue to prevent readmission", func() { + gomega.Expect(util.DeleteLocalQueue(ctx, k8sClient, localQueue)).Should(gomega.Succeed()) + }) + + ginkgo.By("clear the workload's admission to stop the job", func() { + gomega.Expect(k8sClient.Get(ctx, *wlLookupKey, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(util.SetQuotaReservation(ctx, k8sClient, createdWorkload, nil)).Should(gomega.Succeed()) + util.SyncAdmittedConditionForWorkloads(ctx, k8sClient, createdWorkload) + }) + + ginkgo.By("await for the job to be suspended", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(true))) + }) + + ginkgo.By("verify the PodSetUpdates are restored", func() { + gomega.Expect(createdJob.Spec.Template.Annotations).ShouldNot(gomega.HaveKey("ann1")) + gomega.Expect(createdJob.Spec.Template.Annotations).Should(gomega.HaveKeyWithValue("old-ann-key", "old-ann-value")) + gomega.Expect(createdJob.Spec.Template.Labels).ShouldNot(gomega.HaveKey("label1")) + gomega.Expect(createdJob.Spec.Template.Labels).Should(gomega.HaveKeyWithValue("old-label-key", "old-label-value")) + gomega.Expect(createdJob.Spec.Template.Spec.NodeSelector).ShouldNot(gomega.HaveKey(instanceKey)) + gomega.Expect(createdJob.Spec.Template.Spec.NodeSelector).ShouldNot(gomega.HaveKey("selector1")) + }) + }) + + ginkgo.It("should not admit workload if there is a conflict in labels", func() { + createdJob := &batchv1.Job{} + createdWorkload := &kueue.Workload{} + + ginkgo.By("creating the job with default priority", func() { + job := testingjob.MakeJob(jobName, ns.Name). + Queue(localQueue.Name). + Request(corev1.ResourceCPU, "5"). + PodLabel("label-key", "old-label-value"). + Obj() + gomega.Expect(k8sClient.Create(ctx, job)).Should(gomega.Succeed()) + }) + + ginkgo.By("fetch the created job & workload", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(true))) + gomega.Eventually(func() error { + return k8sClient.Get(ctx, *wlLookupKey, createdWorkload) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("add a conflicting label to the admission check in PodSetUpdates", func() { + gomega.Eventually(func() error { + var newWL kueue.Workload + gomega.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(createdWorkload), &newWL)).To(gomega.Succeed()) + workload.SetAdmissionCheckState(&newWL.Status.AdmissionChecks, kueue.AdmissionCheckState{ + Name: "check", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Labels: map[string]string{ + "label-key": "new-label-value", + }, + }, + }, + }) + return k8sClient.Status().Update(ctx, &newWL) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("attempt to admit the workload", func() { + admission := testing.MakeAdmission(clusterQueueAc.Name). + Assignment(corev1.ResourceCPU, "test-flavor", "1"). + AssignmentPodCount(createdWorkload.Spec.PodSets[0].Count). + Obj() + gomega.Expect(k8sClient.Get(ctx, *wlLookupKey, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(util.SetQuotaReservation(ctx, k8sClient, createdWorkload, admission)).Should(gomega.Succeed()) + util.SyncAdmittedConditionForWorkloads(ctx, k8sClient, createdWorkload) + }) + + ginkgo.By("verify the job is not started", func() { + gomega.Consistently(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Suspend + }, util.ConsistentDuration, util.Interval).Should(gomega.Equal(ptr.To(true))) + }) + + ginkgo.By("verify the job has the old label value", func() { + gomega.Expect(createdJob.Spec.Template.Labels).Should(gomega.HaveKeyWithValue("label-key", "old-label-value")) + }) + }) + }) }) var _ = ginkgo.Describe("Job controller when waitForPodsReady enabled", ginkgo.Ordered, ginkgo.ContinueOnFailure, func() { diff --git a/test/integration/controller/jobs/jobset/jobset_controller_test.go b/test/integration/controller/jobs/jobset/jobset_controller_test.go index f5134da358..c95f9c75ba 100644 --- a/test/integration/controller/jobs/jobset/jobset_controller_test.go +++ b/test/integration/controller/jobs/jobset/jobset_controller_test.go @@ -30,6 +30,7 @@ import ( "k8s.io/client-go/kubernetes/scheme" "k8s.io/utils/ptr" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" jobsetapi "sigs.k8s.io/jobset/api/jobset/v1alpha2" kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" @@ -38,6 +39,7 @@ import ( workloadjobset "sigs.k8s.io/kueue/pkg/controller/jobs/jobset" "sigs.k8s.io/kueue/pkg/util/testing" testingjobset "sigs.k8s.io/kueue/pkg/util/testingjobs/jobset" + "sigs.k8s.io/kueue/pkg/workload" "sigs.k8s.io/kueue/test/integration/framework" "sigs.k8s.io/kueue/test/util" ) @@ -284,6 +286,220 @@ var _ = ginkgo.Describe("JobSet controller", ginkgo.Ordered, ginkgo.ContinueOnFa return apimeta.IsStatusConditionTrue(createdWorkload.Status.Conditions, kueue.WorkloadFinished) }, util.Timeout, util.Interval).Should(gomega.BeTrue()) }) + + ginkgo.When("the queue has admission checks", func() { + var ( + clusterQueueAc *kueue.ClusterQueue + localQueue *kueue.LocalQueue + testFlavor *kueue.ResourceFlavor + jobLookupKey *types.NamespacedName + wlLookupKey *types.NamespacedName + admissionCheck *kueue.AdmissionCheck + ) + + ginkgo.BeforeEach(func() { + admissionCheck = testing.MakeAdmissionCheck("check").Obj() + gomega.Expect(k8sClient.Create(ctx, admissionCheck)).To(gomega.Succeed()) + util.SetAdmissionCheckActive(ctx, k8sClient, admissionCheck, metav1.ConditionTrue) + clusterQueueAc = testing.MakeClusterQueue("prod-cq-with-checks"). + ResourceGroup( + *testing.MakeFlavorQuotas("test-flavor").Resource(corev1.ResourceCPU, "5").Obj(), + ).AdmissionChecks("check").Obj() + gomega.Expect(k8sClient.Create(ctx, clusterQueueAc)).Should(gomega.Succeed()) + localQueue = testing.MakeLocalQueue("queue", ns.Name).ClusterQueue(clusterQueueAc.Name).Obj() + gomega.Expect(k8sClient.Create(ctx, localQueue)).To(gomega.Succeed()) + testFlavor = testing.MakeResourceFlavor("test-flavor").Label(instanceKey, "test-flavor").Obj() + gomega.Expect(k8sClient.Create(ctx, testFlavor)).Should(gomega.Succeed()) + + jobLookupKey = &types.NamespacedName{Name: jobSetName, Namespace: ns.Name} + wlLookupKey = &types.NamespacedName{Name: workloadjobset.GetWorkloadNameForJobSet(jobSetName), Namespace: ns.Name} + }) + + ginkgo.AfterEach(func() { + gomega.Expect(util.DeleteAdmissionCheck(ctx, k8sClient, admissionCheck)).To(gomega.Succeed()) + util.ExpectResourceFlavorToBeDeleted(ctx, k8sClient, testFlavor, true) + gomega.Expect(util.DeleteNamespace(ctx, k8sClient, ns)).To(gomega.Succeed()) + util.ExpectClusterQueueToBeDeleted(ctx, k8sClient, clusterQueueAc, true) + }) + + ginkgo.It("labels and annotations should be propagated from admission check to job", func() { + createdJob := &jobsetapi.JobSet{} + createdWorkload := &kueue.Workload{} + + ginkgo.By("creating the job", func() { + job := testingjobset.MakeJobSet(jobSetName, ns.Name).ReplicatedJobs( + testingjobset.ReplicatedJobRequirements{ + Name: "replicated-job-1", + Replicas: 1, + Parallelism: 1, + Completions: 1, + }, testingjobset.ReplicatedJobRequirements{ + Name: "replicated-job-2", + Replicas: 3, + Parallelism: 1, + Completions: 1, + }, + ). + Queue("queue"). + Request("replicated-job-1", corev1.ResourceCPU, "1"). + Request("replicated-job-2", corev1.ResourceCPU, "1"). + Obj() + job.Spec.ReplicatedJobs[0].Template.Spec.Template.Annotations = map[string]string{ + "old-ann-key": "old-ann-value", + } + job.Spec.ReplicatedJobs[0].Template.Spec.Template.Labels = map[string]string{ + "old-label-key": "old-label-value", + } + gomega.Expect(k8sClient.Create(ctx, job)).Should(gomega.Succeed()) + }) + + ginkgo.By("fetch the job and verify it is suspended as the checks are not ready", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(true))) + }) + + ginkgo.By("checking the workload is created", func() { + gomega.Eventually(func() error { + return k8sClient.Get(ctx, *wlLookupKey, createdWorkload) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("add labels & annotations to the admission check in PodSetUpdates", func() { + gomega.Eventually(func() error { + var newWL kueue.Workload + gomega.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(createdWorkload), &newWL)).To(gomega.Succeed()) + workload.SetAdmissionCheckState(&newWL.Status.AdmissionChecks, kueue.AdmissionCheckState{ + Name: "check", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "replicated-job-1", + Annotations: map[string]string{ + "ann1": "ann-value1", + }, + Labels: map[string]string{ + "label1": "label-value1", + }, + NodeSelector: map[string]string{ + "selector1": "selector-value1", + }, + Tolerations: []corev1.Toleration{ + { + Key: "selector1", + Value: "selector-value1", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, + { + Name: "replicated-job-2", + Annotations: map[string]string{ + "ann1": "ann-value2", + }, + Labels: map[string]string{ + "label1": "label-value2", + }, + NodeSelector: map[string]string{ + "selector1": "selector-value2", + }, + }, + }, + }) + return k8sClient.Status().Update(ctx, &newWL) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("admit the workload", func() { + admission := testing.MakeAdmission(clusterQueueAc.Name). + PodSets( + kueue.PodSetAssignment{ + Name: createdWorkload.Spec.PodSets[0].Name, + Flavors: map[corev1.ResourceName]kueue.ResourceFlavorReference{ + corev1.ResourceCPU: "test-flavor", + }, + }, kueue.PodSetAssignment{ + Name: createdWorkload.Spec.PodSets[1].Name, + Flavors: map[corev1.ResourceName]kueue.ResourceFlavorReference{ + corev1.ResourceCPU: "test-flavor", + }, + }, + ). + Obj() + gomega.Expect(k8sClient.Get(ctx, *wlLookupKey, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(util.SetQuotaReservation(ctx, k8sClient, createdWorkload, admission)).Should(gomega.Succeed()) + util.SyncAdmittedConditionForWorkloads(ctx, k8sClient, createdWorkload) + }) + + ginkgo.By("await for the job to be admitted", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(false))) + }) + + ginkgo.By("verify the PodSetUpdates are propagated to the running job, for replicated-job-1", func() { + replica1 := createdJob.Spec.ReplicatedJobs[0].Template.Spec.Template + gomega.Expect(replica1.Annotations).Should(gomega.HaveKeyWithValue("ann1", "ann-value1")) + gomega.Expect(replica1.Annotations).Should(gomega.HaveKeyWithValue("old-ann-key", "old-ann-value")) + gomega.Expect(replica1.Labels).Should(gomega.HaveKeyWithValue("label1", "label-value1")) + gomega.Expect(replica1.Labels).Should(gomega.HaveKeyWithValue("old-label-key", "old-label-value")) + gomega.Expect(replica1.Spec.NodeSelector).Should(gomega.HaveKeyWithValue("selector1", "selector-value1")) + gomega.Expect(replica1.Spec.Tolerations).Should(gomega.BeComparableTo( + []corev1.Toleration{ + { + Key: "selector1", + Value: "selector-value1", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + )) + }) + + ginkgo.By("verify the PodSetUpdates are propagated to the running job, for replicated-job-2", func() { + replica2 := createdJob.Spec.ReplicatedJobs[1].Template.Spec.Template + gomega.Expect(replica2.Spec.NodeSelector).Should(gomega.HaveKeyWithValue("selector1", "selector-value2")) + gomega.Expect(replica2.Annotations).Should(gomega.HaveKeyWithValue("ann1", "ann-value2")) + gomega.Expect(replica2.Labels).Should(gomega.HaveKeyWithValue("label1", "label-value2")) + }) + + ginkgo.By("delete the localQueue to prevent readmission", func() { + gomega.Expect(util.DeleteLocalQueue(ctx, k8sClient, localQueue)).Should(gomega.Succeed()) + }) + + ginkgo.By("clear the workload's admission to stop the job", func() { + gomega.Expect(k8sClient.Get(ctx, *wlLookupKey, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(util.SetQuotaReservation(ctx, k8sClient, createdWorkload, nil)).Should(gomega.Succeed()) + util.SyncAdmittedConditionForWorkloads(ctx, k8sClient, createdWorkload) + }) + + ginkgo.By("await for the job to be suspended", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(true))) + }) + + ginkgo.By("verify the PodSetUpdates are restored for replicated-job-1", func() { + replica1 := createdJob.Spec.ReplicatedJobs[0].Template.Spec.Template + gomega.Expect(replica1.Annotations).ShouldNot(gomega.HaveKey("ann1")) + gomega.Expect(replica1.Annotations).Should(gomega.HaveKeyWithValue("old-ann-key", "old-ann-value")) + gomega.Expect(replica1.Labels).ShouldNot(gomega.HaveKey("label1")) + gomega.Expect(replica1.Labels).Should(gomega.HaveKeyWithValue("old-label-key", "old-label-value")) + gomega.Expect(replica1.Spec.NodeSelector).ShouldNot(gomega.HaveKey("selector1")) + }) + + ginkgo.By("verify the PodSetUpdates are restored for replicated-job-2", func() { + replica2 := createdJob.Spec.ReplicatedJobs[1].Template.Spec.Template + gomega.Expect(replica2.Spec.NodeSelector).ShouldNot(gomega.HaveKey("selector1")) + gomega.Expect(replica2.Annotations).ShouldNot(gomega.HaveKey("ann1")) + gomega.Expect(replica2.Labels).ShouldNot(gomega.HaveKey("label1")) + }) + }) + }) }) var _ = ginkgo.Describe("JobSet controller for workloads when only jobs with queue are managed", ginkgo.Ordered, ginkgo.ContinueOnFailure, func() { diff --git a/test/integration/controller/jobs/mpijob/mpijob_controller_test.go b/test/integration/controller/jobs/mpijob/mpijob_controller_test.go index 0ce2e6df0c..ed8dc2b35d 100644 --- a/test/integration/controller/jobs/mpijob/mpijob_controller_test.go +++ b/test/integration/controller/jobs/mpijob/mpijob_controller_test.go @@ -31,6 +31,7 @@ import ( "k8s.io/client-go/kubernetes/scheme" "k8s.io/utils/ptr" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" "sigs.k8s.io/kueue/pkg/controller/constants" @@ -39,6 +40,7 @@ import ( "sigs.k8s.io/kueue/pkg/util/testing" testingjob "sigs.k8s.io/kueue/pkg/util/testingjobs/job" testingmpijob "sigs.k8s.io/kueue/pkg/util/testingjobs/mpijob" + "sigs.k8s.io/kueue/pkg/workload" "sigs.k8s.io/kueue/test/integration/framework" "sigs.k8s.io/kueue/test/util" ) @@ -285,6 +287,209 @@ var _ = ginkgo.Describe("Job controller", ginkgo.Ordered, ginkgo.ContinueOnFailu return apimeta.IsStatusConditionTrue(createdWorkload.Status.Conditions, kueue.WorkloadFinished) }, util.Timeout, util.Interval).Should(gomega.BeTrue()) }) + + ginkgo.When("the queue has admission checks", func() { + var ( + clusterQueueAc *kueue.ClusterQueue + localQueue *kueue.LocalQueue + testFlavor *kueue.ResourceFlavor + jobLookupKey *types.NamespacedName + wlLookupKey *types.NamespacedName + admissionCheck *kueue.AdmissionCheck + ) + + ginkgo.BeforeEach(func() { + admissionCheck = testing.MakeAdmissionCheck("check").Obj() + gomega.Expect(k8sClient.Create(ctx, admissionCheck)).To(gomega.Succeed()) + util.SetAdmissionCheckActive(ctx, k8sClient, admissionCheck, metav1.ConditionTrue) + clusterQueueAc = testing.MakeClusterQueue("prod-cq-with-checks"). + ResourceGroup( + *testing.MakeFlavorQuotas("test-flavor").Resource(corev1.ResourceCPU, "5").Obj(), + ).AdmissionChecks("check").Obj() + gomega.Expect(k8sClient.Create(ctx, clusterQueueAc)).Should(gomega.Succeed()) + localQueue = testing.MakeLocalQueue("queue", ns.Name).ClusterQueue(clusterQueueAc.Name).Obj() + gomega.Expect(k8sClient.Create(ctx, localQueue)).To(gomega.Succeed()) + testFlavor = testing.MakeResourceFlavor("test-flavor").Label(instanceKey, "test-flavor").Obj() + gomega.Expect(k8sClient.Create(ctx, testFlavor)).Should(gomega.Succeed()) + + jobLookupKey = &types.NamespacedName{Name: jobName, Namespace: ns.Name} + wlLookupKey = &types.NamespacedName{Name: workloadmpijob.GetWorkloadNameForMPIJob(jobName), Namespace: ns.Name} + }) + + ginkgo.AfterEach(func() { + gomega.Expect(util.DeleteAdmissionCheck(ctx, k8sClient, admissionCheck)).To(gomega.Succeed()) + util.ExpectResourceFlavorToBeDeleted(ctx, k8sClient, testFlavor, true) + gomega.Expect(util.DeleteNamespace(ctx, k8sClient, ns)).To(gomega.Succeed()) + util.ExpectClusterQueueToBeDeleted(ctx, k8sClient, clusterQueueAc, true) + }) + + ginkgo.It("labels and annotations should be propagated from admission check to job", func() { + createdJob := &kubeflow.MPIJob{} + createdWorkload := &kueue.Workload{} + + ginkgo.By("creating the job with pod labels & annotations", func() { + job := testingmpijob.MakeMPIJob(jobName, ns.Name). + Queue(localQueue.Name). + PodAnnotation(kubeflow.MPIReplicaTypeWorker, "old-ann-key", "old-ann-value"). + PodLabel(kubeflow.MPIReplicaTypeWorker, "old-label-key", "old-label-value"). + Obj() + gomega.Expect(k8sClient.Create(ctx, job)).Should(gomega.Succeed()) + }) + + ginkgo.By("fetch the job and verify it is suspended as the checks are not ready", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(true))) + }) + + ginkgo.By("fetch the created workload", func() { + gomega.Eventually(func() error { + return k8sClient.Get(ctx, *wlLookupKey, createdWorkload) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("add labels & annotations to the admission check", func() { + gomega.Eventually(func() error { + var newWL kueue.Workload + gomega.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(createdWorkload), &newWL)).To(gomega.Succeed()) + workload.SetAdmissionCheckState(&newWL.Status.AdmissionChecks, kueue.AdmissionCheckState{ + Name: "check", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "launcher", + Annotations: map[string]string{ + "ann1": "ann-value-for-launcher", + }, + Labels: map[string]string{ + "label1": "label-value-for-launcher", + }, + NodeSelector: map[string]string{ + "selector1": "selector-value-for-launcher", + }, + }, + { + Name: "worker", + Annotations: map[string]string{ + "ann1": "ann-value1", + }, + Labels: map[string]string{ + "label1": "label-value1", + }, + NodeSelector: map[string]string{ + "selector1": "selector-value1", + }, + Tolerations: []corev1.Toleration{ + { + Key: "selector1", + Value: "selector-value1", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, + }, + }) + return k8sClient.Status().Update(ctx, &newWL) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("admit the workload", func() { + admission := testing.MakeAdmission(clusterQueueAc.Name). + PodSets( + kueue.PodSetAssignment{ + Name: "launcher", + Flavors: map[corev1.ResourceName]kueue.ResourceFlavorReference{ + corev1.ResourceCPU: "test-flavor", + }, + Count: ptr.To(createdWorkload.Spec.PodSets[0].Count), + }, + kueue.PodSetAssignment{ + Name: "worker", + Flavors: map[corev1.ResourceName]kueue.ResourceFlavorReference{ + corev1.ResourceCPU: "test-flavor", + }, + Count: ptr.To(createdWorkload.Spec.PodSets[1].Count), + }, + ). + Obj() + gomega.Expect(k8sClient.Get(ctx, *wlLookupKey, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(util.SetQuotaReservation(ctx, k8sClient, createdWorkload, admission)).Should(gomega.Succeed()) + util.SyncAdmittedConditionForWorkloads(ctx, k8sClient, createdWorkload) + }) + + ginkgo.By("await for the job to start", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(false))) + }) + + ginkgo.By("verify the PodSetUpdates are propagated to the running job, for worker", func() { + worker := createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker].Template + gomega.Expect(worker.Annotations).Should(gomega.HaveKeyWithValue("ann1", "ann-value1")) + gomega.Expect(worker.Annotations).Should(gomega.HaveKeyWithValue("old-ann-key", "old-ann-value")) + gomega.Expect(worker.Labels).Should(gomega.HaveKeyWithValue("label1", "label-value1")) + gomega.Expect(worker.Labels).Should(gomega.HaveKeyWithValue("old-label-key", "old-label-value")) + gomega.Expect(worker.Spec.NodeSelector).Should(gomega.HaveKeyWithValue(instanceKey, "test-flavor")) + gomega.Expect(worker.Spec.NodeSelector).Should(gomega.HaveKeyWithValue("selector1", "selector-value1")) + gomega.Expect(worker.Spec.Tolerations).Should(gomega.BeComparableTo( + []corev1.Toleration{ + { + Key: "selector1", + Value: "selector-value1", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + )) + }) + + ginkgo.By("verify the PodSetUpdates are propagated to the running job, for launcher", func() { + launcher := createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template + gomega.Expect(launcher.Annotations).Should(gomega.HaveKeyWithValue("ann1", "ann-value-for-launcher")) + gomega.Expect(launcher.Labels).Should(gomega.HaveKeyWithValue("label1", "label-value-for-launcher")) + gomega.Expect(launcher.Spec.NodeSelector).Should(gomega.HaveKeyWithValue(instanceKey, "test-flavor")) + gomega.Expect(launcher.Spec.NodeSelector).Should(gomega.HaveKeyWithValue("selector1", "selector-value-for-launcher")) + }) + + ginkgo.By("delete the localQueue to prevent readmission", func() { + gomega.Expect(util.DeleteLocalQueue(ctx, k8sClient, localQueue)).Should(gomega.Succeed()) + }) + + ginkgo.By("clear the workload's admission to stop the job", func() { + gomega.Expect(k8sClient.Get(ctx, *wlLookupKey, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(util.SetQuotaReservation(ctx, k8sClient, createdWorkload, nil)).Should(gomega.Succeed()) + util.SyncAdmittedConditionForWorkloads(ctx, k8sClient, createdWorkload) + }) + + ginkgo.By("await for the job to be suspended", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(true))) + }) + + ginkgo.By("verify the PodSetUpdates are restored for worker", func() { + worker := createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker].Template + gomega.Expect(worker.Annotations).ShouldNot(gomega.HaveKey("ann1")) + gomega.Expect(worker.Annotations).Should(gomega.HaveKeyWithValue("old-ann-key", "old-ann-value")) + gomega.Expect(worker.Labels).ShouldNot(gomega.HaveKey("label1")) + gomega.Expect(worker.Labels).Should(gomega.HaveKeyWithValue("old-label-key", "old-label-value")) + gomega.Expect(worker.Spec.NodeSelector).ShouldNot(gomega.HaveKey(instanceKey)) + gomega.Expect(worker.Spec.NodeSelector).ShouldNot(gomega.HaveKey("selector1")) + }) + + ginkgo.By("verify the PodSetUpdates are restored for launcher", func() { + launcher := createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template + gomega.Expect(launcher.Annotations).ShouldNot(gomega.HaveKey("ann1")) + gomega.Expect(launcher.Labels).ShouldNot(gomega.HaveKey("label1")) + gomega.Expect(launcher.Spec.NodeSelector).ShouldNot(gomega.HaveKey(instanceKey)) + gomega.Expect(launcher.Spec.NodeSelector).ShouldNot(gomega.HaveKey("selector1")) + }) + }) + }) }) var _ = ginkgo.Describe("Job controller for workloads when only jobs with queue are managed", ginkgo.Ordered, ginkgo.ContinueOnFailure, func() { diff --git a/test/integration/controller/jobs/pod/pod_controller_test.go b/test/integration/controller/jobs/pod/pod_controller_test.go index 7ae26e5fcd..32df544af1 100644 --- a/test/integration/controller/jobs/pod/pod_controller_test.go +++ b/test/integration/controller/jobs/pod/pod_controller_test.go @@ -27,6 +27,7 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" "sigs.k8s.io/kueue/pkg/controller/jobframework" @@ -313,6 +314,126 @@ var _ = ginkgo.Describe("Pod controller", ginkgo.Ordered, ginkgo.ContinueOnFailu gomega.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(testing.BeNotFoundError()) }) }) + + ginkgo.When("the queue has admission checks", func() { + var ( + clusterQueueAc *kueue.ClusterQueue + localQueue *kueue.LocalQueue + testFlavor *kueue.ResourceFlavor + podLookupKey *types.NamespacedName + wlLookupKey *types.NamespacedName + admissionCheck *kueue.AdmissionCheck + ) + + ginkgo.BeforeEach(func() { + admissionCheck = testing.MakeAdmissionCheck("check").Obj() + gomega.Expect(k8sClient.Create(ctx, admissionCheck)).To(gomega.Succeed()) + util.SetAdmissionCheckActive(ctx, k8sClient, admissionCheck, metav1.ConditionTrue) + clusterQueueAc = testing.MakeClusterQueue("prod-cq-with-checks"). + ResourceGroup( + *testing.MakeFlavorQuotas("test-flavor").Resource(corev1.ResourceCPU, "5").Obj(), + ).AdmissionChecks("check").Obj() + gomega.Expect(k8sClient.Create(ctx, clusterQueueAc)).Should(gomega.Succeed()) + localQueue = testing.MakeLocalQueue("queue", ns.Name).ClusterQueue(clusterQueueAc.Name).Obj() + gomega.Expect(k8sClient.Create(ctx, localQueue)).To(gomega.Succeed()) + testFlavor = testing.MakeResourceFlavor("test-flavor").Label(instanceKey, "test-flavor").Obj() + gomega.Expect(k8sClient.Create(ctx, testFlavor)).Should(gomega.Succeed()) + + podLookupKey = &types.NamespacedName{Name: podName, Namespace: ns.Name} + wlLookupKey = &types.NamespacedName{Name: podcontroller.GetWorkloadNameForPod(podName), Namespace: ns.Name} + }) + + ginkgo.AfterEach(func() { + gomega.Expect(util.DeleteLocalQueue(ctx, k8sClient, localQueue)).Should(gomega.Succeed()) + gomega.Expect(util.DeleteNamespace(ctx, k8sClient, ns)).To(gomega.Succeed()) + gomega.Expect(util.DeleteAdmissionCheck(ctx, k8sClient, admissionCheck)).To(gomega.Succeed()) + util.ExpectClusterQueueToBeDeleted(ctx, k8sClient, clusterQueueAc, true) + util.ExpectResourceFlavorToBeDeleted(ctx, k8sClient, testFlavor, true) + }) + + ginkgo.It("labels and annotations should be propagated from admission check to job", func() { + createdPod := &corev1.Pod{} + createdWorkload := &kueue.Workload{} + + ginkgo.By("creating the job with pod labels & annotations", func() { + job := testingpod.MakePod(podName, ns.Name). + Queue(localQueue.Name). + Request(corev1.ResourceCPU, "5"). + Annotation("old-ann-key", "old-ann-value"). + Label("old-label-key", "old-label-value"). + Obj() + gomega.Expect(k8sClient.Create(ctx, job)).Should(gomega.Succeed()) + }) + + ginkgo.By("fetch the job and verify it is suspended as the checks are not ready", func() { + gomega.Eventually(func(g gomega.Gomega) []corev1.PodSchedulingGate { + g.Expect(k8sClient.Get(ctx, *podLookupKey, createdPod)).To(gomega.Succeed()) + return createdPod.Spec.SchedulingGates + }, util.Timeout, util.Interval).Should( + gomega.ContainElement(corev1.PodSchedulingGate{Name: "kueue.x-k8s.io/admission"}), + ) + }) + + ginkgo.By("fetch the created workload", func() { + gomega.Eventually(func() error { + return k8sClient.Get(ctx, *wlLookupKey, createdWorkload) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("add labels & annotations to the admission check", func() { + gomega.Eventually(func() error { + var newWL kueue.Workload + gomega.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(createdWorkload), &newWL)).To(gomega.Succeed()) + workload.SetAdmissionCheckState(&newWL.Status.AdmissionChecks, kueue.AdmissionCheckState{ + Name: "check", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "main", + Labels: map[string]string{ + "label1": "label-value1", + }, + Annotations: map[string]string{ + "ann1": "ann-value1", + }, + NodeSelector: map[string]string{ + "selector1": "selector-value1", + }, + }, + }, + }) + return k8sClient.Status().Update(ctx, &newWL) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("admit the workload", func() { + admission := testing.MakeAdmission(clusterQueueAc.Name). + Assignment(corev1.ResourceCPU, "test-flavor", "1"). + AssignmentPodCount(createdWorkload.Spec.PodSets[0].Count). + Obj() + gomega.Expect(k8sClient.Get(ctx, *wlLookupKey, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(util.SetQuotaReservation(ctx, k8sClient, createdWorkload, admission)).Should(gomega.Succeed()) + util.SyncAdmittedConditionForWorkloads(ctx, k8sClient, createdWorkload) + }) + + ginkgo.By("await for the job to be admitted", func() { + gomega.Eventually(func(g gomega.Gomega) []corev1.PodSchedulingGate { + g.Expect(k8sClient.Get(ctx, *podLookupKey, createdPod)). + To(gomega.Succeed()) + return createdPod.Spec.SchedulingGates + }, util.Timeout, util.Interval).Should(gomega.BeEmpty()) + }) + + ginkgo.By("verify the PodSetUpdates are propagated to the running job", func() { + gomega.Expect(createdPod.Annotations).Should(gomega.HaveKeyWithValue("ann1", "ann-value1")) + gomega.Expect(createdPod.Annotations).Should(gomega.HaveKeyWithValue("old-ann-key", "old-ann-value")) + gomega.Expect(createdPod.Labels).Should(gomega.HaveKeyWithValue("label1", "label-value1")) + gomega.Expect(createdPod.Labels).Should(gomega.HaveKeyWithValue("old-label-key", "old-label-value")) + gomega.Expect(createdPod.Spec.NodeSelector).Should(gomega.HaveKeyWithValue(instanceKey, "test-flavor")) + gomega.Expect(createdPod.Spec.NodeSelector).Should(gomega.HaveKeyWithValue("selector1", "selector-value1")) + }) + }) + }) }) }) diff --git a/test/integration/controller/jobs/pytorchjob/pytorchjob_controller_test.go b/test/integration/controller/jobs/pytorchjob/pytorchjob_controller_test.go index b1f97120cd..c6c960ffc0 100644 --- a/test/integration/controller/jobs/pytorchjob/pytorchjob_controller_test.go +++ b/test/integration/controller/jobs/pytorchjob/pytorchjob_controller_test.go @@ -31,6 +31,7 @@ import ( "k8s.io/client-go/kubernetes/scheme" "k8s.io/utils/ptr" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" "sigs.k8s.io/kueue/pkg/controller/constants" @@ -38,6 +39,7 @@ import ( workloadpytorchjob "sigs.k8s.io/kueue/pkg/controller/jobs/kubeflow/jobs/pytorchjob" "sigs.k8s.io/kueue/pkg/util/testing" testingpytorchjob "sigs.k8s.io/kueue/pkg/util/testingjobs/pytorchjob" + "sigs.k8s.io/kueue/pkg/workload" "sigs.k8s.io/kueue/test/integration/framework" "sigs.k8s.io/kueue/test/util" ) @@ -335,6 +337,184 @@ var _ = ginkgo.Describe("Job controller for workloads when only jobs with queue return k8sClient.Get(ctx, wlLookupKey, createdWorkload) }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) + + ginkgo.When("the queue has admission checks", func() { + var ( + clusterQueueAc *kueue.ClusterQueue + localQueue *kueue.LocalQueue + testFlavor *kueue.ResourceFlavor + jobLookupKey *types.NamespacedName + wlLookupKey *types.NamespacedName + admissionCheck *kueue.AdmissionCheck + ) + + ginkgo.BeforeEach(func() { + admissionCheck = testing.MakeAdmissionCheck("check").Obj() + gomega.Expect(k8sClient.Create(ctx, admissionCheck)).To(gomega.Succeed()) + util.SetAdmissionCheckActive(ctx, k8sClient, admissionCheck, metav1.ConditionTrue) + clusterQueueAc = testing.MakeClusterQueue("prod-cq-with-checks"). + ResourceGroup( + *testing.MakeFlavorQuotas("test-flavor").Resource(corev1.ResourceCPU, "5").Obj(), + ).AdmissionChecks("check").Obj() + gomega.Expect(k8sClient.Create(ctx, clusterQueueAc)).Should(gomega.Succeed()) + localQueue = testing.MakeLocalQueue("queue", ns.Name).ClusterQueue(clusterQueueAc.Name).Obj() + gomega.Expect(k8sClient.Create(ctx, localQueue)).To(gomega.Succeed()) + testFlavor = testing.MakeResourceFlavor("test-flavor").Label(instanceKey, "test-flavor").Obj() + gomega.Expect(k8sClient.Create(ctx, testFlavor)).Should(gomega.Succeed()) + + jobLookupKey = &types.NamespacedName{Name: jobName, Namespace: ns.Name} + wlLookupKey = &types.NamespacedName{Name: workloadpytorchjob.GetWorkloadNameForPyTorchJob(jobName), Namespace: ns.Name} + }) + + ginkgo.AfterEach(func() { + gomega.Expect(util.DeleteAdmissionCheck(ctx, k8sClient, admissionCheck)).To(gomega.Succeed()) + util.ExpectResourceFlavorToBeDeleted(ctx, k8sClient, testFlavor, true) + gomega.Expect(util.DeleteNamespace(ctx, k8sClient, ns)).To(gomega.Succeed()) + util.ExpectClusterQueueToBeDeleted(ctx, k8sClient, clusterQueueAc, true) + }) + + ginkgo.It("labels and annotations should be propagated from admission check to job", func() { + createdJob := &kftraining.PyTorchJob{} + createdWorkload := &kueue.Workload{} + + ginkgo.By("creating the job with pod labels & annotations", func() { + job := testingpytorchjob.MakePyTorchJob(jobName, ns.Name). + Queue(localQueue.Name). + PodAnnotation(kftraining.TFJobReplicaTypeWorker, "old-ann-key", "old-ann-value"). + PodLabel(kftraining.TFJobReplicaTypeWorker, "old-label-key", "old-label-value"). + Obj() + gomega.Expect(k8sClient.Create(ctx, job)).Should(gomega.Succeed()) + }) + + ginkgo.By("fetch the job and verify it is suspended as the checks are not ready", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(true))) + }) + + ginkgo.By("fetch the created workload", func() { + gomega.Eventually(func() error { + return k8sClient.Get(ctx, *wlLookupKey, createdWorkload) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("add labels & annotations to the admission check", func() { + gomega.Eventually(func() error { + var newWL kueue.Workload + gomega.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(createdWorkload), &newWL)).To(gomega.Succeed()) + workload.SetAdmissionCheckState(&newWL.Status.AdmissionChecks, kueue.AdmissionCheckState{ + Name: "check", + State: kueue.CheckStateReady, + PodSetUpdates: []kueue.PodSetUpdate{ + { + Name: "master", + }, + { + Name: "worker", + Annotations: map[string]string{ + "ann1": "ann-value1", + }, + Labels: map[string]string{ + "label1": "label-value1", + }, + NodeSelector: map[string]string{ + "selector1": "selector-value1", + }, + Tolerations: []corev1.Toleration{ + { + Key: "selector1", + Value: "selector-value1", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, + }, + }) + return k8sClient.Status().Update(ctx, &newWL) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("admit the workload", func() { + admission := testing.MakeAdmission(clusterQueueAc.Name). + PodSets( + kueue.PodSetAssignment{ + Name: "master", + Flavors: map[corev1.ResourceName]kueue.ResourceFlavorReference{ + corev1.ResourceCPU: "test-flavor", + }, + Count: ptr.To(createdWorkload.Spec.PodSets[0].Count), + }, + kueue.PodSetAssignment{ + Name: "worker", + Flavors: map[corev1.ResourceName]kueue.ResourceFlavorReference{ + corev1.ResourceCPU: "test-flavor", + }, + Count: ptr.To(createdWorkload.Spec.PodSets[1].Count), + }, + ). + Obj() + gomega.Expect(k8sClient.Get(ctx, *wlLookupKey, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(util.SetQuotaReservation(ctx, k8sClient, createdWorkload, admission)).Should(gomega.Succeed()) + util.SyncAdmittedConditionForWorkloads(ctx, k8sClient, createdWorkload) + }) + + ginkgo.By("await for the job to start", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(false))) + }) + + ginkgo.By("verify the PodSetUpdates are propagated to the running job", func() { + worker := createdJob.Spec.PyTorchReplicaSpecs[kftraining.PyTorchJobReplicaTypeWorker].Template + gomega.Expect(worker.Annotations).Should(gomega.HaveKeyWithValue("ann1", "ann-value1")) + gomega.Expect(worker.Annotations).Should(gomega.HaveKeyWithValue("old-ann-key", "old-ann-value")) + gomega.Expect(worker.Labels).Should(gomega.HaveKeyWithValue("label1", "label-value1")) + gomega.Expect(worker.Labels).Should(gomega.HaveKeyWithValue("old-label-key", "old-label-value")) + gomega.Expect(worker.Spec.NodeSelector).Should(gomega.HaveKeyWithValue(instanceKey, "test-flavor")) + gomega.Expect(worker.Spec.NodeSelector).Should(gomega.HaveKeyWithValue("selector1", "selector-value1")) + gomega.Expect(worker.Spec.Tolerations).Should(gomega.BeComparableTo( + []corev1.Toleration{ + { + Key: "selector1", + Value: "selector-value1", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + )) + }) + + ginkgo.By("delete the localQueue to prevent readmission", func() { + gomega.Expect(util.DeleteLocalQueue(ctx, k8sClient, localQueue)).Should(gomega.Succeed()) + }) + + ginkgo.By("clear the workload's admission to stop the job", func() { + gomega.Expect(k8sClient.Get(ctx, *wlLookupKey, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(util.SetQuotaReservation(ctx, k8sClient, createdWorkload, nil)).Should(gomega.Succeed()) + util.SyncAdmittedConditionForWorkloads(ctx, k8sClient, createdWorkload) + }) + + ginkgo.By("await for the job to be suspended", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(true))) + }) + + ginkgo.By("verify the PodSetUpdates are restored", func() { + worker := createdJob.Spec.PyTorchReplicaSpecs[kftraining.PyTorchJobReplicaTypeWorker].Template + gomega.Expect(worker.Annotations).ShouldNot(gomega.HaveKey("ann1")) + gomega.Expect(worker.Annotations).Should(gomega.HaveKeyWithValue("old-ann-key", "old-ann-value")) + gomega.Expect(worker.Labels).ShouldNot(gomega.HaveKey("label1")) + gomega.Expect(worker.Labels).Should(gomega.HaveKeyWithValue("old-label-key", "old-label-value")) + gomega.Expect(worker.Spec.NodeSelector).ShouldNot(gomega.HaveKey(instanceKey)) + gomega.Expect(worker.Spec.NodeSelector).ShouldNot(gomega.HaveKey("selector1")) + }) + }) + }) }) var _ = ginkgo.Describe("Job controller when waitForPodsReady enabled", ginkgo.Ordered, ginkgo.ContinueOnFailure, func() {