Skip to content

Commit

Permalink
Merge pull request #755 from jGauravGupta/FISH-7866-P5
Browse files Browse the repository at this point in the history
FISH-7866 AWS SDK Security Token Service (STS) support (Payara5)
  • Loading branch information
jGauravGupta authored Feb 21, 2024
2 parents 72bc190 + 3426bef commit ee5d620
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 22 deletions.
2 changes: 1 addition & 1 deletion AmazonSQS/AmazonSQSExample/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-sqs</artifactId>
<version>1.12.286</version>
<version>1.12.661</version>
<type>jar</type>
<scope>provided</scope>
</dependency>
Expand Down
10 changes: 8 additions & 2 deletions AmazonSQS/AmazonSQSJCAAPI/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,21 @@
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-bom</artifactId>
<version>1.12.286</version>
<version>1.12.661</version>
<type>pom</type>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-sqs</artifactId>
<version>1.12.286</version>
<version>1.12.661</version>
<type>jar</type>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-sts</artifactId>
<version>1.12.661</version>
<type>jar</type>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
*
* Copyright (c) 2017 Payara Foundation and/or its affiliates. All rights reserved.
* Copyright (c) 2017-2024 Payara Foundation and/or its affiliates. All rights reserved.
*
* The contents of this file are subject to the terms of either the GNU
* General Public License Version 2 only ("GPL") or the Common Development
Expand Down Expand Up @@ -43,15 +43,18 @@
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.auth.profile.ProfileCredentialsProvider;
import com.amazonaws.regions.Regions;
import com.amazonaws.util.StringUtils;
import fish.payara.cloud.connectors.amazonsqs.api.AmazonSQSListener;
import fish.payara.cloud.connectors.amazonsqs.api.outbound.STSCredentialsProvider;

import javax.resource.ResourceException;
import javax.resource.spi.Activation;
import javax.resource.spi.ActivationSpec;
import javax.resource.spi.InvalidPropertyException;
import javax.resource.spi.ResourceAdapter;


