Skip to content

Commit 77ab80f

Browse files
authored
Fix OOM in UploadBsaUnavailableDomains action (#2817)
* Fix OOM in UploadBsaUnavailableDomains action The action was using string concatenation to generate the upload content. This causes an OOM when string length exceeds 25MB on our current VM. This PR witches to streaming upload. Also added an HTTP upload test. * Fix OOM in UploadBsaUnavailableDomains action The action was using string concatenation to generate the upload content. This causes an OOM when string length exceeds 25MB on our current VM. This PR witches to streaming upload. Also added an HTTP upload test.
1 parent 5e1cd01 commit 77ab80f

File tree

4 files changed

+256
-24
lines changed

4 files changed

+256
-24
lines changed

core/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def fragileTestPatterns = [
5858
// Changes cache timeouts and for some reason appears to have contention
5959
// with other tests.
6060
"google/registry/whois/WhoisCommandFactoryTest.*",
61+
// Breaks random other tests when running with standardTests.
62+
"google/registry/bsa/UploadBsaUnavailableDomainsActionTest.*",
6163
// Currently changes a global configuration parameter that for some reason
6264
// results in timestamp inversions for other tests. TODO(mmuller): fix.
6365
"google/registry/flows/host/HostInfoFlowTest.*",

core/src/main/java/google/registry/bsa/UploadBsaUnavailableDomainsAction.java

Lines changed: 89 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@
2525
import static google.registry.request.Action.Method.POST;
2626
import static jakarta.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR;
2727
import static java.nio.charset.StandardCharsets.US_ASCII;
28+
import static java.nio.charset.StandardCharsets.UTF_8;
2829

2930
import com.google.cloud.storage.BlobId;
30-
import com.google.common.base.Joiner;
3131
import com.google.common.collect.ImmutableList;
3232
import com.google.common.collect.ImmutableSet;
3333
import com.google.common.collect.ImmutableSortedSet;
3434
import com.google.common.collect.Ordering;
3535
import com.google.common.flogger.FluentLogger;
36+
import com.google.common.hash.Hasher;
3637
import com.google.common.hash.Hashing;
37-
import com.google.common.io.ByteSource;
3838
import google.registry.bsa.api.BsaCredential;
3939
import google.registry.config.RegistryConfig.Config;
4040
import google.registry.gcs.GcsUtils;
@@ -47,10 +47,13 @@
4747
import google.registry.util.Clock;
4848
import jakarta.inject.Inject;
4949
import jakarta.persistence.TypedQuery;
50-
import java.io.ByteArrayOutputStream;
50+
import java.io.BufferedInputStream;
5151
import java.io.IOException;
52+
import java.io.InputStream;
5253
import java.io.OutputStream;
5354
import java.io.OutputStreamWriter;
55+
import java.io.PipedInputStream;
56+
import java.io.PipedOutputStream;
5457
import java.io.Writer;
5558
import java.util.Optional;
5659
import java.util.zip.GZIPOutputStream;
@@ -60,14 +63,17 @@
6063
import okhttp3.Request;
6164
import okhttp3.RequestBody;
6265
import okhttp3.Response;
66+
import okio.BufferedSink;
67+
import org.jetbrains.annotations.NotNull;
68+
import org.jetbrains.annotations.Nullable;
6369
import org.joda.time.DateTime;
6470

6571
/**
6672
* Daily action that uploads unavailable domain names on applicable TLDs to BSA.
6773
*
6874
* <p>The upload is a single zipped text file containing combined details for all BSA-enrolled TLDs.
69-
* The text is a newline-delimited list of punycoded fully qualified domain names, and contains all
70-
* domains on each TLD that are registered and/or reserved.
75+
* The text is a newline-delimited list of punycoded fully qualified domain names with a trailing
76+
* newline at the end, and contains all domains on each TLD that are registered and/or reserved.
7177
*
7278
* <p>The file is also uploaded to GCS to preserve it as a record for ourselves.
7379
*/
@@ -118,7 +124,7 @@ public void run() {
118124
// TODO(mcilwain): Implement a date Cursor, have the cronjob run frequently, and short-circuit
119125
// the run if the daily upload is already completed.
120126
DateTime runTime = clock.nowUtc();
121-
String unavailableDomains = Joiner.on("\n").join(getUnavailableDomains(runTime));
127+
ImmutableSortedSet<String> unavailableDomains = getUnavailableDomains(runTime);
122128
if (unavailableDomains.isEmpty()) {
123129
logger.atWarning().log("No unavailable domains found; terminating.");
124130
emailSender.sendNotification(
@@ -136,12 +142,16 @@ public void run() {
136142
}
137143

138144
/** Uploads the unavailable domains list to GCS in the unavailable domains bucket. */
139-
boolean uploadToGcs(String unavailableDomains, DateTime runTime) {
145+
boolean uploadToGcs(ImmutableSortedSet<String> unavailableDomains, DateTime runTime) {
140146
logger.atInfo().log("Uploading unavailable names file to GCS in bucket %s", gcsBucket);
141147
BlobId blobId = BlobId.of(gcsBucket, createFilename(runTime));
148+
// `gcsUtils.openOutputStream` returns a buffered stream
142149
try (OutputStream gcsOutput = gcsUtils.openOutputStream(blobId);
143150
Writer osWriter = new OutputStreamWriter(gcsOutput, US_ASCII)) {
144-
osWriter.write(unavailableDomains);
151+
for (var domainName : unavailableDomains) {
152+
osWriter.write(domainName);
153+
osWriter.write("\n");
154+
}
145155
return true;
146156
} catch (Exception e) {
147157
logger.atSevere().withCause(e).log(
@@ -150,10 +160,14 @@ boolean uploadToGcs(String unavailableDomains, DateTime runTime) {
150160
}
151161
}
152162

153-
boolean uploadToBsa(String unavailableDomains, DateTime runTime) {
163+
boolean uploadToBsa(ImmutableSortedSet<String> unavailableDomains, DateTime runTime) {
154164
try {
155-
byte[] gzippedContents = gzipUnavailableDomains(unavailableDomains);
156-
String sha512Hash = ByteSource.wrap(gzippedContents).hash(Hashing.sha512()).toString();
165+
Hasher sha512Hasher = Hashing.sha512().newHasher();
166+
unavailableDomains.stream()
167+
.map(name -> name + "\n")
168+
.forEachOrdered(line -> sha512Hasher.putString(line, UTF_8));
169+
String sha512Hash = sha512Hasher.hash().toString();
170+
157171
String filename = createFilename(runTime);
158172
OkHttpClient client = new OkHttpClient().newBuilder().build();
159173

@@ -169,7 +183,9 @@ boolean uploadToBsa(String unavailableDomains, DateTime runTime) {
169183
.addFormDataPart(
170184
"file",
171185
String.format("%s.gz", filename),
172-
RequestBody.create(gzippedContents, MediaType.parse("application/octet-stream")))
186+
new StreamingRequestBody(
187+
gzippedStream(unavailableDomains),
188+
MediaType.parse("application/octet-stream")))
173189
.build();
174190

175191
Request request =
@@ -196,15 +212,6 @@ boolean uploadToBsa(String unavailableDomains, DateTime runTime) {
196212
}
197213
}
198214

199-
private byte[] gzipUnavailableDomains(String unavailableDomains) throws IOException {
200-
try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream()) {
201-
try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(byteArrayOutputStream)) {
202-
gzipOutputStream.write(unavailableDomains.getBytes(US_ASCII));
203-
}
204-
return byteArrayOutputStream.toByteArray();
205-
}
206-
}
207-
208215
private static String createFilename(DateTime runTime) {
209216
return String.format("unavailable_domains_%s.txt", runTime.toString());
210217
}
@@ -280,4 +287,65 @@ private ImmutableSortedSet<String> getUnavailableDomains(DateTime runTime) {
280287
private static String toDomain(String domainLabel, Tld tld) {
281288
return String.format("%s.%s", domainLabel, tld.getTldStr());
282289
}
290+
291+
private InputStream gzippedStream(ImmutableSortedSet<String> unavailableDomains)
292+
throws IOException {
293+
PipedInputStream inputStream = new PipedInputStream();
294+
PipedOutputStream outputStream = new PipedOutputStream(inputStream);
295+
296+
new Thread(
297+
() -> {
298+
try {
299+
gzipUnavailableDomains(outputStream, unavailableDomains);
300+
} catch (Throwable e) {
301+
logger.atSevere().withCause(e).log("Failed to gzip unavailable domains.");
302+
try {
303+
// This will cause the next read to throw an IOException.
304+
inputStream.close();
305+
} catch (IOException ignore) {
306+
// Won't happen for `PipedInputStream.close()`
307+
}
308+
}
309+
})
310+
.start();
311+
312+
return inputStream;
313+
}
314+
315+
private void gzipUnavailableDomains(
316+
PipedOutputStream outputStream, ImmutableSortedSet<String> unavailableDomains)
317+
throws IOException {
318+
// `GZIPOutputStream` is buffered.
319+
try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(outputStream)) {
320+
for (String name : unavailableDomains) {
321+
var line = name + "\n";
322+
gzipOutputStream.write(line.getBytes(US_ASCII));
323+
}
324+
}
325+
}
326+
327+
private static class StreamingRequestBody extends RequestBody {
328+
private final BufferedInputStream inputStream;
329+
private final MediaType mediaType;
330+
331+
StreamingRequestBody(InputStream inputStream, MediaType mediaType) {
332+
this.inputStream = new BufferedInputStream(inputStream);
333+
this.mediaType = mediaType;
334+
}
335+
336+
@Nullable
337+
@Override
338+
public MediaType contentType() {
339+
return mediaType;
340+
}
341+
342+
@Override
343+
public void writeTo(@NotNull BufferedSink bufferedSink) throws IOException {
344+
byte[] buffer = new byte[2048];
345+
int bytesRead;
346+
while ((bytesRead = inputStream.read(buffer)) != -1) {
347+
bufferedSink.write(buffer, 0, bytesRead);
348+
}
349+
}
350+
}
283351
}

core/src/test/java/google/registry/bsa/UploadBsaUnavailableDomainsActionTest.java

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,24 @@
2020
import static google.registry.testing.DatabaseHelper.persistDeletedDomain;
2121
import static google.registry.testing.DatabaseHelper.persistReservedList;
2222
import static google.registry.testing.DatabaseHelper.persistResource;
23+
import static google.registry.testing.LogsSubject.assertAboutLogs;
2324
import static google.registry.util.DateTimeUtils.START_OF_TIME;
25+
import static google.registry.util.NetworkUtils.pickUnusedPort;
2426
import static java.nio.charset.StandardCharsets.UTF_8;
27+
import static java.util.concurrent.Executors.newSingleThreadExecutor;
2528
import static org.mockito.Mockito.times;
2629
import static org.mockito.Mockito.verify;
2730

2831
import com.google.cloud.storage.BlobId;
2932
import com.google.cloud.storage.contrib.nio.testing.LocalStorageHelper;
33+
import com.google.common.collect.ImmutableList;
34+
import com.google.common.collect.ImmutableMap;
35+
import com.google.common.flogger.FluentLogger;
36+
import com.google.common.hash.Hashing;
37+
import com.google.common.io.ByteStreams;
38+
import com.google.common.net.HostAndPort;
39+
import com.google.common.testing.TestLogHandler;
40+
import com.google.gson.Gson;
3041
import google.registry.bsa.api.BsaCredential;
3142
import google.registry.gcs.GcsUtils;
3243
import google.registry.model.tld.Tld;
@@ -35,9 +46,25 @@
3546
import google.registry.persistence.transaction.JpaTestExtensions;
3647
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationTestExtension;
3748
import google.registry.request.UrlConnectionService;
49+
import google.registry.server.Route;
50+
import google.registry.server.TestServer;
3851
import google.registry.testing.FakeClock;
3952
import google.registry.testing.FakeResponse;
53+
import jakarta.servlet.ServletException;
54+
import jakarta.servlet.annotation.MultipartConfig;
55+
import jakarta.servlet.http.HttpServlet;
56+
import jakarta.servlet.http.HttpServletRequest;
57+
import jakarta.servlet.http.HttpServletResponse;
58+
import jakarta.servlet.http.Part;
59+
import java.io.IOException;
60+
import java.io.InputStream;
61+
import java.io.PrintWriter;
62+
import java.net.InetAddress;
63+
import java.util.Map;
4064
import java.util.Optional;
65+
import java.util.logging.Level;
66+
import java.util.logging.Logger;
67+
import java.util.zip.GZIPInputStream;
4168
import org.joda.time.DateTime;
4269
import org.junit.jupiter.api.BeforeEach;
4370
import org.junit.jupiter.api.Test;
@@ -102,13 +129,112 @@ void calculatesEntriesCorrectly() throws Exception {
102129
BlobId existingFile =
103130
BlobId.of(BUCKET, String.format("unavailable_domains_%s.txt", clock.nowUtc()));
104131
String blockList = new String(gcsUtils.readBytesFrom(existingFile), UTF_8);
105-
assertThat(blockList).isEqualTo("ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld");
132+
assertThat(blockList).isEqualTo("ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld\n");
106133
assertThat(blockList).doesNotContain("not-blocked.tld");
107134

108135
// This test currently fails in the upload-to-bsa step.
109136
verify(emailSender, times(1))
110137
.sendNotification("BSA daily upload completed with errors", "Please see logs for details.");
138+
}
139+
140+
@Test
141+
void uploadToBsaTest() throws Exception {
142+
TestLogHandler logHandler = new TestLogHandler();
143+
Logger loggerToIntercept =
144+
Logger.getLogger(UploadBsaUnavailableDomainsAction.class.getCanonicalName());
145+
loggerToIntercept.addHandler(logHandler);
146+
147+
persistActiveDomain("foobar.tld");
148+
persistActiveDomain("ace.tld");
149+
persistDeletedDomain("not-blocked.tld", clock.nowUtc().minusDays(1));
150+
151+
var testServer = startTestServer();
152+
action.apiUrl = testServer.getUrl("/upload").toURI().toString();
153+
try {
154+
action.run();
155+
} finally {
156+
testServer.stop();
157+
}
158+
String dataSent = "ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld\n";
159+
String checkSum = Hashing.sha512().hashString(dataSent, UTF_8).toString();
160+
String expectedResponse =
161+
"Received response with code 200 from server: "
162+
+ String.format("Checksum: [%s]\n%s\n", checkSum, dataSent);
163+
assertAboutLogs().that(logHandler).hasLogAtLevelWithMessage(Level.INFO, expectedResponse);
164+
verify(emailSender, times(1)).sendNotification("BSA daily upload completed successfully", "");
165+
}
166+
167+
private TestServer startTestServer() throws Exception {
168+
TestServer testServer =
169+
new TestServer(
170+
HostAndPort.fromParts(InetAddress.getLocalHost().getHostAddress(), pickUnusedPort()),
171+
ImmutableMap.of(),
172+
ImmutableList.of(Route.route("/upload", Servelet.class)));
173+
testServer.start();
174+
newSingleThreadExecutor()
175+
.execute(
176+
() -> {
177+
try {
178+
while (true) {
179+
testServer.process();
180+
}
181+
} catch (InterruptedException e) {
182+
// Expected
183+
}
184+
});
185+
return testServer;
186+
}
111187

112-
// TODO(mcilwain): Add test of BSA API upload as well.
188+
@MultipartConfig(
189+
location = "", // Directory for storing uploaded files. Use default when blank
190+
maxFileSize = 10485760L, // 10MB
191+
maxRequestSize = 20971520L, // 20MB
192+
fileSizeThreshold = 1048576 // Save in memory if file size < 1MB
193+
)
194+
public static class Servelet extends HttpServlet {
195+
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
196+
197+
@Override
198+
protected void doPost(HttpServletRequest req, HttpServletResponse resp)
199+
throws ServletException, IOException {
200+
String checkSum = null;
201+
String content = null;
202+
try {
203+
for (Part part : req.getParts()) {
204+
switch (part.getName()) {
205+
case "zone" -> checkSum = readChecksum(part);
206+
case "file" -> content = readGzipped(part);
207+
}
208+
}
209+
} catch (Exception e) {
210+
logger.atInfo().withCause(e).log("");
211+
}
212+
int status = checkSum == null || content == null ? 400 : 200;
213+
resp.setStatus(status);
214+
resp.setContentType("text/plain");
215+
try (PrintWriter writer = resp.getWriter()) {
216+
writer.printf("Checksum: [%s]\n%s\n", checkSum, content);
217+
}
218+
}
219+
220+
private String readChecksum(Part part) {
221+
try (InputStream is = part.getInputStream()) {
222+
return new Gson()
223+
.fromJson(new String(ByteStreams.toByteArray(is), UTF_8), Map.class)
224+
.getOrDefault("checkSum", "Not found")
225+
.toString();
226+
} catch (IOException e) {
227+
throw new RuntimeException(e);
228+
}
229+
}
230+
231+
private String readGzipped(Part part) {
232+
try (InputStream is = part.getInputStream();
233+
GZIPInputStream gis = new GZIPInputStream(is)) {
234+
return new String(ByteStreams.toByteArray(gis), UTF_8);
235+
} catch (IOException e) {
236+
throw new RuntimeException(e);
237+
}
238+
}
113239
}
114240
}

0 commit comments

Comments
 (0)