Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions container-tests/src/test/java/okhttp3/containers/DeadSocketTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
/*
* Copyright (c) 2026 OkHttp Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.containers

import assertk.assertFailure
import assertk.assertThat
import assertk.assertions.isInstanceOf
import assertk.assertions.isLessThan
import com.github.dockerjava.api.model.Capability
import java.io.IOException
import java.time.Duration
import java.util.concurrent.TimeUnit
import mockwebserver3.Dispatcher
import mockwebserver3.MockResponse
import mockwebserver3.MockWebServer
import mockwebserver3.RecordedRequest
import okhttp3.MediaType
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.OkHttpClient
import okhttp3.Protocol
import okhttp3.Request
import okhttp3.RequestBody.Companion.toRequestBody
import okhttp3.Response
import okhttp3.tls.HandshakeCertificates
import okhttp3.tls.HeldCertificate
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.Timeout
import org.junit.jupiter.api.parallel.Isolated
import org.testcontainers.Testcontainers
import org.testcontainers.containers.GenericContainer
import org.testcontainers.containers.wait.strategy.Wait
import org.testcontainers.images.builder.ImageFromDockerfile

/**
* Reproduces a stateful firewall silently killing a pooled HTTP/2 connection at L3 (no RST, no FIN)
* while a new stream is being written. A socat tunnel in an Alpine container with `NET_ADMIN`
* engages `iptables -j DROP` mid-test, after which the next HEADERS write fills the kernel TCP
* send buffer and blocks indefinitely waiting for ACKs that will never arrive.
*
* Asserts the design intent of [OkHttpClient.Builder.pingInterval]: a configured ping interval
* detects the dead connection and fails the call within a small multiple of the interval.
*
* Requires Docker with the `NET_ADMIN` capability (for `iptables -j DROP`).
*/
@Isolated
@Timeout(value = 30, unit = TimeUnit.SECONDS)
class DeadSocketTest {
private lateinit var rootCa: HeldCertificate
private lateinit var mockWebServer: MockWebServer
private lateinit var tunnel: GenericContainer<*>
private lateinit var client: OkHttpClient

@BeforeEach
fun setUp() {
rootCa =
HeldCertificate
.Builder()
.certificateAuthority(0)
.build()

val serverCert =
HeldCertificate
.Builder()
.addSubjectAlternativeName("localhost")
.signedBy(rootCa)
.build()

val serverHandshakeCerts =
HandshakeCertificates
.Builder()
.heldCertificate(serverCert)
.build()

mockWebServer = MockWebServer()
mockWebServer.useHttps(serverHandshakeCerts.sslSocketFactory())
mockWebServer.protocols = listOf(Protocol.HTTP_2, Protocol.HTTP_1_1)
mockWebServer.dispatcher =
object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse =
MockResponse
.Builder()
.code(200)
.body("ok")
.build()
}
mockWebServer.start()

Testcontainers.exposeHostPorts(mockWebServer.port)

tunnel =
GenericContainer(
ImageFromDockerfile()
.withDockerfileFromBuilder { builder ->
builder
.from("alpine:3")
.run("apk add --no-cache socat iptables")
.build()
},
).withCreateContainerCmdModifier { cmd ->
cmd.hostConfig!!.withCapAdd(Capability.NET_ADMIN)
}.withExposedPorts(TUNNEL_PORT)
.withCommand(
"socat",
"TCP-LISTEN:$TUNNEL_PORT,fork,reuseaddr",
"TCP:host.testcontainers.internal:${mockWebServer.port}",
).waitingFor(Wait.forListeningPort())
tunnel.start()

val clientCerts =
HandshakeCertificates
.Builder()
.addTrustedCertificate(rootCa.certificate)
.build()

client =
OkHttpClient
.Builder()
.protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1))
.pingInterval(Duration.ofSeconds(1))
// Isolate pingInterval detection latency from connect-retry latency: with retry on, the
// retry's new TCP+TLS handshake races against connectTimeout through the still-dead tunnel
// and would dominate the measured elapsed time.
.retryOnConnectionFailure(false)
.sslSocketFactory(clientCerts.sslSocketFactory(), clientCerts.trustManager)
.hostnameVerifier { hostname, _ -> hostname == "localhost" }
.build()
}

@AfterEach
fun tearDown() {
client.connectionPool.evictAll()
client.dispatcher.executorService.shutdown()
mockWebServer.close()
tunnel.stop()
}

/**
* Warms up a pooled HTTP/2 connection through a socat tunnel, then engages `iptables -j DROP` to
* silently kill the connection at L3. The next request carries oversized headers so the HEADERS
* write overflows the TCP send buffer and blocks. The 1s pingInterval should detect the dead
* connection within a few intervals.
*/
@Test
fun pingIntervalBoundsFailureOnDeadConnection() {
repeat(3) { i ->
sendPost("warmup-$i").use { response ->
check(response.code == 200) { "warm-up $i failed: ${response.code}" }
response.body.string()
}
}

execInTunnel("iptables", "-I", "INPUT", "-p", "tcp", "--dport", "$TUNNEL_PORT", "-j", "DROP")
execInTunnel("iptables", "-I", "OUTPUT", "-p", "tcp", "--sport", "$TUNNEL_PORT", "-j", "DROP")

val startNanos = System.nanoTime()
assertFailure {
sendPostWithLargeHeaders().use { it.body.string() }
}.isInstanceOf<IOException>()
val elapsedSeconds = (System.nanoTime() - startNanos) / 1_000_000_000.0

// pingInterval is 1s; design intent is failure within a small multiple of the interval.
assertThat(elapsedSeconds).isLessThan(3.0)
}