/**
* Activation Specification for Amazon SQS
*
Expand All @@ -73,7 +76,9 @@ public class AmazonSQSActivationSpec implements ActivationSpec, AWSCredentialsPr
private String messageAttributeNames = "All";
private String attributeNames = "All";
private String profileName;

private String roleArn;
private String roleSessionName;

@Override
public void validate() throws InvalidPropertyException {
if (StringUtils.isNullOrEmpty(region)) {
Expand Down Expand Up @@ -182,20 +187,38 @@ public String getProfileName() {
public void setProfileName(String profileName) {
this.profileName = profileName;
}

public String getRoleArn() {
return roleArn;
}

public void setRoleArn(String roleArn) {
this.roleArn = roleArn;
}

public String getRoleSessionName() {
return roleSessionName;
}

public void setRoleSessionName(String roleSessionName) {
this.roleSessionName = roleSessionName;
}

@Override
public AWSCredentials getCredentials() {

// Return Credentials based on what is present, profileName taking priority.
if (StringUtils.isNullOrEmpty(getProfileName())) {

// Return Credentials based on what is present, roleArn taking priority.
if (!StringUtils.isNullOrEmpty(getRoleArn())) {
return STSCredentialsProvider.create(getRoleArn(), getRoleSessionName(), Regions.fromName(getRegion())).getCredentials();
} else if (StringUtils.isNullOrEmpty(getProfileName())) {

if (!StringUtils.isNullOrEmpty(awsAccessKeyId) && !StringUtils.isNullOrEmpty(awsSecretKey)) {
return new AWSCredentials() {
@Override
public String getAWSAccessKeyId() {
return awsAccessKeyId;
}

@Override
public String getAWSSecretKey() {
return awsSecretKey;
Expand All @@ -204,7 +227,7 @@ public String getAWSSecretKey() {
} else {
return DefaultAWSCredentialsProviderChain.getInstance().getCredentials();
}

} else {
return new ProfileCredentialsProvider(getProfileName()).getCredentials();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
*
* Copyright (c) 2017 Payara Foundation and/or its affiliates. All rights reserved.
* Copyright (c) 2017-2024 Payara Foundation and/or its affiliates. All rights reserved.
*
* The contents of this file are subject to the terms of either the GNU
* General Public License Version 2 only ("GPL") or the Common Development
Expand Down Expand Up @@ -44,6 +44,7 @@
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.auth.profile.ProfileCredentialsProvider;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.AmazonSQSClientBuilder;
import com.amazonaws.services.sqs.model.SendMessageBatchRequest;
Expand Down Expand Up @@ -198,10 +199,12 @@ public SendMessageBatchResult sendMessageBatch(SendMessageBatchRequest batch) {
public void close() throws Exception {
destroy();
}
private AWSCredentialsProvider getCredentials(AmazonSQSManagedConnectionFactory aThis) {

private AWSCredentialsProvider getCredentials(AmazonSQSManagedConnectionFactory aThis) {
AWSCredentialsProvider credentialsProvider;
if (!StringUtils.isNullOrEmpty(aThis.getProfileName())) {
if (!StringUtils.isNullOrEmpty(aThis.getRoleArn())) {
credentialsProvider = STSCredentialsProvider.create(aThis.getRoleArn(), aThis.getRoleSessionName(), Regions.fromName(aThis.getRegion()));
} else if (!StringUtils.isNullOrEmpty(aThis.getProfileName())) {
credentialsProvider = new ProfileCredentialsProvider(aThis.getProfileName());
} else if (!StringUtils.isNullOrEmpty(aThis.getAwsAccessKeyId()) && !StringUtils.isNullOrEmpty(aThis.getAwsSecretKey()) ) {
credentialsProvider = new AWSStaticCredentialsProvider(new AWSCredentials() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
*
* Copyright (c) 2017 Payara Foundation and/or its affiliates. All rights reserved.
* Copyright (c) 2017-2024 Payara Foundation and/or its affiliates. All rights reserved.
*
* The contents of this file are subject to the terms of either the GNU
* General Public License Version 2 only ("GPL") or the Common Development
Expand Down Expand Up @@ -76,6 +76,12 @@ public class AmazonSQSManagedConnectionFactory implements ManagedConnectionFacto
@ConfigProperty(description = "AWS Profile Name", type = String.class)
private String profileName;

@ConfigProperty(description = "AWS Role ARN", type = String.class)
private String roleArn;

@ConfigProperty(description = "AWS Session name", type = String.class)
private String roleSessionName;

private PrintWriter logger;

public String getAwsSecretKey() {
Expand Down Expand Up @@ -110,10 +116,21 @@ public void setProfileName(String profileName) {
this.profileName = profileName;
}

public AmazonSQSManagedConnectionFactory() {
public String getRoleArn() {
return roleArn;
}

public void setRoleArn(String roleArn) {
this.roleArn = roleArn;
}

public String getRoleSessionName() {
return roleSessionName;
}

public void setRoleSessionName(String roleSessionName) {
this.roleSessionName = roleSessionName;
}

@Override
public Object createConnectionFactory(ConnectionManager cxManager) throws ResourceException {
Expand Down Expand Up @@ -148,11 +165,13 @@ public PrintWriter getLogWriter() throws ResourceException {

@Override
public int hashCode() {
int hash = 5;
int hash = 7;
hash = 97 * hash + Objects.hashCode(this.awsSecretKey);
hash = 97 * hash + Objects.hashCode(this.awsAccessKeyId);
hash = 97 * hash + Objects.hashCode(this.region);
hash = 97 * hash + Objects.hashCode(this.profileName);
hash = 97 * hash + Objects.hashCode(this.roleArn);
hash = 97 * hash + Objects.hashCode(this.roleSessionName);
return hash;
}

Expand Down Expand Up @@ -180,7 +199,10 @@ public boolean equals(Object obj) {
if (!Objects.equals(this.profileName, other.profileName)) {
return false;
}
return true;
if (!Objects.equals(this.roleArn, other.roleArn)) {
return false;
}
return Objects.equals(this.roleSessionName, other.roleSessionName);
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
*
* Copyright (c) 2024 Payara Foundation and/or its affiliates. All rights reserved.
*
* The contents of this file are subject to the terms of either the GNU
* General Public License Version 2 only ("GPL") or the Common Development
* and Distribution License("CDDL") (collectively, the "License"). You
* may not use this file except in compliance with the License. You can
* obtain a copy of the License at
* https://github.com/payara/Payara/blob/master/LICENSE.txt
* See the License for the specific
* language governing permissions and limitations under the License.
*
* When distributing the software, include this License Header Notice in each
* file and include the License file at glassfish/legal/LICENSE.txt.
*
* GPL Classpath Exception:
* The Payara Foundation designates this particular file as subject to the "Classpath"
* exception as provided by the Payara Foundation in the GPL Version 2 section of the License
* file that accompanied this code.
*
* Modifications:
* If applicable, add the following below the License Header, with the fields
* enclosed by brackets [] replaced by your own identifying information:
* "Portions Copyright [year] [name of copyright owner]"
*
* Contributor(s):
* If you wish your version of this file to be governed by only the CDDL or
* only the GPL Version 2, indicate your decision by adding "[Contributor]
* elects to include this software in this distribution under the [CDDL or GPL
* Version 2] license." If you don't indicate a single choice of license, a
* recipient has the option to distribute your version of this file under
* either the CDDL, the GPL Version 2 or to extend the choice of license to
* its licensees as provided above. However, if you add GPL Version 2 code
* and therefore, elected the GPL Version 2 license, then the option applies
* only if the new code is made subject to such option by the copyright
* holder.
*/
package fish.payara.cloud.connectors.amazonsqs.api.outbound;

