Skip to content

Commit

Permalink
[RORDEV-807] Structured groups support (#998)
Browse files Browse the repository at this point in the history
* [RORDEV-808] API definition for structured groups (#955)

* [RORDEV-809] Structured group in domain (#960)

* [RORDEV-812] Structured groups in external authorization rule (#965)

* [RORDEV-813] Structured groups in JWT authorization (#969)

* [RORDEV-814] Structured groups in local groups authorization rules (#985)
  • Loading branch information
mateuszkp96 authored Apr 2, 2024
1 parent 1fcc804 commit 519e600
Show file tree
Hide file tree
Showing 89 changed files with 2,518 additions and 3,686 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ dependencies {
implementation group: 'org.scala-lang', name: 'scala-library', version: '2.13.13'
}

tasks.withType(ScalaCompile).configureEach {
scalaCompileOptions.forkOptions.with {
memoryMaximumSize = '1g'
jvmArgs = ['-XX:MaxMetaspaceSize=512m']
}
}

test {
reports {
junitXml.getRequired().set(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@ import eu.timepit.refined.types.string.NonEmptyString
import monix.eval.Task
import org.apache.logging.log4j.scala.Logging
import tech.beshu.ror.accesscontrol.blocks.definitions.ExternalAuthorizationService.Name
import tech.beshu.ror.accesscontrol.blocks.definitions.HttpExternalAuthorizationService.AuthTokenSendMethod.{UsingHeader, UsingQueryParam}
import tech.beshu.ror.accesscontrol.blocks.definitions.HttpExternalAuthorizationService.Config.AuthTokenSendMethod.{UsingHeader, UsingQueryParam}
import tech.beshu.ror.accesscontrol.blocks.definitions.HttpExternalAuthorizationService.Config._
import tech.beshu.ror.accesscontrol.blocks.definitions.HttpExternalAuthorizationService._
import tech.beshu.ror.accesscontrol.domain.GroupIdLike.GroupId
import tech.beshu.ror.accesscontrol.domain._
import tech.beshu.ror.accesscontrol.factory.HttpClientsFactory.HttpClient
import tech.beshu.ror.accesscontrol.factory.decoders.definitions.Definitions.Item
import tech.beshu.ror.accesscontrol.show.logs._
import tech.beshu.ror.accesscontrol.utils.CacheableAction
import tech.beshu.ror.com.jayway.jsonpath.JsonPath
import tech.beshu.ror.utils.json.JsonPath
import tech.beshu.ror.utils.uniquelist.UniqueList

import scala.jdk.CollectionConverters._
import scala.concurrent.duration.FiniteDuration
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

trait ExternalAuthorizationService extends Item {
Expand All @@ -58,16 +59,10 @@ object ExternalAuthorizationService {
}
}

class HttpExternalAuthorizationService(override val id: ExternalAuthorizationService#Id,
uri: Uri,
method: SupportedHttpMethod,
tokenName: AuthTokenName,
groupsJsonPath: JsonPath,
authTokenSendMethod: AuthTokenSendMethod,
defaultHeaders: Set[Header],
defaultQueryParams: Set[QueryParam],
override val serviceTimeout: Refined[FiniteDuration, Positive],
httpClient: HttpClient)
final class HttpExternalAuthorizationService(override val id: ExternalAuthorizationService#Id,
override val serviceTimeout: Refined[FiniteDuration, Positive],
val config: HttpExternalAuthorizationService.Config,
httpClient: HttpClient)
extends ExternalAuthorizationService
with Logging {

Expand All @@ -85,8 +80,8 @@ class HttpExternalAuthorizationService(override val id: ExternalAuthorizationSer
}

private def createRequest(userId: User.Id) = {
val uriWithParams = uri.params(queryParams(userId))
method match {
val uriWithParams = config.uri.params(queryParams(userId))
config.method match {
case SupportedHttpMethod.Get =>
sttp
.get(uriWithParams)
Expand All @@ -98,16 +93,25 @@ class HttpExternalAuthorizationService(override val id: ExternalAuthorizationSer
}
}

private def queryParams(userId: User.Id): Map[String, String] = {
config.defaultQueryParams.map(p => (autoUnwrap(p.name), autoUnwrap(p.value))).toMap ++
(config.authTokenSendMethod match {
case UsingQueryParam => Map(config.tokenName.value.value -> userId.value.value)
case UsingHeader => Map.empty[String, String]
})
}

private def headersMap(userId: User.Id): Map[String, String] = {
config.defaultHeaders.map(h => (h.name.value.value, h.value.value)).toMap ++
(config.authTokenSendMethod match {
case UsingHeader => Map(config.tokenName.value.value -> userId.value.value)
case UsingQueryParam => Map.empty
})
}

private def groupsFromResponseBody(body: String): UniqueList[Group] = {
val groupsFromPath =
Try(groupsJsonPath.read[java.util.List[String]](body))
.map(
_.asScala
.flatMap(NonEmptyString.from(_).toOption)
.map(GroupId.apply)
.map(Group.from)
)
groupsFromPath match {
val groupsFromBody = groupsFrom(body)
groupsFromBody match {
case Success(groups) =>
logger.debug(s"Groups returned by groups provider '${id.show}': ${groups.map(_.show).mkString(",")}")
UniqueList.fromIterable(groups)
Expand All @@ -117,44 +121,104 @@ class HttpExternalAuthorizationService(override val id: ExternalAuthorizationSer
}
}

private def queryParams(userId: User.Id): Map[String, String] = {
defaultQueryParams.map(p => (autoUnwrap(p.name), autoUnwrap(p.value))).toMap ++
(authTokenSendMethod match {
case UsingQueryParam => Map(tokenName.value.value -> userId.value.value)
case UsingHeader => Map.empty[String, String]
})
private def groupsFrom(body: String): Try[List[Group]] = {
for {
rawGroupIds <- groupIdsFrom(body)
groups <- groupsFrom(body, rawGroupIds)
} yield groups
}

private def headersMap(userId: User.Id): Map[String, String] = {
defaultHeaders.map(h => (h.name.value.value, h.value.value)).toMap ++
(authTokenSendMethod match {
case UsingHeader => Map(tokenName.value.value -> userId.value.value)
case UsingQueryParam => Map.empty
})
private def groupIdsFrom(body: String) = {
config.groupsConfig.idsConfig.jsonPath.read[java.util.List[String]](body)
.map {
_.asScala.toList
}
}

private def groupsFrom(body: String, rawGroupIds: List[String]): Try[List[Group]] = {
config.groupsConfig.namesConfig match {
case Some(namesConfig) =>
groupNamesFrom(body, namesConfig)
.flatMap {
case rawGroupNames if rawGroupNames.size == rawGroupIds.size =>
Success(formGroups(groupIdsWithNames = rawGroupIds.zip(rawGroupNames)))
case rawGroupNames =>
Failure(new IllegalArgumentException(
s"Group names array extracted from the response at json path ${namesConfig.jsonPath.rawPath} has different size [size=${rawGroupNames.size}] than " +
s"the group IDs array extracted from the response at json path ${config.groupsConfig.idsConfig.jsonPath.rawPath} [size=${rawGroupIds.size}]"
))
}
case None =>
Success(
rawGroupIds
.flatMap(toGroupId)
.map(Group.from)
)
}
}

private def groupNamesFrom(body: String, namesConfig: GroupsConfig.GroupNamesConfig): Try[List[String]] = {
namesConfig.jsonPath.read[java.util.List[String]](body)
.map {
_.asScala.toList
}
}

private def formGroups(groupIdsWithNames: List[(String, String)]) = {
groupIdsWithNames.flatMap { case (groupId, groupName) =>
toGroupId(groupId)
.map(id => Group(id, toGroupName(value = groupName, fallback = GroupName.from(id))))
}
}

private def toGroupId(value: String): Option[GroupId] = NonEmptyString.unapply(value).map(GroupId.apply)

private def toGroupName(value: String, fallback: GroupName) =
NonEmptyString
.unapply(value)
.map(GroupName.apply)
.getOrElse(fallback)
}

object HttpExternalAuthorizationService {
final case class QueryParam(name: NonEmptyString, value: NonEmptyString)
final case class AuthTokenName(value: NonEmptyString)

sealed trait AuthTokenSendMethod
object AuthTokenSendMethod {
case object UsingHeader extends AuthTokenSendMethod
case object UsingQueryParam extends AuthTokenSendMethod
}
final case class Config(uri: Uri,
method: SupportedHttpMethod,
tokenName: AuthTokenName,
groupsConfig: GroupsConfig,
authTokenSendMethod: AuthTokenSendMethod,
defaultHeaders: Set[Header],
defaultQueryParams: Set[QueryParam])

object Config {
sealed trait SupportedHttpMethod
object SupportedHttpMethod {
case object Get extends SupportedHttpMethod
case object Post extends SupportedHttpMethod
}

sealed trait SupportedHttpMethod
object SupportedHttpMethod {
case object Get extends SupportedHttpMethod
case object Post extends SupportedHttpMethod
final case class AuthTokenName(value: NonEmptyString)

final case class GroupsConfig(idsConfig: GroupsConfig.GroupIdsConfig, namesConfig: Option[GroupsConfig.GroupNamesConfig])
object GroupsConfig {
final case class GroupIdsConfig(jsonPath: JsonPath)
final case class GroupNamesConfig(jsonPath: JsonPath)
}

final case class QueryParam(name: NonEmptyString, value: NonEmptyString)

sealed trait AuthTokenSendMethod
object AuthTokenSendMethod {
case object UsingHeader extends AuthTokenSendMethod
case object UsingQueryParam extends AuthTokenSendMethod
}
}

final case class InvalidResponse(message: String) extends Exception(message)
}

class CacheableExternalAuthorizationServiceDecorator(underlying: ExternalAuthorizationService,
ttl: FiniteDuration Refined Positive)
final class CacheableExternalAuthorizationServiceDecorator(val underlying: ExternalAuthorizationService,
val ttl: FiniteDuration Refined Positive)
extends ExternalAuthorizationService {

private val cacheableGrantsFor = new CacheableAction[User.Id, UniqueList[Group]](ttl, underlying.grantsFor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ package tech.beshu.ror.accesscontrol.blocks.definitions
import java.security.PublicKey
import cats.{Eq, Show}
import eu.timepit.refined.types.string.NonEmptyString
import tech.beshu.ror.accesscontrol.blocks.definitions.JwtDef.{Name, SignatureCheckMethod}
import tech.beshu.ror.accesscontrol.blocks.definitions.JwtDef.{GroupsConfig, Name, SignatureCheckMethod}
import tech.beshu.ror.accesscontrol.domain.{AuthorizationTokenDef, Jwt}
import tech.beshu.ror.accesscontrol.factory.decoders.definitions.Definitions.Item

final case class JwtDef(id: Name,
authorizationTokenDef: AuthorizationTokenDef,
checkMethod: SignatureCheckMethod,
userClaim: Option[Jwt.ClaimName],
groupsClaim: Option[Jwt.ClaimName])
groupsConfig: Option[GroupsConfig])
extends Item {

override type Id = Name
Expand All @@ -44,6 +44,8 @@ object JwtDef {
final case class Ec(pubKey: PublicKey) extends SignatureCheckMethod
}

final case class GroupsConfig(idsClaim: Jwt.ClaimName, namesClaim: Option[Jwt.ClaimName])

implicit val nameEq: Eq[Name] = Eq.fromUniversalEquals
implicit val nameShow: Show[Name] = Show.show(_.value.value)
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import cats.implicits._
import tech.beshu.ror.accesscontrol.domain.Json._
import tech.beshu.ror.accesscontrol.domain.KibanaAllowedApiPath.AllowedHttpMethod
import tech.beshu.ror.accesscontrol.domain.KibanaAllowedApiPath.AllowedHttpMethod.HttpMethod
import tech.beshu.ror.accesscontrol.domain.{CorrelationId, KibanaAccess, KibanaApp}
import tech.beshu.ror.accesscontrol.domain.{CorrelationId, Group, KibanaAccess, KibanaApp}

import scala.jdk.CollectionConverters._

Expand Down Expand Up @@ -121,7 +121,7 @@ object MetadataValue {
private def availableGroups(userMetadata: UserMetadata) = {
NonEmptyList
.fromList(userMetadata.availableGroups.toList)
.map(groups => ("x-ror-available-groups", MetadataList(groups.map(_.id.value.value))))
.map(groups => ("x-ror-available-groups", MetadataListOfMaps(groups.map(serializeGroup))))
.toMap
}

Expand All @@ -134,13 +134,19 @@ object MetadataValue {
}

private def currentGroup(userMetadata: UserMetadata) = {
userMetadata.currentGroupId.map(g => ("x-ror-current-group", MetadataString(g.value.value))).toMap
userMetadata
.findCurrentGroup
.map(serializeGroup)
.map(group => ("x-ror-current-group", MetadataObject(group.asJava)))
.toMap
}

private def loggedUser(userMetadata: UserMetadata) = {
userMetadata.loggedUser.map(u => ("x-ror-username", MetadataString(u.id.value.value))).toMap
}

private def serializeGroup(group: Group) = Map("id" -> group.id.value.value, "name" -> group.name.value.value)

private implicit val kibanaAccessShow: Show[KibanaAccess] = Show {
case KibanaAccess.RO => "ro"
case KibanaAccess.ROStrict => "ro_strict"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,16 @@ final class JwtAuthRule(val settings: JwtAuthRule.Settings,
groups: Option[ClaimSearchResult[UniqueList[Group]]]): Unit = {
(settings.jwt.userClaim, user) match {
case (Some(userClaim), Some(u)) =>
logger.debug(s"JWT resolved user for claim ${userClaim.name.getPath}: ${u.show}")
logger.debug(s"JWT resolved user for claim ${userClaim.name.rawPath}: ${u.show}")
case _ =>
}
(settings.jwt.groupsClaim, groups) match {
case (Some(groupsClaim), Some(g)) =>
logger.debug(s"JWT resolved groups for claim ${groupsClaim.name.getPath}: ${g.show}")
(settings.jwt.groupsConfig, groups) match {
case (Some(groupsConfig), Some(g)) =>
val claimsDescription = groupsConfig.namesClaim match {
case Some(namesClaim) => s"claims (id:'${groupsConfig.idsClaim.name.show}',name:'${namesClaim.name.show}')"
case None => s"claim '${groupsConfig.idsClaim.name.show}'"
}
logger.debug(s"JWT resolved groups for $claimsDescription: ${g.show}")
case _ =>
}
}
Expand Down Expand Up @@ -179,7 +183,9 @@ final class JwtAuthRule(val settings: JwtAuthRule.Settings,
}

private def groupsFrom(payload: Jwt.Payload) = {
settings.jwt.groupsClaim.map(payload.claims.groupsClaim)
settings.jwt.groupsConfig.map(groupsConfig =>
payload.claims.groupsClaim(groupsConfig.idsClaim, groupsConfig.namesClaim)
)
}

private def handleUserClaimSearchResult[B <: BlockContext : BlockContextUpdater](blockContext: B,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import tech.beshu.ror.accesscontrol.request.RequestContextOps._
import tech.beshu.ror.accesscontrol.show.logs._
import tech.beshu.ror.accesscontrol.utils.ClaimsOps.ClaimSearchResult.{Found, NotFound}
import tech.beshu.ror.accesscontrol.utils.ClaimsOps._
import tech.beshu.ror.com.jayway.jsonpath.JsonPath
import tech.beshu.ror.utils.json.JsonPath
import tech.beshu.ror.utils.uniquelist.{UniqueList, UniqueNonEmptyList}

import scala.util.Try
Expand Down Expand Up @@ -107,7 +107,7 @@ final class RorKbnAuthRule(val settings: Settings,
(
tokenPayload,
tokenPayload.claims.userIdClaim(RorKbnAuthRule.userClaimName),
tokenPayload.claims.groupsClaim(RorKbnAuthRule.groupsClaimName),
tokenPayload.claims.groupsClaim(groupIdsClaimName = RorKbnAuthRule.groupIdsClaimName, groupNamesClaimName = None),
tokenPayload.claims.headerNameClaim(Header.Name.xUserOrigin)
)
}
Expand Down Expand Up @@ -181,6 +181,6 @@ object RorKbnAuthRule {
final case class Defined(groupsLogic: GroupsLogic) extends Groups
}

private val userClaimName = Jwt.ClaimName(JsonPath.compile("user"))
private val groupsClaimName = Jwt.ClaimName(JsonPath.compile("groups"))
private val userClaimName = Jwt.ClaimName(JsonPath("user").get)
private val groupIdsClaimName = Jwt.ClaimName(JsonPath("groups").get)
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import tech.beshu.ror.accesscontrol.show.logs._
import tech.beshu.ror.accesscontrol.utils.ClaimsOps.ClaimSearchResult.{Found, NotFound}
import tech.beshu.ror.accesscontrol.utils.ClaimsOps.CustomClaimValue.{CollectionValue, SingleValue}
import tech.beshu.ror.accesscontrol.utils.ClaimsOps._
import tech.beshu.ror.com.jayway.jsonpath.JsonPath
import tech.beshu.ror.utils.json.JsonPath

private[runtime] trait RuntimeResolvableVariable[VALUE] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ import tech.beshu.ror.accesscontrol.blocks.variables.Tokenizer.Token
import tech.beshu.ror.accesscontrol.blocks.variables.Tokenizer.Token.Transformation
import tech.beshu.ror.accesscontrol.blocks.variables.runtime.MultiExtractable.SingleExtractableWrapper
import tech.beshu.ror.accesscontrol.blocks.variables.runtime.RuntimeResolvableVariable.Convertible
import tech.beshu.ror.accesscontrol.blocks.variables.runtime.RuntimeResolvableVariableCreator._
import tech.beshu.ror.accesscontrol.blocks.variables.transformation.TransformationCompiler
import tech.beshu.ror.accesscontrol.blocks.variables.transformation.TransformationCompiler.CompilationError
import tech.beshu.ror.accesscontrol.blocks.variables.transformation.domain.Function
import tech.beshu.ror.accesscontrol.blocks.variables.{Tokenizer, runtime}
import tech.beshu.ror.accesscontrol.domain.Header
import tech.beshu.ror.com.jayway.jsonpath.JsonPath
import RuntimeResolvableVariableCreator._
import tech.beshu.ror.accesscontrol.blocks.variables.transformation.domain.Function
import tech.beshu.ror.utils.json.JsonPath

import scala.util.matching.Regex
import scala.util.{Failure, Success, Try}
import scala.util.{Failure, Success}

class RuntimeResolvableVariableCreator(transformationCompiler: TransformationCompiler) extends Logging {

Expand Down Expand Up @@ -136,7 +136,7 @@ class RuntimeResolvableVariableCreator(transformationCompiler: TransformationCom
}

private def createJwtExtractable(jsonPathStr: String, maybeTransformation: Option[Function], `type`: ExtractableType): Either[CreationError, `type`.TYPE] = {
Try(JsonPath.compile(jsonPathStr)) match {
JsonPath(jsonPathStr) match {
case Success(compiledPath) =>
Right(`type`.createJwtVariableExtractable(compiledPath, maybeTransformation))
case Failure(ex) =>
Expand Down
Loading

0 comments on commit 519e600

Please sign in to comment.