Skip to content

Commit

Permalink
chore: add checkpoint storage and max slots inv config override in PA…
Browse files Browse the repository at this point in the history
…TCH exp
  • Loading branch information
amandavialva01 committed Oct 24, 2024
1 parent 962810a commit a550364
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 8 deletions.
39 changes: 32 additions & 7 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1226,11 +1226,30 @@ func (a *apiServer) PatchExperiment(
return nil, echo.NewHTTPError(http.StatusForbidden, err.Error())
}

// Only allow setting checkpoint storage if it is not specified as an invariant config.
w, err := getWorkspaceByConfig(activeConfig)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get workspace %s",
activeConfig.Workspace()))
}

storage := activeConfig.CheckpointStorage()
storage.SetSaveExperimentBest(int(newCheckpointStorage.SaveExperimentBest))
storage.SetSaveTrialBest(int(newCheckpointStorage.SaveTrialBest))
storage.SetSaveTrialLatest(int(newCheckpointStorage.SaveTrialLatest))
activeConfig.SetCheckpointStorage(storage)

// If checkpoint storage is set at the workspace or global level in an invariant config
// or constraint
enforcedChkptConf, err := configpolicy.GetEnforcedConfig[expconf.CheckpointStorageConfig](
ctx, &w.ID, "checkpoint_storage",
model.ExperimentType)
if err != nil {
return nil, fmt.Errorf("unable to fetch task config policies: %w", err)
}
if enforcedChkptConf != nil {
activeConfig.SetCheckpointStorage(*enforcedChkptConf)
}
}

// `patch` represents the allowed mutations that can be performed on an experiment, in JSON
Expand Down Expand Up @@ -1471,19 +1490,16 @@ func (a *apiServer) parseAndMergeContinueConfig(expID int, overrideConfig string
}

// Determine which workspace the experiment is in.
wkspName := activeConfig.Workspace()
if wkspName == "" {
wkspName = model.DefaultWorkspaceName
}
ctx := context.TODO()
w, err := workspace.WorkspaceByName(ctx, wkspName)
w, err := getWorkspaceByConfig(activeConfig)
if err != nil {
return nil, false, status.Errorf(codes.Internal,
fmt.Sprintf("failed to get workspace %s", activeConfig.Workspace()))
}

// Merge the config with the optionally specified invariant config specified by task config
// policies.
configWithInvariantDefaults, err := configpolicy.MergeWithInvariantExperimentConfigs(ctx,
configWithInvariantDefaults, err := configpolicy.MergeWithInvariantExperimentConfigs(
context.TODO(),
w.ID, mergedConfig)
if err != nil {
return nil, false,
Expand All @@ -1499,6 +1515,15 @@ func (a *apiServer) parseAndMergeContinueConfig(expID int, overrideConfig string
return bytes.([]byte), isSingle, nil
}

func getWorkspaceByConfig(config expconf.ExperimentConfig) (*model.Workspace, error) {
wkspName := config.Workspace()
if wkspName == "" {
wkspName = model.DefaultWorkspaceName
}
ctx := context.TODO()
return workspace.WorkspaceByName(ctx, wkspName)
}

var errContinueHPSearchCompleted = status.Error(codes.FailedPrecondition,
"experiment has been completed, cannot continue this experiment")

Expand Down
24 changes: 24 additions & 0 deletions master/internal/api_experiment_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2347,3 +2347,27 @@ func TestDeleteExperimentsFiltered(t *testing.T) {
}
t.Error("expected experiments to delete after 15 seconds and they did not")
}

func TestGetWorkspaceByConfig(t *testing.T) {
api, _, ctx := setupAPITest(t, nil)
resp, err := api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{
Name: uuid.New().String(),
})
require.NoError(t, err)
wkspName := &resp.Workspace.Name

t.Run("no workspace name", func(t *testing.T) {
w, err := getWorkspaceByConfig(expconf.ExperimentConfig{RawWorkspace: ptrs.Ptr("")})
require.NoError(t, err)

// Verify we get the Uncategorized workspace.
require.Equal(t, 1, w.ID)
})
t.Run("has workspace name", func(t *testing.T) {
w, err := getWorkspaceByConfig(expconf.ExperimentConfig{
RawWorkspace: wkspName,
})
require.NoError(t, err)
require.Equal(t, *wkspName, w.Name)
})
}
49 changes: 49 additions & 0 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package configpolicy
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"

Expand Down Expand Up @@ -113,3 +114,51 @@ func DeleteConfigPolicies(ctx context.Context,
}
return nil
}

// GetEnforcedConfig gets the fields of the global invariant config if specified, and the workspace
// invariant config otherwise.
func GetEnforcedConfig[T any](ctx context.Context, wkspID *int, field, workloadType string) (*T,
error,
) {
var confBytes []byte
err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
globalField := tx.NewSelect().
ColumnExpr("invariant_config -> ? AS globconf", field).
Table("task_config_policies").
Where("workspace_id IS NULL").
Where("workload_type = ?", workloadType)

wkspField := tx.NewSelect().
ColumnExpr("invariant_config -> ? AS wkspconf", field).
Table("task_config_policies").
Where("workspace_id = '?'", wkspID).
Where("workload_type = ?", workloadType)

both := tx.NewSelect().TableExpr("global_field").
Join("NATURAL JOIN wksp_field")

err := tx.NewSelect().ColumnExpr("coalesce(globconf, wkspconf)").
With("global_field", globalField).
With("wksp_field", wkspField).
Table("both").With("both", both).
Scan(ctx, &confBytes)
if err != nil {
return err
}
return nil
})
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("error getting config field %s: %w", field, err)
}

