tests/regression/tools/trigger/utils/Makefile
tests/regression/tools/trigger/name/Makefile
tests/regression/tools/trigger/hidden/Makefile
+ tests/regression/tools/context/Makefile
tests/regression/ust/Makefile
tests/regression/ust/nprocesses/Makefile
tests/regression/ust/high-throughput/Makefile
tests/unit/ini_config/Makefile
tests/perf/Makefile
tests/utils/Makefile
+ tests/utils/lttngtest/Makefile
tests/utils/tap/Makefile
tests/utils/testapp/Makefile
tests/utils/testapp/gen-ns-events/Makefile
tools/trigger/test_list_triggers_cli \
tools/trigger/test_remove_trigger_cli \
tools/trigger/name/test_trigger_name_backwards_compat \
- tools/trigger/hidden/test_hidden_trigger
+ tools/trigger/hidden/test_hidden_trigger \
+ tools/context/test_ust.py
# Only build kernel tests on Linux.
if IS_LINUX
SUBDIRS = base-path \
channel \
clear \
+ context \
crash \
exclusion \
filtering \
--- /dev/null
+# SPDX-License-Identifier: GPL-2.0-only
+
+noinst_SCRIPTS = test_ust.py
+EXTRA_DIST = test_ust.py
+
+all-local:
+ @if [ x"$(srcdir)" != x"$(builddir)" ]; then \
+ for script in $(EXTRA_DIST); do \
+ cp -f $(srcdir)/$$script $(builddir); \
+ done; \
+ fi
+
+clean-local:
+ @if [ x"$(srcdir)" != x"$(builddir)" ]; then \
+ for script in $(EXTRA_DIST); do \
+ rm -f $(builddir)/$$script; \
+ done; \
+ fi
--- /dev/null
+#!/usr/bin/env python3
+#
+# Copyright (C) 2022 Jérémie Galarneau <jeremie.galarneau@efficios.com>
+#
+# SPDX-License-Identifier: GPL-2.0-only
+
+from cgi import test
+import pathlib
+import sys
+import os
+from typing import Any, Callable, Type
+
+"""
+Test the addition of various user space contexts.
+
+This test successively sets up a session with a certain context enabled, traces
+a test application, and then reads the resulting trace to determine if:
+ - the context field is present in the trace
+ - the context field has the expected value.
+
+The vpid, vuid, vgid and java application contexts are validated by this test.
+"""
+
+# Import in-tree test utils
+test_utils_import_path = pathlib.Path(__file__).absolute().parents[3] / "utils"
+sys.path.append(str(test_utils_import_path))
+
+import lttngtest
+import bt2
+
+
+def context_trace_field_name(context_type: Type[lttngtest.ContextType]) -> str:
+ if isinstance(context_type, lttngtest.VpidContextType):
+ return "vpid"
+ elif isinstance(context_type, lttngtest.VuidContextType):
+ return "vuid"
+ elif isinstance(context_type, lttngtest.VgidContextType):
+ return "vgid"
+ elif isinstance(context_type, lttngtest.JavaApplicationContextType):
+ # Depends on the trace format and will need to be adapted for CTF 2.
+ return "_app_{retriever}_{name}".format(
+ retriever=context_type.retriever_name, name=context_type.field_name
+ )
+ else:
+ raise NotImplementedError
+
+
+def trace_stream_class_has_context_field_in_event_context(
+ trace_location: pathlib.Path, context_field_name: str
+) -> bool:
+ iterator = bt2.TraceCollectionMessageIterator(str(trace_location))
+
+ # A bt2 message sequence is guaranteed to begin with a StreamBeginningMessage.
+ # Since we only have one channel (one stream class) and one trace, it is
+ # safe to use it to determine if the stream class contains the expected
+ # context field.
+ stream_begin_msg = next(iterator)
+
+ trace_class = stream_begin_msg.stream.trace.cls
+ # Ensure the trace class has only one stream class.
+ assert len(trace_class)
+
+ stream_class_id = next(iter(trace_class))
+ stream_class = trace_class[stream_class_id]
+ event_common_context_field_class = stream_class.event_common_context_field_class
+
+ return context_field_name in event_common_context_field_class
+
+
+def trace_events_have_context_value(
+ trace_location: pathlib.Path, context_field_name: str, value: Any
+) -> bool:
+ for msg in bt2.TraceCollectionMessageIterator(str(trace_location)):
+ if type(msg) is not bt2._EventMessageConst:
+ continue
+
+ if msg.event.common_context_field[context_field_name] != value:
+ print(msg.event.common_context_field[context_field_name])
+ return False
+ return True
+
+
+def test_static_context(
+ tap: lttngtest.TapGenerator,
+ test_env: lttngtest._Environment,
+ context_type: lttngtest.ContextType,
+ context_value_retriever: Callable[[lttngtest.WaitTraceTestApplication], Any],
+) -> None:
+ tap.diagnostic(
+ "Test presence and expected value of context `{context_name}`".format(
+ context_name=type(context_type).__name__
+ )
+ )
+
+ session_output_location = lttngtest.LocalSessionOutputLocation(
+ test_env.create_temporary_directory("trace")
+ )
+
+ client: lttngtest.Controller = lttngtest.LTTngClient(test_env, log=tap.diagnostic)
+
+ with tap.case("Create a session") as test_case:
+ session = client.create_session(output=session_output_location)
+ tap.diagnostic("Created session `{session_name}`".format(session_name=session.name))
+
+ with tap.case(
+ "Add a channel to session `{session_name}`".format(session_name=session.name)
+ ) as test_case:
+ channel = session.add_channel(lttngtest.TracingDomain.User)
+ tap.diagnostic("Created channel `{channel_name}`".format(channel_name=channel.name))
+
+ with tap.case(
+ "Add {context_type} context to channel `{channel_name}`".format(
+ context_type=type(context_type).__name__, channel_name=channel.name
+ )
+ ) as test_case:
+ channel.add_context(context_type)
+
+ test_app = test_env.launch_wait_trace_test_application(50)
+
+ # Only track the test application
+ session.user_vpid_process_attribute_tracker.track(test_app.vpid)
+ expected_context_value = context_value_retriever(test_app)
+
+ # Enable all user space events, the default for a user tracepoint event rule.
+ channel.add_recording_rule(lttngtest.UserTracepointEventRule())
+
+ session.start()
+ test_app.trace()
+ test_app.wait_for_exit()
+ session.stop()
+ session.destroy()
+
+ tap.test(
+ trace_stream_class_has_context_field_in_event_context(
+ session_output_location.path, context_trace_field_name(context_type)
+ ),
+ "Stream class contains field `{context_field_name}`".format(
+ context_field_name=context_trace_field_name(context_type)
+ ),
+ )
+
+ tap.test(
+ trace_events_have_context_value(
+ session_output_location.path,
+ context_trace_field_name(context_type),
+ expected_context_value,
+ ),
+ "Trace's events contain the expected `{context_field_name}` value `{expected_context_value}`".format(
+ context_field_name=context_trace_field_name(context_type),
+ expected_context_value=expected_context_value,
+ ),
+ )
+
+
+tap = lttngtest.TapGenerator(20)
+tap.diagnostic("Test user space context tracing")
+
+with lttngtest.test_environment(with_sessiond=True, log=tap.diagnostic) as test_env:
+ test_static_context(
+ tap, test_env, lttngtest.VpidContextType(), lambda test_app: test_app.vpid
+ )
+ test_static_context(
+ tap, test_env, lttngtest.VuidContextType(), lambda test_app: os.getuid()
+ )
+ test_static_context(
+ tap, test_env, lttngtest.VgidContextType(), lambda test_app: os.getgid()
+ )
+ test_static_context(
+ tap,
+ test_env,
+ lttngtest.JavaApplicationContextType("mayo", "ketchup"),
+ lambda test_app: {},
+ )
+
+sys.exit(0 if tap.is_successful else 1)
# SPDX-License-Identifier: GPL-2.0-only
-SUBDIRS = . tap testapp xml-utils
+SUBDIRS = . tap testapp xml-utils lttngtest
EXTRA_DIST = utils.sh test_utils.py babelstats.pl warn_processes.sh \
parse-callstack.py
--- /dev/null
+# SPDX-License-Identifier: GPL-2.0-only
+
+EXTRA_DIST = __init__.py \
+ environment.py \
+ logger.py \
+ lttngctl.py \
+ lttng.py \
+ tap_generator.py
+
+dist_noinst_SCRIPTS = __init__.py \
+ environment.py \
+ logger.py \
+ lttngctl.py \
+ lttng.py \
+ tap_generator.py
+
+all-local:
+ @if [ x"$(srcdir)" != x"$(builddir)" ]; then \
+ for script in $(EXTRA_DIST); do \
+ cp -f $(srcdir)/$$script $(builddir); \
+ done; \
+ fi
+
+clean-local:
+ @if [ x"$(srcdir)" != x"$(builddir)" ]; then \
+ for script in $(EXTRA_DIST); do \
+ rm -f $(builddir)/$$script; \
+ done; \
+ fi
--- /dev/null
+#!/usr/bin/env python3
+#
+# Copyright (C) 2022 Jérémie Galarneau <jeremie.galarneau@efficios.com>
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+from .tap_generator import *
+from .environment import *
+from .environment import _Environment
+from .lttngctl import *
+from .lttng import *
--- /dev/null
+#!/usr/bin/env python3
+#
+# Copyright (C) 2022 Jérémie Galarneau <jeremie.galarneau@efficios.com>
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+from types import FrameType
+from typing import Callable, Optional, Tuple, List
+import sys
+import pathlib
+import signal
+import subprocess
+import shlex
+import shutil
+import os
+import queue
+import tempfile
+from . import logger
+import time
+import threading
+import contextlib
+
+
+class TemporaryDirectory:
+ def __init__(self, prefix: str):
+ self._directory_path = tempfile.mkdtemp(prefix=prefix)
+
+ def __del__(self):
+ shutil.rmtree(self._directory_path, ignore_errors=True)
+
+ @property
+ def path(self) -> pathlib.Path:
+ return pathlib.Path(self._directory_path)
+
+
+class _SignalWaitQueue:
+ """
+ Utility class useful to wait for a signal before proceeding.
+
+ Simply register the `signal` method as the handler for the signal you are
+ interested in and call `wait_for_signal` to wait for its reception.
+
+ Registering a signal:
+ signal.signal(signal.SIGWHATEVER, queue.signal)
+
+ Waiting for the signal:
+ queue.wait_for_signal()
+ """
+
+ def __init__(self):
+ self._queue: queue.Queue = queue.Queue()
+
+ def signal(self, signal_number, frame: Optional[FrameType]):
+ self._queue.put_nowait(signal_number)
+
+ def wait_for_signal(self):
+ self._queue.get(block=True)
+
+
+class WaitTraceTestApplication:
+ """
+ Create an application that waits before tracing. This allows a test to
+ launch an application, get its PID, and get it to start tracing when it
+ has completed its setup.
+ """
+
+ def __init__(
+ self,
+ binary_path: pathlib.Path,
+ event_count: int,
+ environment: "Environment",
+ wait_time_between_events_us: int = 0,
+ ):
+ self._environment: Environment = environment
+ if event_count % 5:
+ # The test application currently produces 5 different events per iteration.
+ raise ValueError("event count must be a multiple of 5")
+ self._iteration_count: int = int(event_count / 5)
+ # File that the application will wait to see before tracing its events.
+ self._app_start_tracing_file_path: pathlib.Path = pathlib.Path(
+ tempfile.mktemp(
+ prefix="app_",
+ suffix="_start_tracing",
+ dir=environment.lttng_home_location,
+ )
+ )
+ self._has_returned = False
+
+ test_app_env = os.environ.copy()
+ test_app_env["LTTNG_HOME"] = str(environment.lttng_home_location)
+ # Make sure the app is blocked until it is properly registered to
+ # the session daemon.
+ test_app_env["LTTNG_UST_REGISTER_TIMEOUT"] = "-1"
+
+ # File that the application will create to indicate it has completed its initialization.
+ app_ready_file_path: str = tempfile.mktemp(
+ prefix="app_", suffix="_ready", dir=environment.lttng_home_location
+ )
+
+ test_app_args = [str(binary_path)]
+ test_app_args.extend(
+ shlex.split(
+ "--iter {iteration_count} --create-in-main {app_ready_file_path} --wait-before-first-event {app_start_tracing_file_path} --wait {wait_time_between_events_us}".format(
+ iteration_count=self._iteration_count,
+ app_ready_file_path=app_ready_file_path,
+ app_start_tracing_file_path=self._app_start_tracing_file_path,
+ wait_time_between_events_us=wait_time_between_events_us,
+ )
+ )
+ )
+
+ self._process: subprocess.Popen = subprocess.Popen(
+ test_app_args,
+ env=test_app_env,
+ )
+
+ # Wait for the application to create the file indicating it has fully
+ # initialized. Make sure the app hasn't crashed in order to not wait
+ # forever.
+ while True:
+ if os.path.exists(app_ready_file_path):
+ break
+
+ if self._process.poll() is not None:
+ # Application has unexepectedly returned.
+ raise RuntimeError(
+ "Test application has unexepectedly returned during its initialization with return code `{return_code}`".format(
+ return_code=self._process.returncode
+ )
+ )
+
+ time.sleep(0.1)
+
+ def trace(self) -> None:
+ if self._process.poll() is not None:
+ # Application has unexepectedly returned.
+ raise RuntimeError(
+ "Test application has unexepectedly before tracing with return code `{return_code}`".format(
+ return_code=self._process.returncode
+ )
+ )
+ open(self._app_start_tracing_file_path, mode="x")
+
+ def wait_for_exit(self) -> None:
+ if self._process.wait() != 0:
+ raise RuntimeError(
+ "Test application has exit with return code `{return_code}`".format(
+ return_code=self._process.returncode
+ )
+ )
+ self._has_returned = True
+
+ @property
+ def vpid(self) -> int:
+ return self._process.pid
+
+ def __del__(self):
+ if not self._has_returned:
+ # This is potentially racy if the pid has been recycled. However,
+ # we can't use pidfd_open since it is only available in python >= 3.9.
+ self._process.kill()
+ self._process.wait()
+
+
+class ProcessOutputConsumer(threading.Thread, logger._Logger):
+ def __init__(
+ self, process: subprocess.Popen, name: str, log: Callable[[str], None]
+ ):
+ threading.Thread.__init__(self)
+ self._prefix = name
+ logger._Logger.__init__(self, log)
+ self._process = process
+
+ def run(self) -> None:
+ while self._process.poll() is None:
+ assert self._process.stdout
+ line = self._process.stdout.readline().decode("utf-8").replace("\n", "")
+ if len(line) != 0:
+ self._log("{prefix}: {line}".format(prefix=self._prefix, line=line))
+
+
+# Generate a temporary environment in which to execute a test.
+class _Environment(logger._Logger):
+ def __init__(
+ self, with_sessiond: bool, log: Optional[Callable[[str], None]] = None
+ ):
+ super().__init__(log)
+ signal.signal(signal.SIGTERM, self._handle_termination_signal)
+ signal.signal(signal.SIGINT, self._handle_termination_signal)
+
+ # Assumes the project's hierarchy to this file is:
+ # tests/utils/python/this_file
+ self._project_root: pathlib.Path = pathlib.Path(__file__).absolute().parents[3]
+ self._lttng_home: Optional[TemporaryDirectory] = TemporaryDirectory(
+ "lttng_test_env_home"
+ )
+
+ self._sessiond: Optional[subprocess.Popen[bytes]] = (
+ self._launch_lttng_sessiond() if with_sessiond else None
+ )
+
+ @property
+ def lttng_home_location(self) -> pathlib.Path:
+ if self._lttng_home is None:
+ raise RuntimeError("Attempt to access LTTng home after clean-up")
+ return self._lttng_home.path
+
+ @property
+ def lttng_client_path(self) -> pathlib.Path:
+ return self._project_root / "src" / "bin" / "lttng" / "lttng"
+
+ def create_temporary_directory(self, prefix: Optional[str] = None) -> pathlib.Path:
+ # Simply return a path that is contained within LTTNG_HOME; it will
+ # be destroyed when the temporary home goes out of scope.
+ assert self._lttng_home
+ return pathlib.Path(
+ tempfile.mkdtemp(
+ prefix="tmp" if prefix is None else prefix,
+ dir=str(self._lttng_home.path),
+ )
+ )
+
+ # Unpack a list of environment variables from a string
+ # such as "HELLO=is_it ME='/you/are/looking/for'"
+ @staticmethod
+ def _unpack_env_vars(env_vars_string: str) -> List[Tuple[str, str]]:
+ unpacked_vars = []
+ for var in shlex.split(env_vars_string):
+ equal_position = var.find("=")
+ # Must have an equal sign and not end with an equal sign
+ if equal_position == -1 or equal_position == len(var) - 1:
+ raise ValueError(
+ "Invalid sessiond environment variable: `{}`".format(var)
+ )
+
+ var_name = var[0:equal_position]
+ var_value = var[equal_position + 1 :]
+ # Unquote any paths
+ var_value = var_value.replace("'", "")
+ var_value = var_value.replace('"', "")
+ unpacked_vars.append((var_name, var_value))
+
+ return unpacked_vars
+
+ def _launch_lttng_sessiond(self) -> Optional[subprocess.Popen]:
+ is_64bits_host = sys.maxsize > 2**32
+
+ sessiond_path = (
+ self._project_root / "src" / "bin" / "lttng-sessiond" / "lttng-sessiond"
+ )
+ consumerd_path_option_name = "--consumerd{bitness}-path".format(
+ bitness="64" if is_64bits_host else "32"
+ )
+ consumerd_path = (
+ self._project_root / "src" / "bin" / "lttng-consumerd" / "lttng-consumerd"
+ )
+
+ no_sessiond_var = os.environ.get("TEST_NO_SESSIOND")
+ if no_sessiond_var and no_sessiond_var == "1":
+ # Run test without a session daemon; the user probably
+ # intends to run one under gdb for example.
+ return None
+
+ # Setup the session daemon's environment
+ sessiond_env_vars = os.environ.get("LTTNG_SESSIOND_ENV_VARS")
+ sessiond_env = os.environ.copy()
+ if sessiond_env_vars:
+ self._log("Additional lttng-sessiond environment variables:")
+ additional_vars = self._unpack_env_vars(sessiond_env_vars)
+ for var_name, var_value in additional_vars:
+ self._log(" {name}={value}".format(name=var_name, value=var_value))
+ sessiond_env[var_name] = var_value
+
+ sessiond_env["LTTNG_SESSION_CONFIG_XSD_PATH"] = str(
+ self._project_root / "src" / "common"
+ )
+
+ assert self._lttng_home is not None
+ sessiond_env["LTTNG_HOME"] = str(self._lttng_home.path)
+
+ wait_queue = _SignalWaitQueue()
+ signal.signal(signal.SIGUSR1, wait_queue.signal)
+
+ self._log(
+ "Launching session daemon with LTTNG_HOME=`{home_dir}`".format(
+ home_dir=str(self._lttng_home.path)
+ )
+ )
+ process = subprocess.Popen(
+ [
+ str(sessiond_path),
+ consumerd_path_option_name,
+ str(consumerd_path),
+ "--sig-parent",
+ ],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ env=sessiond_env,
+ )
+
+ if self._logging_function:
+ self._sessiond_output_consumer: Optional[
+ ProcessOutputConsumer
+ ] = ProcessOutputConsumer(process, "lttng-sessiond", self._logging_function)
+ self._sessiond_output_consumer.daemon = True
+ self._sessiond_output_consumer.start()
+
+ # Wait for SIGUSR1, indicating the sessiond is ready to proceed
+ wait_queue.wait_for_signal()
+ signal.signal(signal.SIGUSR1, wait_queue.signal)
+
+ return process
+
+ def _handle_termination_signal(
+ self, signal_number: int, frame: Optional[FrameType]
+ ) -> None:
+ self._log(
+ "Killed by {signal_name} signal, cleaning-up".format(
+ signal_name=signal.strsignal(signal_number)
+ )
+ )
+ self._cleanup()
+
+ def launch_wait_trace_test_application(
+ self, event_count: int
+ ) -> WaitTraceTestApplication:
+ """
+ Launch an application that will wait before tracing `event_count` events.
+ """
+ return WaitTraceTestApplication(
+ self._project_root
+ / "tests"
+ / "utils"
+ / "testapp"
+ / "gen-ust-nevents"
+ / "gen-ust-nevents",
+ event_count,
+ self,
+ )
+
+ # Clean-up managed processes
+ def _cleanup(self) -> None:
+ if self._sessiond and self._sessiond.poll() is None:
+ # The session daemon is alive; kill it.
+ self._log(
+ "Killing session daemon (pid = {sessiond_pid})".format(
+ sessiond_pid=self._sessiond.pid
+ )
+ )
+
+ self._sessiond.terminate()
+ self._sessiond.wait()
+ if self._sessiond_output_consumer:
+ self._sessiond_output_consumer.join()
+ self._sessiond_output_consumer = None
+
+ self._log("Session daemon killed")
+ self._sessiond = None
+
+ self._lttng_home = None
+
+ def __del__(self):
+ self._cleanup()
+
+
+@contextlib.contextmanager
+def test_environment(with_sessiond: bool, log: Optional[Callable[[str], None]] = None):
+ env = _Environment(with_sessiond, log)
+ try:
+ yield env
+ finally:
+ env._cleanup()
--- /dev/null
+#!/usr/bin/env python3
+#
+# Copyright (C) 2022 Jérémie Galarneau <jeremie.galarneau@efficios.com>
+#
+# SPDX-License-Identifier: GPL-2.0-only
+
+from typing import Callable, Optional
+
+
+class _Logger:
+ def __init__(self, log: Optional[Callable[[str], None]]):
+ self._logging_function: Optional[Callable[[str], None]] = log
+
+ def _log(self, msg: str) -> None:
+ if self._logging_function:
+ self._logging_function(msg)
+
+ @property
+ def logger(self) -> Optional[Callable[[str], None]]:
+ return self._logging_function
--- /dev/null
+#!/usr/bin/env python3
+#
+# Copyright (C) 2022 Jérémie Galarneau <jeremie.galarneau@efficios.com>
+#
+# SPDX-License-Identifier: GPL-2.0-only
+
+from concurrent.futures import process
+from . import lttngctl, logger, environment
+import pathlib
+import os
+from typing import Callable, Optional, Type, Union
+import shlex
+import subprocess
+import enum
+
+"""
+Implementation of the lttngctl interface based on the `lttng` command line client.
+"""
+
+
+class Unsupported(lttngctl.ControlException):
+ def __init__(self, msg: str):
+ super().__init__(msg)
+
+
+def _get_domain_option_name(domain: lttngctl.TracingDomain) -> str:
+ if domain == lttngctl.TracingDomain.User:
+ return "userspace"
+ elif domain == lttngctl.TracingDomain.Kernel:
+ return "kernel"
+ elif domain == lttngctl.TracingDomain.Log4j:
+ return "log4j"
+ elif domain == lttngctl.TracingDomain.JUL:
+ return "jul"
+ elif domain == lttngctl.TracingDomain.Python:
+ return "python"
+ else:
+ raise Unsupported("Domain `{domain_name}` is not supported by the LTTng client")
+
+
+def _get_context_type_name(context: lttngctl.ContextType) -> str:
+ if isinstance(context, lttngctl.VgidContextType):
+ return "vgid"
+ elif isinstance(context, lttngctl.VuidContextType):
+ return "vuid"
+ elif isinstance(context, lttngctl.VpidContextType):
+ return "vpid"
+ elif isinstance(context, lttngctl.JavaApplicationContextType):
+ return "$app.{retriever}:{field}".format(
+ retriever=context.retriever_name, field=context.field_name
+ )
+ else:
+ raise Unsupported(
+ "Context `{context_name}` is not supported by the LTTng client".format(
+ type(context).__name__
+ )
+ )
+
+
+class _Channel(lttngctl.Channel):
+ def __init__(
+ self,
+ client: "LTTngClient",
+ name: str,
+ domain: lttngctl.TracingDomain,
+ session: "_Session",
+ ):
+ self._client: LTTngClient = client
+ self._name: str = name
+ self._domain: lttngctl.TracingDomain = domain
+ self._session: _Session = session
+
+ def add_context(self, context_type: lttngctl.ContextType) -> None:
+ domain_option_name = _get_domain_option_name(self.domain)
+ context_type_name = _get_context_type_name(context_type)
+ self._client._run_cmd(
+ "add-context --{domain_option_name} --type {context_type_name}".format(
+ domain_option_name=domain_option_name,
+ context_type_name=context_type_name,
+ )
+ )
+
+ def add_recording_rule(self, rule: Type[lttngctl.EventRule]) -> None:
+ client_args = (
+ "enable-event --session {session_name} --channel {channel_name}".format(
+ session_name=self._session.name, channel_name=self.name
+ )
+ )
+ if isinstance(rule, lttngctl.TracepointEventRule):
+ domain_option_name = (
+ "userspace"
+ if isinstance(rule, lttngctl.UserTracepointEventRule)
+ else "kernel"
+ )
+ client_args = client_args + " --{domain_option_name}".format(
+ domain_option_name=domain_option_name
+ )
+
+ if rule.name_pattern:
+ client_args = client_args + " " + rule.name_pattern
+ else:
+ client_args = client_args + " --all"
+
+ if rule.filter_expression:
+ client_args = client_args + " " + rule.filter_expression
+
+ if rule.log_level_rule:
+ if isinstance(rule.log_level_rule, lttngctl.LogLevelRuleAsSevereAs):
+ client_args = client_args + " --loglevel {log_level}".format(
+ log_level=rule.log_level_rule.level
+ )
+ elif isinstance(rule.log_level_rule, lttngctl.LogLevelRuleExactly):
+ client_args = client_args + " --loglevel-only {log_level}".format(
+ log_level=rule.log_level_rule.level
+ )
+ else:
+ raise Unsupported(
+ "Unsupported log level rule type `{log_level_rule_type}`".format(
+ log_level_rule_type=type(rule.log_level_rule).__name__
+ )
+ )
+
+ if rule.name_pattern_exclusions:
+ client_args = client_args + " --exclude "
+ for idx, pattern in enumerate(rule.name_pattern_exclusions):
+ if idx != 0:
+ client_args = client_args + ","
+ client_args = client_args + pattern
+ else:
+ raise Unsupported(
+ "event rule type `{event_rule_type}` is unsupported by LTTng client".format(
+ event_rule_type=type(rule).__name__
+ )
+ )
+
+ self._client._run_cmd(client_args)
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ @property
+ def domain(self) -> lttngctl.TracingDomain:
+ return self._domain
+
+
+class _ProcessAttribute(enum.Enum):
+ PID = (enum.auto(),)
+ VPID = (enum.auto(),)
+ UID = (enum.auto(),)
+ VUID = (enum.auto(),)
+ GID = (enum.auto(),)
+ VGID = (enum.auto(),)
+
+
+def _get_process_attribute_option_name(attribute: _ProcessAttribute) -> str:
+ return {
+ _ProcessAttribute.PID: "pid",
+ _ProcessAttribute.VPID: "vpid",
+ _ProcessAttribute.UID: "uid",
+ _ProcessAttribute.VUID: "vuid",
+ _ProcessAttribute.GID: "gid",
+ _ProcessAttribute.VGID: "vgid",
+ }[attribute]
+
+
+class _ProcessAttributeTracker(lttngctl.ProcessAttributeTracker):
+ def __init__(
+ self,
+ client: "LTTngClient",
+ attribute: _ProcessAttribute,
+ domain: lttngctl.TracingDomain,
+ session: "_Session",
+ ):
+ self._client: LTTngClient = client
+ self._tracked_attribute: _ProcessAttribute = attribute
+ self._domain: lttngctl.TracingDomain = domain
+ self._session: "_Session" = session
+ if attribute == _ProcessAttribute.PID or attribute == _ProcessAttribute.VPID:
+ self._allowed_value_types: list[type] = [int, str]
+ else:
+ self._allowed_value_types: list[type] = [int]
+
+ def _call_client(self, cmd_name: str, value: Union[int, str]) -> None:
+ if type(value) not in self._allowed_value_types:
+ raise TypeError(
+ "Value of type `{value_type}` is not allowed for process attribute {attribute_name}".format(
+ value_type=type(value).__name__,
+ attribute_name=self._tracked_attribute.name,
+ )
+ )
+
+ process_attribute_option_name = _get_process_attribute_option_name(
+ self._tracked_attribute
+ )
+ domain_name = _get_domain_option_name(self._domain)
+ self._client._run_cmd(
+ "{cmd_name} --session {session_name} --{domain_name} --{tracked_attribute_name} {value}".format(
+ cmd_name=cmd_name,
+ session_name=self._session.name,
+ domain_name=domain_name,
+ tracked_attribute_name=process_attribute_option_name,
+ value=value,
+ )
+ )
+
+ def track(self, value: Union[int, str]) -> None:
+ self._call_client("track", value)
+
+ def untrack(self, value: Union[int, str]) -> None:
+ self._call_client("untrack", value)
+
+
+class _Session(lttngctl.Session):
+ def __init__(
+ self,
+ client: "LTTngClient",
+ name: str,
+ output: Optional[Type[lttngctl.SessionOutputLocation]],
+ ):
+ self._client: LTTngClient = client
+ self._name: str = name
+ self._output: Optional[Type[lttngctl.SessionOutputLocation]] = output
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ def add_channel(
+ self, domain: lttngctl.TracingDomain, channel_name: Optional[str] = None
+ ) -> lttngctl.Channel:
+ channel_name = lttngctl.Channel._generate_name()
+ domain_option_name = _get_domain_option_name(domain)
+ self._client._run_cmd(
+ "enable-channel --{domain_name} {channel_name}".format(
+ domain_name=domain_option_name, channel_name=channel_name
+ )
+ )
+ return _Channel(self._client, channel_name, domain, self)
+
+ def add_context(self, context_type: lttngctl.ContextType) -> None:
+ pass
+
+ @property
+ def output(self) -> Optional[Type[lttngctl.SessionOutputLocation]]:
+ return self._output
+
+ def start(self) -> None:
+ self._client._run_cmd("start {session_name}".format(session_name=self.name))
+
+ def stop(self) -> None:
+ self._client._run_cmd("stop {session_name}".format(session_name=self.name))
+
+ def destroy(self) -> None:
+ self._client._run_cmd("destroy {session_name}".format(session_name=self.name))
+
+ @property
+ def kernel_pid_process_attribute_tracker(
+ self,
+ ) -> Type[lttngctl.ProcessIDProcessAttributeTracker]:
+ return _ProcessAttributeTracker(self._client, _ProcessAttribute.PID, lttngctl.TracingDomain.Kernel, self) # type: ignore
+
+ @property
+ def kernel_vpid_process_attribute_tracker(
+ self,
+ ) -> Type[lttngctl.VirtualProcessIDProcessAttributeTracker]:
+ return _ProcessAttributeTracker(self._client, _ProcessAttribute.VPID, lttngctl.TracingDomain.Kernel, self) # type: ignore
+
+ @property
+ def user_vpid_process_attribute_tracker(
+ self,
+ ) -> Type[lttngctl.VirtualProcessIDProcessAttributeTracker]:
+ return _ProcessAttributeTracker(self._client, _ProcessAttribute.VPID, lttngctl.TracingDomain.User, self) # type: ignore
+
+ @property
+ def kernel_gid_process_attribute_tracker(
+ self,
+ ) -> Type[lttngctl.GroupIDProcessAttributeTracker]:
+ return _ProcessAttributeTracker(self._client, _ProcessAttribute.GID, lttngctl.TracingDomain.Kernel, self) # type: ignore
+
+ @property
+ def kernel_vgid_process_attribute_tracker(
+ self,
+ ) -> Type[lttngctl.VirtualGroupIDProcessAttributeTracker]:
+ return _ProcessAttributeTracker(self._client, _ProcessAttribute.VGID, lttngctl.TracingDomain.Kernel, self) # type: ignore
+
+ @property
+ def user_vgid_process_attribute_tracker(
+ self,
+ ) -> Type[lttngctl.VirtualGroupIDProcessAttributeTracker]:
+ return _ProcessAttributeTracker(self._client, _ProcessAttribute.VGID, lttngctl.TracingDomain.User, self) # type: ignore
+
+ @property
+ def kernel_uid_process_attribute_tracker(
+ self,
+ ) -> Type[lttngctl.UserIDProcessAttributeTracker]:
+ return _ProcessAttributeTracker(self._client, _ProcessAttribute.UID, lttngctl.TracingDomain.Kernel, self) # type: ignore
+
+ @property
+ def kernel_vuid_process_attribute_tracker(
+ self,
+ ) -> Type[lttngctl.VirtualUserIDProcessAttributeTracker]:
+ return _ProcessAttributeTracker(self._client, _ProcessAttribute.VUID, lttngctl.TracingDomain.Kernel, self) # type: ignore
+
+ @property
+ def user_vuid_process_attribute_tracker(
+ self,
+ ) -> Type[lttngctl.VirtualUserIDProcessAttributeTracker]:
+ return _ProcessAttributeTracker(self._client, _ProcessAttribute.VUID, lttngctl.TracingDomain.User, self) # type: ignore
+
+
+class LTTngClientError(lttngctl.ControlException):
+ def __init__(self, command_args: str, error_output: str):
+ self._command_args: str = command_args
+ self._output: str = error_output
+
+
+class LTTngClient(logger._Logger, lttngctl.Controller):
+ """
+ Implementation of a LTTngCtl Controller that uses the `lttng` client as a back-end.
+ """
+
+ def __init__(
+ self,
+ test_environment: environment._Environment,
+ log: Optional[Callable[[str], None]],
+ ):
+ logger._Logger.__init__(self, log)
+ self._environment: environment._Environment = test_environment
+
+ def _run_cmd(self, command_args: str) -> None:
+ """
+ Invoke the `lttng` client with a set of arguments. The command is
+ executed in the context of the client's test environment.
+ """
+ args: list[str] = [str(self._environment.lttng_client_path)]
+ args.extend(shlex.split(command_args))
+
+ self._log("lttng {command_args}".format(command_args=command_args))
+
+ client_env: dict[str, str] = os.environ.copy()
+ client_env["LTTNG_HOME"] = str(self._environment.lttng_home_location)
+
+ process = subprocess.Popen(
+ args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=client_env
+ )
+
+ out = process.communicate()[0]
+
+ if process.returncode != 0:
+ decoded_output = out.decode("utf-8")
+ for error_line in decoded_output.splitlines():
+ self._log(error_line)
+ raise LTTngClientError(command_args, decoded_output)
+
+ def create_session(
+ self,
+ name: Optional[str] = None,
+ output: Optional[lttngctl.SessionOutputLocation] = None,
+ ) -> lttngctl.Session:
+ name = name if name else lttngctl.Session._generate_name()
+
+ if isinstance(output, lttngctl.LocalSessionOutputLocation):
+ output_option = "--output {output_path}".format(output_path=output.path)
+ elif output is None:
+ output_option = "--no-output"
+ else:
+ raise TypeError("LTTngClient only supports local or no output")
+
+ self._run_cmd(
+ "create {session_name} {output_option}".format(
+ session_name=name, output_option=output_option
+ )
+ )
+ return _Session(self, name, output)
--- /dev/null
+#!/usr/bin/env python3
+#
+# Copyright (C) 2022 Jérémie Galarneau <jeremie.galarneau@efficios.com>
+#
+# SPDX-License-Identifier: GPL-2.0-only
+
+import abc
+import random
+import string
+import pathlib
+import enum
+from typing import Optional, Type, Union, List
+
+"""
+Defines an abstract interface to control LTTng tracing.
+
+The various control concepts are defined by this module. You can use them with a
+Controller to interact with a session daemon.
+
+This interface is not comprehensive; it currently provides a subset of the
+control functionality that is used by tests.
+"""
+
+
+def _generate_random_string(length: int) -> str:
+ return "".join(
+ random.choice(string.ascii_lowercase + string.digits) for _ in range(length)
+ )
+
+
+class ContextType(abc.ABC):
+ """Base class representing a tracing context field."""
+
+ pass
+
+
+class VpidContextType(ContextType):
+ """Application's virtual process id."""
+
+ pass
+
+
+class VuidContextType(ContextType):
+ """Application's virtual user id."""
+
+ pass
+
+
+class VgidContextType(ContextType):
+ """Application's virtual group id."""
+
+ pass
+
+
+class JavaApplicationContextType(ContextType):
+ """A java application-specific context field is a piece of state which the application provides."""
+
+ def __init__(self, retriever_name: str, field_name: str):
+ self._retriever_name: str = retriever_name
+ self._field_name: str = field_name
+
+ @property
+ def retriever_name(self) -> str:
+ return self._retriever_name
+
+ @property
+ def field_name(self) -> str:
+ return self._field_name
+
+
+class TracingDomain(enum.Enum):
+ """Tracing domain."""
+
+ User = enum.auto(), "User space tracing domain"
+ Kernel = enum.auto(), "Linux kernel tracing domain."
+ Log4j = enum.auto(), "Log4j tracing back-end."
+ JUL = enum.auto(), "Java Util Logging tracing back-end."
+ Python = enum.auto(), "Python logging module tracing back-end."
+
+
+class EventRule(abc.ABC):
+ """Event rule base class, see LTTNG-EVENT-RULE(7)."""
+
+ pass
+
+
+class LogLevelRule:
+ pass
+
+
+class LogLevelRuleAsSevereAs(LogLevelRule):
+ def __init__(self, level: int):
+ self._level = level
+
+ @property
+ def level(self) -> int:
+ return self._level
+
+
+class LogLevelRuleExactly(LogLevelRule):
+ def __init__(self, level: int):
+ self._level = level
+
+ @property
+ def level(self) -> int:
+ return self._level
+
+
+class TracepointEventRule(EventRule):
+ def __init__(
+ self,
+ name_pattern: Optional[str] = None,
+ filter_expression: Optional[str] = None,
+ log_level_rule: Optional[LogLevelRule] = None,
+ name_pattern_exclusions: Optional[List[str]] = None,
+ ):
+ self._name_pattern: Optional[str] = name_pattern
+ self._filter_expression: Optional[str] = filter_expression
+ self._log_level_rule: Optional[LogLevelRule] = log_level_rule
+ self._name_pattern_exclusions: Optional[List[str]] = name_pattern_exclusions
+
+ @property
+ def name_pattern(self) -> Optional[str]:
+ return self._name_pattern
+
+ @property
+ def filter_expression(self) -> Optional[str]:
+ return self._filter_expression
+
+ @property
+ def log_level_rule(self) -> Optional[LogLevelRule]:
+ return self._log_level_rule
+
+ @property
+ def name_pattern_exclusions(self) -> Optional[List[str]]:
+ return self._name_pattern_exclusions
+
+
+class UserTracepointEventRule(TracepointEventRule):
+ def __init__(
+ self,
+ name_pattern: Optional[str] = None,
+ filter_expression: Optional[str] = None,
+ log_level_rule: Optional[LogLevelRule] = None,
+ name_pattern_exclusions: Optional[List[str]] = None,
+ ):
+ TracepointEventRule.__init__(**locals())
+
+
+class KernelTracepointEventRule(TracepointEventRule):
+ def __init__(
+ self,
+ name_pattern: Optional[str] = None,
+ filter_expression: Optional[str] = None,
+ log_level_rule: Optional[LogLevelRule] = None,
+ name_pattern_exclusions: Optional[List[str]] = None,
+ ):
+ TracepointEventRule.__init__(**locals())
+
+
+class Channel(abc.ABC):
+ """
+ A channel is an object which is responsible for a set of ring buffers. It is
+ associated to a domain and
+ """
+
+ @staticmethod
+ def _generate_name() -> str:
+ return "channel_{random_id}".format(random_id=_generate_random_string(8))
+
+ @abc.abstractmethod
+ def add_context(self, context_type: ContextType) -> None:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def domain(self) -> TracingDomain:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def name(self) -> str:
+ pass
+
+ @abc.abstractmethod
+ def add_recording_rule(self, rule: Type[EventRule]) -> None:
+ pass
+
+
+class SessionOutputLocation(abc.ABC):
+ pass
+
+
+class LocalSessionOutputLocation(SessionOutputLocation):
+ def __init__(self, trace_path: pathlib.Path):
+ self._path = trace_path
+
+ @property
+ def path(self) -> pathlib.Path:
+ return self._path
+
+
+class ProcessAttributeTracker(abc.ABC):
+ """
+ Process attribute tracker used to filter before the evaluation of event
+ rules.
+
+ Note that this interface is currently limited as it doesn't allow changing
+ the tracking policy. For instance, it is not possible to set the tracking
+ policy back to "all" once it has transitioned to "include set".
+ """
+
+ class TrackingPolicy(enum.Enum):
+ INCLUDE_ALL = (
+ enum.auto(),
+ """
+ Track all possible process attribute value of a given type (i.e. no filtering).
+ This is the default state of a process attribute tracker.
+ """,
+ )
+ EXCLUDE_ALL = (
+ enum.auto(),
+ "Exclude all possible process attribute values of a given type.",
+ )
+ INCLUDE_SET = enum.auto(), "Track a set of specific process attribute values."
+
+ def __init__(self, policy: TrackingPolicy):
+ self._policy = policy
+
+ @property
+ def tracking_policy(self) -> TrackingPolicy:
+ return self._policy
+
+
+class ProcessIDProcessAttributeTracker(ProcessAttributeTracker):
+ @abc.abstractmethod
+ def track(self, pid: int) -> None:
+ pass
+
+ @abc.abstractmethod
+ def untrack(self, pid: int) -> None:
+ pass
+
+
+class VirtualProcessIDProcessAttributeTracker(ProcessAttributeTracker):
+ @abc.abstractmethod
+ def track(self, vpid: int) -> None:
+ pass
+
+ @abc.abstractmethod
+ def untrack(self, vpid: int) -> None:
+ pass
+
+
+class UserIDProcessAttributeTracker(ProcessAttributeTracker):
+ @abc.abstractmethod
+ def track(self, uid: Union[int, str]) -> None:
+ pass
+
+ @abc.abstractmethod
+ def untrack(self, uid: Union[int, str]) -> None:
+ pass
+
+
+class VirtualUserIDProcessAttributeTracker(ProcessAttributeTracker):
+ @abc.abstractmethod
+ def track(self, vuid: Union[int, str]) -> None:
+ pass
+
+ @abc.abstractmethod
+ def untrack(self, vuid: Union[int, str]) -> None:
+ pass
+
+
+class GroupIDProcessAttributeTracker(ProcessAttributeTracker):
+ @abc.abstractmethod
+ def track(self, gid: Union[int, str]) -> None:
+ pass
+
+ @abc.abstractmethod
+ def untrack(self, gid: Union[int, str]) -> None:
+ pass
+
+
+class VirtualGroupIDProcessAttributeTracker(ProcessAttributeTracker):
+ @abc.abstractmethod
+ def track(self, vgid: Union[int, str]) -> None:
+ pass
+
+ @abc.abstractmethod
+ def untrack(self, vgid: Union[int, str]) -> None:
+ pass
+
+
+class Session(abc.ABC):
+ @staticmethod
+ def _generate_name() -> str:
+ return "session_{random_id}".format(random_id=_generate_random_string(8))
+
+ @property
+ @abc.abstractmethod
+ def name(self) -> str:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def output(self) -> Optional[Type[SessionOutputLocation]]:
+ pass
+
+ @abc.abstractmethod
+ def add_channel(
+ self, domain: TracingDomain, channel_name: Optional[str] = None
+ ) -> Channel:
+ """Add a channel with default attributes to the session."""
+ pass
+
+ @abc.abstractmethod
+ def start(self) -> None:
+ pass
+
+ @abc.abstractmethod
+ def stop(self) -> None:
+ pass
+
+ @abc.abstractmethod
+ def destroy(self) -> None:
+ pass
+
+ @abc.abstractproperty
+ def kernel_pid_process_attribute_tracker(
+ self,
+ ) -> Type[ProcessIDProcessAttributeTracker]:
+ raise NotImplementedError
+
+ @abc.abstractproperty
+ def kernel_vpid_process_attribute_tracker(
+ self,
+ ) -> Type[VirtualProcessIDProcessAttributeTracker]:
+ raise NotImplementedError
+
+ @abc.abstractproperty
+ def user_vpid_process_attribute_tracker(
+ self,
+ ) -> Type[VirtualProcessIDProcessAttributeTracker]:
+ raise NotImplementedError
+
+ @abc.abstractproperty
+ def kernel_gid_process_attribute_tracker(
+ self,
+ ) -> Type[GroupIDProcessAttributeTracker]:
+ raise NotImplementedError
+
+ @abc.abstractproperty
+ def kernel_vgid_process_attribute_tracker(
+ self,
+ ) -> Type[VirtualGroupIDProcessAttributeTracker]:
+ raise NotImplementedError
+
+ @abc.abstractproperty
+ def user_vgid_process_attribute_tracker(
+ self,
+ ) -> Type[VirtualGroupIDProcessAttributeTracker]:
+ raise NotImplementedError
+
+ @abc.abstractproperty
+ def kernel_uid_process_attribute_tracker(
+ self,
+ ) -> Type[UserIDProcessAttributeTracker]:
+ raise NotImplementedError
+
+ @abc.abstractproperty
+ def kernel_vuid_process_attribute_tracker(
+ self,
+ ) -> Type[VirtualUserIDProcessAttributeTracker]:
+ raise NotImplementedError
+
+ @abc.abstractproperty
+ def user_vuid_process_attribute_tracker(
+ self,
+ ) -> Type[VirtualUserIDProcessAttributeTracker]:
+ raise NotImplementedError
+
+
+class ControlException(RuntimeError):
+ """Base type for exceptions thrown by a controller."""
+
+ def __init__(self, msg: str):
+ super().__init__(msg)
+
+
+class Controller(abc.ABC):
+ """
+ Interface of a top-level control interface. A control interface can be, for
+ example, the LTTng client or a wrapper around liblttng-ctl. It is used to
+ create and manage top-level objects of a session daemon instance.
+ """
+
+ @abc.abstractmethod
+ def create_session(
+ self, name: Optional[str] = None, output: Optional[SessionOutputLocation] = None
+ ) -> Session:
+ """
+ Create a session with an output. Don't specify an output
+ to create a session without an output.
+ """
+ pass
--- /dev/null
+#!/usr/bin/env python3
+#
+# Copyright (C) 2022 Jérémie Galarneau <jeremie.galarneau@efficios.com>
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+import contextlib
+import sys
+from typing import Optional
+
+
+class InvalidTestPlan(RuntimeError):
+ def __init__(self, msg: str):
+ super().__init__(msg)
+
+
+class BailOut(RuntimeError):
+ def __init__(self, msg: str):
+ super().__init__(msg)
+
+
+class TestCase:
+ def __init__(self, tap_generator: "TapGenerator", description: str):
+ self._tap_generator = tap_generator
+ self._result: Optional[bool] = None
+ self._description = description
+
+ @property
+ def result(self) -> Optional[bool]:
+ return self._result
+
+ @property
+ def description(self) -> str:
+ return self._description
+
+ def _set_result(self, result: bool) -> None:
+ if self._result is not None:
+ raise RuntimeError("Can't set test case result twice")
+
+ self._result = result
+ self._tap_generator.test(result, self._description)
+
+ def success(self) -> None:
+ self._set_result(True)
+
+ def fail(self) -> None:
+ self._set_result(False)
+
+
+# Produces a test execution report in the TAP format.
+class TapGenerator:
+ def __init__(self, total_test_count: int):
+ if total_test_count <= 0:
+ raise ValueError("Test count must be greater than zero")
+
+ self._total_test_count: int = total_test_count
+ self._last_test_case_id: int = 0
+ self._printed_plan: bool = False
+ self._has_failure: bool = False
+
+ def __del__(self):
+ if self.remaining_test_cases > 0:
+ self.bail_out(
+ "Missing {remaining_test_cases} test cases".format(
+ remaining_test_cases=self.remaining_test_cases
+ )
+ )
+
+ @property
+ def remaining_test_cases(self) -> int:
+ return self._total_test_count - self._last_test_case_id
+
+ def _print(self, msg: str) -> None:
+ if not self._printed_plan:
+ print(
+ "1..{total_test_count}".format(total_test_count=self._total_test_count),
+ flush=True,
+ )
+ self._printed_plan = True
+
+ print(msg, flush=True)
+
+ def skip_all(self, reason) -> None:
+ if self._last_test_case_id != 0:
+ raise RuntimeError("Can't skip all tests after running test cases")
+
+ if reason:
+ self._print("1..0 # Skip all: {reason}".format(reason=reason))
+
+ self._last_test_case_id = self._total_test_count
+
+ def skip(self, reason, skip_count: int = 1) -> None:
+ for i in range(skip_count):
+ self._last_test_case_id = self._last_test_case_id + 1
+ self._print(
+ "ok {test_number} # Skip: {reason}".format(
+ reason=reason, test_number=(i + self._last_test_case_id)
+ )
+ )
+
+ def bail_out(self, reason: str) -> None:
+ self._print("Bail out! {reason}".format(reason=reason))
+ self._last_test_case_id = self._total_test_count
+ raise BailOut(reason)
+
+ def test(self, result: bool, description: str) -> None:
+ if self._last_test_case_id == self._total_test_count:
+ raise InvalidTestPlan("Executing too many tests")
+
+ if result is False:
+ self._has_failure = True
+
+ result_string = "ok" if result else "not ok"
+ self._last_test_case_id = self._last_test_case_id + 1
+ self._print(
+ "{result_string} {case_id} - {description}".format(
+ result_string=result_string,
+ case_id=self._last_test_case_id,
+ description=description,
+ )
+ )
+
+ def ok(self, description: str) -> None:
+ self.test(True, description)
+
+ def fail(self, description: str) -> None:
+ self.test(False, description)
+
+ @property
+ def is_successful(self) -> bool:
+ return (
+ self._last_test_case_id == self._total_test_count and not self._has_failure
+ )
+
+ @contextlib.contextmanager
+ def case(self, description: str):
+ test_case = TestCase(self, description)
+ try:
+ yield test_case
+ except Exception as e:
+ self.diagnostic(
+ "Exception `{exception_type}` thrown during test case `{description}`, marking as failure.".format(
+ description=test_case.description, exception_type=type(e).__name__
+ )
+ )
+
+ if str(e) != "":
+ self.diagnostic(str(e))
+
+ test_case.fail()
+ finally:
+ if test_case.result is None:
+ test_case.success()
+
+ def diagnostic(self, msg) -> None:
+ print("# {msg}".format(msg=msg), file=sys.stderr, flush=True)