diff --git a/testgres/node.py b/testgres/node.py index f7109b0c..13d13294 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -63,7 +63,6 @@ from .defaults import \ default_dbname, \ - default_username, \ generate_app_name from .exceptions import \ @@ -683,8 +682,6 @@ def slow_start(self, replica=False, dbname='template1', username=None, max_attem If False, waits for the instance to be in primary mode. Default is False. max_attempts: """ - if not username: - username = default_username() self.start() if replica: @@ -694,7 +691,7 @@ def slow_start(self, replica=False, dbname='template1', username=None, max_attem # Call poll_query_until until the expected value is returned self.poll_query_until(query=query, dbname=dbname, - username=username, + username=username or self.os_ops.username, suppress={InternalError, QueryException, ProgrammingError, @@ -967,15 +964,13 @@ def psql(self, >>> psql(query='select 3', ON_ERROR_STOP=1) """ - # Set default arguments dbname = dbname or default_dbname() - username = username or default_username() psql_params = [ self._get_bin_path("psql"), "-p", str(self.port), "-h", self.host, - "-U", username, + "-U", username or self.os_ops.username, "-X", # no .psqlrc "-A", # unaligned output "-t", # print rows only @@ -1087,9 +1082,6 @@ def tmpfile(): fname = self.os_ops.mkstemp(prefix=TMP_DUMP) return fname - # Set default arguments - dbname = dbname or default_dbname() - username = username or default_username() filename = filename or tmpfile() _params = [ @@ -1097,8 +1089,8 @@ def tmpfile(): "-p", str(self.port), "-h", self.host, "-f", filename, - "-U", username, - "-d", dbname, + "-U", username or self.os_ops.username, + "-d", dbname or default_dbname(), "-F", format.value ] # yapf: disable @@ -1118,7 +1110,7 @@ def restore(self, filename, dbname=None, username=None): # Set default arguments dbname = dbname or default_dbname() - username = username or default_username() + username = username or self.os_ops.username _params = [ self._get_bin_path("pg_restore"), @@ -1388,15 +1380,13 @@ def pgbench(self, if options is None: options = [] - # Set default arguments dbname = dbname or default_dbname() - username = username or default_username() _params = [ self._get_bin_path("pgbench"), "-p", str(self.port), "-h", self.host, - "-U", username, + "-U", username or self.os_ops.username ] + options # yapf: disable # should be the last one @@ -1463,15 +1453,13 @@ def pgbench_run(self, dbname=None, username=None, options=[], **kwargs): >>> pgbench_run(time=10) """ - # Set default arguments dbname = dbname or default_dbname() - username = username or default_username() _params = [ self._get_bin_path("pgbench"), "-p", str(self.port), "-h", self.host, - "-U", username, + "-U", username or self.os_ops.username ] + options # yapf: disable for key, value in iteritems(kwargs): diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index ef360d3b..313d7060 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -38,7 +38,7 @@ def __init__(self, conn_params=None): self.host = conn_params.host self.ssh_key = None self.remote = False - self.username = conn_params.username or self.get_user() + self.username = conn_params.username or getpass.getuser() @staticmethod def _raise_exec_exception(message, command, exit_code, output): @@ -130,10 +130,6 @@ def set_env(self, var_name, var_val): # Check if the directory is already in PATH os.environ[var_name] = var_val - # Get environment variables - def get_user(self): - return self.username or getpass.getuser() - def get_name(self): return os.name diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index 236a08c6..0b5efff9 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -45,9 +45,8 @@ def set_env(self, var_name, var_val): # Check if the directory is already in PATH raise NotImplementedError() - # Get environment variables def get_user(self): - raise NotImplementedError() + return self.username def get_name(self): raise NotImplementedError() diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 697b4258..83965336 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -1,8 +1,9 @@ -import logging +import getpass import os +import logging +import platform import subprocess import tempfile -import platform # we support both pg8000 and psycopg2 try: @@ -52,7 +53,8 @@ def __init__(self, conn_params: ConnectionParams): if self.port: self.ssh_args += ["-p", self.port] self.remote = True - self.username = conn_params.username or self.get_user() + self.username = conn_params.username or getpass.getuser() + self.ssh_dest = f"{self.username}@{self.host}" if conn_params.username else self.host self.add_known_host(self.host) self.tunnel_process = None @@ -97,9 +99,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, """ ssh_cmd = [] if isinstance(cmd, str): - ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_args + [cmd] + ssh_cmd = ['ssh', self.ssh_dest] + self.ssh_args + [cmd] elif isinstance(cmd, list): - ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_args + cmd + ssh_cmd = ['ssh', self.ssh_dest] + self.ssh_args + cmd process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if get_process: return process @@ -174,10 +176,6 @@ def set_env(self, var_name: str, var_val: str): """ return self.exec_command("export {}={}".format(var_name, var_val)) - # Get environment variables - def get_user(self): - return self.exec_command("echo $USER", encoding=get_default_encoding()).strip() - def get_name(self): cmd = 'python3 -c "import os; print(os.name)"' return self.exec_command(cmd, encoding=get_default_encoding()).strip() @@ -248,9 +246,9 @@ def mkdtemp(self, prefix=None): - prefix (str): The prefix of the temporary directory name. """ if prefix: - command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"] + command = ["ssh"] + self.ssh_args + [self.ssh_dest, f"mktemp -d {prefix}XXXXX"] else: - command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", "mktemp -d"] + command = ["ssh"] + self.ssh_args + [self.ssh_dest, "mktemp -d"] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) @@ -296,7 +294,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal # For scp the port is specified by a "-P" option scp_args = ['-P' if x == '-p' else x for x in self.ssh_args] if not truncate: - scp_cmd = ['scp'] + scp_args + [f"{self.username}@{self.host}:{filename}", tmp_file.name] + scp_cmd = ['scp'] + scp_args + [f"{self.ssh_dest}:{filename}", tmp_file.name] subprocess.run(scp_cmd, check=False) # The file might not exist yet tmp_file.seek(0, os.SEEK_END) @@ -312,11 +310,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal tmp_file.write(data) tmp_file.flush() - scp_cmd = ['scp'] + scp_args + [tmp_file.name, f"{self.username}@{self.host}:{filename}"] + scp_cmd = ['scp'] + scp_args + [tmp_file.name, f"{self.ssh_dest}:{filename}"] subprocess.run(scp_cmd, check=True) remote_directory = os.path.dirname(filename) - mkdir_cmd = ['ssh'] + self.ssh_args + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"] + mkdir_cmd = ['ssh'] + self.ssh_args + [self.ssh_dest, f"mkdir -p {remote_directory}"] subprocess.run(mkdir_cmd, check=True) os.remove(tmp_file.name) @@ -381,7 +379,7 @@ def get_pid(self): return int(self.exec_command("echo $$", encoding=get_default_encoding())) def get_process_children(self, pid): - command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", f"pgrep -P {pid}"] + command = ["ssh"] + self.ssh_args + [self.ssh_dest, f"pgrep -P {pid}"] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)