Skip to content

Commit 8b7437a

Browse files
authored
improvements for s3 environements variables (#1343)
* lazy loading for `s3` environements variables * `S3_ENDPOINT` supports http/https * remove `S3_USE_HTTPS` and `S3_VERIFY_SSL`
1 parent 653453a commit 8b7437a

File tree

3 files changed

+30
-46
lines changed

3 files changed

+30
-46
lines changed

tensorflow_io/core/plugins/s3/s3_filesystem.cc

Lines changed: 25 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,6 @@ static Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
135135
absl::MutexLock l(&cfg_lock);
136136

137137
if (!init) {
138-
const char* endpoint = getenv("S3_ENDPOINT");
139-
if (endpoint) cfg.endpointOverride = Aws::String(endpoint);
140138
const char* region = getenv("AWS_REGION");
141139
// TODO (yongtang): `S3_REGION` should be deprecated after 2.0.
142140
if (!region) region = getenv("S3_REGION");
@@ -168,20 +166,6 @@ static Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
168166
cfg.region = profiles["default"].GetRegion();
169167
}
170168
}
171-
const char* use_https = getenv("S3_USE_HTTPS");
172-
if (use_https) {
173-
if (use_https[0] == '0')
174-
cfg.scheme = Aws::Http::Scheme::HTTP;
175-
else
176-
cfg.scheme = Aws::Http::Scheme::HTTPS;
177-
}
178-
const char* verify_ssl = getenv("S3_VERIFY_SSL");
179-
if (verify_ssl) {
180-
if (verify_ssl[0] == '0')
181-
cfg.verifySSL = false;
182-
else
183-
cfg.verifySSL = true;
184-
}
185169
// if these timeouts are low, you may see an error when
186170
// uploading/downloading large files: Unable to connect to endpoint
187171
int64_t timeout;
@@ -241,6 +225,13 @@ static void GetS3Client(tf_s3_filesystem::S3File* s3_file) {
241225
tf_s3_filesystem::AWSLogSystem::ShutdownAWSLogging();
242226
}
243227
});
228+
229+
int temp_value;
230+
if (absl::SimpleAtoi(getenv("S3_DISABLE_MULTI_PART_DOWNLOAD"), &temp_value))
231+
s3_file->use_multi_part_download = (temp_value != 1);
232+
233+
const char* endpoint = getenv("S3_ENDPOINT");
234+
if (endpoint) s3_file->s3_client->OverrideEndpoint(endpoint);
244235
}
245236
}
246237

@@ -263,15 +254,26 @@ static void GetTransferManager(
263254

264255
absl::MutexLock l(&s3_file->initialization_lock);
265256

266-
if (s3_file->transfer_managers[direction].get() == nullptr) {
257+
if (s3_file->transfer_managers.count(direction) == 0) {
258+
uint64_t temp_value;
259+
if (direction == Aws::Transfer::TransferDirection::UPLOAD) {
260+
if (!absl::SimpleAtoi(getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE"),
261+
&temp_value))
262+
temp_value = kS3MultiPartUploadChunkSize;
263+
} else if (direction == Aws::Transfer::TransferDirection::DOWNLOAD) {
264+
if (!absl::SimpleAtoi(getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE"),
265+
&temp_value))
266+
temp_value = kS3MultiPartDownloadChunkSize;
267+
}
268+
s3_file->multi_part_chunk_sizes.emplace(direction, temp_value);
269+
267270
Aws::Transfer::TransferManagerConfiguration config(s3_file->executor.get());
268271
config.s3Client = s3_file->s3_client;
269-
config.bufferSize = s3_file->multi_part_chunk_sizes[direction];
272+
config.bufferSize = temp_value;
270273
// must be larger than pool size * multi part chunk size
271-
config.transferBufferMaxHeapSize =
272-
(kExecutorPoolSize + 1) * s3_file->multi_part_chunk_sizes[direction];
273-
s3_file->transfer_managers[direction] =
274-
Aws::Transfer::TransferManager::Create(config);
274+
config.transferBufferMaxHeapSize = (kExecutorPoolSize + 1) * temp_value;
275+
s3_file->transfer_managers.emplace(
276+
direction, Aws::Transfer::TransferManager::Create(config));
275277
}
276278
}
277279

@@ -529,24 +531,7 @@ S3File::S3File()
529531
transfer_managers(),
530532
multi_part_chunk_sizes(),
531533
use_multi_part_download(false), // TODO: change to true after fix
532-
initialization_lock() {
533-
uint64_t temp_value;
534-
multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD] =
535-
absl::SimpleAtoi(getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE"), &temp_value)
536-
? temp_value
537-
: kS3MultiPartUploadChunkSize;
538-
multi_part_chunk_sizes[Aws::Transfer::TransferDirection::DOWNLOAD] =
539-
absl::SimpleAtoi(getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE"), &temp_value)
540-
? temp_value
541-
: kS3MultiPartDownloadChunkSize;
542-
use_multi_part_download =
543-
absl::SimpleAtoi(getenv("S3_DISABLE_MULTI_PART_DOWNLOAD"), &temp_value)
544-
? (temp_value != 1)
545-
: use_multi_part_download;
546-
transfer_managers.emplace(Aws::Transfer::TransferDirection::UPLOAD, nullptr);
547-
transfer_managers.emplace(Aws::Transfer::TransferDirection::DOWNLOAD,
548-
nullptr);
549-
}
534+
initialization_lock() {}
550535
void Init(TF_Filesystem* filesystem, TF_Status* status) {
551536
filesystem->plugin_filesystem = new S3File();
552537
TF_SetStatus(status, TF_OK, "");

tensorflow_io/core/plugins/s3/s3_filesystem.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,12 @@ typedef struct S3File {
6060
std::shared_ptr<Aws::S3::S3Client> s3_client;
6161
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor> executor;
6262
// We need 2 `TransferManager`, for multipart upload/download.
63-
Aws::Map<Aws::Transfer::TransferDirection,
64-
std::shared_ptr<Aws::Transfer::TransferManager>>
63+
Aws::UnorderedMap<Aws::Transfer::TransferDirection,
64+
std::shared_ptr<Aws::Transfer::TransferManager>>
6565
transfer_managers;
6666
// Sizes to split objects during multipart upload/download.
67-
Aws::Map<Aws::Transfer::TransferDirection, uint64_t> multi_part_chunk_sizes;
67+
Aws::UnorderedMap<Aws::Transfer::TransferDirection, uint64_t>
68+
multi_part_chunk_sizes;
6869
bool use_multi_part_download;
6970
absl::Mutex initialization_lock;
7071
S3File();

tests/test_s3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ def test_read_file():
5151
response = client.get_object(Bucket=bucket_name, Key=key_name)
5252
assert response["Body"].read() == body
5353

54-
os.environ["S3_ENDPOINT"] = "localhost:4566"
55-
os.environ["S3_USE_HTTPS"] = "0"
56-
os.environ["S3_VERIFY_SSL"] = "0"
54+
os.environ["S3_ENDPOINT"] = "http://localhost:4566"
5755

5856
content = tf.io.read_file("s3://{}/{}".format(bucket_name, key_name))
5957
assert content == body

0 commit comments

Comments
 (0)