diff --git a/README.md b/README.md index 6b26ba96..29b974dc 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,33 @@ with testgres.get_new_node().init() as master: Note that `default_conf()` is called by `init()` function; both of them overwrite the configuration file, which means that they should be called before `append_conf()`. +### Remote mode +Testgres supports the creation of PostgreSQL nodes on a remote host. This is useful when you want to run distributed tests involving multiple nodes spread across different machines. + +To use this feature, you need to use the RemoteOperations class. +Here is an example of how you might set this up: + +```python +from testgres import ConnectionParams, RemoteOperations, TestgresConfig, get_remote_node + +# Set up connection params +conn_params = ConnectionParams( + host='your_host', # replace with your host + username='user_name', # replace with your username + ssh_key='path_to_ssh_key' # replace with your SSH key path +) +os_ops = RemoteOperations(conn_params) + +# Add remote testgres config before test +TestgresConfig.set_os_ops(os_ops=os_ops) + +# Proceed with your test +def test_basic_query(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start() + res = node.execute('SELECT 1') + self.assertEqual(res, [(1,)]) +``` ## Authors diff --git a/setup.py b/setup.py index 6d0c2256..8cb0f70a 100755 --- a/setup.py +++ b/setup.py @@ -12,6 +12,9 @@ "six>=1.9.0", "psutil", "packaging", + "paramiko", + "fabric", + "sshtunnel" ] # Add compatibility enum class @@ -27,9 +30,9 @@ readme = f.read() setup( - version='1.8.9', + version='1.9.0', name='testgres', - packages=['testgres'], + packages=['testgres', 'testgres.operations'], description='Testing utility for PostgreSQL and its extensions', url='https://github.com/postgrespro/testgres', long_description=readme, diff --git a/testgres/__init__.py b/testgres/__init__.py index 1b33ba3b..b63c7df1 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -1,4 +1,4 @@ -from .api import get_new_node +from .api import get_new_node, get_remote_node from .backup import NodeBackup from .config import \ @@ -46,8 +46,13 @@ First, \ Any +from .operations.os_ops import OsOperations, ConnectionParams +from .operations.local_ops import LocalOperations +from .operations.remote_ops import RemoteOperations + __all__ = [ "get_new_node", + "get_remote_node", "NodeBackup", "TestgresConfig", "configure_testgres", "scoped_config", "push_config", "pop_config", "NodeConnection", "DatabaseError", "InternalError", "ProgrammingError", "OperationalError", @@ -56,4 +61,5 @@ "PostgresNode", "NodeApp", "reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version", "First", "Any", + "OsOperations", "LocalOperations", "RemoteOperations", "ConnectionParams" ] diff --git a/testgres/api.py b/testgres/api.py index e90cf7bd..e4b1cdd5 100644 --- a/testgres/api.py +++ b/testgres/api.py @@ -40,3 +40,15 @@ def get_new_node(name=None, base_dir=None, **kwargs): """ # NOTE: leave explicit 'name' and 'base_dir' for compatibility return PostgresNode(name=name, base_dir=base_dir, **kwargs) + + +def get_remote_node(name=None, conn_params=None): + """ + Simply a wrapper around :class:`.PostgresNode` constructor for remote node. + See :meth:`.PostgresNode.__init__` for details. + For remote connection you can add the next parameter: + conn_params = ConnectionParams(host='127.0.0.1', + ssh_key=None, + username=default_username()) + """ + return get_new_node(name=name, conn_params=conn_params) diff --git a/testgres/backup.py b/testgres/backup.py index a725a1df..a89e214d 100644 --- a/testgres/backup.py +++ b/testgres/backup.py @@ -2,9 +2,7 @@ import os -from shutil import rmtree, copytree from six import raise_from -from tempfile import mkdtemp from .enums import XLogMethod @@ -15,8 +13,6 @@ PG_CONF_FILE, \ BACKUP_LOG_FILE -from .defaults import default_username - from .exceptions import BackupException from .utils import \ @@ -47,7 +43,7 @@ def __init__(self, username: database user name. xlog_method: none | fetch | stream (see docs) """ - + self.os_ops = node.os_ops if not node.status(): raise BackupException('Node must be running') @@ -60,8 +56,8 @@ def __init__(self, raise BackupException(msg) # Set default arguments - username = username or default_username() - base_dir = base_dir or mkdtemp(prefix=TMP_BACKUP) + username = username or self.os_ops.get_user() + base_dir = base_dir or self.os_ops.mkdtemp(prefix=TMP_BACKUP) # public self.original_node = node @@ -107,14 +103,14 @@ def _prepare_dir(self, destroy): available = not destroy if available: - dest_base_dir = mkdtemp(prefix=TMP_NODE) + dest_base_dir = self.os_ops.mkdtemp(prefix=TMP_NODE) data1 = os.path.join(self.base_dir, DATA_DIR) data2 = os.path.join(dest_base_dir, DATA_DIR) try: # Copy backup to new data dir - copytree(data1, data2) + self.os_ops.copytree(data1, data2) except Exception as e: raise_from(BackupException('Failed to copy files'), e) else: @@ -143,7 +139,7 @@ def spawn_primary(self, name=None, destroy=True): # Build a new PostgresNode NodeClass = self.original_node.__class__ - with clean_on_error(NodeClass(name=name, base_dir=base_dir)) as node: + with clean_on_error(NodeClass(name=name, base_dir=base_dir, conn_params=self.original_node.os_ops.conn_params)) as node: # New nodes should always remove dir tree node._should_rm_dirs = True @@ -185,4 +181,4 @@ def cleanup(self): if self._available: self._available = False - rmtree(self.base_dir, ignore_errors=True) + self.os_ops.rmdirs(self.base_dir, ignore_errors=True) diff --git a/testgres/cache.py b/testgres/cache.py index c3cd9971..bf8658c9 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -1,9 +1,7 @@ # coding: utf-8 -import io import os -from shutil import copytree from six import raise_from from .config import testgres_config @@ -20,12 +18,16 @@ get_bin_path, \ execute_utility +from .operations.local_ops import LocalOperations +from .operations.os_ops import OsOperations -def cached_initdb(data_dir, logfile=None, params=None): + +def cached_initdb(data_dir, logfile=None, params=None, os_ops: OsOperations = LocalOperations()): """ Perform initdb or use cached node files. """ - def call_initdb(initdb_dir, log=None): + + def call_initdb(initdb_dir, log=logfile): try: _params = [get_bin_path("initdb"), "-D", initdb_dir, "-N"] execute_utility(_params + (params or []), log) @@ -39,13 +41,14 @@ def call_initdb(initdb_dir, log=None): cached_data_dir = testgres_config.cached_initdb_dir # Initialize cached initdb - if not os.path.exists(cached_data_dir) or \ - not os.listdir(cached_data_dir): + + if not os_ops.path_exists(cached_data_dir) or \ + not os_ops.listdir(cached_data_dir): call_initdb(cached_data_dir) try: # Copy cached initdb to current data dir - copytree(cached_data_dir, data_dir) + os_ops.copytree(cached_data_dir, data_dir) # Assign this node a unique system id if asked to if testgres_config.cached_initdb_unique: @@ -53,8 +56,8 @@ def call_initdb(initdb_dir, log=None): # Some users might rely upon unique system ids, but # our initdb caching mechanism breaks this contract. pg_control = os.path.join(data_dir, XLOG_CONTROL_FILE) - with io.open(pg_control, "r+b") as f: - f.write(generate_system_id()) # overwrite id + system_id = generate_system_id() + os_ops.write(pg_control, system_id, truncate=True, binary=True, read_and_write=True) # XXX: build new WAL segment with our system id _params = [get_bin_path("pg_resetwal"), "-D", data_dir, "-f"] diff --git a/testgres/config.py b/testgres/config.py index cfcdadc2..b6c43926 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -5,10 +5,10 @@ import tempfile from contextlib import contextmanager -from shutil import rmtree -from tempfile import mkdtemp from .consts import TMP_CACHE +from .operations.os_ops import OsOperations +from .operations.local_ops import LocalOperations class GlobalConfig(object): @@ -43,6 +43,9 @@ class GlobalConfig(object): _cached_initdb_dir = None """ underlying class attribute for cached_initdb_dir property """ + + os_ops = LocalOperations() + """ OsOperation object that allows work on remote host """ @property def cached_initdb_dir(self): """ path to a temp directory for cached initdb. """ @@ -54,6 +57,7 @@ def cached_initdb_dir(self, value): if value: cached_initdb_dirs.add(value) + return testgres_config.cached_initdb_dir @property def temp_dir(self): @@ -118,6 +122,11 @@ def copy(self): return copy.copy(self) + @staticmethod + def set_os_ops(os_ops: OsOperations): + testgres_config.os_ops = os_ops + testgres_config.cached_initdb_dir = os_ops.mkdtemp(prefix=TMP_CACHE) + # cached dirs to be removed cached_initdb_dirs = set() @@ -135,7 +144,7 @@ def copy(self): @atexit.register def _rm_cached_initdb_dirs(): for d in cached_initdb_dirs: - rmtree(d, ignore_errors=True) + testgres_config.os_ops.rmdirs(d, ignore_errors=True) def push_config(**options): @@ -198,4 +207,4 @@ def configure_testgres(**options): # NOTE: assign initial cached dir for initdb -testgres_config.cached_initdb_dir = mkdtemp(prefix=TMP_CACHE) +testgres_config.cached_initdb_dir = testgres_config.os_ops.mkdtemp(prefix=TMP_CACHE) diff --git a/testgres/connection.py b/testgres/connection.py index ee2a2128..aeb040ce 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -41,11 +41,11 @@ def __init__(self, self._node = node - self._connection = pglib.connect(database=dbname, - user=username, - password=password, - host=node.host, - port=node.port) + self._connection = node.os_ops.db_connect(dbname=dbname, + user=username, + password=password, + host=node.host, + port=node.port) self._connection.autocommit = autocommit self._cursor = self.connection.cursor() @@ -103,16 +103,15 @@ def rollback(self): def execute(self, query, *args): self.cursor.execute(query, args) - try: res = self.cursor.fetchall() - # pg8000 might return tuples if isinstance(res, tuple): res = [tuple(t) for t in res] return res - except Exception: + except Exception as e: + print("Error executing query: {}".format(e)) return None def close(self): diff --git a/testgres/defaults.py b/testgres/defaults.py index 8d5b892e..d77361d7 100644 --- a/testgres/defaults.py +++ b/testgres/defaults.py @@ -1,9 +1,9 @@ import datetime -import getpass -import os import struct import uuid +from .config import testgres_config as tconf + def default_dbname(): """ @@ -17,8 +17,7 @@ def default_username(): """ Return default username (current user). """ - - return getpass.getuser() + return tconf.os_ops.get_user() def generate_app_name(): @@ -44,7 +43,7 @@ def generate_system_id(): system_id = 0 system_id |= (secs << 32) system_id |= (usecs << 12) - system_id |= (os.getpid() & 0xFFF) + system_id |= (tconf.os_ops.get_pid() & 0xFFF) # pack ULL in native byte order return struct.pack('=Q', system_id) diff --git a/testgres/node.py b/testgres/node.py index 659a62f8..6483514b 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -1,18 +1,14 @@ # coding: utf-8 -import io import os import random -import shutil import signal +import subprocess import threading from queue import Queue -import psutil -import subprocess import time - try: from collections.abc import Iterable except ImportError: @@ -27,9 +23,7 @@ except ImportError: raise ImportError("You must have psycopg2 or pg8000 modules installed") -from shutil import rmtree from six import raise_from, iteritems, text_type -from tempfile import mkstemp, mkdtemp from .enums import \ NodeStatus, \ @@ -93,7 +87,6 @@ eprint, \ get_bin_path, \ get_pg_version, \ - file_tail, \ reserve_port, \ release_port, \ execute_utility, \ @@ -102,6 +95,10 @@ from .backup import NodeBackup +from .operations.os_ops import ConnectionParams +from .operations.local_ops import LocalOperations +from .operations.remote_ops import RemoteOperations + InternalError = pglib.InternalError ProgrammingError = pglib.ProgrammingError OperationalError = pglib.OperationalError @@ -130,7 +127,7 @@ def __repr__(self): class PostgresNode(object): - def __init__(self, name=None, port=None, base_dir=None): + def __init__(self, name=None, port=None, base_dir=None, conn_params: ConnectionParams = ConnectionParams()): """ PostgresNode constructor. @@ -148,10 +145,19 @@ def __init__(self, name=None, port=None, base_dir=None): self._master = None # basic - self.host = '127.0.0.1' self.name = name or generate_app_name() + if testgres_config.os_ops: + self.os_ops = testgres_config.os_ops + elif conn_params.ssh_key: + self.os_ops = RemoteOperations(conn_params) + else: + self.os_ops = LocalOperations(conn_params) + self.port = port or reserve_port() + self.host = self.os_ops.host + self.ssh_key = self.os_ops.ssh_key + # defaults for __exit__() self.cleanup_on_good_exit = testgres_config.node_cleanup_on_good_exit self.cleanup_on_bad_exit = testgres_config.node_cleanup_on_bad_exit @@ -195,8 +201,9 @@ def pid(self): if self.status(): pid_file = os.path.join(self.data_dir, PG_PID_FILE) - with io.open(pid_file) as f: - return int(f.readline()) + lines = self.os_ops.readlines(pid_file) + pid = int(lines[0]) if lines else None + return pid # for clarity return 0 @@ -236,7 +243,7 @@ def child_processes(self): """ # get a list of postmaster's children - children = psutil.Process(self.pid).children() + children = self.os_ops.get_process_children(self.pid) return [ProcessProxy(p) for p in children] @@ -274,11 +281,11 @@ def master(self): @property def base_dir(self): if not self._base_dir: - self._base_dir = mkdtemp(prefix=TMP_NODE) + self._base_dir = self.os_ops.mkdtemp(prefix=TMP_NODE) # NOTE: it's safe to create a new dir - if not os.path.exists(self._base_dir): - os.makedirs(self._base_dir) + if not self.os_ops.path_exists(self._base_dir): + self.os_ops.makedirs(self._base_dir) return self._base_dir @@ -287,8 +294,8 @@ def logs_dir(self): path = os.path.join(self.base_dir, LOGS_DIR) # NOTE: it's safe to create a new dir - if not os.path.exists(path): - os.makedirs(path) + if not self.os_ops.path_exists(path): + self.os_ops.makedirs(path) return path @@ -365,9 +372,7 @@ def _create_recovery_conf(self, username, slot=None): # Since 12 recovery.conf had disappeared if self.version >= PgVer('12'): signal_name = os.path.join(self.data_dir, "standby.signal") - # cross-python touch(). It is vulnerable to races, but who cares? - with open(signal_name, 'a'): - os.utime(signal_name, None) + self.os_ops.touch(signal_name) else: line += "standby_mode=on\n" @@ -425,19 +430,14 @@ def _collect_special_files(self): for f, num_lines in files: # skip missing files - if not os.path.exists(f): + if not self.os_ops.path_exists(f): continue - with io.open(f, "rb") as _f: - if num_lines > 0: - # take last N lines of file - lines = b''.join(file_tail(_f, num_lines)).decode('utf-8') - else: - # read whole file - lines = _f.read().decode('utf-8') + file_lines = self.os_ops.readlines(f, num_lines, binary=True, encoding=None) + lines = b''.join(file_lines) - # fill list - result.append((f, lines)) + # fill list + result.append((f, lines)) return result @@ -456,9 +456,11 @@ def init(self, initdb_params=None, **kwargs): """ # initialize this PostgreSQL node - cached_initdb(data_dir=self.data_dir, - logfile=self.utils_log_file, - params=initdb_params) + cached_initdb( + data_dir=self.data_dir, + logfile=self.utils_log_file, + os_ops=self.os_ops, + params=initdb_params) # initialize default config files self.default_conf(**kwargs) @@ -489,43 +491,41 @@ def default_conf(self, hba_conf = os.path.join(self.data_dir, HBA_CONF_FILE) # filter lines in hba file - with io.open(hba_conf, "r+") as conf: - # get rid of comments and blank lines - lines = [ - s for s in conf.readlines() - if len(s.strip()) > 0 and not s.startswith('#') - ] - - # write filtered lines - conf.seek(0) - conf.truncate() - conf.writelines(lines) - - # replication-related settings - if allow_streaming: - # get auth method for host or local users - def get_auth_method(t): - return next((s.split()[-1] - for s in lines if s.startswith(t)), 'trust') - - # get auth methods - auth_local = get_auth_method('local') - auth_host = get_auth_method('host') - - new_lines = [ - u"local\treplication\tall\t\t\t{}\n".format(auth_local), - u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), - u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host) - ] # yapf: disable - - # write missing lines - for line in new_lines: - if line not in lines: - conf.write(line) + # get rid of comments and blank lines + hba_conf_file = self.os_ops.readlines(hba_conf) + lines = [ + s for s in hba_conf_file + if len(s.strip()) > 0 and not s.startswith('#') + ] + + # write filtered lines + self.os_ops.write(hba_conf, lines, truncate=True) + + # replication-related settings + if allow_streaming: + # get auth method for host or local users + def get_auth_method(t): + return next((s.split()[-1] + for s in lines if s.startswith(t)), 'trust') + + # get auth methods + auth_local = get_auth_method('local') + auth_host = get_auth_method('host') + subnet_base = ".".join(self.os_ops.host.split('.')[:-1] + ['0']) + + new_lines = [ + u"local\treplication\tall\t\t\t{}\n".format(auth_local), + u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), + u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host), + u"host\treplication\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host), + u"host\tall\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host) + ] # yapf: disable + + # write missing lines + self.os_ops.write(hba_conf, new_lines) # overwrite config file - with io.open(postgres_conf, "w") as conf: - conf.truncate() + self.os_ops.write(postgres_conf, '', truncate=True) self.append_conf(fsync=fsync, max_worker_processes=MAX_WORKER_PROCESSES, @@ -595,15 +595,17 @@ def append_conf(self, line='', filename=PG_CONF_FILE, **kwargs): value = 'on' if value else 'off' elif not str(value).replace('.', '', 1).isdigit(): value = "'{}'".format(value) - - # format a new config line - lines.append('{} = {}'.format(option, value)) + if value == '*': + lines.append("{} = '*'".format(option)) + else: + # format a new config line + lines.append('{} = {}'.format(option, value)) config_name = os.path.join(self.data_dir, filename) - with io.open(config_name, 'a') as conf: - for line in lines: - conf.write(text_type(line)) - conf.write(text_type('\n')) + conf_text = '' + for line in lines: + conf_text += text_type(line) + '\n' + self.os_ops.write(config_name, conf_text) return self @@ -621,7 +623,11 @@ def status(self): "-D", self.data_dir, "status" ] # yapf: disable - execute_utility(_params, self.utils_log_file) + status_code, out, err = execute_utility(_params, self.utils_log_file, verbose=True) + if 'does not exist' in err: + return NodeStatus.Uninitialized + elif 'no server running' in out: + return NodeStatus.Stopped return NodeStatus.Running except ExecUtilException as e: @@ -653,7 +659,7 @@ def get_control_data(self): return out_dict - def slow_start(self, replica=False, dbname='template1', username=default_username()): + def slow_start(self, replica=False, dbname='template1', username=default_username(), max_attempts=0): """ Starts the PostgreSQL instance and then polls the instance until it reaches the expected state (primary or replica). The state is checked @@ -664,6 +670,7 @@ def slow_start(self, replica=False, dbname='template1', username=default_usernam username: replica: If True, waits for the instance to be in recovery (i.e., replica mode). If False, waits for the instance to be in primary mode. Default is False. + max_attempts: """ self.start() @@ -678,7 +685,8 @@ def slow_start(self, replica=False, dbname='template1', username=default_usernam suppress={InternalError, QueryException, ProgrammingError, - OperationalError}) + OperationalError}, + max_attempts=max_attempts) def start(self, params=[], wait=True): """ @@ -706,12 +714,13 @@ def start(self, params=[], wait=True): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file) - except ExecUtilException as e: + exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True) + if 'does not exist' in error: + raise Exception + except Exception as e: msg = 'Cannot start node' files = self._collect_special_files() raise_from(StartNodeException(msg, files), e) - self._maybe_start_logger() self.is_started = True return self @@ -779,7 +788,9 @@ def restart(self, params=[]): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file) + error_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True) + if 'could not start server' in error: + raise ExecUtilException except ExecUtilException as e: msg = 'Cannot restart node' files = self._collect_special_files() @@ -895,7 +906,7 @@ def cleanup(self, max_attempts=3): else: rm_dir = self.data_dir # just data, save logs - rmtree(rm_dir, ignore_errors=True) + self.os_ops.rmdirs(rm_dir, ignore_errors=True) return self @@ -948,7 +959,10 @@ def psql(self, # select query source if query: - psql_params.extend(("-c", query)) + if self.os_ops.remote: + psql_params.extend(("-c", '"{}"'.format(query))) + else: + psql_params.extend(("-c", query)) elif filename: psql_params.extend(("-f", filename)) else: @@ -956,16 +970,20 @@ def psql(self, # should be the last one psql_params.append(dbname) + if not self.os_ops.remote: + # start psql process + process = subprocess.Popen(psql_params, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + # wait until it finishes and get stdout and stderr + out, err = process.communicate(input=input) + return process.returncode, out, err + else: + status_code, out, err = self.os_ops.exec_command(psql_params, verbose=True, input=input) - # start psql process - process = subprocess.Popen(psql_params, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - - # wait until it finishes and get stdout and stderr - out, err = process.communicate(input=input) - return process.returncode, out, err + return status_code, out, err @method_decorator(positional_args_hack(['dbname', 'query'])) def safe_psql(self, query=None, expect_error=False, **kwargs): @@ -989,15 +1007,19 @@ def safe_psql(self, query=None, expect_error=False, **kwargs): # force this setting kwargs['ON_ERROR_STOP'] = 1 - - ret, out, err = self.psql(query=query, **kwargs) + try: + ret, out, err = self.psql(query=query, **kwargs) + except ExecUtilException as e: + ret = e.exit_code + out = e.out + err = e.message if ret: if expect_error: out = (err or b'').decode('utf-8') else: raise QueryException((err or b'').decode('utf-8'), query) elif expect_error: - assert False, f"Exception was expected, but query finished successfully: `{query}` " + assert False, "Exception was expected, but query finished successfully: `{}` ".format(query) return out @@ -1031,10 +1053,9 @@ def dump(self, # Generate tmpfile or tmpdir def tmpfile(): if format == DumpFormat.Directory: - fname = mkdtemp(prefix=TMP_DUMP) + fname = self.os_ops.mkdtemp(prefix=TMP_DUMP) else: - fd, fname = mkstemp(prefix=TMP_DUMP) - os.close(fd) + fname = self.os_ops.mkstemp(prefix=TMP_DUMP) return fname # Set default arguments @@ -1119,9 +1140,9 @@ def poll_query_until(self, # sanity checks assert max_attempts >= 0 assert sleep_time > 0 - attempts = 0 while max_attempts == 0 or attempts < max_attempts: + print(f"Pooling {attempts}") try: res = self.execute(dbname=dbname, query=query, @@ -1350,7 +1371,7 @@ def pgbench(self, # should be the last one _params.append(dbname) - proc = subprocess.Popen(_params, stdout=stdout, stderr=stderr) + proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, proc=True) return proc @@ -1523,18 +1544,16 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): # parse postgresql.auto.conf path = os.path.join(self.data_dir, config) - with open(path, 'r') as f: - raw_content = f.read() - + lines = self.os_ops.readlines(path) current_options = {} current_directives = [] - for line in raw_content.splitlines(): + for line in lines: # ignore comments if line.startswith('#'): continue - if line == '': + if line.strip() == '': continue if line.startswith('include'): @@ -1564,22 +1583,22 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): for directive in current_directives: auto_conf += directive + "\n" - with open(path, 'wt') as f: - f.write(auto_conf) + self.os_ops.write(path, auto_conf, truncate=True) class NodeApp: - def __init__(self, test_path, nodes_to_cleanup): + def __init__(self, test_path, nodes_to_cleanup, os_ops=LocalOperations()): self.test_path = test_path self.nodes_to_cleanup = nodes_to_cleanup + self.os_ops = os_ops def make_empty( self, base_dir=None): real_base_dir = os.path.join(self.test_path, base_dir) - shutil.rmtree(real_base_dir, ignore_errors=True) - os.makedirs(real_base_dir) + self.os_ops.rmdirs(real_base_dir, ignore_errors=True) + self.os_ops.makedirs(real_base_dir) node = PostgresNode(base_dir=real_base_dir) node.should_rm_dirs = True @@ -1602,27 +1621,24 @@ def make_simple( initdb_params=initdb_params, allow_streaming=set_replication) # set major version - with open(os.path.join(node.data_dir, 'PG_VERSION')) as f: - node.major_version_str = str(f.read().rstrip()) - node.major_version = float(node.major_version_str) - - # Sane default parameters - options = {} - options['max_connections'] = 100 - options['shared_buffers'] = '10MB' - options['fsync'] = 'off' - - options['wal_level'] = 'logical' - options['hot_standby'] = 'off' - - options['log_line_prefix'] = '%t [%p]: [%l-1] ' - options['log_statement'] = 'none' - options['log_duration'] = 'on' - options['log_min_duration_statement'] = 0 - options['log_connections'] = 'on' - options['log_disconnections'] = 'on' - options['restart_after_crash'] = 'off' - options['autovacuum'] = 'off' + pg_version_file = self.os_ops.read(os.path.join(node.data_dir, 'PG_VERSION')) + node.major_version_str = str(pg_version_file.rstrip()) + node.major_version = float(node.major_version_str) + + # Set default parameters + options = {'max_connections': 100, + 'shared_buffers': '10MB', + 'fsync': 'off', + 'wal_level': 'logical', + 'hot_standby': 'off', + 'log_line_prefix': '%t [%p]: [%l-1] ', + 'log_statement': 'none', + 'log_duration': 'on', + 'log_min_duration_statement': 0, + 'log_connections': 'on', + 'log_disconnections': 'on', + 'restart_after_crash': 'off', + 'autovacuum': 'off'} # Allow replication in pg_hba.conf if set_replication: diff --git a/testgres/operations/__init__.py b/testgres/operations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py new file mode 100644 index 00000000..89071282 --- /dev/null +++ b/testgres/operations/local_ops.py @@ -0,0 +1,269 @@ +import getpass +import os +import shutil +import stat +import subprocess +import tempfile + +import psutil + +from ..exceptions import ExecUtilException +from .os_ops import ConnectionParams, OsOperations +from .os_ops import pglib + +try: + from shutil import which as find_executable + from shutil import rmtree +except ImportError: + from distutils.spawn import find_executable + from distutils import rmtree + + +CMD_TIMEOUT_SEC = 60 +error_markers = [b'error', b'Permission denied', b'fatal'] + + +class LocalOperations(OsOperations): + def __init__(self, conn_params=None): + if conn_params is None: + conn_params = ConnectionParams() + super(LocalOperations, self).__init__(conn_params.username) + self.conn_params = conn_params + self.host = conn_params.host + self.ssh_key = None + self.remote = False + self.username = conn_params.username or self.get_user() + + # Command execution + def exec_command(self, cmd, wait_exit=False, verbose=False, + expect_error=False, encoding=None, shell=False, text=False, + input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None): + """ + Execute a command in a subprocess. + + Args: + - cmd: The command to execute. + - wait_exit: Whether to wait for the subprocess to exit before returning. + - verbose: Whether to return verbose output. + - expect_error: Whether to raise an error if the subprocess exits with an error status. + - encoding: The encoding to use for decoding the subprocess output. + - shell: Whether to use shell when executing the subprocess. + - text: Whether to return str instead of bytes for the subprocess output. + - input: The input to pass to the subprocess. + - stdout: The stdout to use for the subprocess. + - stderr: The stderr to use for the subprocess. + - proc: The process to use for subprocess creation. + :return: The output of the subprocess. + """ + if os.name == 'nt': + with tempfile.NamedTemporaryFile() as buf: + process = subprocess.Popen(cmd, stdout=buf, stderr=subprocess.STDOUT) + process.communicate() + buf.seek(0) + result = buf.read().decode(encoding) + return result + else: + process = subprocess.Popen( + cmd, + shell=shell, + stdout=stdout, + stderr=stderr, + ) + if proc: + return process + result, error = process.communicate(input) + exit_status = process.returncode + + error_found = exit_status != 0 or any(marker in error for marker in error_markers) + + if encoding: + result = result.decode(encoding) + error = error.decode(encoding) + + if expect_error: + raise Exception(result, error) + + if exit_status != 0 or error_found: + if exit_status == 0: + exit_status = 1 + raise ExecUtilException(message='Utility exited with non-zero code. Error `{}`'.format(error), + command=cmd, + exit_code=exit_status, + out=result) + if verbose: + return exit_status, result, error + else: + return result + + # Environment setup + def environ(self, var_name): + return os.environ.get(var_name) + + def find_executable(self, executable): + return find_executable(executable) + + def is_executable(self, file): + # Check if the file is executable + return os.stat(file).st_mode & stat.S_IXUSR + + def set_env(self, var_name, var_val): + # Check if the directory is already in PATH + os.environ[var_name] = var_val + + # Get environment variables + def get_user(self): + return getpass.getuser() + + def get_name(self): + return os.name + + # Work with dirs + def makedirs(self, path, remove_existing=False): + if remove_existing: + shutil.rmtree(path, ignore_errors=True) + try: + os.makedirs(path) + except FileExistsError: + pass + + def rmdirs(self, path, ignore_errors=True): + return rmtree(path, ignore_errors=ignore_errors) + + def listdir(self, path): + return os.listdir(path) + + def path_exists(self, path): + return os.path.exists(path) + + @property + def pathsep(self): + os_name = self.get_name() + if os_name == "posix": + pathsep = ":" + elif os_name == "nt": + pathsep = ";" + else: + raise Exception("Unsupported operating system: {}".format(os_name)) + return pathsep + + def mkdtemp(self, prefix=None): + return tempfile.mkdtemp(prefix='{}'.format(prefix)) + + def mkstemp(self, prefix=None): + fd, filename = tempfile.mkstemp(prefix=prefix) + os.close(fd) # Close the file descriptor immediately after creating the file + return filename + + def copytree(self, src, dst): + return shutil.copytree(src, dst) + + # Work with files + def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + """ + Write data to a file locally + Args: + filename: The file path where the data will be written. + data: The data to be written to the file. + truncate: If True, the file will be truncated before writing ('w' or 'wb' option); + if False (default), data will be appended ('a' or 'ab' option). + binary: If True, the data will be written in binary mode ('wb' or 'ab' option); + if False (default), the data will be written in text mode ('w' or 'a' option). + read_and_write: If True, the file will be opened with read and write permissions ('r+' option); + if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) + """ + # If it is a bytes str or list + if isinstance(data, bytes) or isinstance(data, list) and all(isinstance(item, bytes) for item in data): + binary = True + mode = "wb" if binary else "w" + if not truncate: + mode = "ab" if binary else "a" + if read_and_write: + mode = "r+b" if binary else "r+" + + with open(filename, mode) as file: + if isinstance(data, list): + file.writelines(data) + else: + file.write(data) + + def touch(self, filename): + """ + Create a new file or update the access and modification times of an existing file. + Args: + filename (str): The name of the file to touch. + + This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. + """ + # cross-python touch(). It is vulnerable to races, but who cares? + with open(filename, "a"): + os.utime(filename, None) + + def read(self, filename, encoding=None): + with open(filename, "r", encoding=encoding) as file: + return file.read() + + def readlines(self, filename, num_lines=0, binary=False, encoding=None): + """ + Read lines from a local file. + If num_lines is greater than 0, only the last num_lines lines will be read. + """ + assert num_lines >= 0 + mode = 'rb' if binary else 'r' + if num_lines == 0: + with open(filename, mode, encoding=encoding) as file: # open in binary mode + return file.readlines() + + else: + bufsize = 8192 + buffers = 1 + + with open(filename, mode, encoding=encoding) as file: # open in binary mode + file.seek(0, os.SEEK_END) + end_pos = file.tell() + + while True: + offset = max(0, end_pos - bufsize * buffers) + file.seek(offset, os.SEEK_SET) + pos = file.tell() + lines = file.readlines() + cur_lines = len(lines) + + if cur_lines >= num_lines or pos == 0: + return lines[-num_lines:] # get last num_lines from lines + + buffers = int( + buffers * max(2, int(num_lines / max(cur_lines, 1))) + ) # Adjust buffer size + + def isfile(self, remote_file): + return os.path.isfile(remote_file) + + def isdir(self, dirname): + return os.path.isdir(dirname) + + def remove_file(self, filename): + return os.remove(filename) + + # Processes control + def kill(self, pid, signal): + # Kill the process + cmd = "kill -{} {}".format(signal, pid) + return self.exec_command(cmd) + + def get_pid(self): + # Get current process id + return os.getpid() + + def get_process_children(self, pid): + return psutil.Process(pid).children() + + # Database control + def db_connect(self, dbname, user, password=None, host="localhost", port=5432): + conn = pglib.connect( + host=host, + port=port, + database=dbname, + user=user, + password=password, + ) + return conn diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py new file mode 100644 index 00000000..9261cacf --- /dev/null +++ b/testgres/operations/os_ops.py @@ -0,0 +1,101 @@ +try: + import psycopg2 as pglib # noqa: F401 +except ImportError: + try: + import pg8000 as pglib # noqa: F401 + except ImportError: + raise ImportError("You must have psycopg2 or pg8000 modules installed") + + +class ConnectionParams: + def __init__(self, host='127.0.0.1', ssh_key=None, username=None): + self.host = host + self.ssh_key = ssh_key + self.username = username + + +class OsOperations: + def __init__(self, username=None): + self.ssh_key = None + self.username = username + + # Command execution + def exec_command(self, cmd, **kwargs): + raise NotImplementedError() + + # Environment setup + def environ(self, var_name): + raise NotImplementedError() + + def find_executable(self, executable): + raise NotImplementedError() + + def is_executable(self, file): + # Check if the file is executable + raise NotImplementedError() + + def set_env(self, var_name, var_val): + # Check if the directory is already in PATH + raise NotImplementedError() + + # Get environment variables + def get_user(self): + raise NotImplementedError() + + def get_name(self): + raise NotImplementedError() + + # Work with dirs + def makedirs(self, path, remove_existing=False): + raise NotImplementedError() + + def rmdirs(self, path, ignore_errors=True): + raise NotImplementedError() + + def listdir(self, path): + raise NotImplementedError() + + def path_exists(self, path): + raise NotImplementedError() + + @property + def pathsep(self): + raise NotImplementedError() + + def mkdtemp(self, prefix=None): + raise NotImplementedError() + + def copytree(self, src, dst): + raise NotImplementedError() + + # Work with files + def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + raise NotImplementedError() + + def touch(self, filename): + raise NotImplementedError() + + def read(self, filename): + raise NotImplementedError() + + def readlines(self, filename): + raise NotImplementedError() + + def isfile(self, remote_file): + raise NotImplementedError() + + # Processes control + def kill(self, pid, signal): + # Kill the process + raise NotImplementedError() + + def get_pid(self): + # Get current process id + raise NotImplementedError() + + def get_process_children(self, pid): + raise NotImplementedError() + + # Database control + def db_connect(self, dbname, user, password=None, host="localhost", port=5432): + raise NotImplementedError() diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py new file mode 100644 index 00000000..6815c7f1 --- /dev/null +++ b/testgres/operations/remote_ops.py @@ -0,0 +1,448 @@ +import os +import tempfile +import time +from typing import Optional + +import sshtunnel + +import paramiko +from paramiko import SSHClient + +from ..exceptions import ExecUtilException + +from .os_ops import OsOperations, ConnectionParams +from .os_ops import pglib + +sshtunnel.SSH_TIMEOUT = 5.0 +sshtunnel.TUNNEL_TIMEOUT = 5.0 + + +error_markers = [b'error', b'Permission denied', b'fatal', b'No such file or directory'] + + +class PsUtilProcessProxy: + def __init__(self, ssh, pid): + self.ssh = ssh + self.pid = pid + + def kill(self): + command = "kill {}".format(self.pid) + self.ssh.exec_command(command) + + def cmdline(self): + command = "ps -p {} -o cmd --no-headers".format(self.pid) + stdin, stdout, stderr = self.ssh.exec_command(command) + cmdline = stdout.read().decode('utf-8').strip() + return cmdline.split() + + +class RemoteOperations(OsOperations): + def __init__(self, conn_params: ConnectionParams): + super().__init__(conn_params.username) + self.conn_params = conn_params + self.host = conn_params.host + self.ssh_key = conn_params.ssh_key + self.ssh = self.ssh_connect() + self.remote = True + self.username = conn_params.username or self.get_user() + self.tunnel = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close_tunnel() + if getattr(self, 'ssh', None): + self.ssh.close() + + def __del__(self): + if getattr(self, 'ssh', None): + self.ssh.close() + + def close_tunnel(self): + if getattr(self, 'tunnel', None): + self.tunnel.stop(force=True) + start_time = time.time() + while self.tunnel.is_active: + if time.time() - start_time > sshtunnel.TUNNEL_TIMEOUT: + break + time.sleep(0.5) + + def ssh_connect(self) -> Optional[SSHClient]: + key = self._read_ssh_key() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.host, username=self.username, pkey=key) + return ssh + + def _read_ssh_key(self): + try: + with open(self.ssh_key, "r") as f: + key_data = f.read() + if "BEGIN OPENSSH PRIVATE KEY" in key_data: + key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) + else: + key = paramiko.RSAKey.from_private_key_file(self.ssh_key) + return key + except FileNotFoundError: + raise ExecUtilException(message="No such file or directory: '{}'".format(self.ssh_key)) + except Exception as e: + ExecUtilException(message="An error occurred while reading the ssh key: {}".format(e)) + + def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, + encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None, + stderr=None, proc=None): + """ + Execute a command in the SSH session. + Args: + - cmd (str): The command to be executed. + """ + if self.ssh is None or not self.ssh.get_transport() or not self.ssh.get_transport().is_active(): + self.ssh = self.ssh_connect() + + if isinstance(cmd, list): + cmd = ' '.join(item.decode('utf-8') if isinstance(item, bytes) else item for item in cmd) + if input: + stdin, stdout, stderr = self.ssh.exec_command(cmd) + stdin.write(input) + stdin.flush() + else: + stdin, stdout, stderr = self.ssh.exec_command(cmd) + exit_status = 0 + if wait_exit: + exit_status = stdout.channel.recv_exit_status() + + if encoding: + result = stdout.read().decode(encoding) + error = stderr.read().decode(encoding) + else: + result = stdout.read() + error = stderr.read() + + if expect_error: + raise Exception(result, error) + + if encoding: + error_found = exit_status != 0 or any( + marker.decode(encoding) in error for marker in error_markers) + else: + error_found = exit_status != 0 or any( + marker in error for marker in error_markers) + + if error_found: + if exit_status == 0: + exit_status = 1 + if encoding: + message = "Utility exited with non-zero code. Error: {}".format(error.decode(encoding)) + else: + message = b"Utility exited with non-zero code. Error: " + error + raise ExecUtilException(message=message, + command=cmd, + exit_code=exit_status, + out=result) + + if verbose: + return exit_status, result, error + else: + return result + + # Environment setup + def environ(self, var_name: str) -> str: + """ + Get the value of an environment variable. + Args: + - var_name (str): The name of the environment variable. + """ + cmd = "echo ${}".format(var_name) + return self.exec_command(cmd, encoding='utf-8').strip() + + def find_executable(self, executable): + search_paths = self.environ("PATH") + if not search_paths: + return None + + search_paths = search_paths.split(self.pathsep) + for path in search_paths: + remote_file = os.path.join(path, executable) + if self.isfile(remote_file): + return remote_file + + return None + + def is_executable(self, file): + # Check if the file is executable + is_exec = self.exec_command("test -x {} && echo OK".format(file)) + return is_exec == b"OK\n" + + def set_env(self, var_name: str, var_val: str): + """ + Set the value of an environment variable. + Args: + - var_name (str): The name of the environment variable. + - var_val (str): The value to be set for the environment variable. + """ + return self.exec_command("export {}={}".format(var_name, var_val)) + + # Get environment variables + def get_user(self): + return self.exec_command("echo $USER", encoding='utf-8').strip() + + def get_name(self): + cmd = 'python3 -c "import os; print(os.name)"' + return self.exec_command(cmd, encoding='utf-8').strip() + + # Work with dirs + def makedirs(self, path, remove_existing=False): + """ + Create a directory in the remote server. + Args: + - path (str): The path to the directory to be created. + - remove_existing (bool): If True, the existing directory at the path will be removed. + """ + if remove_existing: + cmd = "rm -rf {} && mkdir -p {}".format(path, path) + else: + cmd = "mkdir -p {}".format(path) + try: + exit_status, result, error = self.exec_command(cmd, verbose=True) + except ExecUtilException as e: + raise Exception("Couldn't create dir {} because of error {}".format(path, e.message)) + if exit_status != 0: + raise Exception("Couldn't create dir {} because of error {}".format(path, error)) + return result + + def rmdirs(self, path, verbose=False, ignore_errors=True): + """ + Remove a directory in the remote server. + Args: + - path (str): The path to the directory to be removed. + - verbose (bool): If True, return exit status, result, and error. + - ignore_errors (bool): If True, do not raise error if directory does not exist. + """ + cmd = "rm -rf {}".format(path) + exit_status, result, error = self.exec_command(cmd, verbose=True) + if verbose: + return exit_status, result, error + else: + return result + + def listdir(self, path): + """ + List all files and directories in a directory. + Args: + path (str): The path to the directory. + """ + result = self.exec_command("ls {}".format(path)) + return result.splitlines() + + def path_exists(self, path): + result = self.exec_command("test -e {}; echo $?".format(path), encoding='utf-8') + return int(result.strip()) == 0 + + @property + def pathsep(self): + os_name = self.get_name() + if os_name == "posix": + pathsep = ":" + elif os_name == "nt": + pathsep = ";" + else: + raise Exception("Unsupported operating system: {}".format(os_name)) + return pathsep + + def mkdtemp(self, prefix=None): + """ + Creates a temporary directory in the remote server. + Args: + - prefix (str): The prefix of the temporary directory name. + """ + if prefix: + temp_dir = self.exec_command("mktemp -d {}XXXXX".format(prefix), encoding='utf-8') + else: + temp_dir = self.exec_command("mktemp -d", encoding='utf-8') + + if temp_dir: + if not os.path.isabs(temp_dir): + temp_dir = os.path.join('/home', self.username, temp_dir.strip()) + return temp_dir + else: + raise ExecUtilException("Could not create temporary directory.") + + def mkstemp(self, prefix=None): + if prefix: + temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding='utf-8') + else: + temp_dir = self.exec_command("mktemp", encoding='utf-8') + + if temp_dir: + if not os.path.isabs(temp_dir): + temp_dir = os.path.join('/home', self.username, temp_dir.strip()) + return temp_dir + else: + raise ExecUtilException("Could not create temporary directory.") + + def copytree(self, src, dst): + if not os.path.isabs(dst): + dst = os.path.join('~', dst) + if self.isdir(dst): + raise FileExistsError("Directory {} already exists.".format(dst)) + return self.exec_command("cp -r {} {}".format(src, dst)) + + # Work with files + def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding='utf-8'): + """ + Write data to a file on a remote host + + Args: + - filename (str): The file path where the data will be written. + - data (bytes or str): The data to be written to the file. + - truncate (bool): If True, the file will be truncated before writing ('w' or 'wb' option); + if False (default), data will be appended ('a' or 'ab' option). + - binary (bool): If True, the data will be written in binary mode ('wb' or 'ab' option); + if False (default), the data will be written in text mode ('w' or 'a' option). + - read_and_write (bool): If True, the file will be opened with read and write permissions ('r+' option); + if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option). + """ + mode = "wb" if binary else "w" + if not truncate: + mode = "ab" if binary else "a" + if read_and_write: + mode = "r+b" if binary else "r+" + + with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file: + if not truncate: + with self.ssh_connect() as ssh: + sftp = ssh.open_sftp() + try: + sftp.get(filename, tmp_file.name) + tmp_file.seek(0, os.SEEK_END) + except FileNotFoundError: + pass # File does not exist yet, we'll create it + sftp.close() + if isinstance(data, bytes) and not binary: + data = data.decode(encoding) + elif isinstance(data, str) and binary: + data = data.encode(encoding) + if isinstance(data, list): + # ensure each line ends with a newline + data = [(s if isinstance(s, str) else s.decode('utf-8')).rstrip('\n') + '\n' for s in data] + tmp_file.writelines(data) + else: + tmp_file.write(data) + tmp_file.flush() + + with self.ssh_connect() as ssh: + sftp = ssh.open_sftp() + remote_directory = os.path.dirname(filename) + try: + sftp.stat(remote_directory) + except IOError: + sftp.mkdir(remote_directory) + sftp.put(tmp_file.name, filename) + sftp.close() + + os.remove(tmp_file.name) + + def touch(self, filename): + """ + Create a new file or update the access and modification times of an existing file on the remote server. + + Args: + filename (str): The name of the file to touch. + + This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. + """ + self.exec_command("touch {}".format(filename)) + + def read(self, filename, binary=False, encoding=None): + cmd = "cat {}".format(filename) + result = self.exec_command(cmd, encoding=encoding) + + if not binary and result: + result = result.decode(encoding or 'utf-8') + + return result + + def readlines(self, filename, num_lines=0, binary=False, encoding=None): + if num_lines > 0: + cmd = "tail -n {} {}".format(num_lines, filename) + else: + cmd = "cat {}".format(filename) + + result = self.exec_command(cmd, encoding=encoding) + + if not binary and result: + lines = result.decode(encoding or 'utf-8').splitlines() + else: + lines = result.splitlines() + + return lines + + def isfile(self, remote_file): + stdout = self.exec_command("test -f {}; echo $?".format(remote_file)) + result = int(stdout.strip()) + return result == 0 + + def isdir(self, dirname): + cmd = "if [ -d {} ]; then echo True; else echo False; fi".format(dirname) + response = self.exec_command(cmd) + return response.strip() == b"True" + + def remove_file(self, filename): + cmd = "rm {}".format(filename) + return self.exec_command(cmd) + + # Processes control + def kill(self, pid, signal): + # Kill the process + cmd = "kill -{} {}".format(signal, pid) + return self.exec_command(cmd) + + def get_pid(self): + # Get current process id + return int(self.exec_command("echo $$", encoding='utf-8')) + + def get_process_children(self, pid): + command = "pgrep -P {}".format(pid) + stdin, stdout, stderr = self.ssh.exec_command(command) + children = stdout.readlines() + return [PsUtilProcessProxy(self.ssh, int(child_pid.strip())) for child_pid in children] + + # Database control + def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, ssh_key=None): + """ + Connects to a PostgreSQL database on the remote system. + Args: + - dbname (str): The name of the database to connect to. + - user (str): The username for the database connection. + - password (str, optional): The password for the database connection. Defaults to None. + - host (str, optional): The IP address of the remote system. Defaults to "localhost". + - port (int, optional): The port number of the PostgreSQL service. Defaults to 5432. + + This function establishes a connection to a PostgreSQL database on the remote system using the specified + parameters. It returns a connection object that can be used to interact with the database. + """ + self.close_tunnel() + self.tunnel = sshtunnel.open_tunnel( + (host, 22), # Remote server IP and SSH port + ssh_username=user or self.username, + ssh_pkey=ssh_key or self.ssh_key, + remote_bind_address=(host, port), # PostgreSQL server IP and PostgreSQL port + local_bind_address=('localhost', port) # Local machine IP and available port + ) + + self.tunnel.start() + + try: + conn = pglib.connect( + host=host, # change to 'localhost' because we're connecting through a local ssh tunnel + port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel + database=dbname, + user=user or self.username, + password=password + ) + + return conn + except Exception as e: + self.tunnel.stop() + raise ExecUtilException("Could not create db tunnel. {}".format(e)) diff --git a/testgres/pubsub.py b/testgres/pubsub.py index da85caac..1be673bb 100644 --- a/testgres/pubsub.py +++ b/testgres/pubsub.py @@ -214,4 +214,4 @@ def catchup(self, username=None): username=username or self.pub.username, max_attempts=LOGICAL_REPL_MAX_CATCHUP_ATTEMPTS) except Exception as e: - raise_from(CatchUpException("Failed to catch up", query), e) + raise_from(CatchUpException("Failed to catch up"), e) diff --git a/testgres/utils.py b/testgres/utils.py index 9760908d..5e12eba9 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -3,24 +3,18 @@ from __future__ import division from __future__ import print_function -import io import os import port_for -import subprocess import sys -import tempfile from contextlib import contextmanager from packaging.version import Version, InvalidVersion import re -try: - from shutil import which as find_executable -except ImportError: - from distutils.spawn import find_executable + from six import iteritems -from .config import testgres_config from .exceptions import ExecUtilException +from .config import testgres_config as tconf # rows returned by PG_CONFIG _pg_config_data = {} @@ -58,7 +52,7 @@ def release_port(port): bound_ports.discard(port) -def execute_utility(args, logfile=None): +def execute_utility(args, logfile=None, verbose=False): """ Execute utility (pg_ctl, pg_dump etc). @@ -69,63 +63,28 @@ def execute_utility(args, logfile=None): Returns: stdout of executed utility. """ - - # run utility - if os.name == 'nt': - # using output to a temporary file in Windows - buf = tempfile.NamedTemporaryFile() - - process = subprocess.Popen( - args, # util + params - stdout=buf, - stderr=subprocess.STDOUT) - process.communicate() - - # get result - buf.file.flush() - buf.file.seek(0) - out = buf.file.read() - buf.close() - else: - process = subprocess.Popen( - args, # util + params - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) - - # get result - out, _ = process.communicate() - + exit_status, out, error = tconf.os_ops.exec_command(args, verbose=True) # decode result - out = '' if not out else out.decode('utf-8') - - # format command - command = u' '.join(args) + out = '' if not out else out + if isinstance(out, bytes): + out = out.decode('utf-8') + if isinstance(error, bytes): + error = error.decode('utf-8') # write new log entry if possible if logfile: try: - with io.open(logfile, 'a') as file_out: - file_out.write(command) - - if out: - # comment-out lines - lines = ('# ' + line for line in out.splitlines(True)) - file_out.write(u'\n') - file_out.writelines(lines) - - file_out.write(u'\n') + tconf.os_ops.write(filename=logfile, data=args, truncate=True) + if out: + # comment-out lines + lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n'] + tconf.os_ops.write(filename=logfile, data=lines) except IOError: - pass - - exit_code = process.returncode - if exit_code: - message = 'Utility exited with non-zero code' - raise ExecUtilException(message=message, - command=command, - exit_code=exit_code, - out=out) - - return out + raise ExecUtilException("Problem with writing to logfile `{}` during run command `{}`".format(logfile, args)) + if verbose: + return exit_status, out, error + else: + return out def get_bin_path(filename): @@ -133,23 +92,25 @@ def get_bin_path(filename): Return absolute path to an executable using PG_BIN or PG_CONFIG. This function does nothing if 'filename' is already absolute. """ - # check if it's already absolute if os.path.isabs(filename): return filename + if tconf.os_ops.remote: + pg_config = os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG") + else: + # try PG_CONFIG - get from local machine + pg_config = os.environ.get("PG_CONFIG") - # try PG_CONFIG - pg_config = os.environ.get("PG_CONFIG") if pg_config: bindir = get_pg_config()["BINDIR"] return os.path.join(bindir, filename) # try PG_BIN - pg_bin = os.environ.get("PG_BIN") + pg_bin = tconf.os_ops.environ("PG_BIN") if pg_bin: return os.path.join(pg_bin, filename) - pg_config_path = find_executable('pg_config') + pg_config_path = tconf.os_ops.find_executable('pg_config') if pg_config_path: bindir = get_pg_config(pg_config_path)["BINDIR"] return os.path.join(bindir, filename) @@ -160,11 +121,12 @@ def get_bin_path(filename): def get_pg_config(pg_config_path=None): """ Return output of pg_config (provided that it is installed). - NOTE: this fuction caches the result by default (see GlobalConfig). + NOTE: this function caches the result by default (see GlobalConfig). """ + def cache_pg_config_data(cmd): # execute pg_config and get the output - out = subprocess.check_output([cmd]).decode('utf-8') + out = tconf.os_ops.exec_command(cmd, encoding='utf-8') data = {} for line in out.splitlines(): @@ -179,7 +141,7 @@ def cache_pg_config_data(cmd): return data # drop cache if asked to - if not testgres_config.cache_pg_config: + if not tconf.cache_pg_config: global _pg_config_data _pg_config_data = {} @@ -188,7 +150,11 @@ def cache_pg_config_data(cmd): return _pg_config_data # try specified pg_config path or PG_CONFIG - pg_config = pg_config_path or os.environ.get("PG_CONFIG") + if tconf.os_ops.remote: + pg_config = pg_config_path or os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG") + else: + # try PG_CONFIG - get from local machine + pg_config = pg_config_path or os.environ.get("PG_CONFIG") if pg_config: return cache_pg_config_data(pg_config) @@ -209,7 +175,7 @@ def get_pg_version(): # get raw version (e.g. postgres (PostgreSQL) 9.5.7) _params = [get_bin_path('postgres'), '--version'] - raw_ver = subprocess.check_output(_params).decode('utf-8') + raw_ver = tconf.os_ops.exec_command(_params, encoding='utf-8') # cook version of PostgreSQL version = raw_ver.strip().split(' ')[-1] \ diff --git a/tests/README.md b/tests/README.md index a6d50992..d89efc7e 100644 --- a/tests/README.md +++ b/tests/README.md @@ -27,3 +27,32 @@ export PYTHON_VERSION=3 # or 2 # Run tests ./run_tests.sh ``` + + +#### Remote host tests + +1. Start remote host or docker container +2. Make sure that you run ssh +```commandline +sudo apt-get install openssh-server +sudo systemctl start sshd +``` +3. You need to connect to the remote host at least once to add it to the known hosts file +4. Generate ssh keys +5. Set up params for tests + + +```commandline +conn_params = ConnectionParams( + host='remote_host', + username='username', + ssh_key=/path/to/your/ssh/key' +) +os_ops = RemoteOperations(conn_params) +``` +If you have different path to `PG_CONFIG` on your local and remote host you can set up `PG_CONFIG_REMOTE`, this value will be +using during work with remote host. + +`test_remote` - Tests for RemoteOperations class. + +`test_simple_remote` - Tests that create node and check it. The same as `test_simple`, but for remote node. \ No newline at end of file diff --git a/tests/test_remote.py b/tests/test_remote.py new file mode 100755 index 00000000..3794349c --- /dev/null +++ b/tests/test_remote.py @@ -0,0 +1,198 @@ +import os + +import pytest + +from testgres import ExecUtilException +from testgres import RemoteOperations +from testgres import ConnectionParams + + +class TestRemoteOperations: + + @pytest.fixture(scope="function", autouse=True) + def setup(self): + conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3', + username='dev', + ssh_key=os.getenv( + 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519') + self.operations = RemoteOperations(conn_params) + + yield + self.operations.__del__() + + def test_exec_command_success(self): + """ + Test exec_command for successful command execution. + """ + cmd = "python3 --version" + response = self.operations.exec_command(cmd, wait_exit=True) + + assert b'Python 3.' in response + + def test_exec_command_failure(self): + """ + Test exec_command for command execution failure. + """ + cmd = "nonexistent_command" + try: + exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True) + except ExecUtilException as e: + error = e.message + assert error == b'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n' + + def test_is_executable_true(self): + """ + Test is_executable for an existing executable. + """ + cmd = "postgres" + response = self.operations.is_executable(cmd) + + assert response is True + + def test_is_executable_false(self): + """ + Test is_executable for a non-executable. + """ + cmd = "python" + response = self.operations.is_executable(cmd) + + assert response is False + + def test_makedirs_and_rmdirs_success(self): + """ + Test makedirs and rmdirs for successful directory creation and removal. + """ + cmd = "pwd" + pwd = self.operations.exec_command(cmd, wait_exit=True, encoding='utf-8').strip() + + path = "{}/test_dir".format(pwd) + + # Test makedirs + self.operations.makedirs(path) + assert self.operations.path_exists(path) + + # Test rmdirs + self.operations.rmdirs(path) + assert not self.operations.path_exists(path) + + def test_makedirs_and_rmdirs_failure(self): + """ + Test makedirs and rmdirs for directory creation and removal failure. + """ + # Try to create a directory in a read-only location + path = "/root/test_dir" + + # Test makedirs + with pytest.raises(Exception): + self.operations.makedirs(path) + + # Test rmdirs + try: + exit_status, result, error = self.operations.rmdirs(path, verbose=True) + except ExecUtilException as e: + error = e.message + assert error == b"Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n" + + def test_listdir(self): + """ + Test listdir for listing directory contents. + """ + path = "/etc" + files = self.operations.listdir(path) + + assert isinstance(files, list) + + def test_path_exists_true(self): + """ + Test path_exists for an existing path. + """ + path = "/etc" + response = self.operations.path_exists(path) + + assert response is True + + def test_path_exists_false(self): + """ + Test path_exists for a non-existing path. + """ + path = "/nonexistent_path" + response = self.operations.path_exists(path) + + assert response is False + + def test_write_text_file(self): + """ + Test write for writing data to a text file. + """ + filename = "/tmp/test_file.txt" + data = "Hello, world!" + + self.operations.write(filename, data, truncate=True) + self.operations.write(filename, data) + + response = self.operations.read(filename) + + assert response == data + data + + def test_write_binary_file(self): + """ + Test write for writing data to a binary file. + """ + filename = "/tmp/test_file.bin" + data = b"\x00\x01\x02\x03" + + self.operations.write(filename, data, binary=True, truncate=True) + + response = self.operations.read(filename, binary=True) + + assert response == data + + def test_read_text_file(self): + """ + Test read for reading data from a text file. + """ + filename = "/etc/hosts" + + response = self.operations.read(filename) + + assert isinstance(response, str) + + def test_read_binary_file(self): + """ + Test read for reading data from a binary file. + """ + filename = "/usr/bin/python3" + + response = self.operations.read(filename, binary=True) + + assert isinstance(response, bytes) + + def test_touch(self): + """ + Test touch for creating a new file or updating access and modification times of an existing file. + """ + filename = "/tmp/test_file.txt" + + self.operations.touch(filename) + + assert self.operations.isfile(filename) + + def test_isfile_true(self): + """ + Test isfile for an existing file. + """ + filename = "/etc/hosts" + + response = self.operations.isfile(filename) + + assert response is True + + def test_isfile_false(self): + """ + Test isfile for a non-existing file. + """ + filename = "/nonexistent_file.txt" + + response = self.operations.isfile(filename) + + assert response is False diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py new file mode 100755 index 00000000..e8386383 --- /dev/null +++ b/tests/test_simple_remote.py @@ -0,0 +1,996 @@ +#!/usr/bin/env python +# coding: utf-8 + +import os +import re +import subprocess +import tempfile + +import testgres +import time +import six +import unittest +import psutil + +import logging.config + +from contextlib import contextmanager + +from testgres.exceptions import \ + InitNodeException, \ + StartNodeException, \ + ExecUtilException, \ + BackupException, \ + QueryException, \ + TimeoutException, \ + TestgresException + +from testgres.config import \ + TestgresConfig, \ + configure_testgres, \ + scoped_config, \ + pop_config, testgres_config + +from testgres import \ + NodeStatus, \ + ProcessType, \ + IsolationLevel, \ + get_remote_node, \ + RemoteOperations + +from testgres import \ + get_bin_path, \ + get_pg_config, \ + get_pg_version + +from testgres import \ + First, \ + Any + +# NOTE: those are ugly imports +from testgres import bound_ports +from testgres.utils import PgVer +from testgres.node import ProcessProxy, ConnectionParams + +conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3', + username='dev', + ssh_key=os.getenv( + 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519') +os_ops = RemoteOperations(conn_params) +testgres_config.set_os_ops(os_ops=os_ops) + + +def pg_version_ge(version): + cur_ver = PgVer(get_pg_version()) + min_ver = PgVer(version) + return cur_ver >= min_ver + + +def util_exists(util): + def good_properties(f): + return (os_ops.path_exists(f) and # noqa: W504 + os_ops.isfile(f) and # noqa: W504 + os_ops.is_executable(f)) # yapf: disable + + # try to resolve it + if good_properties(get_bin_path(util)): + return True + + # check if util is in PATH + for path in os_ops.environ("PATH").split(os_ops.pathsep): + if good_properties(os.path.join(path, util)): + return True + + +@contextmanager +def removing(f): + try: + yield f + finally: + if os_ops.isfile(f): + os_ops.remove_file(f) + + elif os_ops.isdir(f): + os_ops.rmdirs(f, ignore_errors=True) + + +class TestgresRemoteTests(unittest.TestCase): + + def test_node_repr(self): + with get_remote_node(conn_params=conn_params) as node: + pattern = r"PostgresNode\(name='.+', port=.+, base_dir='.+'\)" + self.assertIsNotNone(re.match(pattern, str(node))) + + def test_custom_init(self): + with get_remote_node(conn_params=conn_params) as node: + # enable page checksums + node.init(initdb_params=['-k']).start() + + with get_remote_node(conn_params=conn_params) as node: + node.init( + allow_streaming=True, + initdb_params=['--auth-local=reject', '--auth-host=reject']) + + hba_file = os.path.join(node.data_dir, 'pg_hba.conf') + lines = os_ops.readlines(hba_file) + + # check number of lines + self.assertGreaterEqual(len(lines), 6) + + # there should be no trust entries at all + self.assertFalse(any('trust' in s for s in lines)) + + def test_double_init(self): + with get_remote_node(conn_params=conn_params).init() as node: + # can't initialize node more than once + with self.assertRaises(InitNodeException): + node.init() + + def test_init_after_cleanup(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start().execute('select 1') + node.cleanup() + node.init().start().execute('select 1') + + @unittest.skipUnless(util_exists('pg_resetwal'), 'might be missing') + @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') + def test_init_unique_system_id(self): + # FAIL + # this function exists in PostgreSQL 9.6+ + query = 'select system_identifier from pg_control_system()' + + with scoped_config(cache_initdb=False): + with get_remote_node(conn_params=conn_params).init().start() as node0: + id0 = node0.execute(query)[0] + + with scoped_config(cache_initdb=True, + cached_initdb_unique=True) as config: + self.assertTrue(config.cache_initdb) + self.assertTrue(config.cached_initdb_unique) + + # spawn two nodes; ids must be different + with get_remote_node(conn_params=conn_params).init().start() as node1, \ + get_remote_node(conn_params=conn_params).init().start() as node2: + id1 = node1.execute(query)[0] + id2 = node2.execute(query)[0] + + # ids must increase + self.assertGreater(id1, id0) + self.assertGreater(id2, id1) + + def test_node_exit(self): + with self.assertRaises(QueryException): + with get_remote_node(conn_params=conn_params).init() as node: + base_dir = node.base_dir + node.safe_psql('select 1') + + # we should save the DB for "debugging" + self.assertTrue(os_ops.path_exists(base_dir)) + os_ops.rmdirs(base_dir, ignore_errors=True) + + with get_remote_node(conn_params=conn_params).init() as node: + base_dir = node.base_dir + + # should have been removed by default + self.assertFalse(os_ops.path_exists(base_dir)) + + def test_double_start(self): + with get_remote_node(conn_params=conn_params).init().start() as node: + # can't start node more than once + node.start() + self.assertTrue(node.is_started) + + def test_uninitialized_start(self): + with get_remote_node(conn_params=conn_params) as node: + # node is not initialized yet + with self.assertRaises(StartNodeException): + node.start() + + def test_restart(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start() + + # restart, ok + res = node.execute('select 1') + self.assertEqual(res, [(1,)]) + node.restart() + res = node.execute('select 2') + self.assertEqual(res, [(2,)]) + + # restart, fail + with self.assertRaises(StartNodeException): + node.append_conf('pg_hba.conf', 'DUMMY') + node.restart() + + def test_reload(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start() + + # change client_min_messages and save old value + cmm_old = node.execute('show client_min_messages') + node.append_conf(client_min_messages='DEBUG1') + + # reload config + node.reload() + + # check new value + cmm_new = node.execute('show client_min_messages') + self.assertEqual('debug1', cmm_new[0][0].lower()) + self.assertNotEqual(cmm_old, cmm_new) + + def test_pg_ctl(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start() + + status = node.pg_ctl(['status']) + self.assertTrue('PID' in status) + + def test_status(self): + self.assertTrue(NodeStatus.Running) + self.assertFalse(NodeStatus.Stopped) + self.assertFalse(NodeStatus.Uninitialized) + + # check statuses after each operation + with get_remote_node(conn_params=conn_params) as node: + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Uninitialized) + + node.init() + + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Stopped) + + node.start() + + self.assertNotEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Running) + + node.stop() + + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Stopped) + + node.cleanup() + + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Uninitialized) + + def test_psql(self): + with get_remote_node(conn_params=conn_params).init().start() as node: + # check returned values (1 arg) + res = node.psql('select 1') + self.assertEqual(res, (0, b'1\n', b'')) + + # check returned values (2 args) + res = node.psql('postgres', 'select 2') + self.assertEqual(res, (0, b'2\n', b'')) + + # check returned values (named) + res = node.psql(query='select 3', dbname='postgres') + self.assertEqual(res, (0, b'3\n', b'')) + + # check returned values (1 arg) + res = node.safe_psql('select 4') + self.assertEqual(res, b'4\n') + + # check returned values (2 args) + res = node.safe_psql('postgres', 'select 5') + self.assertEqual(res, b'5\n') + + # check returned values (named) + res = node.safe_psql(query='select 6', dbname='postgres') + self.assertEqual(res, b'6\n') + + # check feeding input + node.safe_psql('create table horns (w int)') + node.safe_psql('copy horns from stdin (format csv)', + input=b"1\n2\n3\n\\.\n") + _sum = node.safe_psql('select sum(w) from horns') + self.assertEqual(_sum, b'6\n') + + # check psql's default args, fails + with self.assertRaises(QueryException): + node.psql() + + node.stop() + + # check psql on stopped node, fails + with self.assertRaises(QueryException): + node.safe_psql('select 1') + + def test_transactions(self): + with get_remote_node(conn_params=conn_params).init().start() as node: + with node.connect() as con: + con.begin() + con.execute('create table test(val int)') + con.execute('insert into test values (1)') + con.commit() + + con.begin() + con.execute('insert into test values (2)') + res = con.execute('select * from test order by val asc') + self.assertListEqual(res, [(1,), (2,)]) + con.rollback() + + con.begin() + res = con.execute('select * from test') + self.assertListEqual(res, [(1,)]) + con.rollback() + + con.begin() + con.execute('drop table test') + con.commit() + + def test_control_data(self): + with get_remote_node(conn_params=conn_params) as node: + # node is not initialized yet + with self.assertRaises(ExecUtilException): + node.get_control_data() + + node.init() + data = node.get_control_data() + + # check returned dict + self.assertIsNotNone(data) + self.assertTrue(any('pg_control' in s for s in data.keys())) + + def test_backup_simple(self): + with get_remote_node(conn_params=conn_params) as master: + # enable streaming for backups + master.init(allow_streaming=True) + + # node must be running + with self.assertRaises(BackupException): + master.backup() + + # it's time to start node + master.start() + + # fill node with some data + master.psql('create table test as select generate_series(1, 4) i') + + with master.backup(xlog_method='stream') as backup: + with backup.spawn_primary().start() as slave: + res = slave.execute('select * from test order by i asc') + self.assertListEqual(res, [(1,), (2,), (3,), (4,)]) + + def test_backup_multiple(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(allow_streaming=True).start() + + with node.backup(xlog_method='fetch') as backup1, \ + node.backup(xlog_method='fetch') as backup2: + self.assertNotEqual(backup1.base_dir, backup2.base_dir) + + with node.backup(xlog_method='fetch') as backup: + with backup.spawn_primary('node1', destroy=False) as node1, \ + backup.spawn_primary('node2', destroy=False) as node2: + self.assertNotEqual(node1.base_dir, node2.base_dir) + + def test_backup_exhaust(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(allow_streaming=True).start() + + with node.backup(xlog_method='fetch') as backup: + # exhaust backup by creating new node + with backup.spawn_primary(): + pass + + # now let's try to create one more node + with self.assertRaises(BackupException): + backup.spawn_primary() + + def test_backup_wrong_xlog_method(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(allow_streaming=True).start() + + with self.assertRaises(BackupException, + msg='Invalid xlog_method "wrong"'): + node.backup(xlog_method='wrong') + + def test_pg_ctl_wait_option(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start(wait=False) + while True: + try: + node.stop(wait=False) + break + except ExecUtilException: + # it's ok to get this exception here since node + # could be not started yet + pass + + def test_replicate(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(allow_streaming=True).start() + + with node.replicate().start() as replica: + res = replica.execute('select 1') + self.assertListEqual(res, [(1,)]) + + node.execute('create table test (val int)', commit=True) + + replica.catchup() + + res = node.execute('select * from test') + self.assertListEqual(res, []) + + @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') + def test_synchronous_replication(self): + with get_remote_node(conn_params=conn_params) as master: + old_version = not pg_version_ge('9.6') + + master.init(allow_streaming=True).start() + + if not old_version: + master.append_conf('synchronous_commit = remote_apply') + + # create standby + with master.replicate() as standby1, master.replicate() as standby2: + standby1.start() + standby2.start() + + # check formatting + self.assertEqual( + '1 ("{}", "{}")'.format(standby1.name, standby2.name), + str(First(1, (standby1, standby2)))) # yapf: disable + self.assertEqual( + 'ANY 1 ("{}", "{}")'.format(standby1.name, standby2.name), + str(Any(1, (standby1, standby2)))) # yapf: disable + + # set synchronous_standby_names + master.set_synchronous_standbys(First(2, [standby1, standby2])) + master.restart() + + # the following part of the test is only applicable to newer + # versions of PostgresQL + if not old_version: + master.safe_psql('create table abc(a int)') + + # Create a large transaction that will take some time to apply + # on standby to check that it applies synchronously + # (If set synchronous_commit to 'on' or other lower level then + # standby most likely won't catchup so fast and test will fail) + master.safe_psql( + 'insert into abc select generate_series(1, 1000000)') + res = standby1.safe_psql('select count(*) from abc') + self.assertEqual(res, b'1000000\n') + + @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') + def test_logical_replication(self): + with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2: + node1.init(allow_logical=True) + node1.start() + node2.init().start() + + create_table = 'create table test (a int, b int)' + node1.safe_psql(create_table) + node2.safe_psql(create_table) + + # create publication / create subscription + pub = node1.publish('mypub') + sub = node2.subscribe(pub, 'mysub') + + node1.safe_psql('insert into test values (1, 1), (2, 2)') + + # wait until changes apply on subscriber and check them + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [(1, 1), (2, 2)]) + + # disable and put some new data + sub.disable() + node1.safe_psql('insert into test values (3, 3)') + + # enable and ensure that data successfully transfered + sub.enable() + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [(1, 1), (2, 2), (3, 3)]) + + # Add new tables. Since we added "all tables" to publication + # (default behaviour of publish() method) we don't need + # to explicitely perform pub.add_tables() + create_table = 'create table test2 (c char)' + node1.safe_psql(create_table) + node2.safe_psql(create_table) + sub.refresh() + + # put new data + node1.safe_psql('insert into test2 values (\'a\'), (\'b\')') + sub.catchup() + res = node2.execute('select * from test2') + self.assertListEqual(res, [('a',), ('b',)]) + + # drop subscription + sub.drop() + pub.drop() + + # create new publication and subscription for specific table + # (ommitting copying data as it's already done) + pub = node1.publish('newpub', tables=['test']) + sub = node2.subscribe(pub, 'newsub', copy_data=False) + + node1.safe_psql('insert into test values (4, 4)') + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [(1, 1), (2, 2), (3, 3), (4, 4)]) + + # explicitely add table + with self.assertRaises(ValueError): + pub.add_tables([]) # fail + pub.add_tables(['test2']) + node1.safe_psql('insert into test2 values (\'c\')') + sub.catchup() + res = node2.execute('select * from test2') + self.assertListEqual(res, [('a',), ('b',)]) + + @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') + def test_logical_catchup(self): + """ Runs catchup for 100 times to be sure that it is consistent """ + with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2: + node1.init(allow_logical=True) + node1.start() + node2.init().start() + + create_table = 'create table test (key int primary key, val int); ' + node1.safe_psql(create_table) + node1.safe_psql('alter table test replica identity default') + node2.safe_psql(create_table) + + # create publication / create subscription + sub = node2.subscribe(node1.publish('mypub'), 'mysub') + + for i in range(0, 100): + node1.execute('insert into test values ({0}, {0})'.format(i)) + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [( + i, + i, + )]) + node1.execute('delete from test') + + @unittest.skipIf(pg_version_ge('10'), 'requires <10') + def test_logical_replication_fail(self): + with get_remote_node(conn_params=conn_params) as node: + with self.assertRaises(InitNodeException): + node.init(allow_logical=True) + + def test_replication_slots(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(allow_streaming=True).start() + + with node.replicate(slot='slot1').start() as replica: + replica.execute('select 1') + + # cannot create new slot with the same name + with self.assertRaises(TestgresException): + node.replicate(slot='slot1') + + def test_incorrect_catchup(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(allow_streaming=True).start() + + # node has no master, can't catch up + with self.assertRaises(TestgresException): + node.catchup() + + def test_promotion(self): + with get_remote_node(conn_params=conn_params) as master: + master.init().start() + master.safe_psql('create table abc(id serial)') + + with master.replicate().start() as replica: + master.stop() + replica.promote() + + # make standby becomes writable master + replica.safe_psql('insert into abc values (1)') + res = replica.safe_psql('select * from abc') + self.assertEqual(res, b'1\n') + + def test_dump(self): + query_create = 'create table test as select generate_series(1, 2) as val' + query_select = 'select * from test order by val asc' + + with get_remote_node(conn_params=conn_params).init().start() as node1: + + node1.execute(query_create) + for format in ['plain', 'custom', 'directory', 'tar']: + with removing(node1.dump(format=format)) as dump: + with get_remote_node(conn_params=conn_params).init().start() as node3: + if format == 'directory': + self.assertTrue(os_ops.isdir(dump)) + else: + self.assertTrue(os_ops.isfile(dump)) + # restore dump + node3.restore(filename=dump) + res = node3.execute(query_select) + self.assertListEqual(res, [(1,), (2,)]) + + def test_users(self): + with get_remote_node(conn_params=conn_params).init().start() as node: + node.psql('create role test_user login') + value = node.safe_psql('select 1', username='test_user') + self.assertEqual(b'1\n', value) + + def test_poll_query_until(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start() + + get_time = 'select extract(epoch from now())' + check_time = 'select extract(epoch from now()) - {} >= 5' + + start_time = node.execute(get_time)[0][0] + node.poll_query_until(query=check_time.format(start_time)) + end_time = node.execute(get_time)[0][0] + + self.assertTrue(end_time - start_time >= 5) + + # check 0 columns + with self.assertRaises(QueryException): + node.poll_query_until( + query='select from pg_catalog.pg_class limit 1') + + # check None, fail + with self.assertRaises(QueryException): + node.poll_query_until(query='create table abc (val int)') + + # check None, ok + node.poll_query_until(query='create table def()', + expected=None) # returns nothing + + # check 0 rows equivalent to expected=None + node.poll_query_until( + query='select * from pg_catalog.pg_class where true = false', + expected=None) + + # check arbitrary expected value, fail + with self.assertRaises(TimeoutException): + node.poll_query_until(query='select 3', + expected=1, + max_attempts=3, + sleep_time=0.01) + + # check arbitrary expected value, ok + node.poll_query_until(query='select 2', expected=2) + + # check timeout + with self.assertRaises(TimeoutException): + node.poll_query_until(query='select 1 > 2', + max_attempts=3, + sleep_time=0.01) + + # check ProgrammingError, fail + with self.assertRaises(testgres.ProgrammingError): + node.poll_query_until(query='dummy1') + + # check ProgrammingError, ok + with self.assertRaises(TimeoutException): + node.poll_query_until(query='dummy2', + max_attempts=3, + sleep_time=0.01, + suppress={testgres.ProgrammingError}) + + # check 1 arg, ok + node.poll_query_until('select true') + + def test_logging(self): + # FAIL + logfile = tempfile.NamedTemporaryFile('w', delete=True) + + log_conf = { + 'version': 1, + 'handlers': { + 'file': { + 'class': 'logging.FileHandler', + 'filename': logfile.name, + 'formatter': 'base_format', + 'level': logging.DEBUG, + }, + }, + 'formatters': { + 'base_format': { + 'format': '%(node)-5s: %(message)s', + }, + }, + 'root': { + 'handlers': ('file',), + 'level': 'DEBUG', + }, + } + + logging.config.dictConfig(log_conf) + + with scoped_config(use_python_logging=True): + node_name = 'master' + + with get_remote_node(name=node_name) as master: + master.init().start() + + # execute a dummy query a few times + for i in range(20): + master.execute('select 1') + time.sleep(0.01) + + # let logging worker do the job + time.sleep(0.1) + + # check that master's port is found + with open(logfile.name, 'r') as log: + lines = log.readlines() + self.assertTrue(any(node_name in s for s in lines)) + + # test logger after stop/start/restart + master.stop() + master.start() + master.restart() + self.assertTrue(master._logger.is_alive()) + + @unittest.skipUnless(util_exists('pgbench'), 'might be missing') + def test_pgbench(self): + with get_remote_node(conn_params=conn_params).init().start() as node: + # initialize pgbench DB and run benchmarks + node.pgbench_init(scale=2, foreign_keys=True, + options=['-q']).pgbench_run(time=2) + + # run TPC-B benchmark + out = node.pgbench(stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + options=['-T3']) + self.assertTrue(b'tps = ' in out) + + def test_pg_config(self): + # check same instances + a = get_pg_config() + b = get_pg_config() + self.assertEqual(id(a), id(b)) + + # save right before config change + c1 = get_pg_config() + # modify setting for this scope + with scoped_config(cache_pg_config=False) as config: + # sanity check for value + self.assertFalse(config.cache_pg_config) + + # save right after config change + c2 = get_pg_config() + + # check different instances after config change + self.assertNotEqual(id(c1), id(c2)) + + # check different instances + a = get_pg_config() + b = get_pg_config() + self.assertNotEqual(id(a), id(b)) + + def test_config_stack(self): + # no such option + with self.assertRaises(TypeError): + configure_testgres(dummy=True) + + # we have only 1 config in stack + with self.assertRaises(IndexError): + pop_config() + + d0 = TestgresConfig.cached_initdb_dir + d1 = 'dummy_abc' + d2 = 'dummy_def' + + with scoped_config(cached_initdb_dir=d1) as c1: + self.assertEqual(c1.cached_initdb_dir, d1) + + with scoped_config(cached_initdb_dir=d2) as c2: + stack_size = len(testgres.config.config_stack) + + # try to break a stack + with self.assertRaises(TypeError): + with scoped_config(dummy=True): + pass + + self.assertEqual(c2.cached_initdb_dir, d2) + self.assertEqual(len(testgres.config.config_stack), stack_size) + + self.assertEqual(c1.cached_initdb_dir, d1) + + self.assertEqual(TestgresConfig.cached_initdb_dir, d0) + + def test_unix_sockets(self): + with get_remote_node(conn_params=conn_params) as node: + node.init(unix_sockets=False, allow_streaming=True) + node.start() + + res_exec = node.execute('select 1') + res_psql = node.safe_psql('select 1') + self.assertEqual(res_exec, [(1,)]) + self.assertEqual(res_psql, b'1\n') + + with node.replicate().start() as r: + res_exec = r.execute('select 1') + res_psql = r.safe_psql('select 1') + self.assertEqual(res_exec, [(1,)]) + self.assertEqual(res_psql, b'1\n') + + def test_auto_name(self): + with get_remote_node(conn_params=conn_params).init(allow_streaming=True).start() as m: + with m.replicate().start() as r: + # check that nodes are running + self.assertTrue(m.status()) + self.assertTrue(r.status()) + + # check their names + self.assertNotEqual(m.name, r.name) + self.assertTrue('testgres' in m.name) + self.assertTrue('testgres' in r.name) + + def test_file_tail(self): + from testgres.utils import file_tail + + s1 = "the quick brown fox jumped over that lazy dog\n" + s2 = "abc\n" + s3 = "def\n" + + with tempfile.NamedTemporaryFile(mode='r+', delete=True) as f: + sz = 0 + while sz < 3 * 8192: + sz += len(s1) + f.write(s1) + f.write(s2) + f.write(s3) + + f.seek(0) + lines = file_tail(f, 3) + self.assertEqual(lines[0], s1) + self.assertEqual(lines[1], s2) + self.assertEqual(lines[2], s3) + + f.seek(0) + lines = file_tail(f, 1) + self.assertEqual(lines[0], s3) + + def test_isolation_levels(self): + with get_remote_node(conn_params=conn_params).init().start() as node: + with node.connect() as con: + # string levels + con.begin('Read Uncommitted').commit() + con.begin('Read Committed').commit() + con.begin('Repeatable Read').commit() + con.begin('Serializable').commit() + + # enum levels + con.begin(IsolationLevel.ReadUncommitted).commit() + con.begin(IsolationLevel.ReadCommitted).commit() + con.begin(IsolationLevel.RepeatableRead).commit() + con.begin(IsolationLevel.Serializable).commit() + + # check wrong level + with self.assertRaises(QueryException): + con.begin('Garbage').commit() + + def test_ports_management(self): + # check that no ports have been bound yet + self.assertEqual(len(bound_ports), 0) + + with get_remote_node(conn_params=conn_params) as node: + # check that we've just bound a port + self.assertEqual(len(bound_ports), 1) + + # check that bound_ports contains our port + port_1 = list(bound_ports)[0] + port_2 = node.port + self.assertEqual(port_1, port_2) + + # check that port has been freed successfully + self.assertEqual(len(bound_ports), 0) + + def test_exceptions(self): + str(StartNodeException('msg', [('file', 'lines')])) + str(ExecUtilException('msg', 'cmd', 1, 'out')) + str(QueryException('msg', 'query')) + + def test_version_management(self): + a = PgVer('10.0') + b = PgVer('10') + c = PgVer('9.6.5') + d = PgVer('15.0') + e = PgVer('15rc1') + f = PgVer('15beta4') + + self.assertTrue(a == b) + self.assertTrue(b > c) + self.assertTrue(a > c) + self.assertTrue(d > e) + self.assertTrue(e > f) + self.assertTrue(d > f) + + version = get_pg_version() + with get_remote_node(conn_params=conn_params) as node: + self.assertTrue(isinstance(version, six.string_types)) + self.assertTrue(isinstance(node.version, PgVer)) + self.assertEqual(node.version, PgVer(version)) + + def test_child_pids(self): + master_processes = [ + ProcessType.AutovacuumLauncher, + ProcessType.BackgroundWriter, + ProcessType.Checkpointer, + ProcessType.StatsCollector, + ProcessType.WalSender, + ProcessType.WalWriter, + ] + + if pg_version_ge('10'): + master_processes.append(ProcessType.LogicalReplicationLauncher) + + repl_processes = [ + ProcessType.Startup, + ProcessType.WalReceiver, + ] + + with get_remote_node(conn_params=conn_params).init().start() as master: + + # master node doesn't have a source walsender! + with self.assertRaises(TestgresException): + master.source_walsender + + with master.connect() as con: + self.assertGreater(con.pid, 0) + + with master.replicate().start() as replica: + + # test __str__ method + str(master.child_processes[0]) + + master_pids = master.auxiliary_pids + for ptype in master_processes: + self.assertIn(ptype, master_pids) + + replica_pids = replica.auxiliary_pids + for ptype in repl_processes: + self.assertIn(ptype, replica_pids) + + # there should be exactly 1 source walsender for replica + self.assertEqual(len(master_pids[ProcessType.WalSender]), 1) + pid1 = master_pids[ProcessType.WalSender][0] + pid2 = replica.source_walsender.pid + self.assertEqual(pid1, pid2) + + replica.stop() + + # there should be no walsender after we've stopped replica + with self.assertRaises(TestgresException): + replica.source_walsender + + def test_child_process_dies(self): + # test for FileNotFound exception during child_processes() function + with subprocess.Popen(["sleep", "60"]) as process: + self.assertEqual(process.poll(), None) + # collect list of processes currently running + children = psutil.Process(os.getpid()).children() + # kill a process, so received children dictionary becomes invalid + process.kill() + process.wait() + # try to handle children list -- missing processes will have ptype "ProcessType.Unknown" + [ProcessProxy(p) for p in children] + + +if __name__ == '__main__': + if os_ops.environ('ALT_CONFIG'): + suite = unittest.TestSuite() + + # Small subset of tests for alternative configs (PG_BIN or PG_CONFIG) + suite.addTest(TestgresRemoteTests('test_pg_config')) + suite.addTest(TestgresRemoteTests('test_pg_ctl')) + suite.addTest(TestgresRemoteTests('test_psql')) + suite.addTest(TestgresRemoteTests('test_replicate')) + + print('Running tests for alternative config:') + for t in suite: + print(t) + print() + + runner = unittest.TextTestRunner() + runner.run(suite) + else: + unittest.main()