From 800e648fc30710dc779dcc056441159bf8f8d810 Mon Sep 17 00:00:00 2001
From: asavchkov <pge@mail.ee>
Date: Tue, 25 Jun 2024 17:24:24 +0700
Subject: [PATCH] Add an SSH port parameter

---
 testgres/operations/os_ops.py     |  3 ++-
 testgres/operations/remote_ops.py | 26 +++++++++++++++-----------
 2 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py
index dd6613cf..236a08c6 100644
--- a/testgres/operations/os_ops.py
+++ b/testgres/operations/os_ops.py
@@ -10,8 +10,9 @@
 
 
 class ConnectionParams:
-    def __init__(self, host='127.0.0.1', ssh_key=None, username=None):
+    def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None):
         self.host = host
+        self.port = port
         self.ssh_key = ssh_key
         self.username = username
 
diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py
index 01251e1c..697b4258 100644
--- a/testgres/operations/remote_ops.py
+++ b/testgres/operations/remote_ops.py
@@ -44,11 +44,13 @@ def __init__(self, conn_params: ConnectionParams):
         super().__init__(conn_params.username)
         self.conn_params = conn_params
         self.host = conn_params.host
+        self.port = conn_params.port
         self.ssh_key = conn_params.ssh_key
+        self.ssh_args = []
         if self.ssh_key:
-            self.ssh_cmd = ["-i", self.ssh_key]
-        else:
-            self.ssh_cmd = []
+            self.ssh_args += ["-i", self.ssh_key]
+        if self.port:
+            self.ssh_args += ["-p", self.port]
         self.remote = True
         self.username = conn_params.username or self.get_user()
         self.add_known_host(self.host)
@@ -95,9 +97,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
         """
         ssh_cmd = []
         if isinstance(cmd, str):
-            ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + [cmd]
+            ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_args + [cmd]
         elif isinstance(cmd, list):
-            ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + cmd
+            ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_args + cmd
         process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
         if get_process:
             return process
@@ -246,9 +248,9 @@ def mkdtemp(self, prefix=None):
         - prefix (str): The prefix of the temporary directory name.
         """
         if prefix:
-            command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
+            command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
         else:
-            command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", "mktemp -d"]
+            command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", "mktemp -d"]
 
         result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
 
@@ -291,8 +293,10 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
             mode = "r+b" if binary else "r+"
 
         with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file:
+            # For scp the port is specified by a "-P" option
+            scp_args = ['-P' if x == '-p' else x for x in self.ssh_args]
             if not truncate:
-                scp_cmd = ['scp'] + self.ssh_cmd + [f"{self.username}@{self.host}:{filename}", tmp_file.name]
+                scp_cmd = ['scp'] + scp_args + [f"{self.username}@{self.host}:{filename}", tmp_file.name]
                 subprocess.run(scp_cmd, check=False)  # The file might not exist yet
                 tmp_file.seek(0, os.SEEK_END)
 
@@ -308,11 +312,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
                 tmp_file.write(data)
 
             tmp_file.flush()
-            scp_cmd = ['scp'] + self.ssh_cmd + [tmp_file.name, f"{self.username}@{self.host}:{filename}"]
+            scp_cmd = ['scp'] + scp_args + [tmp_file.name, f"{self.username}@{self.host}:{filename}"]
             subprocess.run(scp_cmd, check=True)
 
             remote_directory = os.path.dirname(filename)
-            mkdir_cmd = ['ssh'] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"]
+            mkdir_cmd = ['ssh'] + self.ssh_args + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"]
             subprocess.run(mkdir_cmd, check=True)
 
             os.remove(tmp_file.name)
@@ -377,7 +381,7 @@ def get_pid(self):
         return int(self.exec_command("echo $$", encoding=get_default_encoding()))
 
     def get_process_children(self, pid):
-        command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"pgrep -P {pid}"]
+        command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", f"pgrep -P {pid}"]
 
         result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)