from __future__ import annotations
import datetime
import inspect
import textwrap
from collections import Counter
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from functools import lru_cache, partial
from itertools import chain
from logging import LogRecord
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generator,
Generic,
Iterator,
NoReturn,
Optional,
Sequence,
Type,
TypeVar,
cast,
)
from urllib.parse import quote, unquote, urljoin, urlparse, urlsplit, urlunsplit
from . import serializers
from ._dependency_versions import IS_WERKZEUG_ABOVE_3
from .auths import AuthStorage
from .code_samples import CodeSampleStyle
from .constants import (
NOT_SET,
SCHEMATHESIS_TEST_CASE_HEADER,
SERIALIZERS_SUGGESTION_MESSAGE,
USER_AGENT,
)
from .exceptions import (
CheckFailed,
FailureContext,
OperationSchemaError,
SerializationNotPossible,
SkipTest,
deduplicate_failed_checks,
get_grouped_exception,
maybe_set_assertion_message,
prepare_response_payload,
)
from .generation import DataGenerationMethod, GenerationConfig, generate_random_case_id
from .hooks import GLOBAL_HOOK_DISPATCHER, HookContext, HookDispatcher, dispatch
from .internal.copy import fast_deepcopy
from .internal.deprecation import deprecated_property, deprecated_function
from .parameters import Parameter, ParameterSet, PayloadAlternatives
from .sanitization import sanitize_request, sanitize_response
from .serializers import Serializer
from .transports import ASGITransport, RequestsTransport, WSGITransport, serialize_payload
from .types import Body, Cookies, FormData, Headers, NotSet, PathParameters, Query
if TYPE_CHECKING:
import unittest
import requests.auth
import werkzeug
from hypothesis import strategies as st
from requests.structures import CaseInsensitiveDict
from .schemas import BaseSchema
from .stateful import Stateful, StatefulTest
from .transports.responses import GenericResponse, WSGIResponse
@dataclass
class CaseSource:
"""Data sources, used to generate a test case."""
case: Case
response: GenericResponse
elapsed: float
def partial_deepcopy(self) -> CaseSource:
return self.__class__(case=self.case.partial_deepcopy(), response=self.response, elapsed=self.elapsed)
def cant_serialize(media_type: str) -> NoReturn: # type: ignore
"""Reject the current example if we don't know how to send this data to the application."""
from hypothesis import event, note, reject
event_text = f"Can't serialize data to `{media_type}`."
note(f"{event_text} {SERIALIZERS_SUGGESTION_MESSAGE}")
event(event_text)
reject() # type: ignore
@lru_cache
def get_request_signature() -> inspect.Signature:
import requests
return inspect.signature(requests.Request)
@dataclass()
class PreparedRequestData:
method: str
url: str
body: str | bytes | None
headers: Headers
def prepare_request_data(kwargs: dict[str, Any]) -> PreparedRequestData:
"""Prepare request data for generating code samples."""
import requests
kwargs = {key: value for key, value in kwargs.items() if key in get_request_signature().parameters}
request = requests.Request(**kwargs).prepare()
return PreparedRequestData(
method=str(request.method), url=str(request.url), body=request.body, headers=dict(request.headers)
)
@dataclass(repr=False)
class Case:
"""A single test case parameters."""
operation: APIOperation
# Time spent on generation of this test case
generation_time: float
# Unique test case identifier
id: str = field(default_factory=generate_random_case_id, compare=False)
path_parameters: PathParameters | None = None
headers: CaseInsensitiveDict | None = None
cookies: Cookies | None = None
query: Query | None = None
# By default, there is no body, but we can't use `None` as the default value because it clashes with `null`
# which is a valid payload.
body: Body | NotSet = NOT_SET
# The media type for cases with a payload. For example, "application/json"
media_type: str | None = None
source: CaseSource | None = None
# The way the case was generated (None for manually crafted ones)
data_generation_method: DataGenerationMethod | None = None
_auth: requests.auth.AuthBase | None = None
def __repr__(self) -> str:
parts = [f"{self.__class__.__name__}("]
first = True
for name in ("path_parameters", "headers", "cookies", "query", "body"):
value = getattr(self, name)
if value is not None and not isinstance(value, NotSet):
if first:
first = False
else:
parts.append(", ")
parts.extend((name, "=", repr(value)))
return "".join(parts) + ")"
def __hash__(self) -> int:
return hash(self.as_curl_command({SCHEMATHESIS_TEST_CASE_HEADER: "0"}))
@deprecated_property(removed_in="4.0", replacement="operation")
def endpoint(self) -> APIOperation:
return self.operation
@property
def path(self) -> str:
return self.operation.path
@property
def full_path(self) -> str:
return self.operation.full_path
@property
def method(self) -> str:
return self.operation.method.upper()
@property
def base_url(self) -> str | None:
return self.operation.base_url
@property
def app(self) -> Any:
return self.operation.app
def set_source(self, response: GenericResponse, case: Case, elapsed: float) -> None:
self.source = CaseSource(case=case, response=response, elapsed=elapsed)
@property
def formatted_path(self) -> str:
try:
return self.path.format(**self.path_parameters or {})
except KeyError as exc:
# This may happen when a path template has a placeholder for variable "X", but parameter "X" is not defined
# in the parameters list.
# When `exc` is formatted, it is the missing key name in quotes. E.g. 'id'
raise OperationSchemaError(f"Path parameter {exc} is not defined") from exc
except (IndexError, ValueError) as exc:
# A single unmatched `}` inside the path template may cause this
raise OperationSchemaError(f"Malformed path template: `{self.path}`\n\n {exc}") from exc
def get_full_base_url(self) -> str | None:
"""Create a full base url, adding "localhost" for WSGI apps."""
parts = urlsplit(self.base_url)
if not parts.hostname:
path = cast(str, parts.path or "")
return urlunsplit(("http", "localhost", path or "", "", ""))
return self.base_url
def prepare_code_sample_data(self, headers: dict[str, Any] | None) -> PreparedRequestData:
base_url = self.get_full_base_url()
kwargs = RequestsTransport().serialize_case(self, base_url=base_url, headers=headers)
return prepare_request_data(kwargs)
def get_code_to_reproduce(
self,
headers: dict[str, Any] | None = None,
request: requests.PreparedRequest | None = None,
verify: bool = True,
) -> str:
"""Construct a Python code to reproduce this case with `requests`."""
if request is not None:
request_data = prepare_request_data(
{
"method": request.method,
"url": request.url,
"headers": request.headers,
"data": request.body,
}
)
else:
request_data = self.prepare_code_sample_data(headers)
return CodeSampleStyle.python.generate(
method=request_data.method,
url=request_data.url,
body=request_data.body,
headers=dict(self.headers) if self.headers is not None else None,
verify=verify,
extra_headers=request_data.headers,
)
def as_curl_command(self, headers: dict[str, Any] | None = None, verify: bool = True) -> str:
"""Construct a curl command for a given case."""
request_data = self.prepare_code_sample_data(headers)
return CodeSampleStyle.curl.generate(
method=request_data.method,
url=request_data.url,
body=request_data.body,
headers=dict(self.headers) if self.headers is not None else None,
verify=verify,
extra_headers=request_data.headers,
)
def _get_base_url(self, base_url: str | None = None) -> str:
if base_url is None:
if self.base_url is not None:
base_url = self.base_url
else:
raise ValueError(
"Base URL is required as `base_url` argument in `call` or should be specified "
"in the schema constructor as a part of Schema URL."
)
return base_url
def _get_headers(self, headers: dict[str, str] | None = None) -> CaseInsensitiveDict:
from requests.structures import CaseInsensitiveDict
final_headers = self.headers.copy() if self.headers is not None else CaseInsensitiveDict()
if headers:
final_headers.update(headers)
final_headers.setdefault("User-Agent", USER_AGENT)
final_headers.setdefault(SCHEMATHESIS_TEST_CASE_HEADER, self.id)
return final_headers
def _get_serializer(self) -> Serializer | None:
"""Get a serializer for the payload, if there is any."""
if self.media_type is not None:
media_type = serializers.get_first_matching_media_type(self.media_type)
if media_type is None:
# This media type is set manually. Otherwise, it should have been rejected during the data generation
raise SerializationNotPossible.for_media_type(self.media_type)
# SAFETY: It is safe to assume that serializer will be found, because `media_type` returned above
# is registered. This intentionally ignores cases with concurrent serializers registry modification.
cls = cast(Type[serializers.Serializer], serializers.get(media_type))
return cls()
return None
def _get_body(self) -> Body | NotSet:
return self.body
@deprecated_function(removed_in="4.0", replacement="Case.as_transport_kwargs")
def as_requests_kwargs(self, base_url: str | None = None, headers: dict[str, str] | None = None) -> dict[str, Any]:
"""Convert the case into a dictionary acceptable by requests."""
return RequestsTransport().serialize_case(self, base_url=base_url, headers=headers)
def as_transport_kwargs(self, base_url: str | None = None, headers: dict[str, str] | None = None) -> dict[str, Any]:
"""Convert the test case into a dictionary acceptable by the underlying transport call."""
return self.operation.schema.transport.serialize_case(self, base_url=base_url, headers=headers)
def call(
self,
base_url: str | None = None,
session: requests.Session | None = None,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
cookies: dict[str, Any] | None = None,
**kwargs: Any,
) -> GenericResponse:
hook_context = HookContext(operation=self.operation)
dispatch("before_call", hook_context, self)
response = self.operation.schema.transport.send(
self, session=session, base_url=base_url, headers=headers, params=params, cookies=cookies, **kwargs
)
dispatch("after_call", hook_context, self, response)
return response
@deprecated_function(removed_in="4.0", replacement="Case.as_transport_kwargs")
def as_werkzeug_kwargs(self, headers: dict[str, str] | None = None) -> dict[str, Any]:
"""Convert the case into a dictionary acceptable by werkzeug.Client."""
return WSGITransport(self.app).serialize_case(self, headers=headers)
@deprecated_function(removed_in="4.0", replacement="Case.call")
def call_wsgi(
self,
app: Any = None,
headers: dict[str, str] | None = None,
query_string: dict[str, str] | None = None,
**kwargs: Any,
) -> WSGIResponse:
application = app or self.app
if application is None:
raise RuntimeError(
"WSGI application instance is required. "
"Please, set `app` argument in the schema constructor or pass it to `call_wsgi`"
)
hook_context = HookContext(operation=self.operation)
dispatch("before_call", hook_context, self)
response = WSGITransport(application).send(self, headers=headers, params=query_string, **kwargs)
dispatch("after_call", hook_context, self, response)
return response
@deprecated_function(removed_in="4.0", replacement="Case.call")
def call_asgi(
self,
app: Any = None,
base_url: str | None = None,
headers: dict[str, str] | None = None,
**kwargs: Any,
) -> requests.Response:
application = app or self.app
if application is None:
raise RuntimeError(
"ASGI application instance is required. "
"Please, set `app` argument in the schema constructor or pass it to `call_asgi`"
)
hook_context = HookContext(operation=self.operation)
dispatch("before_call", hook_context, self)
response = ASGITransport(application).send(self, base_url=base_url, headers=headers, **kwargs)
dispatch("after_call", hook_context, self, response)
return response
def validate_response(
self,
response: GenericResponse,
checks: tuple[CheckFunction, ...] = (),
additional_checks: tuple[CheckFunction, ...] = (),
excluded_checks: tuple[CheckFunction, ...] = (),
code_sample_style: str | None = None,
) -> None:
"""Validate application response.
By default, all available checks will be applied.
:param response: Application response.
:param checks: A tuple of check functions that accept ``response`` and ``case``.
:param additional_checks: A tuple of additional checks that will be executed after ones from the ``checks``
argument.
:param excluded_checks: Checks excluded from the default ones.
:param code_sample_style: Controls the style of code samples for failure reproduction.
"""
__tracebackhide__ = True
from .checks import ALL_CHECKS
from .transports.responses import get_payload, get_reason
checks = checks or ALL_CHECKS
checks = tuple(check for check in checks if check not in excluded_checks)
additional_checks = tuple(check for check in additional_checks if check not in excluded_checks)
failed_checks = []
for check in chain(checks, additional_checks):
copied_case = self.partial_deepcopy()
try:
check(response, copied_case)
except AssertionError as exc:
maybe_set_assertion_message(exc, check.__name__)
failed_checks.append(exc)
failed_checks = list(deduplicate_failed_checks(failed_checks))
if failed_checks:
exception_cls = get_grouped_exception(self.operation.verbose_name, *failed_checks)
formatted = ""
for idx, failed in enumerate(failed_checks, 1):
if isinstance(failed, CheckFailed) and failed.context is not None:
title = failed.context.title
if failed.context.message:
message = failed.context.message
else:
message = None
else:
title, message = failed.args
formatted += "\n\n"
formatted += f"{idx}. {title}"
if message is not None:
formatted += "\n\n"
formatted += textwrap.indent(message, prefix=" ")
status_code = response.status_code
reason = get_reason(status_code)
formatted += f"\n\n[{response.status_code}] {reason}:"
payload = get_payload(response)
if not payload:
formatted += "\n\n <EMPTY>"
else:
payload = prepare_response_payload(payload)
payload = textwrap.indent(f"\n`{payload}`", prefix=" ")
formatted += f"\n{payload}"
code_sample_style = (
CodeSampleStyle.from_str(code_sample_style)
if code_sample_style is not None
else self.operation.schema.code_sample_style
)
verify = getattr(response, "verify", True)
if self.operation.schema.sanitize_output:
sanitize_request(response.request)
sanitize_response(response)
code_message = self._get_code_message(code_sample_style, response.request, verify=verify)
raise exception_cls(
f"{formatted}\n\n" f"{code_message}",
causes=tuple(failed_checks),
)
def _get_code_message(
self, code_sample_style: CodeSampleStyle, request: requests.PreparedRequest, verify: bool
) -> str:
if code_sample_style == CodeSampleStyle.python:
code = self.get_code_to_reproduce(request=request, verify=verify)
elif code_sample_style == CodeSampleStyle.curl:
code = self.as_curl_command(headers=dict(request.headers), verify=verify)
else:
raise ValueError(f"Unknown code sample style: {code_sample_style.name}")
return f"Reproduce with: \n\n {code}\n"
def call_and_validate(
self,
base_url: str | None = None,
session: requests.Session | None = None,
headers: dict[str, Any] | None = None,
checks: tuple[CheckFunction, ...] = (),
code_sample_style: str | None = None,
**kwargs: Any,
) -> requests.Response:
__tracebackhide__ = True
response = self.call(base_url, session, headers, **kwargs)
self.validate_response(response, checks, code_sample_style=code_sample_style)
return response
def _get_url(self, base_url: str | None) -> str:
base_url = self._get_base_url(base_url)
formatted_path = self.formatted_path.lstrip("/")
if not base_url.endswith("/"):
base_url += "/"
return unquote(urljoin(base_url, quote(formatted_path)))
def get_full_url(self) -> str:
"""Make a full URL to the current API operation, including query parameters."""
import requests
base_url = self.base_url or "http://127.0.0.1"
kwargs = RequestsTransport().serialize_case(self, base_url=base_url)
request = requests.Request(**kwargs)
prepared = requests.Session().prepare_request(request) # type: ignore
return cast(str, prepared.url)
def partial_deepcopy(self) -> Case:
return self.__class__(
operation=self.operation.partial_deepcopy(),
data_generation_method=self.data_generation_method,
media_type=self.media_type,
source=self.source if self.source is None else self.source.partial_deepcopy(),
path_parameters=fast_deepcopy(self.path_parameters),
headers=fast_deepcopy(self.headers),
cookies=fast_deepcopy(self.cookies),
query=fast_deepcopy(self.query),
body=fast_deepcopy(self.body),
generation_time=self.generation_time,
)
def _merge_dict_to(data: dict[str, Any], data_key: str, new: dict[str, Any]) -> None:
original = data[data_key] or {}
for key, value in new.items():
original[key] = value
data[data_key] = original
def validate_vanilla_requests_kwargs(data: dict[str, Any]) -> None:
"""Check arguments for `requests.Session.request`.
Some arguments can be valid for cases like ASGI integration, but at the same time they won't work for the regular
`requests` calls. In such cases we need to avoid an obscure error message, that comes from `requests`.
"""
url = data["url"]
if not urlparse(url).netloc:
raise RuntimeError(
"The URL should be absolute, so Schemathesis knows where to send the data. \n"
f"If you use the ASGI integration, please supply your test client "
f"as the `session` argument to `call`.\nURL: {url}"
)
@contextmanager
def cookie_handler(client: werkzeug.Client, cookies: Cookies | None) -> Generator[None, None, None]:
"""Set cookies required for a call."""
if not cookies:
yield
else:
for key, value in cookies.items():
if IS_WERKZEUG_ABOVE_3:
client.set_cookie(key=key, value=value, domain="localhost")
else:
client.set_cookie("localhost", key=key, value=value)
yield
for key in cookies:
if IS_WERKZEUG_ABOVE_3:
client.delete_cookie(key=key, domain="localhost")
else:
client.delete_cookie("localhost", key=key)
P = TypeVar("P", bound=Parameter)
D = TypeVar("D", bound=dict)
@dataclass
class OperationDefinition(Generic[P, D]):
"""A wrapper to store not resolved API operation definitions.
To prevent recursion errors we need to store definitions without resolving references. But operation definitions
itself can be behind a reference (when there is a ``$ref`` in ``paths`` values), therefore we need to store this
scope change to have a proper reference resolving later.
"""
raw: D
resolved: D
scope: str
parameters: Sequence[P]
def __contains__(self, item: str | int) -> bool:
return item in self.resolved
def __getitem__(self, item: str | int) -> None | bool | float | str | list | dict[str, Any]:
return self.resolved[item]
def get(self, item: str | int, default: Any = None) -> None | bool | float | str | list | dict[str, Any]:
return self.resolved.get(item, default)
C = TypeVar("C", bound=Case)
[docs]@dataclass(eq=False)
class APIOperation(Generic[P, C]):
"""A single operation defined in an API.
You can get one via a ``schema`` instance.
.. code-block:: python
# Get the POST /items operation
operation = schema["/items"]["POST"]
"""
# `path` does not contain `basePath`
# Example <scheme>://<host>/<basePath>/users - "/users" is path
# https://swagger.io/docs/specification/2-0/api-host-and-base-path/
path: str
method: str
definition: OperationDefinition = field(repr=False)
schema: BaseSchema
verbose_name: str = None # type: ignore
app: Any = None
base_url: str | None = None
path_parameters: ParameterSet[P] = field(default_factory=ParameterSet)
headers: ParameterSet[P] = field(default_factory=ParameterSet)
cookies: ParameterSet[P] = field(default_factory=ParameterSet)
query: ParameterSet[P] = field(default_factory=ParameterSet)
body: PayloadAlternatives[P] = field(default_factory=PayloadAlternatives)
case_cls: type[C] = Case # type: ignore
def __post_init__(self) -> None:
if self.verbose_name is None:
self.verbose_name = f"{self.method.upper()} {self.full_path}" # type: ignore
@property
def full_path(self) -> str:
return self.schema.get_full_path(self.path)
@property
def links(self) -> dict[str, dict[str, Any]]:
return self.schema.get_links(self)
@property
def tags(self) -> list[str] | None:
return self.schema.get_tags(self)
def iter_parameters(self) -> Iterator[P]:
"""Iterate over all operation's parameters."""
return chain(self.path_parameters, self.headers, self.cookies, self.query)
def _lookup_container(self, location: str) -> ParameterSet[P] | PayloadAlternatives[P] | None:
return {
"path": self.path_parameters,
"header": self.headers,
"cookie": self.cookies,
"query": self.query,
"body": self.body,
}.get(location)
def add_parameter(self, parameter: P) -> None:
"""Add a new processed parameter to an API operation.
:param parameter: A parameter that will be used with this operation.
:rtype: None
"""
# If the parameter has a typo, then by default, there will be an error from `jsonschema` earlier.
# But if the user wants to skip schema validation, we choose to ignore a malformed parameter.
# In this case, we still might generate some tests for an API operation, but without this parameter,
# which is better than skip the whole operation from testing.
container = self._lookup_container(parameter.location)
if container is not None:
container.add(parameter)
def get_parameter(self, name: str, location: str) -> P | None:
container = self._lookup_container(location)
if container is not None:
return container.get(name)
return None
[docs] def as_strategy(
self,
hooks: HookDispatcher | None = None,
auth_storage: AuthStorage | None = None,
data_generation_method: DataGenerationMethod = DataGenerationMethod.default(),
generation_config: GenerationConfig | None = None,
**kwargs: Any,
) -> st.SearchStrategy:
"""Turn this API operation into a Hypothesis strategy."""
strategy = self.schema.get_case_strategy(
self, hooks, auth_storage, data_generation_method, generation_config=generation_config, **kwargs
)
def _apply_hooks(dispatcher: HookDispatcher, _strategy: st.SearchStrategy[Case]) -> st.SearchStrategy[Case]:
context = HookContext(self)
for hook in dispatcher.get_all_by_name("before_generate_case"):
_strategy = hook(context, _strategy)
for hook in dispatcher.get_all_by_name("filter_case"):
hook = partial(hook, context)
_strategy = _strategy.filter(hook)
for hook in dispatcher.get_all_by_name("map_case"):
hook = partial(hook, context)
_strategy = _strategy.map(hook)
for hook in dispatcher.get_all_by_name("flatmap_case"):
hook = partial(hook, context)
_strategy = _strategy.flatmap(hook)
return _strategy
strategy = _apply_hooks(GLOBAL_HOOK_DISPATCHER, strategy)
strategy = _apply_hooks(self.schema.hooks, strategy)
if hooks is not None:
strategy = _apply_hooks(hooks, strategy)
return strategy
def get_security_requirements(self) -> list[str]:
return self.schema.get_security_requirements(self)
def get_strategies_from_examples(self) -> list[st.SearchStrategy[Case]]:
"""Get examples from the API operation."""
return self.schema.get_strategies_from_examples(self)
def get_stateful_tests(self, response: GenericResponse, stateful: Stateful | None) -> Sequence[StatefulTest]:
return self.schema.get_stateful_tests(response, self, stateful)
def get_parameter_serializer(self, location: str) -> Callable | None:
"""Get a function that serializes parameters for the given location.
It handles serializing data into various `collectionFormat` options and similar.
Note that payload is handled by this function - it is handled by serializers.
"""
return self.schema.get_parameter_serializer(self, location)
def prepare_multipart(self, form_data: FormData) -> tuple[list | None, dict[str, Any] | None]:
return self.schema.prepare_multipart(form_data, self)
def get_request_payload_content_types(self) -> list[str]:
return self.schema.get_request_payload_content_types(self)
def partial_deepcopy(self) -> APIOperation:
return self.__class__(
path=self.path, # string, immutable
method=self.method, # string, immutable
definition=fast_deepcopy(self.definition),
schema=self.schema.clone(), # shallow copy
verbose_name=self.verbose_name, # string, immutable
app=self.app, # not deepcopyable
base_url=self.base_url, # string, immutable
path_parameters=fast_deepcopy(self.path_parameters),
headers=fast_deepcopy(self.headers),
cookies=fast_deepcopy(self.cookies),
query=fast_deepcopy(self.query),
body=fast_deepcopy(self.body),
)
def clone(self, **components: Any) -> APIOperation:
"""Create a new instance of this API operation with updated components."""
return self.__class__(
path=self.path,
method=self.method,
verbose_name=self.verbose_name,
definition=self.definition,
schema=self.schema,
app=self.app,
base_url=self.base_url,
path_parameters=components["path_parameters"],
query=components["query"],
headers=components["headers"],
cookies=components["cookies"],
body=components["body"],
)
[docs] def make_case(
self,
*,
path_parameters: PathParameters | None = None,
headers: Headers | None = None,
cookies: Cookies | None = None,
query: Query | None = None,
body: Body | NotSet = NOT_SET,
media_type: str | None = None,
) -> C:
"""Create a new example for this API operation.
The main use case is constructing Case instances completely manually, without data generation.
"""
return self.schema.make_case(
case_cls=self.case_cls,
operation=self,
path_parameters=path_parameters,
headers=headers,
cookies=cookies,
query=query,
body=body,
media_type=media_type,
)
@property
def operation_reference(self) -> str:
path = self.path.replace("~", "~0").replace("/", "~1")
return f"#/paths/{path}/{self.method}"
[docs] def validate_response(self, response: GenericResponse) -> bool | None:
"""Validate API response for conformance.
:raises CheckFailed: If the response does not conform to the API schema.
"""
return self.schema.validate_response(self, response)
[docs] def is_response_valid(self, response: GenericResponse) -> bool:
"""Validate API response for conformance."""
try:
self.validate_response(response)
return True
except CheckFailed:
return False
def get_raw_payload_schema(self, media_type: str) -> dict[str, Any] | None:
return self.schema._get_payload_schema(self.definition.raw, media_type)
def get_resolved_payload_schema(self, media_type: str) -> dict[str, Any] | None:
return self.schema._get_payload_schema(self.definition.resolved, media_type)
# backward-compatibility
Endpoint = APIOperation
class Status(str, Enum):
"""Status of an action or multiple actions."""
success = "success"
failure = "failure"
error = "error"
skip = "skip"
@dataclass(repr=False)
class Check:
"""Single check run result."""
name: str
value: Status
response: GenericResponse | None
elapsed: float
example: Case
message: str | None = None
# Failure-specific context
context: FailureContext | None = None
request: requests.PreparedRequest | None = None
@dataclass(repr=False)
class Request:
"""Request data extracted from `Case`."""
method: str
uri: str
body: str | None
headers: Headers
@classmethod
def from_case(cls, case: Case, session: requests.Session) -> Request:
"""Create a new `Request` instance from `Case`."""
import requests
base_url = case.get_full_base_url()
kwargs = RequestsTransport().serialize_case(case, base_url=base_url)
request = requests.Request(**kwargs)
prepared = session.prepare_request(request) # type: ignore
return cls.from_prepared_request(prepared)
@classmethod
def from_prepared_request(cls, prepared: requests.PreparedRequest) -> Request:
"""A prepared request version is already stored in `requests.Response`."""
body = prepared.body
if isinstance(body, str):
# can be a string for `application/x-www-form-urlencoded`
body = body.encode("utf-8")
# these values have `str` type at this point
uri = cast(str, prepared.url)
method = cast(str, prepared.method)
return cls(
uri=uri,
method=method,
headers={key: [value] for (key, value) in prepared.headers.items()},
body=serialize_payload(body) if body is not None else body,
)
@dataclass(repr=False)
class Response:
"""Unified response data."""
status_code: int
message: str
headers: dict[str, list[str]]
body: str | None
encoding: str | None
http_version: str
elapsed: float
verify: bool
@classmethod
def from_requests(cls, response: requests.Response) -> Response:
"""Create a response from requests.Response."""
raw = response.raw
raw_headers = raw.headers if raw is not None else {}
headers = {name: response.raw.headers.getlist(name) for name in raw_headers.keys()}
# Similar to http.client:319 (HTTP version detection in stdlib's `http` package)
version = raw.version if raw is not None else 10
http_version = "1.0" if version == 10 else "1.1"
def is_empty(_response: requests.Response) -> bool:
# Assume the response is empty if:
# - no `Content-Length` header
# - no chunks when iterating over its content
return "Content-Length" not in headers and list(_response.iter_content()) == []
body = None if is_empty(response) else serialize_payload(response.content)
return cls(
status_code=response.status_code,
message=response.reason,
body=body,
encoding=response.encoding,
headers=headers,
http_version=http_version,
elapsed=response.elapsed.total_seconds(),
verify=getattr(response, "verify", True),
)
@classmethod
def from_wsgi(cls, response: WSGIResponse, elapsed: float) -> Response:
"""Create a response from WSGI response."""
from .transports.responses import get_reason
message = get_reason(response.status_code)
headers = {name: response.headers.getlist(name) for name in response.headers.keys()}
# Note, this call ensures that `response.response` is a sequence, which is needed for comparison
data = response.get_data()
body = None if response.response == [] else serialize_payload(data)
encoding: str | None
if body is not None:
# Werkzeug <3.0 had `charset` attr, newer versions always have UTF-8
encoding = response.mimetype_params.get("charset", getattr(response, "charset", "utf-8"))
else:
encoding = None
return cls(
status_code=response.status_code,
message=message,
body=body,
encoding=encoding,
headers=headers,
http_version="1.1",
elapsed=elapsed,
verify=True,
)
@dataclass
class Interaction:
"""A single interaction with the target app."""
request: Request
response: Response
checks: list[Check]
status: Status
data_generation_method: DataGenerationMethod
recorded_at: str = field(default_factory=lambda: datetime.datetime.now().isoformat())
@classmethod
def from_requests(cls, case: Case, response: requests.Response, status: Status, checks: list[Check]) -> Interaction:
return cls(
request=Request.from_prepared_request(response.request),
response=Response.from_requests(response),
status=status,
checks=checks,
data_generation_method=cast(DataGenerationMethod, case.data_generation_method),
)
@classmethod
def from_wsgi(
cls,
case: Case,
response: WSGIResponse,
headers: dict[str, Any],
elapsed: float,
status: Status,
checks: list[Check],
) -> Interaction:
import requests
session = requests.Session()
session.headers.update(headers)
return cls(
request=Request.from_case(case, session),
response=Response.from_wsgi(response, elapsed),
status=status,
checks=checks,
data_generation_method=cast(DataGenerationMethod, case.data_generation_method),
)
@dataclass(repr=False)
class TestResult:
"""Result of a single test."""
__test__ = False
method: str
path: str
verbose_name: str
data_generation_method: list[DataGenerationMethod]
checks: list[Check] = field(default_factory=list)
errors: list[Exception] = field(default_factory=list)
interactions: list[Interaction] = field(default_factory=list)
logs: list[LogRecord] = field(default_factory=list)
is_errored: bool = False
is_flaky: bool = False
is_skipped: bool = False
skip_reason: str | None = None
is_executed: bool = False
# DEPRECATED: Seed is the same per test run
seed: int | None = None
def mark_errored(self) -> None:
self.is_errored = True
def mark_flaky(self) -> None:
self.is_flaky = True
def mark_skipped(self, exc: SkipTest | unittest.case.SkipTest | None) -> None:
self.is_skipped = True
if exc is not None:
self.skip_reason = str(exc)
def mark_executed(self) -> None:
self.is_executed = True
@property
def has_errors(self) -> bool:
return bool(self.errors)
@property
def has_failures(self) -> bool:
return any(check.value == Status.failure for check in self.checks)
@property
def has_logs(self) -> bool:
return bool(self.logs)
def add_success(self, name: str, example: Case, response: GenericResponse, elapsed: float) -> Check:
check = Check(
name=name, value=Status.success, response=response, elapsed=elapsed, example=example, request=None
)
self.checks.append(check)
return check
def add_failure(
self,
name: str,
example: Case,
response: GenericResponse | None,
elapsed: float,
message: str,
context: FailureContext | None,
request: requests.PreparedRequest | None = None,
) -> Check:
check = Check(
name=name,
value=Status.failure,
response=response,
elapsed=elapsed,
example=example,
message=message,
context=context,
request=request,
)
self.checks.append(check)
return check
def add_error(self, exception: Exception) -> None:
self.errors.append(exception)
def store_requests_response(
self, case: Case, response: requests.Response, status: Status, checks: list[Check]
) -> None:
self.interactions.append(Interaction.from_requests(case, response, status, checks))
def store_wsgi_response(
self,
case: Case,
response: WSGIResponse,
headers: dict[str, Any],
elapsed: float,
status: Status,
checks: list[Check],
) -> None:
self.interactions.append(Interaction.from_wsgi(case, response, headers, elapsed, status, checks))
@dataclass(repr=False)
class TestResultSet:
"""Set of multiple test results."""
__test__ = False
seed: int | None
results: list[TestResult] = field(default_factory=list)
generic_errors: list[OperationSchemaError] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
def __iter__(self) -> Iterator[TestResult]:
return iter(self.results)
@property
def is_empty(self) -> bool:
"""If the result set contains no results."""
return len(self.results) == 0 and len(self.generic_errors) == 0
@property
def has_failures(self) -> bool:
"""If any result has any failures."""
return any(result.has_failures for result in self)
@property
def has_errors(self) -> bool:
"""If any result has any errors."""
return self.errored_count > 0
@property
def has_logs(self) -> bool:
"""If any result has any captured logs."""
return any(result.has_logs for result in self)
def _count(self, predicate: Callable) -> int:
return sum(1 for result in self if predicate(result))
@property
def passed_count(self) -> int:
return self._count(lambda result: not result.has_errors and not result.is_skipped and not result.has_failures)
@property
def skipped_count(self) -> int:
return self._count(lambda result: result.is_skipped)
@property
def failed_count(self) -> int:
return self._count(lambda result: result.has_failures and not result.is_errored)
@property
def errored_count(self) -> int:
return self._count(lambda result: result.has_errors or result.is_errored) + len(self.generic_errors)
@property
def total(self) -> dict[str, dict[str | Status, int]]:
"""An aggregated statistic about test results."""
output: dict[str, dict[str | Status, int]] = {}
for item in self.results:
for check in item.checks:
output.setdefault(check.name, Counter())
output[check.name][check.value] += 1
output[check.name]["total"] += 1
# Avoid using Counter, since its behavior could harm in other places:
# `if not total["unknown"]:` - this will lead to the branch execution
# It is better to let it fail if there is a wrong key
return {key: dict(value) for key, value in output.items()}
def append(self, item: TestResult) -> None:
"""Add a new item to the results list."""
self.results.append(item)
def add_warning(self, warning: str) -> None:
"""Add a new warning to the warnings list."""
self.warnings.append(warning)
CheckFunction = Callable[["GenericResponse", Case], Optional[bool]]