Skip to content

Commit 79d5740

Browse files
Enhanced the BTEQ operator to ensure platform independence. (apache#52252)
* OS platform dependent code changed to platform independent (#59) Co-authored-by: Satish Ch <[email protected]> * Bteq platform independent (#61) * OS platform dependent code changed to platform independent * mac platform verified and adjusted code to work with zsh and normal shell --------- Co-authored-by: Satish Ch <[email protected]> --------- Co-authored-by: Satish Ch <[email protected]>
1 parent 85e559d commit 79d5740

File tree

6 files changed

+265
-29
lines changed

6 files changed

+265
-29
lines changed

providers/teradata/src/airflow/providers/teradata/hooks/bteq.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from airflow.providers.ssh.hooks.ssh import SSHHook
3030
from airflow.providers.teradata.hooks.ttu import TtuHook
3131
from airflow.providers.teradata.utils.bteq_util import (
32+
get_remote_tmp_dir,
33+
identify_os,
3234
prepare_bteq_command_for_local_execution,
3335
prepare_bteq_command_for_remote_execution,
3436
transfer_file_sftp,
@@ -161,7 +163,13 @@ def _transfer_to_and_execute_bteq_on_remote(
161163
password = generate_random_password() # Encryption/Decryption password
162164
encrypted_file_path = os.path.join(tmp_dir, "bteq_script.enc")
163165
generate_encrypted_file_with_openssl(file_path, password, encrypted_file_path)
166+
if not remote_working_dir:
167+
remote_working_dir = get_remote_tmp_dir(ssh_client)
168+
self.log.debug(
169+
"Transferring encrypted BTEQ script to remote host: %s", remote_working_dir
170+
)
164171
remote_encrypted_path = os.path.join(remote_working_dir or "", "bteq_script.enc")
172+
remote_encrypted_path = remote_encrypted_path.replace("/", "\\")
165173

166174
transfer_file_sftp(ssh_client, encrypted_file_path, remote_encrypted_path)
167175

@@ -219,14 +227,20 @@ def _transfer_to_and_execute_bteq_on_remote(
219227
if encrypted_file_path and os.path.exists(encrypted_file_path):
220228
os.remove(encrypted_file_path)
221229
# Cleanup: Delete the remote temporary file
222-
if encrypted_file_path:
223-
cleanup_en_command = f"rm -f {remote_encrypted_path}"
230+
if remote_encrypted_path:
224231
if self.ssh_hook and self.ssh_hook.get_conn():
225232
with self.ssh_hook.get_conn() as ssh_client:
226233
if ssh_client is None:
227234
raise AirflowException(
228235
"Failed to establish SSH connection. `ssh_client` is None."
229236
)
237+
# Detect OS
238+
os_info = identify_os(ssh_client)
239+
if "windows" in os_info:
240+
cleanup_en_command = f'del /f /q "{remote_encrypted_path}"'
241+
else:
242+
cleanup_en_command = f"rm -f '{remote_encrypted_path}'"
243+
self.log.debug("cleaning up remote file: %s", cleanup_en_command)
230244
ssh_client.exec_command(cleanup_en_command)
231245

232246
def execute_bteq_script_at_local(

providers/teradata/src/airflow/providers/teradata/operators/bteq.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,6 @@ def execute(self, context: Context) -> int | None:
141141
elif self.bteq_script_encoding == "UTF16":
142142
self.temp_file_read_encoding = "UTF-16"
143143

144-
if not self.remote_working_dir:
145-
self.remote_working_dir = "/tmp"
146144
# Handling execution on local:
147145
if not self._ssh_hook:
148146
if self.sql:

providers/teradata/src/airflow/providers/teradata/utils/bteq_util.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
from airflow.exceptions import AirflowException
2929

3030

31+
def identify_os(ssh_client: SSHClient) -> str:
32+
stdin, stdout, stderr = ssh_client.exec_command("uname || ver")
33+
return stdout.read().decode().lower()
34+
35+
3136
def verify_bteq_installed():
3237
"""Verify if BTEQ is installed and available in the system's PATH."""
3338
if shutil.which("bteq") is None:
@@ -36,7 +41,23 @@ def verify_bteq_installed():
3641

3742
def verify_bteq_installed_remote(ssh_client: SSHClient):
3843
"""Verify if BTEQ is installed on the remote machine."""
39-
stdin, stdout, stderr = ssh_client.exec_command("which bteq")
44+
# Detect OS
45+
os_info = identify_os(ssh_client)
46+
47+
if "windows" in os_info:
48+
check_cmd = "where bteq"
49+
elif "darwin" in os_info:
50+
# Check if zsh exists first
51+
stdin, stdout, stderr = ssh_client.exec_command("command -v zsh")
52+
zsh_path = stdout.read().strip()
53+
if zsh_path:
54+
check_cmd = 'zsh -l -c "which bteq"'
55+
else:
56+
check_cmd = "which bteq"
57+
else:
58+
check_cmd = "which bteq"
59+
60+
stdin, stdout, stderr = ssh_client.exec_command(check_cmd)
4061
exit_status = stdout.channel.recv_exit_status()
4162
output = stdout.read().strip()
4263
error = stderr.read().strip()
@@ -53,6 +74,20 @@ def transfer_file_sftp(ssh_client, local_path, remote_path):
5374
sftp.close()
5475

5576

77+
def get_remote_tmp_dir(ssh_client):
78+
os_info = identify_os(ssh_client)
79+
80+
if "windows" in os_info:
81+
# Try getting Windows temp dir
82+
stdin, stdout, stderr = ssh_client.exec_command("echo %TEMP%")
83+
tmp_dir = stdout.read().decode().strip()
84+
if not tmp_dir:
85+
tmp_dir = "C:\\Temp"
86+
else:
87+
tmp_dir = "/tmp"
88+
return tmp_dir
89+
90+
5691
# We can not pass host details with bteq command when executing on remote machine. Instead, we will prepare .logon in bteq script itself to avoid risk of
5792
# exposing sensitive information
5893
def prepare_bteq_script_for_remote_execution(conn: dict[str, Any], sql: str) -> str:

providers/teradata/tests/unit/teradata/hooks/test_bteq.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,14 @@ def test_execute_bteq_script_at_remote_success(
240240
mock_ssh_hook.get_conn.return_value.__enter__.return_value = mock_ssh_client
241241
mock_ssh_hook_class.return_value = mock_ssh_hook
242242

243-
# Instantiate BteqHook with ssh_conn_id (will use mocked SSHHook)
243+
# Mock exec_command to simulate 'uname || ver'
244+
mock_stdin = MagicMock()
245+
mock_stdout = MagicMock()
246+
mock_stderr = MagicMock()
247+
mock_stdout.read.return_value = b"Linux\n"
248+
mock_ssh_client.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr)
249+
250+
# Instantiate BteqHook
244251
hook = BteqHook(ssh_conn_id="ssh_conn_id", teradata_conn_id="teradata_conn")
245252

246253
# Call method under test
@@ -342,13 +349,28 @@ def test_remote_execution_cleanup_on_exception(
342349
temp_dir = "/tmp"
343350
local_file_path = os.path.join(temp_dir, "bteq_script.txt")
344351
remote_working_dir = temp_dir
345-
346-
# Make sure the local encrypted file exists for cleanup
347352
encrypted_file_path = os.path.join(temp_dir, "bteq_script.enc")
353+
354+
# Create dummy local encrypted file
348355
with open(encrypted_file_path, "w") as f:
349356
f.write("dummy")
350357

351-
with pytest.raises(AirflowException):
358+
# Simulate decrypt failing
359+
mock_decrypt.side_effect = Exception("mocked exception")
360+
361+
# Patch exec_command for remote cleanup (identify_os, rm)
362+
ssh_client = hook_with_ssh.ssh_hook.get_conn.return_value.__enter__.return_value
363+
364+
mock_stdin = MagicMock()
365+
mock_stdout = MagicMock()
366+
mock_stderr = MagicMock()
367+
368+
# For identify_os ("uname || ver")
369+
mock_stdout.read.return_value = b"Linux\n"
370+
ssh_client.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr)
371+
372+
# Run the test
373+
with pytest.raises(AirflowException, match="mocked exception"):
352374
hook_with_ssh._transfer_to_and_execute_bteq_on_remote(
353375
file_path=local_file_path,
354376
remote_working_dir=remote_working_dir,
@@ -360,5 +382,5 @@ def test_remote_execution_cleanup_on_exception(
360382
tmp_dir=temp_dir,
361383
)
362384

363-
# After exception, encrypted file should be deleted
385+
# Verify local encrypted file is deleted
364386
assert not os.path.exists(encrypted_file_path)

providers/teradata/tests/unit/teradata/operators/test_bteq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_execute(self, mock_hook_init, mock_execute_bteq):
5151

5252
# Then
5353
mock_hook_init.assert_called_once_with(teradata_conn_id=teradata_conn_id, ssh_conn_id=None)
54-
mock_execute_bteq.assert_called_once_with(sql + "\n.EXIT", "/tmp", "", 600, None, "", None, "UTF-8")
54+
mock_execute_bteq.assert_called_once_with(sql + "\n.EXIT", None, "", 600, None, "", None, "UTF-8")
5555
assert result == "BTEQ execution result"
5656

5757
@mock.patch.object(BteqHook, "execute_bteq_script")
@@ -81,7 +81,7 @@ def test_execute_sql_only(self, mock_hook_init, mock_execute_bteq):
8181
mock_hook_init.assert_called_once_with(teradata_conn_id=teradata_conn_id, ssh_conn_id=None)
8282
mock_execute_bteq.assert_called_once_with(
8383
sql + "\n.EXIT", # Assuming the prepare_bteq_script_for_local_execution appends ".EXIT"
84-
"/tmp", # default remote_working_dir
84+
None, # default remote_working_dir
8585
"", # bteq_script_encoding (default ASCII => empty string)
8686
600, # timeout default
8787
None, # timeout_rc

0 commit comments

Comments
 (0)