var conf T
err = json.Unmarshal(confBytes, &conf)
if err != nil {
return nil, fmt.Errorf("error unmarshaling config field: %w", err)
}

return &conf, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/determined-ai/determined/master/pkg/etc"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -659,3 +660,98 @@ func requireEqualTaskPolicy(t *testing.T, exp *model.TaskConfigPolicies, act *mo
require.Equal(t, expJSONMap, actJSONMap)
}
}

func TestGetMergedConfigPolicies(t *testing.T) {
ctx := context.Background()
require.NoError(t, etc.SetRootPath(db.RootFromDB))
pgDB, cleanup := db.MustResolveNewPostgresDatabase(t)
defer cleanup()
db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB)

user := db.RequireMockUser(t, pgDB)

w := model.Workspace{Name: uuid.NewString(), UserID: user.ID}
_, err := db.Bun().NewInsert().Model(&w).Exec(ctx)
require.NoError(t, err)

globalConf := `
{
"checkpoint_storage": {
"type": "shared_fs",
"host_path": "global_host_path",
"container_path": "global_container_path"
}
}
`
wkspConf := `
{
"checkpoint_storage": {
"type": "shared_fs",
"host_path": "wksp_host_path",
"container_path": "wksp_container_path",
"checkpoint_path": "wksp_checkpoint_path"
}
}
`

t.Run("checkpoint storage", func(t *testing.T) {
err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
WorkloadType: model.ExperimentType,
LastUpdatedBy: user.ID,
InvariantConfig: &globalConf,
})
require.NoError(t, err)

err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
WorkspaceID: &w.ID,
WorkloadType: model.ExperimentType,
LastUpdatedBy: user.ID,
InvariantConfig: &wkspConf,
})
require.NoError(t, err)

checkpointStorage, err := GetEnforcedConfig[expconf.CheckpointStorageConfig](ctx, &w.ID,
"checkpoint_storage", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, checkpointStorage)
require.Equal(t, expconf.CheckpointStorageConfigV0{
RawSharedFSConfig: &expconf.SharedFSConfigV0{
RawHostPath: ptrs.Ptr("global_host_path"),
RawContainerPath: ptrs.Ptr("global_container_path"),
}}, *checkpointStorage)
})

globalConf = `{
"debug": true
}`
wkspConf = `
{
"resources": {
"max_slots": 15
}
}
`

t.Run("max slots", func(t *testing.T) {
err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
WorkloadType: model.ExperimentType,
LastUpdatedBy: user.ID,
InvariantConfig: &globalConf,
})
require.NoError(t, err)

err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
WorkspaceID: &w.ID,
WorkloadType: model.ExperimentType,
LastUpdatedBy: user.ID,
InvariantConfig: &wkspConf,
})
require.NoError(t, err)

maxSlots, err := GetEnforcedConfig[*int](ctx, &w.ID,
"resources"+" "+"-> "+"max_slots", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, maxSlots)
require.Equal(t, 15, *maxSlots)
})
}
5 changes: 4 additions & 1 deletion master/internal/db/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ func initTheOneBun(db *PgDB) {
bunMutex.Lock()
defer bunMutex.Unlock()
if theOneBun != nil {
fmt.Println("its not nil! \n \n ")
log.Warn(
"detected re-initialization of Bun that should never occur outside of tests",
)
} else {
fmt.Println("its IS nil! \n \n ")
}
theOneBun = bun.NewDB(db.sql.DB, pgdialect.New())
theOneDB = db
Expand All @@ -59,7 +62,7 @@ func initTheOneBun(db *PgDB) {
}

// This will print every query that runs.
// theOneBun.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
theOneBun.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))

// This will print only the failed queries.
theOneBun.AddQueryHook(bundebug.NewQueryHook())
Expand Down
3 changes: 3 additions & 0 deletions master/internal/db/postgres_test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ func ResolveNewPostgresDatabase() (*PgDB, func(), error) {
}

dbname := fmt.Sprintf("intg-%x", randomSuffix)
fmt.Printf("This is the db name: %s \n \n \n ", dbname)
_, err = sql.Exec(fmt.Sprintf("CREATE DATABASE %q", dbname))
if err != nil {
return nil, nil, errors.Wrapf(err, "failed to create new database %q", dbname)
Expand All @@ -155,6 +156,8 @@ func ResolveNewPostgresDatabase() (*PgDB, func(), error) {
}
if _, err := sql.Exec(fmt.Sprintf("DROP DATABASE %q", dbname)); err != nil {
log.WithError(err).Errorf("failed to delete temp database %q", dbname)
} else {
fmt.Printf("we dropped database: %q \n \n \n ", dbname)
}
}

Expand Down

0 comments on commit a550364

Please sign in to comment.