Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Servant auth server PoC #1560

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions hie.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
cradle:
multi:
- path: "./.stack-work"
config: { cradle: { none } }
- path: "./"
config:
cradle:
stack:
- path: "./servant-auth/servant-auth/src"
component: "servant-auth:lib"
- path: "./servant-auth/servant-auth-client/src"
component: "servant-auth-client:lib"
- path: "./servant-auth/servant-auth-docs/src"
component: "servant-auth-docs:lib"
- path: "./servant-auth/servant-auth-server/src"
component: "servant-auth-server:lib"

- path: "./servant/src"
component: "servant:lib"
- path: "./servant-client/src"
component: "servant-client:lib"
- path: "./servant-client-core/src"
component: "servant-client-core:lib"
- path: "./servant-client-ghcjs/src"
component: "servant-client-ghcjs:lib"
- path: "./servant-conduit/src"
component: "servant-conduit:lib"
- path: "./servant-docs/src"
component: "servant-docs:lib"
- path: "./servant-foreign/src"
component: "servant-foreign:lib"
- path: "./servant-http-streams/src"
component: "servant-http-streams:lib"
- path: "./servant-machines/src"
component: "servant-machines:lib"
- path: "./servant-pipes/src"
component: "servant-pipes:lib"
- path: "./servant-server/src"
component: "servant-server:lib"
- path: "./servant-swagger/src"
component: "servant-swagger:lib"
160 changes: 154 additions & 6 deletions servant-auth/servant-auth-server/src/Servant/Auth/Server/Internal.hs
Original file line number Diff line number Diff line change
@@ -1,24 +1,38 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Servant.Auth.Server.Internal where

import Control.Monad.Except (runExceptT, join)
import Control.Monad.Trans (liftIO)
import Servant ((:>), Handler, HasServer (..),
Proxy (..),
HasContextEntry(getContextEntry))
import Servant.Auth
import Data.Kind (Type)
import Data.Typeable (Typeable, typeRep)
import Network.Wai (Request, queryString)
import Servant
import Servant.API.Modifiers (FoldLenient, FoldRequired, RequestArgument, unfoldRequestArgument)
import Servant.Auth (Auth)
import Servant.Auth.JWT (ToJWT)
import Data.Text (Text)

import Servant.Auth.Server.Internal.AddSetCookie
import Servant.Auth.Server.Internal.Class
import Servant.Auth.Server.Internal.Cookie
import Servant.Auth.Server.Internal.ConfigTypes
import Servant.Auth.Server.Internal.JWT
-- import Servant.Auth.Server.Internal.JWT
import Servant.Auth.Server.Internal.Types

import Servant.Server.Internal (DelayedIO, addAuthCheck, withRequest)
import Servant.Server.Experimental.Auth (AuthHandler (..))
import Servant.Server.Internal (DelayedIO, addAuthCheck, delayedFail, delayedFailFatal, mkContextWithErrorFormatter, withRequest, MkContextWithErrorFormatter)
import qualified Data.Text as T
import qualified Data.Text.Lazy.Encoding as TLE
import qualified Data.Text.Lazy as TL
import Data.ByteString.Lazy (ByteString)
import GHC.TypeLits (KnownSymbol, symbolVal)
import Network.HTTP.Types (queryToQueryText)
import Data.String (fromString)

