Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions axlearn/cloud/gcp/jobset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import dataclass
from typing import Any, Optional, Sequence
from urllib.parse import urlparse
import re

from absl import flags

Expand Down Expand Up @@ -84,6 +85,7 @@ class GCSFuseMount(VolumeMount):
can wait to get a response from the server before timing out. Setting it to 0s means no
limit.
read_only: Whether the mount should be read-only.
mount_options: Comma-separated GCS FUSE mount options.
"""

gcs_path: str
Expand All @@ -94,6 +96,7 @@ class GCSFuseMount(VolumeMount):
ephemeral_gb: str = "5Gi"
shared_memory: str = "1Gi"
http_client_timeout: str = "0s"
mount_options: str = ""


@dataclass(kw_only=True)
Expand Down Expand Up @@ -298,7 +301,21 @@ def from_flags(cls, fv: flags.FlagValues, **kwargs):
# pylint: disable=missing-kwoa
# pytype: disable=missing-parameter
if fv.gcsfuse_mount_spec:
cfg.gcsfuse_mount = GCSFuseMount(**parse_kv_flags(fv.gcsfuse_mount_spec, delimiter="="))
specs = []
# This regex splits by comma, but only if the comma is not inside double quotes.
quote_aware_splitter = re.compile(r',(?=(?:[^"]*"[^"]*")*[^"]*$)')
for spec_group in fv.gcsfuse_mount_spec:
specs.extend(quote_aware_splitter.split(spec_group))

# Parse the specs into a dictionary. This will preserve quotes in the values.
parsed_args = parse_kv_flags(specs, delimiter="=")

# If mount_options was parsed, remove the outer quotes from its value.
if "mount_options" in parsed_args:
parsed_args["mount_options"] = parsed_args["mount_options"].strip('"\'')

# Create the GCSFuseMount object with the cleaned arguments.
cfg.gcsfuse_mount = GCSFuseMount(**parsed_args)
if fv.host_mount_spec:
cfg.host_mounts = [
HostMount(**parse_kv_flags(item.split(","), delimiter="="))
Expand Down Expand Up @@ -563,6 +580,7 @@ def _build_pod(self) -> Nested[Any]:
)
# Parse GCSFuseMount path into bucket, prefix.
parsed = urlparse(cfg.gcsfuse_mount.gcs_path)
gcsMountOptions = cfg.gcsfuse_mount.mount_options or f"only-dir={parsed.path.lstrip('/')},implicit-dirs,metadata-cache:ttl-secs:-1,metadata-cache:stat-cache-max-size-mb:-1,metadata-cache:type-cache-max-size-mb:-1,kernel-list-cache-ttl-secs=-1,gcs-connection:http-client-timeout:{cfg.gcsfuse_mount.http_client_timeout}"
# https://cloud.google.com/kubernetes-engine/docs/how-to/persistent-volumes/cloud-storage-fuse-csi-driver#consume-ephemeral-volume-pod
# Caveat: --implicit-dirs might have negative impacts on i/o performance. See
# https://github.com/googlecloudplatform/gcsfuse/blob/master/docs/semantics.md .
Expand All @@ -578,7 +596,7 @@ def _build_pod(self) -> Nested[Any]:
volumeAttributes=dict(
bucketName=parsed.netloc,
# pylint: disable=line-too-long
mountOptions=f"only-dir={parsed.path.lstrip('/')},implicit-dirs,metadata-cache:ttl-secs:-1,metadata-cache:stat-cache-max-size-mb:-1,metadata-cache:type-cache-max-size-mb:-1,kernel-list-cache-ttl-secs=-1,gcs-connection:http-client-timeout:{cfg.gcsfuse_mount.http_client_timeout}",
mountOptions=gcsMountOptions,
gcsfuseMetadataPrefetchOnMount="false", # Improves first-time read.
disableMetrics="false", # Enables GCSFuse metrics by default.
),
Expand Down
41 changes: 41 additions & 0 deletions axlearn/cloud/gcp/jobset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,47 @@ def test_env_override(self):
self.assertIn("key1", cfg.env_vars)
self.assertEqual(cfg.env_vars["key1"], "value1")

@parameterized.named_parameters(
{
"testcase_name": "basic_spec_no_mount_options",
"spec": ["gcs_path=gs://a/b,mount_path=/c"],
"expected_gcs_path": "gs://a/b",
"expected_mount_path": "/c",
"expected_mount_options": "",
},
{
"testcase_name": "complex_spec_with_quoted_mount_options",
"spec": [
'gcs_path=gs://a/b,mount_path=/c,mount_options="implicit-dirs,foo=bar,baz=qux"'
],
"expected_gcs_path": "gs://a/b",
"expected_mount_path": "/c",
"expected_mount_options": "implicit-dirs,foo=bar,baz=qux",
},
{
"testcase_name": "spec_with_empty_mount_options",
"spec": ['gcs_path=gs://a/b,mount_path=/c,mount_options=""'],
"expected_gcs_path": "gs://a/b",
"expected_mount_path": "/c",
"expected_mount_options": "",
},
)
def test_gcsfuse_mount_spec_parsing(
self,
spec: list[str],
expected_gcs_path: str,
expected_mount_path: str,
expected_mount_options: str,
):
"""Tests that gcsfuse_mount_spec is parsed correctly."""
with self._job_config(
ArtifactRegistryBundler, gcsfuse_mount_spec=spec
) as (cfg, _):
self.assertIsNotNone(cfg.gcsfuse_mount)
self.assertEqual(cfg.gcsfuse_mount.gcs_path, expected_gcs_path)
self.assertEqual(cfg.gcsfuse_mount.mount_path, expected_mount_path)
self.assertEqual(cfg.gcsfuse_mount.mount_options, expected_mount_options)

def test_validate_jobset_name(self):
with (
self.assertRaisesRegex(ValueError, "invalid"),
Expand Down
Loading