Skip to content

Commit

Permalink
[jobframework] Validate original node selectors
Browse files Browse the repository at this point in the history
  • Loading branch information
trasc committed Mar 31, 2023
1 parent 3bc3889 commit 68bc891
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 1 deletion.
13 changes: 13 additions & 0 deletions pkg/controller/jobframework/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ var (
labelsPath = field.NewPath("metadata", "labels")
parentWorkloadKeyPath = annotationsPath.Key(ParentWorkloadAnnotation)
queueNameLabelPath = labelsPath.Key(QueueLabel)

originalNodeSelectorsWorkloadKeyPath = annotationsPath.Key(OriginalNodeSelectorsAnnotation)
)

func ValidateCreateForQueueName(job GenericJob) field.ErrorList {
Expand Down Expand Up @@ -71,3 +73,14 @@ func ValidateUpdateForParentWorkload(oldJob, newJob GenericJob) field.ErrorList
}
return allErrs
}

func ValidateUpdateForOriginalNodeSelectors(oldJob, newJob GenericJob) field.ErrorList {
var allErrs field.ErrorList
if !oldJob.IsSuspended() && !newJob.IsSuspended() {
if errList := apivalidation.ValidateImmutableField(oldJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation],
newJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation], originalNodeSelectorsWorkloadKeyPath); len(errList) > 0 {
allErrs = append(allErrs, field.Forbidden(originalNodeSelectorsWorkloadKeyPath, "this annotation is immutable while the job is not suspended"))
}
}
return allErrs
}
1 change: 1 addition & 0 deletions pkg/controller/jobs/job/job_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func (w *JobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.
func validateUpdate(oldJob, newJob jobframework.GenericJob) field.ErrorList {
allErrs := validateCreate(newJob)
allErrs = append(allErrs, jobframework.ValidateUpdateForParentWorkload(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForQueueName(oldJob, newJob)...)
return allErrs
}
Expand Down
16 changes: 16 additions & 0 deletions pkg/controller/jobs/job/job_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ var (
parentWorkloadKeyPath = annotationsPath.Key(jobframework.ParentWorkloadAnnotation)
queueNameLabelPath = labelsPath.Key(jobframework.QueueLabel)
queueNameAnnotationsPath = annotationsPath.Key(jobframework.QueueAnnotation)

originalNodeSelectorsKeyPath = annotationsPath.Key(jobframework.OriginalNodeSelectorsAnnotation)
)

func TestValidateCreate(t *testing.T) {
Expand Down Expand Up @@ -156,6 +158,20 @@ func TestValidateUpdate(t *testing.T) {
field.Forbidden(parentWorkloadKeyPath, "this annotation is immutable"),
},
},
{
name: "original node selectors can be set while suspended",
oldJob: testingutil.MakeJob("job", "default").Suspend(true).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("new selectors").Obj(),
wantErr: nil,
},
{
name: "immutable original node selectors while not suspended",
oldJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("old selectors").Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("new selectors").Obj(),
wantErr: field.ErrorList{
field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not suspended"),
},
},
}

for _, tc := range testcases {
Expand Down
6 changes: 5 additions & 1 deletion pkg/controller/jobs/mpijob/mpijob_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,14 @@ func validateCreate(job jobframework.GenericJob) field.ErrorList {
// ValidateUpdate implements webhook.CustomValidator so a webhook will be registered for the type
func (w *MPIJobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) error {
oldJob := oldObj.(*kubeflow.MPIJob)
oldGenJob := &MPIJob{*oldJob}
newJob := newObj.(*kubeflow.MPIJob)
newGenJob := &MPIJob{*newJob}
log := ctrl.LoggerFrom(ctx).WithName("job-webhook")
log.Info("Validating update", "job", klog.KObj(newJob))
return jobframework.ValidateUpdateForQueueName(&MPIJob{*oldJob}, &MPIJob{*newJob}).ToAggregate()
allErrs := jobframework.ValidateUpdateForQueueName(oldGenJob, newGenJob)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldGenJob, newGenJob)...)
return allErrs.ToAggregate()
}

// ValidateDelete implements webhook.CustomValidator so a webhook will be registered for the type
Expand Down
57 changes: 57 additions & 0 deletions pkg/controller/jobs/mpijob/mpijob_webhook_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package mpijob

import (
"context"
"testing"

kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1"

testingutil "sigs.k8s.io/kueue/pkg/util/testingjobs/mpijob"
)

func TestUpdate(t *testing.T) {
testcases := map[string]struct {
oldJob *kubeflow.MPIJob
newJob *kubeflow.MPIJob
wantError bool
}{
"original node selectors can be set": {
oldJob: testingutil.MakeMPIJob("job", "ns").Suspend(true).Obj(),
newJob: testingutil.MakeMPIJob("job", "ns").Suspend(false).OriginalNodeSelectorsAnnotation("new selectors").Obj(),
wantError: false,
},
"original node selectors immutable while not suspended": {
oldJob: testingutil.MakeMPIJob("job", "ns").Suspend(false).OriginalNodeSelectorsAnnotation("old selectors").Obj(),
newJob: testingutil.MakeMPIJob("job", "ns").Suspend(false).OriginalNodeSelectorsAnnotation("new selectors").Obj(),
wantError: true,
},
}

for name, tc := range testcases {
t.Run(name, func(t *testing.T) {
wh := &MPIJobWebhook{}
result := wh.ValidateUpdate(context.Background(), tc.oldJob, tc.newJob)
if result != nil && !tc.wantError {
t.Errorf("Unexpected error: %s", result)
} else if result == nil && tc.wantError {
t.Errorf("Expecting error")
}
})
}
}
5 changes: 5 additions & 0 deletions pkg/util/testingjobs/job/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ func (j *JobWrapper) ParentWorkload(parentWorkload string) *JobWrapper {
return j
}

func (j *JobWrapper) OriginalNodeSelectorsAnnotation(content string) *JobWrapper {
j.Annotations[jobframework.OriginalNodeSelectorsAnnotation] = content
return j
}

// Toleration adds a toleration to the job.
func (j *JobWrapper) Toleration(t corev1.Toleration) *JobWrapper {
j.Spec.Template.Spec.Tolerations = append(j.Spec.Template.Spec.Tolerations, t)
Expand Down
12 changes: 12 additions & 0 deletions pkg/util/testingjobs/mpijob/wrappers_mpijob.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,15 @@ func (j *MPIJobWrapper) Parallelism(p int32) *MPIJobWrapper {
j.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker].Replicas = pointer.Int32(p)
return j
}

// OriginalNodeSelectorsAnnotation updates the original node selectors annotation
func (j *MPIJobWrapper) OriginalNodeSelectorsAnnotation(content string) *MPIJobWrapper {
j.Annotations[jobframework.OriginalNodeSelectorsAnnotation] = content
return j
}

// Suspend updates the suspend status of the job
func (j *MPIJobWrapper) Suspend(s bool) *MPIJobWrapper {
j.Spec.RunPolicy.Suspend = &s
return j
}

0 comments on commit 68bc891

Please sign in to comment.