Skip to content

Commit ea77156

Browse files
KAFKA-14604: SASL session expiration time will be overflowed when calculation (#18526)
The timeout value may be overflowed if users set a large expiration time. ``` sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 1000 * sessionLifetimeMs; ``` Fixed it by throwing exception if the value is overflowed. Reviewers: TaiJuWu <[email protected]>, Luke Chen <[email protected]>, TengYao Chi <[email protected]> Signed-off-by: PoAn Yang <[email protected]>
1 parent 3f1d830 commit ea77156

File tree

6 files changed

+136
-14
lines changed

6 files changed

+136
-14
lines changed

clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ public void setAuthenticationEndAndSessionReauthenticationTimes(long nowNanos) {
690690
double pctToUse = pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + RNG.nextDouble()
691691
* pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously;
692692
sessionLifetimeMsToUse = (long) (positiveSessionLifetimeMs * pctToUse);
693-
clientSessionReauthenticationTimeNanos = authenticationEndNanos + 1000 * 1000 * sessionLifetimeMsToUse;
693+
clientSessionReauthenticationTimeNanos = Math.addExact(authenticationEndNanos, Utils.msToNs(sessionLifetimeMsToUse));
694694
log.debug(
695695
"Finished {} with session expiration in {} ms and session re-authentication on or after {} ms",
696696
authenticationOrReauthenticationText(), positiveSessionLifetimeMs, sessionLifetimeMsToUse);

clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ else if (!maxReauthSet)
681681
else
682682
retvalSessionLifetimeMs = zeroIfNegative(Math.min(credentialExpirationMs - authenticationEndMs, connectionsMaxReauthMs));
683683

684-
sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 1000 * retvalSessionLifetimeMs;
684+
sessionExpirationTimeNanos = Math.addExact(authenticationEndNanos, Utils.msToNs(retvalSessionLifetimeMs));
685685
}
686686

687687
if (credentialExpirationMs != null) {

clients/src/main/java/org/apache/kafka/common/utils/Utils.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,4 +1719,17 @@ public static ConfigDef mergeConfigs(List<ConfigDef> configDefs) {
17191719
public interface ThrowingRunnable {
17201720
void run() throws Exception;
17211721
}
1722+
1723+
/**
1724+
* convert millisecond to nanosecond, or throw exception if overflow
1725+
* @param timeMs the time in millisecond
1726+
* @return the converted nanosecond
1727+
*/
1728+
public static long msToNs(long timeMs) {
1729+
try {
1730+
return Math.multiplyExact(1000 * 1000, timeMs);
1731+
} catch (ArithmeticException e) {
1732+
throw new IllegalArgumentException("Cannot convert " + timeMs + " millisecond to nanosecond due to arithmetic overflow", e);
1733+
}
1734+
}
17221735
}

clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java

Lines changed: 85 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ public class SaslAuthenticatorTest {
158158
private static final long CONNECTIONS_MAX_REAUTH_MS_VALUE = 100L;
159159
private static final int BUFFER_SIZE = 4 * 1024;
160160
private static Time time = Time.SYSTEM;
161+
private static boolean needLargeExpiration = false;
161162

162163
private NioEchoServer server;
163164
private Selector selector;
@@ -181,6 +182,7 @@ public void setup() throws Exception {
181182

182183
@AfterEach
183184
public void teardown() throws Exception {
185+
needLargeExpiration = false;
184186
if (server != null)
185187
this.server.close();
186188
if (selector != null)
@@ -1610,6 +1612,42 @@ public void testCannotReauthenticateWithDifferentPrincipal() throws Exception {
16101612
server.verifyReauthenticationMetrics(0, 1);
16111613
}
16121614

1615+
@Test
1616+
public void testReauthenticateWithLargeReauthValue() throws Exception {
1617+
// enable it, we'll get a large expiration timestamp token
1618+
needLargeExpiration = true;
1619+
String node = "0";
1620+
SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
1621+
1622+
configureMechanisms(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
1623+
List.of(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM));
1624+
// set a large re-auth timeout in server side
1625+
saslServerConfigs.put(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS_CONFIG, Long.MAX_VALUE);
1626+
server = createEchoServer(securityProtocol);
1627+
1628+
// set to default value for sasl login configs for initialization in ExpiringCredentialRefreshConfig
1629+
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR, SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_FACTOR);
1630+
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER, SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_JITTER);
1631+
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS, SaslConfigs.DEFAULT_LOGIN_REFRESH_MIN_PERIOD_SECONDS);
1632+
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS, SaslConfigs.DEFAULT_LOGIN_REFRESH_BUFFER_SECONDS);
1633+
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, AlternateLoginCallbackHandler.class);
1634+
1635+
createCustomClientConnection(securityProtocol, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, node, true);
1636+
1637+
// channel should be not null before sasl handshake
1638+
assertNotNull(selector.channel(node));
1639+
1640+
TestUtils.waitForCondition(() -> {
1641+
selector.poll(1000);
1642+
// this channel should be closed due to session timeout calculation overflow
1643+
return selector.channel(node) == null;
1644+
}, "channel didn't close with large re-authentication value");
1645+
1646+
// ensure metrics are as expected
1647+
server.verifyAuthenticationMetrics(0, 0);
1648+
server.verifyReauthenticationMetrics(0, 0);
1649+
}
1650+
16131651
@Test
16141652
public void testCorrelationId() {
16151653
SaslClientAuthenticator authenticator = new SaslClientAuthenticator(
@@ -2002,7 +2040,7 @@ private void createClientConnection(SecurityProtocol securityProtocol, String sa
20022040
if (enableSaslAuthenticateHeader)
20032041
createClientConnection(securityProtocol, node);
20042042
else
2005-
createClientConnectionWithoutSaslAuthenticateHeader(securityProtocol, saslMechanism, node);
2043+
createCustomClientConnection(securityProtocol, saslMechanism, node, false);
20062044
}
20072045

20082046
private NioEchoServer startServerApiVersionsUnsupportedByClient(final SecurityProtocol securityProtocol, String saslMechanism) throws Exception {
@@ -2090,15 +2128,13 @@ protected void enableKafkaSaslAuthenticateHeaders(boolean flag) {
20902128
return server;
20912129
}
20922130

2093-
private void createClientConnectionWithoutSaslAuthenticateHeader(final SecurityProtocol securityProtocol,
2094-
final String saslMechanism, String node) throws Exception {
2095-
2096-
final ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol);
2097-
final Map<String, ?> configs = Collections.emptyMap();
2098-
final JaasContext jaasContext = JaasContext.loadClientContext(configs);
2099-
final Map<String, JaasContext> jaasContexts = Collections.singletonMap(saslMechanism, jaasContext);
2100-
2101-
SaslChannelBuilder clientChannelBuilder = new SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts,
2131+
private SaslChannelBuilder saslChannelBuilderWithoutHeader(
2132+
final SecurityProtocol securityProtocol,
2133+
final String saslMechanism,
2134+
final Map<String, JaasContext> jaasContexts,
2135+
final ListenerName listenerName
2136+
) {
2137+
return new SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts,
21022138
securityProtocol, listenerName, false, saslMechanism,
21032139
null, null, null, time, new LogContext(), null) {
21042140

@@ -2125,6 +2161,42 @@ protected void setSaslAuthenticateAndHandshakeVersions(ApiVersionsResponse apiVe
21252161
};
21262162
}
21272163
};
2164+
}
2165+
2166+
private void createCustomClientConnection(
2167+
final SecurityProtocol securityProtocol,
2168+
final String saslMechanism,
2169+
String node,
2170+
boolean withSaslAuthenticateHeader
2171+
) throws Exception {
2172+
2173+
final ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol);
2174+
final Map<String, ?> configs = Collections.emptyMap();
2175+
final JaasContext jaasContext = JaasContext.loadClientContext(configs);
2176+
final Map<String, JaasContext> jaasContexts = Collections.singletonMap(saslMechanism, jaasContext);
2177+
2178+
SaslChannelBuilder clientChannelBuilder;
2179+
if (!withSaslAuthenticateHeader) {
2180+
clientChannelBuilder = saslChannelBuilderWithoutHeader(securityProtocol, saslMechanism, jaasContexts, listenerName);
2181+
} else {
2182+
clientChannelBuilder = new SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts,
2183+
securityProtocol, listenerName, false, saslMechanism,
2184+
null, null, null, time, new LogContext(), null) {
2185+
2186+
@Override
2187+
protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs,
2188+
AuthenticateCallbackHandler callbackHandler,
2189+
String id,
2190+
String serverHost,
2191+
String servicePrincipal,
2192+
TransportLayer transportLayer,
2193+
Subject subject) {
2194+
2195+
return new SaslClientAuthenticator(configs, callbackHandler, id, subject,
2196+
servicePrincipal, serverHost, saslMechanism, transportLayer, time, new LogContext());
2197+
}
2198+
};
2199+
}
21282200
clientChannelBuilder.configure(saslClientConfigs);
21292201
this.selector = NetworkTestUtils.createSelector(clientChannelBuilder, time);
21302202
InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
@@ -2581,10 +2653,11 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback
25812653
+ ++numInvocations;
25822654
String headerJson = "{" + claimOrHeaderJsonText("alg", "none") + "}";
25832655
/*
2584-
* Use a short lifetime so the background refresh thread replaces it before we
2656+
* If we're testing large expiration scenario, use a large lifetime.
2657+
* Otherwise, use a short lifetime so the background refresh thread replaces it before we
25852658
* re-authenticate
25862659
*/
2587-
String lifetimeSecondsValueToUse = "1";
2660+
String lifetimeSecondsValueToUse = needLargeExpiration ? String.valueOf(Long.MAX_VALUE) : "1";
25882661
String claimsJson;
25892662
try {
25902663
claimsJson = String.format("{%s,%s,%s}",

clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,35 @@ public void testSessionExpiresAtTokenExpiry() throws IOException {
270270
}
271271
}
272272

273+
@Test
274+
public void testSessionWontExpireWithLargeExpirationTime() throws IOException {
275+
String mechanism = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM;
276+
SaslServer saslServer = mock(SaslServer.class);
277+
MockTime time = new MockTime(0, 1, 1000);
278+
// set a Long.MAX_VALUE as the expiration time
279+
Duration largeExpirationTime = Duration.ofMillis(Long.MAX_VALUE);
280+
281+
try (
282+
MockedStatic<?> ignored = mockSaslServer(saslServer, mechanism, time, largeExpirationTime);
283+
MockedStatic<?> ignored2 = mockKafkaPrincipal("[principal-type]", "[principal-name");
284+
TransportLayer transportLayer = mockTransportLayer()
285+
) {
286+
287+
SaslServerAuthenticator authenticator = getSaslServerAuthenticatorForOAuth(mechanism, transportLayer, time, largeExpirationTime.toMillis());
288+
289+
mockRequest(saslHandshakeRequest(mechanism), transportLayer);
290+
authenticator.authenticate();
291+
292+
when(saslServer.isComplete()).thenReturn(false).thenReturn(true);
293+
mockRequest(saslAuthenticateRequest(), transportLayer);
294+
295+
Throwable t = assertThrows(IllegalArgumentException.class, () -> authenticator.authenticate());
296+
assertEquals(ArithmeticException.class, t.getCause().getClass());
297+
assertEquals("Cannot convert " + Long.MAX_VALUE + " millisecond to nanosecond due to arithmetic overflow",
298+
t.getMessage());
299+
}
300+
}
301+
273302
private SaslServerAuthenticator getSaslServerAuthenticatorForOAuth(String mechanism, TransportLayer transportLayer, Time time, Long maxReauth) {
274303
Map<String, ?> configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG,
275304
Collections.singletonList(mechanism));

clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,13 @@ public void testTryAll() throws Throwable {
12691269
assertEquals(expected, recorded);
12701270
}
12711271

1272+
@Test
1273+
public void testMsToNs() {
1274+
assertEquals(1000000, Utils.msToNs(1));
1275+
assertEquals(0, Utils.msToNs(0));
1276+
assertThrows(IllegalArgumentException.class, () -> Utils.msToNs(Long.MAX_VALUE));
1277+
}
1278+
12721279
private Callable<Void> recordingCallable(Map<String, Object> recordingMap, String success, TestException failure) {
12731280
return () -> {
12741281
if (success == null)

0 commit comments

Comments
 (0)