Skip to content

Commit

Permalink
Merge #114342
Browse files Browse the repository at this point in the history
114342: roachprod:add --aws-use-spot r=BabuSrithar a=BabuSrithar

This change adds an option to use AWS spot VMs to power CRDB clusters. AWS Spot VMs are significantly cheaper, but can be preempted by AWS at anytime.

Related PR : #105470

Epic: none

Release note: None

Co-authored-by: babusrithar <[email protected]>
  • Loading branch information
craig[bot] and BabuSrithar committed Nov 16, 2023
2 parents aa812a6 + 7f17dc4 commit fbb3446
Showing 1 changed file with 226 additions and 49 deletions.
275 changes: 226 additions & 49 deletions pkg/roachprod/vm/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ type ProviderOpts struct {
// The request limit from aws' side can vary across regions, as well as the
// size of cluster being created.
CreateRateLimit float64
// use spot vms, spot vms are significantly cheaper, but can be preempted AWS.
// see https://aws.amazon.com/ec2/spot/ for more details.
UseSpot bool
}

// Provider implements the vm.Provider interface for AWS.
Expand Down Expand Up @@ -365,6 +368,8 @@ func (o *ProviderOpts) ConfigureCreateFlags(flags *pflag.FlagSet) {
" rate limit (per second) for instance creation. This is used to avoid hitting the request"+
" limits from aws, which can vary based on the region, and the size of the cluster being"+
" created. Try lowering this limit when hitting 'Request limit exceeded' errors.")
flags.BoolVar(&o.UseSpot, ProviderName+"-use-spot",
false, "use AWS Spot VMs, which are significantly cheaper, but can be preempted by AWS.")
flags.StringVar(&providerInstance.IAMProfile, ProviderName+"-iam-profile", providerInstance.IAMProfile,
"the IAM instance profile to associate with created VMs if non-empty")

Expand Down Expand Up @@ -863,58 +868,93 @@ func (p *Provider) getVolumesForInstance(
return vols, err
}

// DescribeInstancesOutput represents the output of the aws ec2 describe-instances command
type DescribeInstancesOutput struct {
Reservations []struct {
Instances []struct {
InstanceID string `json:"InstanceId"`
Architecture string
LaunchTime string
Placement struct {
AvailabilityZone string
}
PrivateDNSName string `json:"PrivateDnsName"`
PrivateIPAddress string `json:"PrivateIpAddress"`
PublicDNSName string `json:"PublicDnsName"`
PublicIPAddress string `json:"PublicIpAddress"`
State struct {
Code int
Name string
}
RootDeviceName string

BlockDeviceMappings []struct {
DeviceName string `json:"DeviceName"`
Disk struct {
AttachTime time.Time `json:"AttachTime"`
DeleteOnTermination bool `json:"DeleteOnTermination"`
Status string `json:"Status"`
VolumeID string `json:"VolumeId"`
} `json:"Ebs"`
} `json:"BlockDeviceMappings"`

Tags Tags

VpcID string `json:"VpcId"`
InstanceType string
InstanceLifecycle string `json:"InstanceLifecycle"`
SpotInstanceRequestId string `json:"SpotInstanceRequestId"`
}
}
}

// CancelSpotInstanceRequestsOutput represents the output structure of the cancel-spot-instance-requests command.
type CancelSpotInstanceRequestsOutput struct {
CancelledSpotInstanceRequests []struct {
SpotInstanceRequestId string `json:"SpotInstanceRequestId"`
State string `json:"State"`
} `json:"CancelledSpotInstanceRequests"`
}

// DescribeSpotInstanceRequestsOutput represents the output of the aws ec2 describe-spot-instance-requests command
type DescribeSpotInstanceRequestsOutput struct {
SpotInstanceRequests []struct {
SpotInstanceRequestId string `json:"SpotInstanceRequestId"`
InstanceId string `json:"InstanceId"`
State string `json:"State"`
Status struct {
Code string `json:"Code"`
Message string `json:"Message"`
UpdateTime string `json:"UpdateTime"`
} `json:"Status"`
} `json:"SpotInstanceRequests"`
}

// RunInstancesOutput represents the output of the aws ec2 run-instances command
type RunInstancesOutput struct {
Instances []struct {
InstanceID string `json:"InstanceId"`
}
}

