Skip to content

Commit

Permalink
Record original node selectors (#660)
Browse files Browse the repository at this point in the history
* [jobframework] Improve node selectors tracking

Save and try to restore the original node selectors in/from a
job annotation "kueue.x-k8s.io/original-selectors".

* [test/integration/jobs] Check selectors restoration on workload deletion

* [jobframework] Validate original node selectors

* [jobframework] Record podSet name in originalNodeSelectors annotation
  • Loading branch information
trasc authored Apr 5, 2023
1 parent 5227f60 commit c25bbc0
Show file tree
Hide file tree
Showing 14 changed files with 412 additions and 25 deletions.
6 changes: 6 additions & 0 deletions pkg/controller/jobframework/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,10 @@ const (
// ignores this Job from admission, and takes control of its suspension
// status based on the admission status of the parent workload.
ParentWorkloadAnnotation = "kueue.x-k8s.io/parent-workload"

// OriginalNodeSelectorsAnnotation is the annotation in which the original
// node selectors are recorded upon a workload admission. This information
// will be used to restore them when the job is suspended.
// The content is a json marshaled slice of selectors.
OriginalNodeSelectorsAnnotation = "kueue.x-k8s.io/original-node-selectors"
)
4 changes: 2 additions & 2 deletions pkg/controller/jobframework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ type GenericJob interface {
// If true, status is modified, if not, status is as it was.
ResetStatus() bool
// RunWithNodeAffinity will inject the node affinity extracting from workload to job and unsuspend the job.
RunWithNodeAffinity(nodeSelectors []map[string]string)
RunWithNodeAffinity(nodeSelectors []PodSetNodeSelector)
// RestoreNodeAffinity will restore the original node affinity of job.
RestoreNodeAffinity(podSets []kueue.PodSet)
RestoreNodeAffinity(nodeSelectors []PodSetNodeSelector)
// Finished means whether the job is completed/failed or not,
// condition represents the workload finished condition.
Finished() (condition metav1.Condition, finished bool)
Expand Down
98 changes: 90 additions & 8 deletions pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package jobframework

import (
"context"
"encoding/json"
"fmt"

corev1 "k8s.io/api/core/v1"
Expand All @@ -34,6 +35,10 @@ import (
"sigs.k8s.io/kueue/pkg/workload"
)

var (
errNodeSelectorsNotFound = fmt.Errorf("annotation %s not found", OriginalNodeSelectorsAnnotation)
)

// JobReconciler reconciles a GenericJob object
type JobReconciler struct {
client client.Client
Expand Down Expand Up @@ -300,7 +305,13 @@ func (r *JobReconciler) equivalentToWorkload(job GenericJob, object client.Objec

// startJob will unsuspend the job, and also inject the node affinity.
func (r *JobReconciler) startJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload) error {
nodeSelectors, err := r.getNodeSelectors(ctx, wl)
//get the original selectors and store them in the job object
originalSelectors := r.getNodeSelectorsFromPodSets(wl)
if err := setNodeSelectorsInAnnotation(object, originalSelectors); err != nil {
return fmt.Errorf("startJob, record original node selectors: %w", err)
}

nodeSelectors, err := r.getNodeSelectorsFromAdmission(ctx, wl)
if err != nil {
return err
}
Expand All @@ -318,6 +329,7 @@ func (r *JobReconciler) startJob(ctx context.Context, job GenericJob, object cli

// stopJob will suspend the job, and also restore node affinity, reset job status if needed.
func (r *JobReconciler) stopJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload, eventMsg string) error {
log := ctrl.LoggerFrom(ctx)
// Suspend the job at first then we're able to update the scheduling directives.
job.Suspend()

Expand All @@ -333,8 +345,12 @@ func (r *JobReconciler) stopJob(ctx context.Context, job GenericJob, object clie
}
}

if wl != nil {
job.RestoreNodeAffinity(wl.Spec.PodSets)
log.V(3).Info("restore node selectors from annotation")
selectors, err := getNodeSelectorsFromObjectAnnotation(object)
if err != nil {
log.V(3).Error(err, "Unable to get original node selectors")
} else {
job.RestoreNodeAffinity(selectors)
return r.client.Update(ctx, object)
}

Expand Down Expand Up @@ -369,17 +385,25 @@ func (r *JobReconciler) constructWorkload(ctx context.Context, job GenericJob, o
return wl, nil
}

// getNodeSelectors will extract node selectors from admitted workloads.
func (r *JobReconciler) getNodeSelectors(ctx context.Context, w *kueue.Workload) ([]map[string]string, error) {
type PodSetNodeSelector struct {
Name string `json:"name"`
NodeSelector map[string]string `json:"nodeSelector"`
}

// getNodeSelectorsFromAdmission will extract node selectors from admitted workloads.
func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *kueue.Workload) ([]PodSetNodeSelector, error) {
if len(w.Status.Admission.PodSetAssignments) == 0 {
return nil, nil
}

nodeSelectors := make([]map[string]string, len(w.Status.Admission.PodSetAssignments))
nodeSelectors := make([]PodSetNodeSelector, len(w.Status.Admission.PodSetAssignments))

for i, podSetFlavor := range w.Status.Admission.PodSetAssignments {
processedFlvs := sets.NewString()
nodeSelector := map[string]string{}
nodeSelector := PodSetNodeSelector{
Name: podSetFlavor.Name,
NodeSelector: make(map[string]string),
}
for _, flvRef := range podSetFlavor.Flavors {
flvName := string(flvRef)
if processedFlvs.Has(flvName) {
Expand All @@ -391,7 +415,7 @@ func (r *JobReconciler) getNodeSelectors(ctx context.Context, w *kueue.Workload)
return nil, err
}
for k, v := range flv.Spec.NodeLabels {
nodeSelector[k] = v
nodeSelector.NodeSelector[k] = v
}
processedFlvs.Insert(flvName)
}
Expand All @@ -401,6 +425,23 @@ func (r *JobReconciler) getNodeSelectors(ctx context.Context, w *kueue.Workload)
return nodeSelectors, nil
}

// getNodeSelectorsFromPodSets will extract node selectors from a workload's podSets.
func (r *JobReconciler) getNodeSelectorsFromPodSets(w *kueue.Workload) []PodSetNodeSelector {
podSets := w.Spec.PodSets
if len(podSets) == 0 {
return nil
}
ret := make([]PodSetNodeSelector, len(podSets))
for psi := range podSets {
ps := &podSets[psi]
ret[psi] = PodSetNodeSelector{
Name: ps.Name,
NodeSelector: cloneNodeSelector(ps.Template.Spec.NodeSelector),
}
}
return ret
}

func (r *JobReconciler) handleJobWithNoWorkload(ctx context.Context, job GenericJob, object client.Object) error {
log := ctrl.LoggerFrom(ctx)

Expand Down Expand Up @@ -442,3 +483,44 @@ func generatePodsReadyCondition(job GenericJob, wl *kueue.Workload) metav1.Condi
Message: message,
}
}

func cloneNodeSelector(src map[string]string) map[string]string {
ret := make(map[string]string, len(src))
for k, v := range src {
ret[k] = v
}
return ret
}

// getNodeSelectorsFromObjectAnnotation tries to retrieve a node selectors slice from the
// object's annotations fails if it's not found or is unable to unmarshal
func getNodeSelectorsFromObjectAnnotation(obj client.Object) ([]PodSetNodeSelector, error) {
str, found := obj.GetAnnotations()[OriginalNodeSelectorsAnnotation]
if !found {
return nil, errNodeSelectorsNotFound
}
// unmarshal
ret := []PodSetNodeSelector{}
if err := json.Unmarshal([]byte(str), &ret); err != nil {
return nil, err
}
return ret, nil
}

// setNodeSelectorsInAnnotation - sets an annotation containing the provided node selectors into
// a job object, even if very unlikely it could return an error related to json.marshaling
func setNodeSelectorsInAnnotation(obj client.Object, nodeSelectors []PodSetNodeSelector) error {
nodeSelectorsBytes, err := json.Marshal(nodeSelectors)
if err != nil {
return err
}

annotations := obj.GetAnnotations()
if annotations == nil {
annotations = map[string]string{OriginalNodeSelectorsAnnotation: string(nodeSelectorsBytes)}
} else {
annotations[OriginalNodeSelectorsAnnotation] = string(nodeSelectorsBytes)
}
obj.SetAnnotations(annotations)
return nil
}
19 changes: 19 additions & 0 deletions pkg/controller/jobframework/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
package jobframework

import (
"encoding/json"
"strings"

apivalidation "k8s.io/apimachinery/pkg/api/validation"
Expand All @@ -26,6 +27,8 @@ var (
labelsPath = field.NewPath("metadata", "labels")
parentWorkloadKeyPath = annotationsPath.Key(ParentWorkloadAnnotation)
queueNameLabelPath = labelsPath.Key(QueueLabel)

originalNodeSelectorsWorkloadKeyPath = annotationsPath.Key(OriginalNodeSelectorsAnnotation)
)

func ValidateCreateForQueueName(job GenericJob) field.ErrorList {
Expand Down Expand Up @@ -71,3 +74,19 @@ func ValidateUpdateForParentWorkload(oldJob, newJob GenericJob) field.ErrorList
}
return allErrs
}

func ValidateUpdateForOriginalNodeSelectors(oldJob, newJob GenericJob) field.ErrorList {
var allErrs field.ErrorList
if oldJob.IsSuspended() == newJob.IsSuspended() {
if errList := apivalidation.ValidateImmutableField(oldJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation],
newJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation], originalNodeSelectorsWorkloadKeyPath); len(errList) > 0 {
allErrs = append(allErrs, field.Forbidden(originalNodeSelectorsWorkloadKeyPath, "this annotation is immutable while the job is not changing its suspended state"))
}
} else if av, found := newJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation]; found {
out := []PodSetNodeSelector{}
if err := json.Unmarshal([]byte(av), &out); err != nil {
allErrs = append(allErrs, field.Invalid(originalNodeSelectorsWorkloadKeyPath, av, err.Error()))
}
}
return allErrs
}
12 changes: 6 additions & 6 deletions pkg/controller/jobs/job/job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,29 +152,29 @@ func (j *Job) PodSets() []kueue.PodSet {
}
}

func (j *Job) RunWithNodeAffinity(nodeSelectors []map[string]string) {
func (j *Job) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
j.Spec.Suspend = pointer.Bool(false)
if len(nodeSelectors) == 0 {
return
}

if j.Spec.Template.Spec.NodeSelector == nil {
j.Spec.Template.Spec.NodeSelector = nodeSelectors[0]
j.Spec.Template.Spec.NodeSelector = nodeSelectors[0].NodeSelector
} else {
for k, v := range nodeSelectors[0] {
for k, v := range nodeSelectors[0].NodeSelector {
j.Spec.Template.Spec.NodeSelector[k] = v
}
}
}

func (j *Job) RestoreNodeAffinity(podSets []kueue.PodSet) {
if len(podSets) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, podSets[0].Template.Spec.NodeSelector) {
func (j *Job) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0].NodeSelector) {
return
}

j.Spec.Template.Spec.NodeSelector = map[string]string{}

for k, v := range podSets[0].Template.Spec.NodeSelector {
for k, v := range nodeSelectors[0].NodeSelector {
j.Spec.Template.Spec.NodeSelector[k] = v
}
}
Expand Down
1 change: 1 addition & 0 deletions pkg/controller/jobs/job/job_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func (w *JobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.
func validateUpdate(oldJob, newJob jobframework.GenericJob) field.ErrorList {
allErrs := validateCreate(newJob)
allErrs = append(allErrs, jobframework.ValidateUpdateForParentWorkload(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForQueueName(oldJob, newJob)...)
return allErrs
}
Expand Down
41 changes: 41 additions & 0 deletions pkg/controller/jobs/job/job_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ var (
parentWorkloadKeyPath = annotationsPath.Key(jobframework.ParentWorkloadAnnotation)
queueNameLabelPath = labelsPath.Key(jobframework.QueueLabel)
queueNameAnnotationsPath = annotationsPath.Key(jobframework.QueueAnnotation)

originalNodeSelectorsKeyPath = annotationsPath.Key(jobframework.OriginalNodeSelectorsAnnotation)
)

func TestValidateCreate(t *testing.T) {
Expand Down Expand Up @@ -93,6 +95,17 @@ func TestValidateCreate(t *testing.T) {
}

func TestValidateUpdate(t *testing.T) {

validPodSelectors := `
[
{
"name": "podSetName",
"nodeSelector": {
"l1": "v1"
}
}
]
`
testcases := []struct {
name string
oldJob *batchv1.Job
Expand Down Expand Up @@ -156,6 +169,34 @@ func TestValidateUpdate(t *testing.T) {
field.Forbidden(parentWorkloadKeyPath, "this annotation is immutable"),
},
},
{
name: "original node selectors can be set while unsuspending",
oldJob: testingutil.MakeJob("job", "default").Suspend(true).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
wantErr: nil,
},
{
name: "original node selectors can be set while suspending",
oldJob: testingutil.MakeJob("job", "default").Suspend(false).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
wantErr: nil,
},
{
name: "immutable original node selectors while not suspended",
oldJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("").Obj(),
wantErr: field.ErrorList{
field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"),
},
},
{
name: "immutable original node selectors while suspended",
oldJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("").Obj(),
wantErr: field.ErrorList{
field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"),
},
},
}

for _, tc := range testcases {
Expand Down
17 changes: 9 additions & 8 deletions pkg/controller/jobs/mpijob/mpijob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,35 +106,36 @@ func (j *MPIJob) PodSets() []kueue.PodSet {
return podSets
}

func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []map[string]string) {
func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
j.Spec.RunPolicy.Suspend = pointer.Bool(false)
if len(nodeSelectors) == 0 {
return
}
// The node selectors are provided in the same order as the generated list of
// podSets, use the same ordering logic to restore them.
orderedReplicaTypes := orderedReplicaTypes(&j.Spec)
for index := range nodeSelectors {
replicaType := orderedReplicaTypes[index]
nodeSelector := nodeSelectors[index]
if len(nodeSelector) != 0 {
if len(nodeSelector.NodeSelector) != 0 {
if j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector == nil {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = nodeSelector
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = nodeSelector.NodeSelector
} else {
for k, v := range nodeSelector {
for k, v := range nodeSelector.NodeSelector {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector[k] = v
}
}
}
}
}

func (j *MPIJob) RestoreNodeAffinity(podSets []kueue.PodSet) {
func (j *MPIJob) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
orderedReplicaTypes := orderedReplicaTypes(&j.Spec)
for index := range podSets {
for index, nodeSelector := range nodeSelectors {
replicaType := orderedReplicaTypes[index]
nodeSelector := podSets[index].Template.Spec.NodeSelector
if !equality.Semantic.DeepEqual(j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector, nodeSelector) {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = map[string]string{}
for k, v := range nodeSelector {
for k, v := range nodeSelector.NodeSelector {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector[k] = v
}
}
Expand Down
6 changes: 5 additions & 1 deletion pkg/controller/jobs/mpijob/mpijob_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,14 @@ func validateCreate(job jobframework.GenericJob) field.ErrorList {
// ValidateUpdate implements webhook.CustomValidator so a webhook will be registered for the type
func (w *MPIJobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) error {
oldJob := oldObj.(*kubeflow.MPIJob)
oldGenJob := &MPIJob{*oldJob}
newJob := newObj.(*kubeflow.MPIJob)
newGenJob := &MPIJob{*newJob}
log := ctrl.LoggerFrom(ctx).WithName("job-webhook")
log.Info("Validating update", "job", klog.KObj(newJob))
return jobframework.ValidateUpdateForQueueName(&MPIJob{*oldJob}, &MPIJob{*newJob}).ToAggregate()
allErrs := jobframework.ValidateUpdateForQueueName(oldGenJob, newGenJob)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldGenJob, newGenJob)...)
return allErrs.ToAggregate()
}

// ValidateDelete implements webhook.CustomValidator so a webhook will be registered for the type
Expand Down
Loading

0 comments on commit c25bbc0

Please sign in to comment.