instance ( n ~ 'S ('S 'Z)
, HasServer (AddSetCookiesApi n api) ctxs, AreAuths auths ctxs v
Expand Down Expand Up @@ -68,3 +82,137 @@ instance ( n ~ 'S ('S 'Z)
-> (AuthResult v, SetCookieList n)
-> ServerT (AddSetCookiesApi n api) Handler
go fn (authResult, cookies) = addSetCookies cookies $ fn authResult


{-
NewAuth is a "quick" PoC to have a more modular way of providing
authentications and the checking thereof.

In the current implementation, all of the 'auths' are checked one
by one, and the first that is present is tried and will either
be returned when successful, or when not, throw an error given
the appropriate 'ErrorFormatter' or return a 'Left err' if the
'Lenient' modifier is set.
Only if NONE of the auths are present will it either throw an
'err401' or, if the 'Optional' modifier is set, return a 'Nothing'.
This error might also be customizeable if we make it be required
from the context.
-}

data NewAuth (mods :: [Type]) (auths :: [Type]) (a :: Type)
deriving (Typeable)

type NewAuthResult a = Maybe (Either Text a)

instance
( HasServer api ctxs
, HasContextEntry (MkContextWithErrorFormatter ctxs) ErrorFormatters
, SBoolI (FoldRequired mods)
, SBoolI (FoldLenient mods)
, AllAuth auths a
, HasContextEntry ctxs (AuthHandler Request (NewAuthResult a))
) => HasServer (NewAuth mods auths a :> api) ctxs where
type ServerT (NewAuth mods auths a :> api) m =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense define the instance for NewAuth mods (auth ': auths) a :> api to statically ensure that at least one auth mode is defined ? We could define a custom type error for the case where the list of auths is empty.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, probably a good idea, yeah.

RequestArgument mods a -> ServerT api m
hoistServerWithContext _ pc nt s =
hoistServerWithContext (Proxy @api) pc nt . s
route _ context subserver =
route (Proxy @api) context $ subserver `addAuthCheck` authCheck
where
errorFormatters :: ErrorFormatters
errorFormatters = getContextEntry $ mkContextWithErrorFormatter context

authCheck :: DelayedIO (RequestArgument mods a)
authCheck = withRequest $ \req -> do
mRes <- processAllAuthHandlers req authHandlers
case mRes of
Nothing -> case sbool :: SBool (FoldRequired mods) of
STrue -> absent
SFalse -> pure Nothing
Just (badForm, eRes) ->
unfoldRequestArgument (Proxy @mods) absent badForm $ Just eRes

processAllAuthHandlers ::
Request ->
[NewAuthHandler Request (NewAuthResult a)] ->
DelayedIO (Maybe (Text -> DelayedIO (RequestArgument mods a), Either Text a))
processAllAuthHandlers _req [] = pure Nothing
processAllAuthHandlers req (auth : auths) = do
eRes <- liftIO . runExceptT . runHandler' $ getHandler auth req
either delayedFail go eRes
where
go Nothing = processAllAuthHandlers req auths
go (Just res) = pure $ Just (badForm, res)
badForm err = delayedFailFatal $ errFmtr rep req . T.unpack $ msg <> err
errFmtr = getErrorFormatter auth errorFormatters
rep = typeRep (Proxy :: Proxy NewAuth)
msg = "Authentication via " <> getAuthName auth <> " failed: "


authHandlers :: [NewAuthHandler Request (NewAuthResult a)]
authHandlers = allAuthHandlers (Proxy @auths)

toLBS :: Text -> ByteString
toLBS = TLE.encodeUtf8 . TL.fromStrict

absent :: DelayedIO (RequestArgument mods a)
absent =
delayedFailFatal $ err401{ errBody = toLBS msg }
where
allAuthMethodNames = getAuthName <$> authHandlers
msg = case allAuthMethodNames of
[] -> "No authentication required, something went wrong."
[auth] -> "Authentication required: " <> auth
auths -> "One of the following authentications required: " <> T.intercalate ", " auths

data NewAuthHandler r res = NewAuthHandler
{ getHandler :: r -> Handler res
-- Used in the following errors:
-- "One of the following authentications required: {authName}, {authName}, etc."
-- And
-- "Authentication via {authName}: (Left Text)"
, getAuthName :: Text
, getErrorFormatter :: ErrorFormatters -> ErrorFormatter
} deriving Typeable

class AllAuth (auths :: [Type]) a where
allAuthHandlers :: proxy auths -> [NewAuthHandler Request (NewAuthResult a)]

instance AllAuth '[] a where
allAuthHandlers _ = []

instance (AllAuth auths a, HasAuthHandler auth a) => AllAuth (auth ': auths) a where
allAuthHandlers _ = getAuthHandler (Proxy @auth) : allAuthHandlers (Proxy @auths)

class HasAuthHandler auth a where
getAuthHandler :: proxy auth -> NewAuthHandler Request (NewAuthResult a)

{-
The following is an example of a partial implementation to be used by users
where they will only have to supply a 'FromJWT' instance for their type.
There's a lot of possibilities, but this is just a quick and easy example.
-}

-- | Designates a query parameter named @sym@ which should contain a valid JWT
--
-- E.g. @JWTQueryParam "token"@ will try to parse the value from the query
-- parameter named "token" in the path (i.e. "example.com/path?token=...")
data JWTQueryParam sym

-- | Pretty library agnostic way of providing a JWT API, but less plug'n'play.
--
-- The Servant.Auth.JWT 'FromJWT' forces the 'Value' to be in the "dat" field,
-- but makes this implementation easier and user-friendly. So it depends on
-- what we'd prefer.
class FromJWT a where
parseJWT :: Text -> Either Text a

instance (FromJWT a, KnownSymbol sym) => HasAuthHandler (JWTQueryParam sym) a where
getAuthHandler _ = NewAuthHandler
{ getHandler = \req ->
let paramname = fromString $ symbolVal (Proxy :: Proxy sym)
mev = join . lookup paramname . queryToQueryText $ queryString req
in pure $ parseJWT <$> mev
, getAuthName = "JWT Query Parameter"
, getErrorFormatter = urlParseErrorFormatter
}
8 changes: 7 additions & 1 deletion stack.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
resolver: lts-18.5
resolver: lts-18.27
packages:
- servant-client-core/
- servant-client/
Expand All @@ -12,6 +12,12 @@ packages:
- servant-machines/
- servant-pipes/

- servant-auth/servant-auth
- servant-auth/servant-auth-docs
- servant-auth/servant-auth-client
- servant-auth/servant-auth-server
- servant-auth/servant-auth-swagger

# allow-newer: true # ignores all bounds, that's a sledgehammer
# - doc/tutorial/

Expand Down
8 changes: 4 additions & 4 deletions stack.yaml.lock
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ packages:
hackage: hspec-wai-0.10.1
snapshots:
- completed:
size: 585817
url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/18/5.yaml
sha256: 22d24d0dacad9c1450b9a174c28d203f9bb482a2a8da9710a2f2a9f4afee2887
original: lts-18.5
size: 590102
url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/18/27.yaml
sha256: 79a786674930a89301b0e908fad2822a48882f3d01486117693c377b8edffdbe
original: lts-18.27