Skip to content

Commit

Permalink
add checkpoint and max slots config policy enforcements in PATCH expe…
Browse files Browse the repository at this point in the history
…riment
  • Loading branch information
amandavialva01 committed Oct 24, 2024
1 parent a550364 commit bb87e37
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 27 deletions.
21 changes: 11 additions & 10 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,8 @@ func (a *apiServer) PatchExperiment(
}
activeConfig.SetResources(resources)
}

// Only allow setting checkpoint storage if it is not specified as an invariant config.
newCheckpointStorage := req.Experiment.CheckpointStorage

if newCheckpointStorage != nil {
Expand All @@ -1226,23 +1228,21 @@ func (a *apiServer) PatchExperiment(
return nil, echo.NewHTTPError(http.StatusForbidden, err.Error())
}

// Only allow setting checkpoint storage if it is not specified as an invariant config.
w, err := getWorkspaceByConfig(activeConfig)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get workspace %s",
activeConfig.Workspace()))
}

storage := activeConfig.CheckpointStorage()
storage.SetSaveExperimentBest(int(newCheckpointStorage.SaveExperimentBest))
storage.SetSaveTrialBest(int(newCheckpointStorage.SaveTrialBest))
storage.SetSaveTrialLatest(int(newCheckpointStorage.SaveTrialLatest))
activeConfig.SetCheckpointStorage(storage)

// If checkpoint storage is set at the workspace or global level in an invariant config
// or constraint
// Only allow checkpoint storage changes if it is not specified as an invariant config.
w, err := getWorkspaceByConfig(activeConfig)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get workspace %s",
activeConfig.Workspace()))
}

