Source code for schemathesis.stateful.state_machine

from __future__ import annotations

import re
import time
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Any, ClassVar

from hypothesis.errors import InvalidDefinition
from hypothesis.stateful import RuleBasedStateMachine

from .._dependency_versions import HYPOTHESIS_HAS_STATEFUL_NAMING_IMPROVEMENTS
from ..constants import NO_LINKS_ERROR_MESSAGE, NOT_SET
from ..exceptions import UsageError
from ..internal.checks import CheckFunction
from ..models import APIOperation, Case
from .config import _default_hypothesis_settings_factory
from .runner import StatefulTestRunner, StatefulTestRunnerConfig
from .sink import StateMachineSink

if TYPE_CHECKING:
    import hypothesis
    from requests.structures import CaseInsensitiveDict

    from ..schemas import BaseSchema
    from ..transports.responses import GenericResponse
    from .statistic import TransitionStats


@dataclass
class StepResult:
    """Output from a single transition of a state machine."""

    response: GenericResponse
    case: Case
    elapsed: float


def _normalize_name(name: str) -> str:
    return re.sub(r"\W|^(?=\d)", "_", name).replace("__", "_")


[docs]class APIStateMachine(RuleBasedStateMachine): """The base class for state machines generated from API schemas. Exposes additional extension points in the testing process. """ # This is a convenience attribute, which happened to clash with `RuleBasedStateMachine` instance level attribute # They don't interfere, since it is properly overridden on the Hypothesis side, but it is likely that this # attribute will be renamed in the future bundles: ClassVar[dict[str, CaseInsensitiveDict]] # type: ignore schema: BaseSchema # A template for transition statistics that can be filled with data from the state machine during its execution _transition_stats_template: ClassVar[TransitionStats] def __init__(self) -> None: try: super().__init__() # type: ignore except InvalidDefinition as exc: if "defines no rules" in str(exc): raise UsageError(NO_LINKS_ERROR_MESSAGE) from None raise self.setup() @classmethod @lru_cache def _to_test_case(cls) -> type: from . import run_state_machine_as_test class StateMachineTestCase(RuleBasedStateMachine.TestCase): settings = _default_hypothesis_settings_factory() def runTest(self) -> None: run_state_machine_as_test(cls, settings=self.settings) runTest.is_hypothesis_test = True # type: ignore[attr-defined] StateMachineTestCase.__name__ = cls.__name__ + ".TestCase" StateMachineTestCase.__qualname__ = cls.__qualname__ + ".TestCase" return StateMachineTestCase def _pretty_print(self, value: Any) -> str: if isinstance(value, Case): # State machines suppose to be reproducible, hence it is OK to get kwargs here kwargs = self.get_call_kwargs(value) return _print_case(value, kwargs) return super()._pretty_print(value) # type: ignore if HYPOTHESIS_HAS_STATEFUL_NAMING_IMPROVEMENTS: def _new_name(self, target: str) -> str: target = _normalize_name(target) return super()._new_name(target) # type: ignore def _get_target_for_result(self, result: StepResult) -> str | None: raise NotImplementedError def _add_result_to_targets(self, targets: tuple[str, ...], result: StepResult | None) -> None: if result is None: return target = self._get_target_for_result(result) if target is not None: super()._add_result_to_targets((target,), result) @classmethod def format_rules(cls) -> str: raise NotImplementedError @classmethod def run(cls, *, settings: hypothesis.settings | None = None) -> None: """Run state machine as a test.""" from . import run_state_machine_as_test return run_state_machine_as_test(cls, settings=settings) @classmethod def runner(cls, *, config: StatefulTestRunnerConfig | None = None) -> StatefulTestRunner: """Create a runner for this state machine.""" from .runner import StatefulTestRunnerConfig return StatefulTestRunner(cls, config=config or StatefulTestRunnerConfig()) @classmethod def sink(cls) -> StateMachineSink: """Create a sink to collect events into.""" return StateMachineSink(transitions=cls._transition_stats_template.copy())
[docs] def setup(self) -> None: """Hook method that runs unconditionally in the beginning of each test scenario. Does nothing by default. """
[docs] def teardown(self) -> None: pass
# To provide the return type in the rendered documentation teardown.__doc__ = RuleBasedStateMachine.teardown.__doc__ def transform(self, result: StepResult, direction: Direction, case: Case) -> Case: raise NotImplementedError def _step(self, case: Case, previous: StepResult | None = None, link: Direction | None = None) -> StepResult | None: # This method is a proxy that is used under the hood during the state machine initialization. # The whole point of having it is to make it possible to override `step`; otherwise, custom "step" is ignored. # It happens because, at the point of initialization, the final class is not yet created. __tracebackhide__ = True if previous is not None and link is not None: return self.step(case, (previous, link)) return self.step(case, None) def step(self, case: Case, previous: tuple[StepResult, Direction] | None = None) -> StepResult | None: """A single state machine step. :param Case case: Generated test case data that should be sent in an API call to the tested API operation. :param previous: Optional result from the previous step and the direction in which this step should be done. Schemathesis prepares data, makes a call and validates the received response. It is the most high-level point to extend the testing process. You probably don't need it in most cases. """ from ..specs.openapi.checks import use_after_free __tracebackhide__ = True if previous is not None: result, direction = previous case = self.transform(result, direction, case) self.before_call(case) kwargs = self.get_call_kwargs(case) start = time.monotonic() response = self.call(case, **kwargs) elapsed = time.monotonic() - start self.after_call(response, case) self.validate_response(response, case, additional_checks=(use_after_free,)) return self.store_result(response, case, elapsed)
[docs] def before_call(self, case: Case) -> None: """Hook method for modifying the case data before making a request. :param Case case: Generated test case data that should be sent in an API call to the tested API operation. Use it if you want to inject static data, for example, a query parameter that should always be used in API calls: .. code-block:: python class APIWorkflow(schema.as_state_machine()): def before_call(self, case): case.query = case.query or {} case.query["test"] = "true" You can also modify data only for some operations: .. code-block:: python class APIWorkflow(schema.as_state_machine()): def before_call(self, case): if case.method == "PUT" and case.path == "/items": case.body["is_fake"] = True """
[docs] def after_call(self, response: GenericResponse, case: Case) -> None: """Hook method for additional actions with case or response instances. :param response: Response from the application under test. :param Case case: Generated test case data that should be sent in an API call to the tested API operation. For example, you can log all response statuses by using this hook: .. code-block:: python import logging logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) class APIWorkflow(schema.as_state_machine()): def after_call(self, response, case): logger.info( "%s %s -> %d", case.method, case.path, response.status_code, ) # POST /users/ -> 201 # GET /users/{user_id} -> 200 # PATCH /users/{user_id} -> 200 # GET /users/{user_id} -> 200 # PATCH /users/{user_id} -> 500 """
[docs] def call(self, case: Case, **kwargs: Any) -> GenericResponse: """Make a request to the API. :param Case case: Generated test case data that should be sent in an API call to the tested API operation. :param kwargs: Keyword arguments that will be passed to the appropriate ``case.call_*`` method. :return: Response from the application under test. Note that WSGI/ASGI applications are detected automatically in this method. Depending on the result of this detection the state machine will call the ``call`` method. Usually, you don't need to override this method unless you are building a different state machine on top of this one and want to customize the transport layer itself. """ return case.call(**kwargs)
[docs] def get_call_kwargs(self, case: Case) -> dict[str, Any]: """Create custom keyword arguments that will be passed to the :meth:`Case.call` method. Mostly they are proxied to the :func:`requests.request` call. :param Case case: Generated test case data that should be sent in an API call to the tested API operation. .. code-block:: python class APIWorkflow(schema.as_state_machine()): def get_call_kwargs(self, case): return {"verify": False} The above example disables the server's TLS certificate verification. """ return {}
[docs] def validate_response( self, response: GenericResponse, case: Case, additional_checks: tuple[CheckFunction, ...] = () ) -> None: """Validate an API response. :param response: Response from the application under test. :param Case case: Generated test case data that should be sent in an API call to the tested API operation. :param additional_checks: A list of checks that will be run together with the default ones. :raises CheckFailed: If any of the supplied checks failed. If you need to change the default checks or provide custom validation rules, you can do it here. .. code-block:: python def my_check(response, case): ... # some assertions class APIWorkflow(schema.as_state_machine()): def validate_response(self, response, case): case.validate_response(response, checks=(my_check,)) The state machine from the example above will execute only the ``my_check`` check instead of all available checks. Each check function should accept ``response`` as the first argument and ``case`` as the second one and raise ``AssertionError`` if the check fails. **Note** that it is preferred to pass check functions as an argument to ``case.validate_response``. In this case, all checks will be executed, and you'll receive a grouped exception that contains results from all provided checks rather than only the first encountered exception. """ __tracebackhide__ = True case.validate_response(response, additional_checks=additional_checks)
def store_result(self, response: GenericResponse, case: Case, elapsed: float) -> StepResult: return StepResult(response, case, elapsed)
def _print_case(case: Case, kwargs: dict[str, Any]) -> str: from requests.structures import CaseInsensitiveDict operation = f"state.schema['{case.operation.path}']['{case.operation.method.upper()}']" headers = case.headers or CaseInsensitiveDict() headers.update(kwargs.get("headers", {})) case.headers = headers data = [ f"{name}={getattr(case, name)!r}" for name in ("path_parameters", "headers", "cookies", "query", "body", "media_type") if getattr(case, name) not in (None, NOT_SET) ] return f"{operation}.make_case({', '.join(data)})" class Direction: name: str status_code: str operation: APIOperation def set_data(self, case: Case, elapsed: float, **kwargs: Any) -> None: raise NotImplementedError