diff --git a/pkg/controller/jobframework/podsetinfo.go b/pkg/controller/jobframework/podsetinfo.go index 2755b6c9f1..3caa39136a 100644 --- a/pkg/controller/jobframework/podsetinfo.go +++ b/pkg/controller/jobframework/podsetinfo.go @@ -25,13 +25,12 @@ import ( ) type PodSetInfo struct { - Name string - Count int32 - NodeSelectorOverwrite map[string]string - Annotations map[string]string - Labels map[string]string - NodeSelector map[string]string - Tolerations []corev1.Toleration + Name string + Count int32 + Annotations map[string]string + Labels map[string]string + NodeSelector map[string]string + Tolerations []corev1.Toleration } func (podSetInfo *PodSetInfo) Merge(o PodSetInfo) error { @@ -41,18 +40,12 @@ func (podSetInfo *PodSetInfo) Merge(o PodSetInfo) error { if err := utilmaps.HaveConflict(podSetInfo.Labels, o.Labels); err != nil { return BadPodSetsUpdateError("labels", err) } - newNodeSelector := make(map[string]string) - for k, v := range o.NodeSelector { - if _, exists := podSetInfo.NodeSelectorOverwrite[k]; !exists { - newNodeSelector[k] = v - } - } - if err := utilmaps.HaveConflict(podSetInfo.NodeSelector, newNodeSelector); err != nil { + if err := utilmaps.HaveConflict(podSetInfo.NodeSelector, o.NodeSelector); err != nil { return BadPodSetsUpdateError("nodeSelector", err) } podSetInfo.Annotations = utilmaps.MergeKeepFirst(podSetInfo.Annotations, o.Annotations) podSetInfo.Labels = utilmaps.MergeKeepFirst(podSetInfo.Labels, o.Labels) - podSetInfo.NodeSelector = utilmaps.MergeKeepFirst(podSetInfo.NodeSelector, newNodeSelector) + podSetInfo.NodeSelector = utilmaps.MergeKeepFirst(podSetInfo.NodeSelector, o.NodeSelector) podSetInfo.Tolerations = append(podSetInfo.Tolerations, o.Tolerations...) return nil } @@ -71,7 +64,6 @@ func Merge(meta *metav1.ObjectMeta, spec *v1.PodSpec, info PodSetInfo) error { meta.Annotations = info.Annotations meta.Labels = info.Labels spec.NodeSelector = info.NodeSelector - spec.NodeSelector = utilmaps.MergeKeepFirst(info.NodeSelectorOverwrite, spec.NodeSelector) spec.Tolerations = info.Tolerations return nil } diff --git a/pkg/controller/jobframework/reconciler.go b/pkg/controller/jobframework/reconciler.go index 9368a90802..2d6fcc050b 100644 --- a/pkg/controller/jobframework/reconciler.go +++ b/pkg/controller/jobframework/reconciler.go @@ -657,12 +657,11 @@ func (r *JobReconciler) getPodSetsInfoFromStatus(ctx context.Context, w *kueue.W for i, podSetFlavor := range w.Status.Admission.PodSetAssignments { processedFlvs := sets.NewString() podSetInfo := PodSetInfo{ - Name: podSetFlavor.Name, - NodeSelector: make(map[string]string), - NodeSelectorOverwrite: make(map[string]string), - Count: ptr.Deref(podSetFlavor.Count, w.Spec.PodSets[i].Count), - Labels: make(map[string]string), - Annotations: make(map[string]string), + Name: podSetFlavor.Name, + NodeSelector: make(map[string]string), + Count: ptr.Deref(podSetFlavor.Count, w.Spec.PodSets[i].Count), + Labels: make(map[string]string), + Annotations: make(map[string]string), } for _, flvRef := range podSetFlavor.Flavors { flvName := string(flvRef) @@ -675,7 +674,7 @@ func (r *JobReconciler) getPodSetsInfoFromStatus(ctx context.Context, w *kueue.W return nil, err } for k, v := range flv.Spec.NodeLabels { - podSetInfo.NodeSelectorOverwrite[k] = v + podSetInfo.NodeSelector[k] = v } processedFlvs.Insert(flvName) } diff --git a/pkg/controller/jobs/job/job_controller_test.go b/pkg/controller/jobs/job/job_controller_test.go index 6066d65994..240a41b3ac 100644 --- a/pkg/controller/jobs/job/job_controller_test.go +++ b/pkg/controller/jobs/job/job_controller_test.go @@ -205,14 +205,15 @@ func TestPodSetsInfo(t *testing.T) { Obj()), runInfo: []jobframework.PodSetInfo{ { - NodeSelectorOverwrite: map[string]string{ + NodeSelector: map[string]string{ "orig-key": "new-val", }, }, }, + wantRunError: jobframework.ErrInvalidPodSetUpdate, wantUnsuspended: utiltestingjob.MakeJob("job", "ns"). Parallelism(1). - NodeSelector("orig-key", "new-val"). + NodeSelector("orig-key", "orig-val"). Suspend(false). Obj(), restoreInfo: []jobframework.PodSetInfo{ diff --git a/pkg/controller/jobs/rayjob/rayjob_controller_test.go b/pkg/controller/jobs/rayjob/rayjob_controller_test.go index 4667a87a80..9ff7421854 100644 --- a/pkg/controller/jobs/rayjob/rayjob_controller_test.go +++ b/pkg/controller/jobs/rayjob/rayjob_controller_test.go @@ -168,8 +168,8 @@ func TestNodeSelectors(t *testing.T) { }, }, { - NodeSelectorOverwrite: map[string]string{ - "key-wg1": "updated-value-wg1", + NodeSelector: map[string]string{ + "key-wg1": "value-wg1", }, }, { @@ -210,7 +210,7 @@ func TestNodeSelectors(t *testing.T) { Template: corev1.PodTemplateSpec{ Spec: corev1.PodSpec{ NodeSelector: map[string]string{ - "key-wg1": "updated-value-wg1", + "key-wg1": "value-wg1", }, }, },