From 115e96b2dcf2e7aa496a6202df05eab5fa24dcce Mon Sep 17 00:00:00 2001 From: Deepika Date: Fri, 11 Jul 2025 00:03:02 +0000 Subject: [PATCH] gcsfuse_mount_spec parsing and supporting mount_options from flags --- axlearn/cloud/gcp/jobset_utils.py | 22 ++++++++++++-- axlearn/cloud/gcp/jobset_utils_test.py | 41 ++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/axlearn/cloud/gcp/jobset_utils.py b/axlearn/cloud/gcp/jobset_utils.py index 710f96701..513277386 100644 --- a/axlearn/cloud/gcp/jobset_utils.py +++ b/axlearn/cloud/gcp/jobset_utils.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from typing import Any, Optional, Sequence from urllib.parse import urlparse +import re from absl import flags @@ -84,6 +85,7 @@ class GCSFuseMount(VolumeMount): can wait to get a response from the server before timing out. Setting it to 0s means no limit. read_only: Whether the mount should be read-only. + mount_options: Comma-separated GCS FUSE mount options. """ gcs_path: str @@ -94,6 +96,7 @@ class GCSFuseMount(VolumeMount): ephemeral_gb: str = "5Gi" shared_memory: str = "1Gi" http_client_timeout: str = "0s" + mount_options: str = "" @dataclass(kw_only=True) @@ -298,7 +301,21 @@ def from_flags(cls, fv: flags.FlagValues, **kwargs): # pylint: disable=missing-kwoa # pytype: disable=missing-parameter if fv.gcsfuse_mount_spec: - cfg.gcsfuse_mount = GCSFuseMount(**parse_kv_flags(fv.gcsfuse_mount_spec, delimiter="=")) + specs = [] + # This regex splits by comma, but only if the comma is not inside double quotes. + quote_aware_splitter = re.compile(r',(?=(?:[^"]*"[^"]*")*[^"]*$)') + for spec_group in fv.gcsfuse_mount_spec: + specs.extend(quote_aware_splitter.split(spec_group)) + + # Parse the specs into a dictionary. This will preserve quotes in the values. + parsed_args = parse_kv_flags(specs, delimiter="=") + + # If mount_options was parsed, remove the outer quotes from its value. + if "mount_options" in parsed_args: + parsed_args["mount_options"] = parsed_args["mount_options"].strip('"\'') + + # Create the GCSFuseMount object with the cleaned arguments. + cfg.gcsfuse_mount = GCSFuseMount(**parsed_args) if fv.host_mount_spec: cfg.host_mounts = [ HostMount(**parse_kv_flags(item.split(","), delimiter="=")) @@ -563,6 +580,7 @@ def _build_pod(self) -> Nested[Any]: ) # Parse GCSFuseMount path into bucket, prefix. parsed = urlparse(cfg.gcsfuse_mount.gcs_path) + gcsMountOptions = cfg.gcsfuse_mount.mount_options or f"only-dir={parsed.path.lstrip('/')},implicit-dirs,metadata-cache:ttl-secs:-1,metadata-cache:stat-cache-max-size-mb:-1,metadata-cache:type-cache-max-size-mb:-1,kernel-list-cache-ttl-secs=-1,gcs-connection:http-client-timeout:{cfg.gcsfuse_mount.http_client_timeout}" # https://cloud.google.com/kubernetes-engine/docs/how-to/persistent-volumes/cloud-storage-fuse-csi-driver#consume-ephemeral-volume-pod # Caveat: --implicit-dirs might have negative impacts on i/o performance. See # https://github.com/googlecloudplatform/gcsfuse/blob/master/docs/semantics.md . @@ -578,7 +596,7 @@ def _build_pod(self) -> Nested[Any]: volumeAttributes=dict( bucketName=parsed.netloc, # pylint: disable=line-too-long - mountOptions=f"only-dir={parsed.path.lstrip('/')},implicit-dirs,metadata-cache:ttl-secs:-1,metadata-cache:stat-cache-max-size-mb:-1,metadata-cache:type-cache-max-size-mb:-1,kernel-list-cache-ttl-secs=-1,gcs-connection:http-client-timeout:{cfg.gcsfuse_mount.http_client_timeout}", + mountOptions=gcsMountOptions, gcsfuseMetadataPrefetchOnMount="false", # Improves first-time read. disableMetrics="false", # Enables GCSFuse metrics by default. ), diff --git a/axlearn/cloud/gcp/jobset_utils_test.py b/axlearn/cloud/gcp/jobset_utils_test.py index 62f7a8c6d..72f4578a3 100644 --- a/axlearn/cloud/gcp/jobset_utils_test.py +++ b/axlearn/cloud/gcp/jobset_utils_test.py @@ -106,6 +106,47 @@ def test_env_override(self): self.assertIn("key1", cfg.env_vars) self.assertEqual(cfg.env_vars["key1"], "value1") + @parameterized.named_parameters( + { + "testcase_name": "basic_spec_no_mount_options", + "spec": ["gcs_path=gs://a/b,mount_path=/c"], + "expected_gcs_path": "gs://a/b", + "expected_mount_path": "/c", + "expected_mount_options": "", + }, + { + "testcase_name": "complex_spec_with_quoted_mount_options", + "spec": [ + 'gcs_path=gs://a/b,mount_path=/c,mount_options="implicit-dirs,foo=bar,baz=qux"' + ], + "expected_gcs_path": "gs://a/b", + "expected_mount_path": "/c", + "expected_mount_options": "implicit-dirs,foo=bar,baz=qux", + }, + { + "testcase_name": "spec_with_empty_mount_options", + "spec": ['gcs_path=gs://a/b,mount_path=/c,mount_options=""'], + "expected_gcs_path": "gs://a/b", + "expected_mount_path": "/c", + "expected_mount_options": "", + }, + ) + def test_gcsfuse_mount_spec_parsing( + self, + spec: list[str], + expected_gcs_path: str, + expected_mount_path: str, + expected_mount_options: str, + ): + """Tests that gcsfuse_mount_spec is parsed correctly.""" + with self._job_config( + ArtifactRegistryBundler, gcsfuse_mount_spec=spec + ) as (cfg, _): + self.assertIsNotNone(cfg.gcsfuse_mount) + self.assertEqual(cfg.gcsfuse_mount.gcs_path, expected_gcs_path) + self.assertEqual(cfg.gcsfuse_mount.mount_path, expected_mount_path) + self.assertEqual(cfg.gcsfuse_mount.mount_options, expected_mount_options) + def test_validate_jobset_name(self): with ( self.assertRaisesRegex(ValueError, "invalid"),