Skip to content

Commit 01f139d

Browse files
committed
Fix spinloop bug in TLSEngine
1 parent 17f6852 commit 01f139d

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed

io/jvm/src/main/scala/fs2/io/net/tls/TLSEngine.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,7 @@ private[tls] object TLSEngine {
215215
else
216216
binding.read(engine.getSession.getPacketBufferSize).flatMap {
217217
case Some(c) => unwrapBuffer.input(c) >> unwrapHandshake
218-
case None =>
219-
unwrapBuffer.inputRemains
220-
.flatMap(x => if (x > 0) Applicative[F].unit else stopUnwrap)
218+
case None => stopUnwrap
221219
}
222220
}
223221
case SSLEngineResult.HandshakeStatus.NEED_UNWRAP_AGAIN =>

io/jvm/src/test/scala/fs2/io/net/tls/TLSSocketSuite.scala

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,65 @@ class TLSSocketSuite extends TLSSuite {
219219
.to(Chunk)
220220
.assertEquals(msg)
221221
}
222+
223+
test("endOfOutput during handshake results in termination".only) {
224+
val msg = Chunk.array(("Hello, world! " * 20000).getBytes)
225+
226+
def limitWrites(raw: Socket[IO], limit: Int): Socket[IO] = new Socket[IO] {
227+
def endOfInput = raw.endOfInput
228+
def endOfOutput = raw.endOfOutput
229+
@deprecated("", "")
230+
def isOpen = raw.isOpen
231+
@deprecated("", "")
232+
def localAddress = raw.localAddress
233+
def peerAddress = raw.peerAddress
234+
def read(maxBytes: Int) = raw.read(maxBytes)
235+
def readN(numBytes: Int) = raw.readN(numBytes)
236+
def reads = raw.reads
237+
@deprecated("", "")
238+
def remoteAddress = raw.remoteAddress
239+
def writes = raw.writes
240+
241+
def address = raw.address
242+
def getOption[A](key: SocketOption.Key[A]) = raw.getOption(key)
243+
def setOption[A](key: SocketOption.Key[A], value: A) = raw.setOption(key, value)
244+
def supportedOptions = raw.supportedOptions
245+
246+
private var totalWritten: Long = 0
247+
def write(bytes: Chunk[Byte]) =
248+
if (totalWritten >= limit) endOfOutput
249+
else {
250+
val b = bytes.take(limit)
251+
raw.write(b) >> IO(totalWritten += b.size)
252+
}
253+
}
254+
255+
val setup = for {
256+
tlsContext <- Resource.eval(testTlsContext)
257+
serverSocket <- Network[IO].bind(SocketAddress(ip"127.0.0.1", Port.Wildcard))
258+
client <- Network[IO].connect(serverSocket.address).flatMap { rawClient =>
259+
tlsContext.clientBuilder(rawClient).withLogger(logger).build
260+
}
261+
} yield serverSocket.accept
262+
.flatMap(s => Stream.resource(tlsContext.server(limitWrites(s, 100)))) -> client
263+
264+
Stream
265+
.resource(setup)
266+
.flatMap { case (server, clientSocket) =>
267+
val echoServer = server.map { socket =>
268+
socket.reads.chunks.foreach(socket.write(_))
269+
}.parJoinUnbounded
270+
271+
val client =
272+
Stream.exec(clientSocket.write(msg)).onFinalize(clientSocket.endOfOutput) ++
273+
clientSocket.reads.take(msg.size.toLong)
274+
275+
client.concurrently(echoServer)
276+
}
277+
.compile
278+
.drain
279+
.intercept[javax.net.ssl.SSLException]
280+
}
222281
}
223282

224283
group("TLSContextBuilder") {

0 commit comments

Comments
 (0)