From 96a072355c14151b7585e707e0de338734e372a7 Mon Sep 17 00:00:00 2001 From: Satish Ch Date: Tue, 24 Jun 2025 20:37:39 -0700 Subject: [PATCH 1/2] OS platform dependent code changed to platform independent --- .../airflow/providers/teradata/hooks/bteq.py | 18 ++- .../providers/teradata/operators/bteq.py | 2 - .../providers/teradata/utils/bteq_util.py | 29 +++- .../tests/unit/teradata/hooks/test_bteq.py | 32 ++++- .../unit/teradata/operators/test_bteq.py | 4 +- .../unit/teradata/utils/test_bteq_util.py | 124 +++++++++++++++--- 6 files changed, 181 insertions(+), 28 deletions(-) diff --git a/providers/teradata/src/airflow/providers/teradata/hooks/bteq.py b/providers/teradata/src/airflow/providers/teradata/hooks/bteq.py index 89aac594ddb18..dae400a638825 100644 --- a/providers/teradata/src/airflow/providers/teradata/hooks/bteq.py +++ b/providers/teradata/src/airflow/providers/teradata/hooks/bteq.py @@ -29,6 +29,8 @@ from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.providers.teradata.hooks.ttu import TtuHook from airflow.providers.teradata.utils.bteq_util import ( + get_remote_tmp_dir, + identify_os, prepare_bteq_command_for_local_execution, prepare_bteq_command_for_remote_execution, transfer_file_sftp, @@ -161,7 +163,13 @@ def _transfer_to_and_execute_bteq_on_remote( password = generate_random_password() # Encryption/Decryption password encrypted_file_path = os.path.join(tmp_dir, "bteq_script.enc") generate_encrypted_file_with_openssl(file_path, password, encrypted_file_path) + if not remote_working_dir: + remote_working_dir = get_remote_tmp_dir(ssh_client) + self.log.debug( + "Transferring encrypted BTEQ script to remote host: %s", remote_working_dir + ) remote_encrypted_path = os.path.join(remote_working_dir or "", "bteq_script.enc") + remote_encrypted_path = remote_encrypted_path.replace("/", "\\") transfer_file_sftp(ssh_client, encrypted_file_path, remote_encrypted_path) @@ -219,14 +227,20 @@ def _transfer_to_and_execute_bteq_on_remote( if encrypted_file_path and os.path.exists(encrypted_file_path): os.remove(encrypted_file_path) # Cleanup: Delete the remote temporary file - if encrypted_file_path: - cleanup_en_command = f"rm -f {remote_encrypted_path}" + if remote_encrypted_path: if self.ssh_hook and self.ssh_hook.get_conn(): with self.ssh_hook.get_conn() as ssh_client: if ssh_client is None: raise AirflowException( "Failed to establish SSH connection. `ssh_client` is None." ) + # Detect OS + os_info = identify_os(ssh_client) + if "windows" in os_info: + cleanup_en_command = f'del /f /q "{remote_encrypted_path}"' + else: + cleanup_en_command = f"rm -f '{remote_encrypted_path}'" + self.log.debug("cleaning up remote file: %s", cleanup_en_command) ssh_client.exec_command(cleanup_en_command) def execute_bteq_script_at_local( diff --git a/providers/teradata/src/airflow/providers/teradata/operators/bteq.py b/providers/teradata/src/airflow/providers/teradata/operators/bteq.py index 22779be9bfb33..c1373d51e9cd1 100644 --- a/providers/teradata/src/airflow/providers/teradata/operators/bteq.py +++ b/providers/teradata/src/airflow/providers/teradata/operators/bteq.py @@ -141,8 +141,6 @@ def execute(self, context: Context) -> int | None: elif self.bteq_script_encoding == "UTF16": self.temp_file_read_encoding = "UTF-16" - if not self.remote_working_dir: - self.remote_working_dir = "/tmp" # Handling execution on local: if not self._ssh_hook: if self.sql: diff --git a/providers/teradata/src/airflow/providers/teradata/utils/bteq_util.py b/providers/teradata/src/airflow/providers/teradata/utils/bteq_util.py index 0741ebb20090c..140c4f6c41262 100644 --- a/providers/teradata/src/airflow/providers/teradata/utils/bteq_util.py +++ b/providers/teradata/src/airflow/providers/teradata/utils/bteq_util.py @@ -28,6 +28,11 @@ from airflow.exceptions import AirflowException +def identify_os(ssh_client: SSHClient) -> str: + stdin, stdout, stderr = ssh_client.exec_command("uname || ver") + return stdout.read().decode().lower() + + def verify_bteq_installed(): """Verify if BTEQ is installed and available in the system's PATH.""" if shutil.which("bteq") is None: @@ -36,7 +41,15 @@ def verify_bteq_installed(): def verify_bteq_installed_remote(ssh_client: SSHClient): """Verify if BTEQ is installed on the remote machine.""" - stdin, stdout, stderr = ssh_client.exec_command("which bteq") + # Detect OS + os_info = identify_os(ssh_client) + + if "windows" in os_info: + check_cmd = "where bteq" + else: + check_cmd = "which bteq" + + stdin, stdout, stderr = ssh_client.exec_command(check_cmd) exit_status = stdout.channel.recv_exit_status() output = stdout.read().strip() error = stderr.read().strip() @@ -53,6 +66,20 @@ def transfer_file_sftp(ssh_client, local_path, remote_path): sftp.close() +def get_remote_tmp_dir(ssh_client): + os_info = identify_os(ssh_client) + + if "windows" in os_info: + # Try getting Windows temp dir + stdin, stdout, stderr = ssh_client.exec_command("echo %TEMP%") + tmp_dir = stdout.read().decode().strip() + if not tmp_dir: + tmp_dir = "C:\\Temp" + else: + tmp_dir = "/tmp" + return tmp_dir + + # We can not pass host details with bteq command when executing on remote machine. Instead, we will prepare .logon in bteq script itself to avoid risk of # exposing sensitive information def prepare_bteq_script_for_remote_execution(conn: dict[str, Any], sql: str) -> str: diff --git a/providers/teradata/tests/unit/teradata/hooks/test_bteq.py b/providers/teradata/tests/unit/teradata/hooks/test_bteq.py index 46b50e652b597..38269ecb99d17 100644 --- a/providers/teradata/tests/unit/teradata/hooks/test_bteq.py +++ b/providers/teradata/tests/unit/teradata/hooks/test_bteq.py @@ -240,7 +240,14 @@ def test_execute_bteq_script_at_remote_success( mock_ssh_hook.get_conn.return_value.__enter__.return_value = mock_ssh_client mock_ssh_hook_class.return_value = mock_ssh_hook - # Instantiate BteqHook with ssh_conn_id (will use mocked SSHHook) + # Mock exec_command to simulate 'uname || ver' + mock_stdin = MagicMock() + mock_stdout = MagicMock() + mock_stderr = MagicMock() + mock_stdout.read.return_value = b"Linux\n" + mock_ssh_client.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr) + + # Instantiate BteqHook hook = BteqHook(ssh_conn_id="ssh_conn_id", teradata_conn_id="teradata_conn") # Call method under test @@ -342,13 +349,28 @@ def test_remote_execution_cleanup_on_exception( temp_dir = "/tmp" local_file_path = os.path.join(temp_dir, "bteq_script.txt") remote_working_dir = temp_dir - - # Make sure the local encrypted file exists for cleanup encrypted_file_path = os.path.join(temp_dir, "bteq_script.enc") + + # Create dummy local encrypted file with open(encrypted_file_path, "w") as f: f.write("dummy") - with pytest.raises(AirflowException): + # Simulate decrypt failing + mock_decrypt.side_effect = Exception("mocked exception") + + # Patch exec_command for remote cleanup (identify_os, rm) + ssh_client = hook_with_ssh.ssh_hook.get_conn.return_value.__enter__.return_value + + mock_stdin = MagicMock() + mock_stdout = MagicMock() + mock_stderr = MagicMock() + + # For identify_os ("uname || ver") + mock_stdout.read.return_value = b"Linux\n" + ssh_client.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr) + + # Run the test + with pytest.raises(AirflowException, match="mocked exception"): hook_with_ssh._transfer_to_and_execute_bteq_on_remote( file_path=local_file_path, remote_working_dir=remote_working_dir, @@ -360,5 +382,5 @@ def test_remote_execution_cleanup_on_exception( tmp_dir=temp_dir, ) - # After exception, encrypted file should be deleted + # Verify local encrypted file is deleted assert not os.path.exists(encrypted_file_path) diff --git a/providers/teradata/tests/unit/teradata/operators/test_bteq.py b/providers/teradata/tests/unit/teradata/operators/test_bteq.py index d6690096aa83f..0e9a0ed0b150d 100644 --- a/providers/teradata/tests/unit/teradata/operators/test_bteq.py +++ b/providers/teradata/tests/unit/teradata/operators/test_bteq.py @@ -51,7 +51,7 @@ def test_execute(self, mock_hook_init, mock_execute_bteq): # Then mock_hook_init.assert_called_once_with(teradata_conn_id=teradata_conn_id, ssh_conn_id=None) - mock_execute_bteq.assert_called_once_with(sql + "\n.EXIT", "/tmp", "", 600, None, "", None, "UTF-8") + mock_execute_bteq.assert_called_once_with(sql + "\n.EXIT", None, "", 600, None, "", None, "UTF-8") assert result == "BTEQ execution result" @mock.patch.object(BteqHook, "execute_bteq_script") @@ -81,7 +81,7 @@ def test_execute_sql_only(self, mock_hook_init, mock_execute_bteq): mock_hook_init.assert_called_once_with(teradata_conn_id=teradata_conn_id, ssh_conn_id=None) mock_execute_bteq.assert_called_once_with( sql + "\n.EXIT", # Assuming the prepare_bteq_script_for_local_execution appends ".EXIT" - "/tmp", # default remote_working_dir + None, # default remote_working_dir "", # bteq_script_encoding (default ASCII => empty string) 600, # timeout default None, # timeout_rc diff --git a/providers/teradata/tests/unit/teradata/utils/test_bteq_util.py b/providers/teradata/tests/unit/teradata/utils/test_bteq_util.py index f0ee54aaa6804..fbefe42213cb6 100644 --- a/providers/teradata/tests/unit/teradata/utils/test_bteq_util.py +++ b/providers/teradata/tests/unit/teradata/utils/test_bteq_util.py @@ -25,6 +25,7 @@ from airflow.exceptions import AirflowException from airflow.providers.teradata.utils.bteq_util import ( + identify_os, is_valid_encoding, is_valid_file, is_valid_remote_bteq_script_file, @@ -38,6 +39,62 @@ class TestBteqUtils: + def test_identify_os_linux(self): + # Arrange + ssh_client = MagicMock() + stdout_mock = MagicMock() + stdout_mock.read.return_value = b"Linux\n" + ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock()) + + # Act + os_info = identify_os(ssh_client) + + # Assert + ssh_client.exec_command.assert_called_once_with("uname || ver") + assert os_info == "linux\n" + + def test_identify_os_windows(self): + # Arrange + ssh_client = MagicMock() + stdout_mock = MagicMock() + stdout_mock.read.return_value = b"Microsoft Windows [Version 10.0.19045.3324]\n" + ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock()) + + # Act + os_info = identify_os(ssh_client) + + # Assert + ssh_client.exec_command.assert_called_once_with("uname || ver") + assert "windows" in os_info + + def test_identify_os_macos(self): + # Arrange + ssh_client = MagicMock() + stdout_mock = MagicMock() + stdout_mock.read.return_value = b"Darwin\n" + ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock()) + + # Act + os_info = identify_os(ssh_client) + + # Assert + ssh_client.exec_command.assert_called_once_with("uname || ver") + assert os_info == "darwin\n" + + def test_identify_os_empty_response(self): + # Arrange + ssh_client = MagicMock() + stdout_mock = MagicMock() + stdout_mock.read.return_value = b"" + ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock()) + + # Act + os_info = identify_os(ssh_client) + + # Assert + ssh_client.exec_command.assert_called_once_with("uname || ver") + assert os_info == "" + @patch("shutil.which") def test_verify_bteq_installed_success(self, mock_which): mock_which.return_value = "/usr/bin/bteq" @@ -65,6 +122,57 @@ def test_prepare_bteq_script_for_local_execution(self): assert "SELECT 1;" in script assert ".EXIT" in script + @patch("airflow.providers.teradata.utils.bteq_util.identify_os", return_value="linux") + def test_verify_bteq_installed_remote_linux(self, mock_os): + ssh_client = MagicMock() + stdout_mock = MagicMock() + stdout_mock.read.return_value = b"/usr/bin/bteq" + stdout_mock.channel.recv_exit_status.return_value = 0 + + ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock()) + + verify_bteq_installed_remote(ssh_client) + ssh_client.exec_command.assert_called_once_with("which bteq") + + @patch("airflow.providers.teradata.utils.bteq_util.identify_os", return_value="windows") + def test_verify_bteq_installed_remote_windows(self, mock_os): + ssh_client = MagicMock() + stdout_mock = MagicMock() + stdout_mock.read.return_value = b"C:\\Program Files\\bteq.exe" + stdout_mock.channel.recv_exit_status.return_value = 0 + + ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock()) + + verify_bteq_installed_remote(ssh_client) + ssh_client.exec_command.assert_called_once_with("where bteq") + + @patch("airflow.providers.teradata.utils.bteq_util.identify_os", return_value="darwin") + def test_verify_bteq_installed_remote_macos(self, mock_os): + ssh_client = MagicMock() + stdout_mock = MagicMock() + stdout_mock.read.return_value = b"/usr/local/bin/bteq" + stdout_mock.channel.recv_exit_status.return_value = 0 + + ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock()) + + verify_bteq_installed_remote(ssh_client) + ssh_client.exec_command.assert_called_once_with("which bteq") + + @patch("airflow.providers.teradata.utils.bteq_util.identify_os", return_value="linux") + def test_verify_bteq_installed_remote_fail(self, mock_os): + ssh_client = MagicMock() + stdout_mock = MagicMock() + stderr_mock = MagicMock() + stdout_mock.read.return_value = b"" + stderr_mock.read.return_value = b"command not found" + stdout_mock.channel.recv_exit_status.return_value = 1 + + ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, stderr_mock) + + with pytest.raises(AirflowException, match="BTEQ is not installed or not available in PATH"): + verify_bteq_installed_remote(ssh_client) + ssh_client.exec_command.assert_called_once_with("which bteq") + @patch("paramiko.SSHClient.exec_command") def test_verify_bteq_installed_remote_success(self, mock_exec): mock_stdin = MagicMock() @@ -81,22 +189,6 @@ def test_verify_bteq_installed_remote_success(self, mock_exec): # Should not raise verify_bteq_installed_remote(ssh_client) - @patch("paramiko.SSHClient.exec_command") - def test_verify_bteq_installed_remote_fail(self, mock_exec): - mock_stdin = MagicMock() - mock_stdout = MagicMock() - mock_stderr = MagicMock() - mock_stdout.channel.recv_exit_status.return_value = 1 - mock_stdout.read.return_value = b"" - mock_stderr.read.return_value = b"command not found" - mock_exec.return_value = (mock_stdin, mock_stdout, mock_stderr) - - ssh_client = MagicMock() - ssh_client.exec_command = mock_exec - - with pytest.raises(AirflowException): - verify_bteq_installed_remote(ssh_client) - @patch("paramiko.SSHClient.open_sftp") def test_transfer_file_sftp(self, mock_open_sftp): mock_sftp = MagicMock() From 70aed001e2fc39d544d08f09c11be9a4dc4ecfe3 Mon Sep 17 00:00:00 2001 From: Satish Ch Date: Tue, 24 Jun 2025 23:02:46 -0700 Subject: [PATCH 2/2] mac platform verified and adjusted code to work with zsh and normal shell --- .../providers/teradata/utils/bteq_util.py | 8 ++ .../unit/teradata/utils/test_bteq_util.py | 79 ++++++++++++++++++- 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/providers/teradata/src/airflow/providers/teradata/utils/bteq_util.py b/providers/teradata/src/airflow/providers/teradata/utils/bteq_util.py index 140c4f6c41262..591e4c085c874 100644 --- a/providers/teradata/src/airflow/providers/teradata/utils/bteq_util.py +++ b/providers/teradata/src/airflow/providers/teradata/utils/bteq_util.py @@ -46,6 +46,14 @@ def verify_bteq_installed_remote(ssh_client: SSHClient): if "windows" in os_info: check_cmd = "where bteq" + elif "darwin" in os_info: + # Check if zsh exists first + stdin, stdout, stderr = ssh_client.exec_command("command -v zsh") + zsh_path = stdout.read().strip() + if zsh_path: + check_cmd = 'zsh -l -c "which bteq"' + else: + check_cmd = "which bteq" else: check_cmd = "which bteq" diff --git a/providers/teradata/tests/unit/teradata/utils/test_bteq_util.py b/providers/teradata/tests/unit/teradata/utils/test_bteq_util.py index fbefe42213cb6..78f2624d01e94 100644 --- a/providers/teradata/tests/unit/teradata/utils/test_bteq_util.py +++ b/providers/teradata/tests/unit/teradata/utils/test_bteq_util.py @@ -19,7 +19,7 @@ import os import stat import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import pytest @@ -156,7 +156,82 @@ def test_verify_bteq_installed_remote_macos(self, mock_os): ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock()) verify_bteq_installed_remote(ssh_client) - ssh_client.exec_command.assert_called_once_with("which bteq") + + ssh_client.exec_command.assert_has_calls( + [ + call("command -v zsh"), + call('zsh -l -c "which bteq"'), + ] + ) + + @patch("airflow.providers.teradata.utils.bteq_util.identify_os", return_value="darwin") + def test_verify_bteq_installed_remote_macos_which_called_when_no_zsh(self, mock_os): + ssh_client = MagicMock() + + # Mock for "command -v zsh" returning empty (no zsh) + stdin_mock_1 = MagicMock() + stdout_mock_1 = MagicMock() + stderr_mock_1 = MagicMock() + stdout_mock_1.read.return_value = b"" # No zsh path found + stderr_mock_1.read.return_value = b"" # Return empty bytes here! + ssh_client.exec_command.side_effect = [ + (stdin_mock_1, stdout_mock_1, stderr_mock_1), # command -v zsh + (MagicMock(), MagicMock(), MagicMock()), # which bteq + ] + + # Mock for "which bteq" command response + stdin_mock_2 = MagicMock() + stdout_mock_2 = MagicMock() + stderr_mock_2 = MagicMock() + stdout_mock_2.channel.recv_exit_status.return_value = 0 + stdout_mock_2.read.return_value = b"/usr/local/bin/bteq" + stderr_mock_2.read.return_value = b"" # Also return bytes here + + # Since side_effect was already assigned, override second call manually + ssh_client.exec_command.side_effect = [ + (stdin_mock_1, stdout_mock_1, stderr_mock_1), # command -v zsh + (stdin_mock_2, stdout_mock_2, stderr_mock_2), # which bteq + ] + + verify_bteq_installed_remote(ssh_client) + + ssh_client.exec_command.assert_has_calls( + [ + call("command -v zsh"), + call("which bteq"), + ] + ) + + @patch("airflow.providers.teradata.utils.bteq_util.identify_os", return_value="darwin") + def test_verify_bteq_installed_remote_macos_which_fails_no_zsh(self, mock_os): + ssh_client = MagicMock() + + # Mock for "command -v zsh" returning empty (no zsh) + stdin_mock_1 = MagicMock() + stdout_mock_1 = MagicMock() + stderr_mock_1 = MagicMock() + stdout_mock_1.read.return_value = b"" # No zsh path found + ssh_client.exec_command.side_effect = [ + (stdin_mock_1, stdout_mock_1, stderr_mock_1), # command -v zsh + (MagicMock(), MagicMock(), MagicMock()), # which bteq + ] + + # For which bteq failure + ssh_client.exec_command.return_value[1].channel.recv_exit_status.return_value = 1 + ssh_client.exec_command.return_value[1].read.return_value = b"" + ssh_client.exec_command.return_value[2].read.return_value = b"command not found" + + with pytest.raises(AirflowException) as exc_info: + verify_bteq_installed_remote(ssh_client) + + assert "BTEQ is not installed or not available in PATH" in str(exc_info.value) + + ssh_client.exec_command.assert_has_calls( + [ + call("command -v zsh"), + call("which bteq"), + ] + ) @patch("airflow.providers.teradata.utils.bteq_util.identify_os", return_value="linux") def test_verify_bteq_installed_remote_fail(self, mock_os):