diff --git a/io/js/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala b/io/js/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala index eebb3c8982..78732d17bc 100644 --- a/io/js/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala +++ b/io/js/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala @@ -70,6 +70,7 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type => } yield new AsyncTLSSocket( tlsSock, readStream, + socket, sessionRef.discrete.unNone.head.compile.lastOrError, F.delay[Any](tlsSock.alpnProtocol).flatMap { case false => "".pure // mimicking JVM @@ -81,8 +82,12 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type => private[tls] final class AsyncTLSSocket[F[_]: Async]( sock: facade.tls.TLSSocket, readStream: SuspendedStream[F, Byte], + underlying: Socket[F], val session: F[SSLSession], val applicationProtocol: F[String] ) extends Socket.AsyncSocket[F](sock, readStream) - with UnsealedTLSSocket[F] + with UnsealedTLSSocket[F] { + override def localAddress = underlying.localAddress + override def remoteAddress = underlying.remoteAddress + } } diff --git a/io/js/src/test/scala/fs2/io/net/tls/TLSSocketSuite.scala b/io/js/src/test/scala/fs2/io/net/tls/TLSSocketSuite.scala index f88e35f6fc..87014b2427 100644 --- a/io/js/src/test/scala/fs2/io/net/tls/TLSSocketSuite.scala +++ b/io/js/src/test/scala/fs2/io/net/tls/TLSSocketSuite.scala @@ -360,5 +360,46 @@ class TLSSocketSuite extends TLSSuite { .to(Chunk) .intercept[SSLException] } + + test("get local and remote address") { + val setup = for { + tlsContext <- Resource.eval(testTlsContext(true)) + addressAndConnections <- Network[IO].serverResource(Some(ip"127.0.0.1")) + (serverAddress, server) = addressAndConnections + client = Network[IO] + .client(serverAddress) + .flatMap( + tlsContext + .clientBuilder(_) + .withParameters( + TLSParameters(checkServerIdentity = + Some((sn, _) => Either.cond(sn == "localhost", (), new RuntimeException())) + ) + ) + .build + ) + } yield server.flatMap(s => Stream.resource(tlsContext.server(s))) -> client + + Stream + .resource(setup) + .flatMap { case (server, clientSocket) => + val serverSocketAddresses = server.evalMap { socket => + socket.localAddress.product(socket.remoteAddress) + } + + val clientSocketAddresses = + Stream.resource(clientSocket).evalMap { socket => + socket.localAddress.product(socket.remoteAddress) + } + + serverSocketAddresses.parZip(clientSocketAddresses).map { + case ((serverLocal, serverRemote), (clientLocal, clientRemote)) => + assertEquals(clientRemote, serverLocal) + assertEquals(clientLocal, serverRemote) + } + } + .compile + .drain + } } }