enforcedChkptConf, err := configpolicy.GetEnforcedConfig[expconf.CheckpointStorageConfig](
ctx, &w.ID, "checkpoint_storage",
ctx, &w.ID, "invariant_config", "checkpoint_storage",
model.ExperimentType)
if err != nil {
return nil, fmt.Errorf("unable to fetch task config policies: %w", err)
Expand All @@ -1266,6 +1266,7 @@ func (a *apiServer) PatchExperiment(
if newResources.MaxSlots != nil {
msg := sproto.SetGroupMaxSlots{MaxSlots: ptrs.Ptr(int(*newResources.MaxSlots))}
e.SetGroupMaxSlots(msg)

}
if newResources.Weight != nil {
err := e.SetGroupWeight(*newResources.Weight)
Expand Down
12 changes: 8 additions & 4 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,24 @@ func DeleteConfigPolicies(ctx context.Context,

// GetEnforcedConfig gets the fields of the global invariant config if specified, and the workspace
// invariant config otherwise.
func GetEnforcedConfig[T any](ctx context.Context, wkspID *int, field, workloadType string) (*T,
func GetEnforcedConfig[T any](ctx context.Context, wkspID *int, policyType, field, workloadType string) (*T,
error,
) {
if policyType != "invariant_config" && policyType != "constraints" {
return nil, fmt.Errorf("invalid policy type :%s", policyType)
}

var confBytes []byte
var conf T
err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
globalField := tx.NewSelect().
ColumnExpr("invariant_config -> ? AS globconf", field).
ColumnExpr("? -> ? AS globconf", bun.Safe(policyType), bun.Safe(field)).
Table("task_config_policies").
Where("workspace_id IS NULL").
Where("workload_type = ?", workloadType)

wkspField := tx.NewSelect().
ColumnExpr("invariant_config -> ? AS wkspconf", field).
ColumnExpr("? -> ? AS wkspconf", bun.Safe(policyType), bun.Safe(field)).
Table("task_config_policies").
Where("workspace_id = '?'", wkspID).
Where("workload_type = ?", workloadType)
Expand All @@ -154,7 +159,6 @@ func GetEnforcedConfig[T any](ctx context.Context, wkspID *int, field, workloadT
return nil, fmt.Errorf("error getting config field %s: %w", field, err)
}

var conf T
err = json.Unmarshal(confBytes, &conf)
if err != nil {
return nil, fmt.Errorf("error unmarshaling config field: %w", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ func requireEqualTaskPolicy(t *testing.T, exp *model.TaskConfigPolicies, act *mo
}
}

func TestGetMergedConfigPolicies(t *testing.T) {
func TestGetEnforcedConfig(t *testing.T) {
ctx := context.Background()
require.NoError(t, etc.SetRootPath(db.RootFromDB))
pgDB, cleanup := db.MustResolveNewPostgresDatabase(t)
Expand Down Expand Up @@ -694,7 +694,7 @@ func TestGetMergedConfigPolicies(t *testing.T) {
}
`

t.Run("checkpoint storage", func(t *testing.T) {
t.Run("checkpoint storage config", func(t *testing.T) {
err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
WorkloadType: model.ExperimentType,
LastUpdatedBy: user.ID,
Expand All @@ -711,7 +711,7 @@ func TestGetMergedConfigPolicies(t *testing.T) {
require.NoError(t, err)

checkpointStorage, err := GetEnforcedConfig[expconf.CheckpointStorageConfig](ctx, &w.ID,
"checkpoint_storage", model.ExperimentType)
"invariant_config", "'checkpoint_storage'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, checkpointStorage)
require.Equal(t, expconf.CheckpointStorageConfigV0{
Expand All @@ -732,7 +732,7 @@ func TestGetMergedConfigPolicies(t *testing.T) {
}
`

t.Run("max slots", func(t *testing.T) {
t.Run("max slots config", func(t *testing.T) {
err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
WorkloadType: model.ExperimentType,
LastUpdatedBy: user.ID,
Expand All @@ -748,10 +748,84 @@ func TestGetMergedConfigPolicies(t *testing.T) {
})
require.NoError(t, err)

maxSlots, err := GetEnforcedConfig[*int](ctx, &w.ID,
"resources"+" "+"-> "+"max_slots", model.ExperimentType)
maxSlots, err := GetEnforcedConfig[int](ctx, &w.ID, "invariant_config",
"'resources' -> 'max_slots'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, maxSlots)
require.Equal(t, 15, *maxSlots)
})

globalConstraints := `
{
"resources": {
"max_slots": 25
}
}
`

wkspConstraints := `
{
"resources": {
"max_slots": 20
}
}
`

t.Run("max slots constraints", func(t *testing.T) {
err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
WorkloadType: model.ExperimentType,
LastUpdatedBy: user.ID,
Constraints: &globalConstraints,
})
require.NoError(t, err)

err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
WorkspaceID: &w.ID,
WorkloadType: model.ExperimentType,
LastUpdatedBy: user.ID,
Constraints: &wkspConstraints,
})
require.NoError(t, err)

maxSlots, err := GetEnforcedConfig[int](ctx, &w.ID, "constraints",
"'resources' -> 'max_slots'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, maxSlots)
require.Equal(t, 25, *maxSlots)
})

globalConstraints = `
{
"priority_limit": 40
}
`

wkspConstraints = `
{
"priority_limit": 50
}
`

t.Run("priority constraints", func(t *testing.T) {
err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
WorkloadType: model.ExperimentType,
LastUpdatedBy: user.ID,
Constraints: &globalConstraints,
})
require.NoError(t, err)

err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
WorkspaceID: &w.ID,
WorkloadType: model.ExperimentType,
LastUpdatedBy: user.ID,
Constraints: &wkspConstraints,
})
require.NoError(t, err)

maxSlots, err := GetEnforcedConfig[int](ctx, &w.ID, "constraints",
"'priority_limit'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, maxSlots)
require.Equal(t, 40, *maxSlots)
})
}
5 changes: 1 addition & 4 deletions master/internal/db/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,9 @@ func initTheOneBun(db *PgDB) {
bunMutex.Lock()
defer bunMutex.Unlock()
if theOneBun != nil {
fmt.Println("its not nil! \n \n ")
log.Warn(
"detected re-initialization of Bun that should never occur outside of tests",
)
} else {
fmt.Println("its IS nil! \n \n ")
}
theOneBun = bun.NewDB(db.sql.DB, pgdialect.New())
theOneDB = db
Expand All @@ -62,7 +59,7 @@ func initTheOneBun(db *PgDB) {
}

// This will print every query that runs.
theOneBun.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
// theOneBun.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))

// This will print only the failed queries.
theOneBun.AddQueryHook(bundebug.NewQueryHook())
Expand Down
3 changes: 0 additions & 3 deletions master/internal/db/postgres_test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ func ResolveNewPostgresDatabase() (*PgDB, func(), error) {
}

dbname := fmt.Sprintf("intg-%x", randomSuffix)
fmt.Printf("This is the db name: %s \n \n \n ", dbname)
_, err = sql.Exec(fmt.Sprintf("CREATE DATABASE %q", dbname))
if err != nil {
return nil, nil, errors.Wrapf(err, "failed to create new database %q", dbname)
Expand All @@ -156,8 +155,6 @@ func ResolveNewPostgresDatabase() (*PgDB, func(), error) {
}
if _, err := sql.Exec(fmt.Sprintf("DROP DATABASE %q", dbname)); err != nil {
log.WithError(err).Errorf("failed to delete temp database %q", dbname)
} else {
fmt.Printf("we dropped database: %q \n \n \n ", dbname)
}
}

Expand Down
34 changes: 34 additions & 0 deletions master/internal/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/pkg/errors"
"github.com/shopspring/decimal"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"github.com/uptrace/bun"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -411,6 +412,39 @@ func (e *internalExperiment) PatchTrialState(msg experiment.PatchTrialState) err
func (e *internalExperiment) SetGroupMaxSlots(msg sproto.SetGroupMaxSlots) {
e.mu.Lock()
defer e.mu.Unlock()
// Only allow max slots changes if it is not specified as an invariant config or enforced as a
// constraint.

w, err := getWorkspaceByConfig(e.activeConfig)
if err != nil {
log.Warnf("unable to set max slots")
return
}
enforcedMaxSlots, err := configpolicy.GetEnforcedConfig[int](context.TODO(), &w.ID,
"invariant_config",
"'resources' -> 'max_slots'", model.ExperimentType)
if err != nil {
log.Warnf("unable to set max slots")
return
}

if enforcedMaxSlots != nil {
msg.MaxSlots = enforcedMaxSlots
}

maxSlotsLimit, err := configpolicy.GetEnforcedConfig[int](context.TODO(), &w.ID,
"constraints",
"'resources' -> 'max_slots'", model.ExperimentType)
if err != nil {
log.Warnf("unable to set max slots")
return
}

if enforcedMaxSlots == nil && maxSlotsLimit != nil && msg.MaxSlots != nil &&
*msg.MaxSlots > *maxSlotsLimit {
log.Warnf("unable to set max slots")
return
}

resources := e.activeConfig.Resources()
resources.SetMaxSlots(msg.MaxSlots)
Expand Down

0 comments on commit bb87e37

Please sign in to comment.