private fun sendPost(body: String): Response =
client
.newCall(
Request
.Builder()
.url("https://localhost:${tunnel.getMappedPort(TUNNEL_PORT)}/echo")
.post(body.toRequestBody(JSON))
.build(),
).execute()

private fun sendPostWithLargeHeaders(): Response =
client
.newCall(
Request
.Builder()
.url("https://localhost:${tunnel.getMappedPort(TUNNEL_PORT)}/echo")
.post("hello".toRequestBody(JSON))
.header("X-Padding", "X".repeat(LARGE_HEADER_SIZE))
.build(),
).execute()

private fun execInTunnel(vararg command: String) {
val result = tunnel.execInContainer(*command)
check(result.exitCode == 0) { "container exec failed: ${result.stderr}" }
}

companion object {
private const val TUNNEL_PORT = 8443
private const val LARGE_HEADER_SIZE = 8 * 1024 * 1024
private val JSON: MediaType = "application/json".toMediaType()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,17 @@ class ConnectPlan internal constructor(
null
}
javaNetSocket = sslSocket
socket = sslSocket.asBufferedSocket()
val rawSocket = requireNotNull(rawSocket) { "TCP not connected" }
val sslBuffered = sslSocket.asBufferedSocket()
// Close the raw TCP socket before the SSL layer: SSLSocket.close() may block for many
// seconds attempting to send close_notify on a half-open connection.
socket =
object : BufferedSocket by sslBuffered {
override fun cancel() {
rawSocket.closeQuietly()
sslBuffered.cancel()
}
}
protocol = if (maybeProtocol != null) Protocol.get(maybeProtocol) else Protocol.HTTP_1_1
success = true
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ class Http2Connection internal constructor(
/** Asynchronously writes frames to the outgoing socket. */
private val writerQueue = taskRunner.newQueue()

/**
* Runs the periodic pingInterval watchdog. Separate from [writerQueue] so missing-pong detection
* still fires when the writer is wedged on a frame write to a half-open connection.
*/
private val pingQueue = taskRunner.newQueue()

/** Ensures push promise callbacks events are sent in order per stream. */
private val pushQueue = taskRunner.newQueue()

Expand Down Expand Up @@ -147,7 +153,7 @@ class Http2Connection internal constructor(
init {
if (builder.pingIntervalMillis != 0) {
val pingIntervalNanos = TimeUnit.MILLISECONDS.toNanos(builder.pingIntervalMillis.toLong())
writerQueue.schedule("$connectionName ping", pingIntervalNanos) {
pingQueue.schedule("$connectionName ping", pingIntervalNanos) {
val failDueToMissingPong =
withLock {
if (intervalPongsReceived < intervalPingsSent) {
Expand All @@ -161,7 +167,12 @@ class Http2Connection internal constructor(
failConnection(null)
return@schedule -1L
} else {
writePing(false, INTERVAL_PING, 0)
// Fire-and-forget the ping write so the watchdog returns promptly. If writer.lock is
// held by a wedged frame write, the ping queues behind it; the next tick sees
// pongs < pings and fails the connection, which cancels the socket and unblocks the writer.
writerQueue.execute("$connectionName ping write") {
writePing(false, INTERVAL_PING, 0)
}
return@schedule pingIntervalNanos
}
}
Expand Down Expand Up @@ -481,11 +492,31 @@ class Http2Connection internal constructor(

// Release the threads.
writerQueue.shutdown()
pingQueue.shutdown()
pushQueue.shutdown()
settingsListenerQueue.shutdown()
}

private fun failConnection(e: IOException?) {
// Mark active streams with PROTOCOL_ERROR before socket.cancel() so the caller surfaces
// StreamResetException(PROTOCOL_ERROR) on its waiting read or write.
//
// Note that if we cancel first: the reader thread's blocked socket read returns SocketException,
// which can reach the caller before close() -> stream.close() sets errorCode in Http2Stream.closeInternal.
//
// We use closeLater() (which enqueues RST_STREAM via writerQueue) rather than close(): the latter
// writes RST_STREAM synchronously under writer.lock, and that lock is already held by the
// frame write that's blocked on the half-open socket. See pingQueue.
val streamsToClose: Array<Http2Stream>? =
withLock {
if (streams.isNotEmpty()) streams.values.toTypedArray() else null
}
streamsToClose?.forEach { stream ->
stream.closeLater(ErrorCode.PROTOCOL_ERROR)
}
ignoreIoExceptions {
socket.cancel()
}
close(ErrorCode.PROTOCOL_ERROR, ErrorCode.PROTOCOL_ERROR, e)
}

Expand Down