// listRegion extracts the roachprod-managed instances in the
// given region.
func (p *Provider) listRegion(
l *logger.Logger, region string, opts ProviderOpts, listOpt vm.ListOptions,
) (vm.List, error) {
var data struct {
Reservations []struct {
Instances []struct {
InstanceID string `json:"InstanceId"`
Architecture string
LaunchTime string
Placement struct {
AvailabilityZone string
}
PrivateDNSName string `json:"PrivateDnsName"`
PrivateIPAddress string `json:"PrivateIpAddress"`
PublicDNSName string `json:"PublicDnsName"`
PublicIPAddress string `json:"PublicIpAddress"`
State struct {
Code int
Name string
}
RootDeviceName string

BlockDeviceMappings []struct {
DeviceName string `json:"DeviceName"`
Disk struct {
AttachTime time.Time `json:"AttachTime"`
DeleteOnTermination bool `json:"DeleteOnTermination"`
Status string `json:"Status"`
VolumeID string `json:"VolumeId"`
} `json:"Ebs"`
} `json:"BlockDeviceMappings"`

Tags Tags

VpcID string `json:"VpcId"`
InstanceType string
}
}
}

args := []string{
"ec2", "describe-instances",
"--region", region,
}
err := p.runJSONCommand(l, args, &data)
var describeInstancesResponse DescribeInstancesOutput
err := p.runJSONCommand(l, args, &describeInstancesResponse)
if err != nil {
return nil, err
}

var ret vm.List
for _, res := range data.Reservations {
for _, res := range describeInstancesResponse.Reservations {
in:
for _, in := range res.Instances {
// Ignore any instances that are not pending or running
Expand Down Expand Up @@ -993,6 +1033,7 @@ func (p *Provider) listRegion(
CPUArch: vm.ParseArch(in.Architecture),
Zone: in.Placement.AvailabilityZone,
NonBootAttachedVolumes: nonBootableVolumes,
Preemptible: in.InstanceLifecycle == "spot",
}
ret = append(ret, m)
}
Expand Down Expand Up @@ -1063,16 +1104,6 @@ func (p *Provider) runInstance(
vmTagSpecs := fmt.Sprintf("ResourceType=instance,Tags=[%s]", labels)
volumeTagSpecs := fmt.Sprintf("ResourceType=volume,Tags=[%s]", labels)

var data struct {
Instances []struct {
InstanceID string `json:"InstanceId"`
}
}
_ = data.Instances // silence unused warning
if len(data.Instances) > 0 {
_ = data.Instances[0].InstanceID // silence unused warning
}

// Create AWS startup script file.
extraMountOpts := ""
// Dynamic args.
Expand Down Expand Up @@ -1140,7 +1171,153 @@ func (p *Provider) runInstance(
if err != nil {
return err
}
return p.runJSONCommand(l, args, &data)

if providerOpts.UseSpot {
return runSpotInstance(l, p, args, az.region.Name)
//todo(babusrithar): Add fallback to on-demand instances if spot instances are not available.
}
runInstancesOutput := RunInstancesOutput{}
return p.runJSONCommand(l, args, &runInstancesOutput)
}

// runSpotInstance uses run-instances command to create a spot instance.
// It returns an error if the spot request is not fulfilled within 2 minutes.
// It uses describe-spot-instance-requests command to get the status of the spot request.
func runSpotInstance(l *logger.Logger, p *Provider, args []string, regionName string) error {
waitForSpotDuration := 2 * time.Minute

// Add spot instance options to the run-instances command.
spotArgs := append(args, "--instance-market-options",
fmt.Sprintf("MarketType=spot,SpotOptions={SpotInstanceType=one-time,"+
"InstanceInterruptionBehavior=terminate}"))
runInstancesOutput := RunInstancesOutput{}
err := p.runJSONCommand(l, spotArgs, &runInstancesOutput)
if err != nil {
return err
}
// If the spot request is accepted, the run-instances command will return an instance-id.
if len(runInstancesOutput.Instances) == 0 {
return errors.Errorf("No instances found for spot request, likely the spot request had bad parameter")
}
instanceId := runInstancesOutput.Instances[0].InstanceID
spotInstanceRequestId, err := getSpotInstanceRequestId(l, p, regionName, instanceId)
if err != nil {
return err
}

// Loop every 10 seconds till the spot instance is fulfilled, for a maximum of 2 minutes.
startTime := timeutil.Now()
duration := waitForSpotDuration
for {
describeSpotInstanceRequestsOutput, err := describeSpotInstanceRequest(l, p, regionName, spotInstanceRequestId)
if err != nil {
return err
}
spotRequestFulfilled, err := processSpotInstanceRequestStatus(l, describeSpotInstanceRequestsOutput, spotInstanceRequestId, instanceId)
if err != nil {
return err
}
if spotRequestFulfilled {
return nil
}
// This part of the code depends on demand/supply of AWS and can be hard to test.
// One way to manually test is tested by commenting out return nil above and check cancellation after 2 minutes.
if timeutil.Since(startTime) >= duration {
l.Printf("waitForSpotDuration passed, cancel the spot instance request and exit loop")
err := cancelSpotRequest(l, p, regionName, spotInstanceRequestId)
if err != nil {
return err
}
return errors.New("waitForSpotDuration over")
}
l.Printf("Sleeping for 10 seconds before checking the status of the spot instance request again")
time.Sleep(10 * time.Second)
}
}

func cancelSpotRequest(
l *logger.Logger, p *Provider, regionName string, spotInstanceRequestId string,
) error {
// Cancel the spot instance request.
csrArgs := []string{
"ec2", "cancel-spot-instance-requests",
"--region", regionName,
"--spot-instance-request-ids", spotInstanceRequestId,
}
err := p.runJSONCommand(l, csrArgs, &CancelSpotInstanceRequestsOutput{})
if err != nil {
// This code path is not expected to be hit, but if it does, we should return the error, so that roachprod
// can destroy the cluster being created.
return err
}
return nil
}

func describeSpotInstanceRequest(
l *logger.Logger, p *Provider, regionName string, spotInstanceRequestId string,
) (DescribeSpotInstanceRequestsOutput, error) {
// Use describe-spot-instance-requests to get the status of the spot request.
dsirArgs := []string{
"ec2", "describe-spot-instance-requests",
"--region", regionName,
"--spot-instance-request-ids", spotInstanceRequestId,
}
var describeSpotInstanceRequestsOutput DescribeSpotInstanceRequestsOutput
err := p.runJSONCommand(l, dsirArgs, &describeSpotInstanceRequestsOutput)
if err != nil {
return DescribeSpotInstanceRequestsOutput{}, err
}
return describeSpotInstanceRequestsOutput, nil
}

func processSpotInstanceRequestStatus(
l *logger.Logger,
describeSpotInstanceRequestsOutput DescribeSpotInstanceRequestsOutput,
spotInstanceRequestId string,
instanceId string,
) (fullFilled bool, err error) {
if len(describeSpotInstanceRequestsOutput.SpotInstanceRequests) == 0 {
return false, errors.Errorf("No Spot Instance Request found for instance-id: %s", instanceId)
}
requestState := describeSpotInstanceRequestsOutput.SpotInstanceRequests[0].State
requestStatusCode := describeSpotInstanceRequestsOutput.SpotInstanceRequests[0].Status.Code
if requestState == "closed" || requestState == "cancelled" || requestState == "failed" {
return false, errors.Errorf("Spot request %s for instance %s not active with state: %s",
spotInstanceRequestId, instanceId, requestState)
}
if requestStatusCode == "fulfilled" {
l.Printf("Spot request %s for instance %s fulfilled.", spotInstanceRequestId, instanceId)
return true, nil
} else {
// Spot instance request is not fulfilled yet, but active, continue looping.
l.Printf("Spot request %s for instance %s not fulfilled yet, status.code: %s., state: %s",
spotInstanceRequestId, instanceId, requestStatusCode, requestState)
}
return false, nil
}

func getSpotInstanceRequestId(
l *logger.Logger, p *Provider, regionName string, instanceId string,
) (string, error) {
diArgs := []string{
"ec2", "describe-instances",
"--region", regionName,
"--instance-ids", instanceId,
}
var describeInstancesResponse DescribeInstancesOutput
err := p.runJSONCommand(l, diArgs, &describeInstancesResponse)
if err != nil {
return "", err
}

// Sanity check to make sure that the instance-id is valid.
if len(describeInstancesResponse.Reservations) < 1 ||
len(describeInstancesResponse.Reservations[0].Instances) < 1 ||
describeInstancesResponse.Reservations[0].Instances[0].SpotInstanceRequestId == "" {
return "", errors.Errorf("No SpotInstanceRequestId found for instance-id: %s", instanceId)
}
spotInstanceRequestId := describeInstancesResponse.Reservations[0].Instances[0].SpotInstanceRequestId
return spotInstanceRequestId, nil
}

func genDeviceMapping(ebsVolumes ebsVolumeList, args []string) ([]string, error) {
Expand Down

0 comments on commit fbb3446

Please sign in to comment.