import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.AssumeRoleRequest;
import com.amazonaws.services.securitytoken.model.AssumeRoleResult;
import com.amazonaws.services.securitytoken.model.Credentials;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSSessionCredentials;
import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.regions.Regions;
import java.time.Duration;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
* AWS STS Credentials Provider with caching and thread safety.
*
* This class provides AWS credentials by assuming a role using the AWS Security Token Service (STS).
* It caches the credentials and ensures thread safety using locks.
*
* @author Gaurav Gupta
*/
public class STSCredentialsProvider implements AWSCredentialsProvider {

private static final Logger LOGGER = Logger.getLogger(STSCredentialsProvider.class.getName());
private static final Duration EXPIRATION_THRESHOLD = Duration.ofMinutes(5);
private final String roleArn;
private final String roleSessionName;
private final Regions region;
private volatile AWSSessionCredentials cachedCredentials;
private volatile Instant expirationTime;
private final Lock lock = new ReentrantLock();
private static final Map<String, STSCredentialsProvider> providerInstances = new HashMap<>();

/**
* Returns a singleton instance of STSCredentialsProvider for a unique session name.
*
* @param roleArn The ARN of the role to assume.
* @param roleSessionName The name of the role session.
* @param region The AWS region.
* @return The STSCredentialsProvider instance.
*/
public static STSCredentialsProvider create(String roleArn, String roleSessionName, Regions region) {
String uniqueSessionKey = roleSessionName + "@" + region.getName();
return providerInstances.computeIfAbsent(uniqueSessionKey, key -> new STSCredentialsProvider(roleArn, roleSessionName, region));
}

private STSCredentialsProvider(String roleArn, String roleSessionName, Regions region) {
this.roleArn = roleArn;
this.roleSessionName = roleSessionName;
this.region = region;
}

@Override
public AWSCredentials getCredentials() {
if (cachedCredentials != null && !isCredentialsExpired()) {
LOGGER.fine("Reusing cached AWS session credentials");
return cachedCredentials;
} else {
lock.lock();
try {
if (cachedCredentials != null && !isCredentialsExpired()) {
LOGGER.fine("Reusing cached AWS session credentials after lock");
return cachedCredentials;
}
LOGGER.fine("Cached AWS session credentials expired or not present");
refresh();
return cachedCredentials;
} finally {
lock.unlock();
}
}
}

private boolean isCredentialsExpired() {
// Check if the credentials are expired or about to expire
return expirationTime == null || Instant.now().isAfter(expirationTime.minus(EXPIRATION_THRESHOLD));
}

@Override
public void refresh() {
AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard().withRegion(region).build();
AssumeRoleRequest assumeRoleRequest = new AssumeRoleRequest()
.withRoleArn(roleArn)
.withRoleSessionName(roleSessionName);

AssumeRoleResult assumeRoleResponse = stsClient.assumeRole(assumeRoleRequest);
Credentials stsCredentials = assumeRoleResponse.getCredentials();
cachedCredentials = new BasicSessionCredentials(
stsCredentials.getAccessKeyId(),
stsCredentials.getSecretAccessKey(),
stsCredentials.getSessionToken()
);
expirationTime = stsCredentials.getExpiration().toInstant();
LOGGER.log(Level.FINE, "Obtained new AWS session credentials - Session Token: {0}, Expiration Time: {1}", new Object[]{stsCredentials.getSessionToken(), stsCredentials.getExpiration()});
}
}
Loading

0 comments on commit ee5d620

Please sign in to comment.