Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add checkpoint and max slots config policy enforcements in PATCH experiment #10125

Merged
merged 4 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1219,7 +1219,6 @@ func (a *apiServer) PatchExperiment(
activeConfig.SetResources(resources)
}

// Only allow setting checkpoint storage if it is not specified as an invariant config.
newCheckpointStorage := req.Experiment.CheckpointStorage

if newCheckpointStorage != nil {
Expand All @@ -1241,7 +1240,7 @@ func (a *apiServer) PatchExperiment(
activeConfig.Workspace()))
}

enforcedChkptConf, err := configpolicy.GetEnforcedConfig[expconf.CheckpointStorageConfig](
enforcedChkptConf, err := configpolicy.GetConfigPolicyField[expconf.CheckpointStorageConfig](
ctx, &w.ID, "invariant_config", "checkpoint_storage",
model.ExperimentType)
if err != nil {
Expand All @@ -1262,6 +1261,7 @@ func (a *apiServer) PatchExperiment(
if !ok {
return nil, api.NotFoundErrs("experiment", strconv.Itoa(int(exp.Id)), true)
}

if newResources.MaxSlots != nil {
msg := sproto.SetGroupMaxSlots{MaxSlots: ptrs.Ptr(int(*newResources.MaxSlots))}
e.SetGroupMaxSlots(msg)
Expand Down Expand Up @@ -1488,15 +1488,14 @@ func (a *apiServer) parseAndMergeContinueConfig(expID int, overrideConfig string
fmt.Sprintf("override config must have single searcher type got '%s' instead", overrideName))
}

// Determine which workspace the experiment is in.
// Merge the config with the optionally specified invariant config specified by task config
// policies.
w, err := getWorkspaceByConfig(activeConfig)
if err != nil {
return nil, false, status.Errorf(codes.Internal,
fmt.Sprintf("failed to get workspace %s", activeConfig.Workspace()))
}

// Merge the config with the optionally specified invariant config specified by task config
// policies.
configWithInvariantDefaults, err := configpolicy.MergeWithInvariantExperimentConfigs(
context.TODO(),
w.ID, mergedConfig)
Expand Down
61 changes: 33 additions & 28 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ import (
)

const (
wkspIDQuery = "workspace_id = ?"
wkspIDGlobalQuery = "workspace_id IS ?"
wkspIDQuery = "workspace_id = ?"
wkspIDGlobalQuery = "workspace_id IS ?"
invalidPolicyTypeErr = "invalid policy type"
// DefaultInvariantConfigStr is the default invariant config val used for tests.
DefaultInvariantConfigStr = `{
"description": "random description",
Expand Down Expand Up @@ -115,44 +116,48 @@ func DeleteConfigPolicies(ctx context.Context,
return nil
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetEnforcedConfig is fine as a name. I think something like GetConfigPolicyField is more descriptive. When I first read the function name, I thought it was only for invariant configs. It also wasn't clear that it was fetching a single field, rather than a whole or partial config.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point! Changed this

// GetEnforcedConfig gets the fields of the global invariant config or constraint if specified, and
// the workspace invariant config or constraint otherwise. If neither is specified, returns nil.
func GetEnforcedConfig[T any](ctx context.Context, wkspID *int, policyType, field, workloadType string) (*T,
// GetConfigPolicyField fetches the field from an invariant_config or constraints policyType, in order
// of precedence. Global scope has highest precedence, then workspace. Returns nil if none is found.
// **NOTE** The field arguments are wrapped in bun.Safe, so you must specify the "raw" string
// exactly as you wish for it to be accessed in the database. For example, if you want to access
// resources.max_slots, the field argument should be "'resources' -> 'max_slots'" NOT
// "resources -> max_slots".
// **NOTE**When using this function to retrieve an object of Kind Pointer, set T as the Type of
// object that the Pointer wraps. For example, if we want an object of type *int, set T to int, so
// that when its pointer is returned, you get an object of type *int.
func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, policyType, field, workloadType string) (*T,
error,
) {
if policyType != "invariant_config" && policyType != "constraints" {
return nil, fmt.Errorf("invalid policy type :%s", policyType)
return nil, fmt.Errorf("%s :%s", invalidPolicyTypeErr, policyType)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also validate workloadType; I think all our other postgres functions do. There's no need to add a test case for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like we actually don't validate workloadType in any of the postgres functions! At first this seemed odd, but then I remembered we made workload_type an enum!
I can still perform the validation if you'd like, but this function would be unique in that regard

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, don't add it. I was mistaken!

var confBytes []byte
var conf T
err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
globalField := tx.NewSelect().
ColumnExpr("? -> ? AS globconf", bun.Safe(policyType), bun.Safe(field)).
Table("task_config_policies").
var globalBytes []byte
err := tx.NewSelect().Table("task_config_policies").
ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)).
Where("workspace_id IS NULL").
Where("workload_type = ?", workloadType)

wkspField := tx.NewSelect().
ColumnExpr("? -> ? AS wkspconf", bun.Safe(policyType), bun.Safe(field)).
Table("task_config_policies").
Where("workspace_id = '?'", wkspID).
Where("workload_type = ?", workloadType)

both := tx.NewSelect().TableExpr("global_field").
Join("NATURAL JOIN wksp_field")

err := tx.NewSelect().ColumnExpr("coalesce(globconf, wkspconf)").
With("global_field", globalField).
With("wksp_field", wkspField).
Table("both").With("both", both).
Scan(ctx, &confBytes)
if err != nil {
Where("workload_type = ?", workloadType).Scan(ctx, &globalBytes)
if err == nil && len(globalBytes) > 0 {
confBytes = globalBytes
}
if err != nil && err != sql.ErrNoRows {
return err
}
return nil

var wkspBytes []byte
err = tx.NewSelect().Table("task_config_policies").
ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)).
Where("workspace_id = ?", wkspID).
Where("workload_type = ?", workloadType).Scan(ctx, &wkspBytes)
if err == nil && len(globalBytes) == 0 {
confBytes = wkspBytes
}
return err
})
if err == sql.ErrNoRows {
if err == sql.ErrNoRows || len(confBytes) == 0 {
return nil, nil
}
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,8 @@ func requireEqualTaskPolicy(t *testing.T, exp *model.TaskConfigPolicies, act *mo
func TestGetEnforcedConfig(t *testing.T) {
ctx := context.Background()
require.NoError(t, etc.SetRootPath(db.RootFromDB))
pgDB, cleanup := db.MustResolveNewPostgresDatabase(t)
defer cleanup()
pgDB, _ := db.MustResolveNewPostgresDatabase(t)
// defer cleanup()
db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB)

user := db.RequireMockUser(t, pgDB)
Expand Down Expand Up @@ -710,10 +710,12 @@ func TestGetEnforcedConfig(t *testing.T) {
})
require.NoError(t, err)

checkpointStorage, err := GetEnforcedConfig[expconf.CheckpointStorageConfig](ctx, &w.ID,
checkpointStorage, err := GetConfigPolicyField[expconf.CheckpointStorageConfig](ctx, &w.ID,
"invariant_config", "'checkpoint_storage'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, checkpointStorage)

// global config enforced?
require.Equal(t, expconf.CheckpointStorageConfigV0{
RawSharedFSConfig: &expconf.SharedFSConfigV0{
RawHostPath: ptrs.Ptr("global_host_path"),
Expand Down Expand Up @@ -749,10 +751,12 @@ func TestGetEnforcedConfig(t *testing.T) {
})
require.NoError(t, err)

maxSlots, err := GetEnforcedConfig[int](ctx, &w.ID, "invariant_config",
maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config",
"'resources' -> 'max_slots'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, maxSlots)

// workspace config enforced?
require.Equal(t, 15, *maxSlots)
})

Expand Down Expand Up @@ -788,10 +792,12 @@ func TestGetEnforcedConfig(t *testing.T) {
})
require.NoError(t, err)

maxSlots, err := GetEnforcedConfig[int](ctx, &w.ID, "constraints",
maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints",
"'resources' -> 'max_slots'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, maxSlots)

// global constraint enforced?
require.Equal(t, 25, *maxSlots)
})

Expand Down Expand Up @@ -823,10 +829,54 @@ func TestGetEnforcedConfig(t *testing.T) {
})
require.NoError(t, err)

maxSlots, err := GetEnforcedConfig[int](ctx, &w.ID, "constraints",
priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints",
"'priority_limit'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, maxSlots)
require.Equal(t, 40, *maxSlots)
require.NotNil(t, priority)

// global constraint enforced?
require.Equal(t, 40, *priority)
})

t.Run("priority constraints wksp", func(t *testing.T) {
// delete global config policies
err = DeleteConfigPolicies(ctx, nil, model.ExperimentType)
require.NoError(t, err)

err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
WorkspaceID: &w.ID,
WorkloadType: model.ExperimentType,
LastUpdatedBy: user.ID,
Constraints: &wkspConstraints,
})
require.NoError(t, err)

priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints",
"'priority_limit'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, priority)

// workspace constraint enforced?
require.Equal(t, 50, *priority)
})

t.Run("field not set in config", func(t *testing.T) {
maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config",
"'max_restarts'", model.ExperimentType)
require.NoError(t, err)
require.Nil(t, maxRestarts)
})

t.Run("nonexistent constraints field", func(t *testing.T) {
maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints",
"'max_restarts'", model.ExperimentType)
require.NoError(t, err)
require.Nil(t, maxRestarts)
})

t.Run("invalid policy type", func(t *testing.T) {
_, err := GetConfigPolicyField[int](ctx, &w.ID, "bad policy",
"'debug'", model.ExperimentType)
require.ErrorContains(t, err, invalidPolicyTypeErr)
})
}
41 changes: 41 additions & 0 deletions master/internal/configpolicy/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package configpolicy

import (
"bytes"
"context"
"encoding/json"
"fmt"
"reflect"
Expand All @@ -27,6 +28,9 @@ const (
InvalidNTSCConfigPolicyErr = "invalid ntsc config policy"
// NotSupportedConfigPolicyErr is the error reported when admins attempt to set NTSC invariant config.
NotSupportedConfigPolicyErr = "not supported"
// SlotsReqTooHighErr is the error reported when the requested slots violates the max slots
// constraint.
SlotsReqTooHighErr = "requested slots is violates max slots constraint"
)

// ConfigPolicyWarning logs a warning for the configuration policy component.
Expand Down Expand Up @@ -298,3 +302,40 @@ func configPolicyOverlap(config1, config2 interface{}) {
}
}
}

// CanSetMaxSlots returns true if the slots requested don't violate a constraint. It returns the
// enforced max slots for the workspace if that's set as an invariant config, and returns the
// requested max slots otherwise. Returns an error when max slots is not set as an invariant config
// and the requested max slots violates the constriant.
func CanSetMaxSlots(slotsReq *int, wkspID int) (bool, *int, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function should return bool, int, error or *int, error. In the first return group, the bool lets the caller know if int was set or not. In the second group, a valid int is inferred from whether or not the pointer is nil.

I would simplify it further to just int, error. If there's an error, max_slots cannot be updated. If error is nil, then set max_slots to the returned int value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point, changed return type to *int, error!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh wait ok i see your point about int, error!
hmm, yes i see! ok ill change to this

Copy link
Contributor Author

@amandavialva01 amandavialva01 Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gah ok actually, I think it's easier to keep this as is since the func takes in an optional *int (so it can return that same optional *int).
So when the caller gets the func output, it can just replace its input w the func output.
Are you cool w leaving it *int, error instead of int, error?

if slotsReq == nil {
return true, slotsReq, nil
}
enforcedMaxSlots, err := GetConfigPolicyField[int](context.TODO(), &wkspID,
"invariant_config",
"'resources' -> 'max_slots'", model.ExperimentType)
if err != nil {
return false, nil, err
}

if enforcedMaxSlots != nil {
return true, enforcedMaxSlots, nil
}

maxSlotsLimit, err := GetConfigPolicyField[int](context.TODO(), &wkspID,
"constraints",
"'resources' -> 'max_slots'", model.ExperimentType)
if err != nil {
return false, nil, err
}

var canSetReqSlots bool
if maxSlotsLimit == nil || *slotsReq <= *maxSlotsLimit {
canSetReqSlots = true
}
if !canSetReqSlots {
return false, nil, fmt.Errorf(SlotsReqTooHighErr+": %d > %d", *slotsReq, *maxSlotsLimit)
}

return true, slotsReq, nil
}
Loading
Loading