diff --git a/setup.py b/setup.py index 074de8a1..16d4c300 100755 --- a/setup.py +++ b/setup.py @@ -11,9 +11,7 @@ "port-for>=0.4", "six>=1.9.0", "psutil", - "packaging", - "fabric", - "sshtunnel" + "packaging" ] # Add compatibility enum class @@ -29,7 +27,7 @@ readme = f.read() setup( - version='1.9.1', + version='1.9.2', name='testgres', packages=['testgres', 'testgres.operations'], description='Testing utility for PostgreSQL and its extensions', diff --git a/testgres/exceptions.py b/testgres/exceptions.py index 6832c788..ee329031 100644 --- a/testgres/exceptions.py +++ b/testgres/exceptions.py @@ -32,7 +32,16 @@ def __str__(self): if self.out: msg.append(u'----\n{}'.format(self.out)) - return six.text_type('\n').join(msg) + return self.convert_and_join(msg) + + @staticmethod + def convert_and_join(msg_list): + # Convert each byte element in the list to str + str_list = [six.text_type(item, 'utf-8') if isinstance(item, bytes) else six.text_type(item) for item in + msg_list] + + # Join the list into a single string with the specified delimiter + return six.text_type('\n').join(str_list) @six.python_2_unicode_compatible diff --git a/testgres/node.py b/testgres/node.py index 6483514b..84c25327 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -1371,7 +1371,7 @@ def pgbench(self, # should be the last one _params.append(dbname) - proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, proc=True) + proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, get_process=True) return proc diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index 318ae675..a692750e 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -18,7 +18,6 @@ from distutils.spawn import find_executable from distutils import rmtree - CMD_TIMEOUT_SEC = 60 error_markers = [b'error', b'Permission denied', b'fatal'] @@ -37,7 +36,8 @@ def __init__(self, conn_params=None): # Command execution def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=False, text=False, - input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None): + input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + get_process=None, timeout=None): """ Execute a command in a subprocess. @@ -69,9 +69,14 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, stdout=stdout, stderr=stderr, ) - if proc: + if get_process: return process - result, error = process.communicate(input) + + try: + result, error = process.communicate(input, timeout=timeout) + except subprocess.TimeoutExpired: + process.kill() + raise ExecUtilException("Command timed out after {} seconds.".format(timeout)) exit_status = process.returncode error_found = exit_status != 0 or any(marker in error for marker in error_markers) diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 5d9bfe7e..421c0a6d 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -3,17 +3,19 @@ import os import subprocess import tempfile -import time -import sshtunnel +# we support both pg8000 and psycopg2 +try: + import psycopg2 as pglib +except ImportError: + try: + import pg8000 as pglib + except ImportError: + raise ImportError("You must have psycopg2 or pg8000 modules installed") from ..exceptions import ExecUtilException from .os_ops import OsOperations, ConnectionParams -from .os_ops import pglib - -sshtunnel.SSH_TIMEOUT = 5.0 -sshtunnel.TUNNEL_TIMEOUT = 5.0 ConsoleEncoding = locale.getdefaultlocale()[1] if not ConsoleEncoding: @@ -50,21 +52,28 @@ def __init__(self, conn_params: ConnectionParams): self.remote = True self.username = conn_params.username or self.get_user() self.add_known_host(self.host) + self.tunnel_process = None def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - self.close_tunnel() + self.close_ssh_tunnel() - def close_tunnel(self): - if getattr(self, 'tunnel', None): - self.tunnel.stop(force=True) - start_time = time.time() - while self.tunnel.is_active: - if time.time() - start_time > sshtunnel.TUNNEL_TIMEOUT: - break - time.sleep(0.5) + def establish_ssh_tunnel(self, local_port, remote_port): + """ + Establish an SSH tunnel from a local port to a remote PostgreSQL port. + """ + ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"] + self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300) + + def close_ssh_tunnel(self): + if hasattr(self, 'tunnel_process'): + self.tunnel_process.terminate() + self.tunnel_process.wait() + del self.tunnel_process + else: + print("No active tunnel to close.") def add_known_host(self, host): cmd = 'ssh-keyscan -H %s >> /home/%s/.ssh/known_hosts' % (host, os.getlogin()) @@ -78,21 +87,29 @@ def add_known_host(self, host): raise ExecUtilException(message="Failed to add %s to known_hosts. Error: %s" % (host, str(e)), command=cmd, exit_code=e.returncode, out=e.stderr) - def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, + def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None, - stderr=None, proc=None): + stderr=None, get_process=None, timeout=None): """ Execute a command in the SSH session. Args: - cmd (str): The command to be executed. """ + ssh_cmd = [] if isinstance(cmd, str): ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key, cmd] elif isinstance(cmd, list): ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key] + cmd process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if get_process: + return process + + try: + result, error = process.communicate(input, timeout=timeout) + except subprocess.TimeoutExpired: + process.kill() + raise ExecUtilException("Command timed out after {} seconds.".format(timeout)) - result, error = process.communicate(input) exit_status = process.returncode if encoding: @@ -372,41 +389,19 @@ def get_process_children(self, pid): raise ExecUtilException(f"Error in getting process children. Error: {result.stderr}") # Database control - def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, ssh_key=None): + def db_connect(self, dbname, user, password=None, host="localhost", port=5432): """ - Connects to a PostgreSQL database on the remote system. - Args: - - dbname (str): The name of the database to connect to. - - user (str): The username for the database connection. - - password (str, optional): The password for the database connection. Defaults to None. - - host (str, optional): The IP address of the remote system. Defaults to "localhost". - - port (int, optional): The port number of the PostgreSQL service. Defaults to 5432. - - This function establishes a connection to a PostgreSQL database on the remote system using the specified - parameters. It returns a connection object that can be used to interact with the database. + Established SSH tunnel and Connects to a PostgreSQL """ - self.close_tunnel() - self.tunnel = sshtunnel.open_tunnel( - (self.host, 22), # Remote server IP and SSH port - ssh_username=self.username, - ssh_pkey=self.ssh_key, - remote_bind_address=(self.host, port), # PostgreSQL server IP and PostgreSQL port - local_bind_address=('localhost', 0) - # Local machine IP and available port (0 means it will pick any available port) - ) - self.tunnel.start() - + self.establish_ssh_tunnel(local_port=port, remote_port=5432) try: - # Use localhost and self.tunnel.local_bind_port to connect conn = pglib.connect( - host='localhost', # Connect to localhost - port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel + host=host, + port=port, database=dbname, - user=user or self.username, - password=password + user=user, + password=password, ) - return conn except Exception as e: - self.tunnel.stop() - raise ExecUtilException("Could not create db tunnel. {}".format(e)) + raise Exception(f"Could not connect to the database. Error: {e}") diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 44e77fbd..1042f3c4 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -735,9 +735,10 @@ def test_pgbench(self): options=['-q']).pgbench_run(time=2) # run TPC-B benchmark - out = node.pgbench(stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - options=['-T3']) + proc = node.pgbench(stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + options=['-T3']) + out = proc.communicate()[0] self.assertTrue(b'tps = ' in out) def test_pg_config(self):