Source code for schemathesis.schemas

"""Schema objects provide a convenient interface to raw schemas.

Their responsibilities:
  - Provide a unified way to work with different types of schemas
  - Give all paths / methods combinations that are available directly from the schema;

They give only static definitions of paths.
"""
from collections.abc import Mapping
from contextlib import nullcontext
from dataclasses import dataclass, field
from difflib import get_close_matches
from functools import lru_cache
from typing import (
    Any,
    Callable,
    ContextManager,
    Dict,
    Generator,
    Iterable,
    Iterator,
    List,
    NoReturn,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    Union,
)
from urllib.parse import quote, unquote, urljoin, urlparse, urlsplit, urlunsplit

import hypothesis
from hypothesis.strategies import SearchStrategy
from pyrate_limiter import Limiter
from requests.structures import CaseInsensitiveDict

from ._hypothesis import create_test
from .auths import AuthStorage
from .constants import DEFAULT_DATA_GENERATION_METHODS, CodeSampleStyle, DataGenerationMethod
from .exceptions import InvalidSchema, UsageError
from .hooks import HookContext, HookDispatcher, HookScope, dispatch
from .models import APIOperation, Case
from .stateful import APIStateMachine, Stateful, StatefulTest
from .types import (
    Body,
    Cookies,
    DataGenerationMethodInput,
    Filter,
    FormData,
    GenericTest,
    Headers,
    NotSet,
    PathParameters,
    Query,
)
from .utils import NOT_SET, PARAMETRIZE_MARKER, GenericResponse, GivenInput, Ok, Result, given_proxy


class MethodsDict(CaseInsensitiveDict):
    """Container for accessing API operations.

    Provides a more specific error message if API operation is not found.
    """

    def __getitem__(self, item: Any) -> Any:
        try:
            return super().__getitem__(item)
        except KeyError as exc:
            available_methods = ", ".join(map(str.upper, self))
            message = f"Method `{item}` not found. Available methods: {available_methods}"
            raise KeyError(message) from exc


C = TypeVar("C", bound=Case)


@lru_cache()
def get_full_path(base_path: str, path: str) -> str:
    return unquote(urljoin(base_path, quote(path.lstrip("/"))))


[docs]@dataclass(eq=False) class BaseSchema(Mapping): raw_schema: Dict[str, Any] location: Optional[str] = None base_url: Optional[str] = None method: Optional[Filter] = None endpoint: Optional[Filter] = None tag: Optional[Filter] = None operation_id: Optional[Filter] = None app: Any = None hooks: HookDispatcher = field(default_factory=lambda: HookDispatcher(scope=HookScope.SCHEMA)) auth: AuthStorage = field(default_factory=AuthStorage) test_function: Optional[GenericTest] = None validate_schema: bool = True skip_deprecated_operations: bool = False data_generation_methods: List[DataGenerationMethod] = field( default_factory=lambda: list(DEFAULT_DATA_GENERATION_METHODS) ) code_sample_style: CodeSampleStyle = CodeSampleStyle.default() rate_limiter: Optional[Limiter] = None def __iter__(self) -> Iterator[str]: return iter(self.operations) def __getitem__(self, item: str) -> MethodsDict: try: return self.operations[item] except KeyError as exc: matches = get_close_matches(item, list(self.operations)) message = f"`{item}` not found" if matches: message += f". Did you mean `{matches[0]}`?" raise KeyError(message) from exc def __len__(self) -> int: return len(self.operations) def hook(self, hook: Union[str, Callable]) -> Callable: return self.hooks.register(hook) @property def verbose_name(self) -> str: raise NotImplementedError def get_full_path(self, path: str) -> str: """Compute full path for the given path.""" return get_full_path(self.base_path, path) @property def base_path(self) -> str: """Base path for the schema.""" # if `base_url` is specified, then it should include base path # Example: http://127.0.0.1:8080/api if self.base_url: path = urlsplit(self.base_url).path else: path = self._get_base_path() if not path.endswith("/"): path += "/" return path def _get_base_path(self) -> str: raise NotImplementedError def _build_base_url(self) -> str: path = self._get_base_path() parts = urlsplit(self.location or "")[:2] + (path, "", "") return urlunsplit(parts) def get_base_url(self) -> str: base_url = self.base_url if base_url is not None: return base_url.rstrip("/") return self._build_base_url() @property def operations(self) -> Dict[str, MethodsDict]: if not hasattr(self, "_operations"): operations = self.get_all_operations() self._operations = operations_to_dict(operations) return self._operations @property def operations_count(self) -> int: raise NotImplementedError def get_all_operations(self) -> Generator[Result[APIOperation, InvalidSchema], None, None]: raise NotImplementedError def get_strategies_from_examples(self, operation: APIOperation) -> List[SearchStrategy[Case]]: """Get examples from the API operation.""" raise NotImplementedError def get_security_requirements(self, operation: APIOperation) -> List[str]: """Get applied security requirements for the given API operation.""" raise NotImplementedError def get_stateful_tests( self, response: GenericResponse, operation: APIOperation, stateful: Optional[Stateful] ) -> Sequence[StatefulTest]: """Get a list of additional tests, that should be executed after this response from the API operation.""" raise NotImplementedError def get_parameter_serializer(self, operation: APIOperation, location: str) -> Optional[Callable]: """Get a function that serializes parameters for the given location.""" raise NotImplementedError def get_all_tests( self, func: Callable, settings: Optional[hypothesis.settings] = None, seed: Optional[int] = None, as_strategy_kwargs: Optional[Dict[str, Any]] = None, _given_kwargs: Optional[Dict[str, GivenInput]] = None, ) -> Generator[Result[Tuple[APIOperation, Callable], InvalidSchema], None, None]: """Generate all operations and Hypothesis tests for them.""" for result in self.get_all_operations(): if isinstance(result, Ok): test = create_test( operation=result.ok(), test=func, settings=settings, seed=seed, data_generation_methods=self.data_generation_methods, as_strategy_kwargs=as_strategy_kwargs, _given_kwargs=_given_kwargs, ) yield Ok((result.ok(), test)) else: yield result
[docs] def parametrize( self, method: Optional[Filter] = NOT_SET, endpoint: Optional[Filter] = NOT_SET, tag: Optional[Filter] = NOT_SET, operation_id: Optional[Filter] = NOT_SET, validate_schema: Union[bool, NotSet] = NOT_SET, skip_deprecated_operations: Union[bool, NotSet] = NOT_SET, data_generation_methods: Union[Iterable[DataGenerationMethod], NotSet] = NOT_SET, code_sample_style: Union[str, NotSet] = NOT_SET, ) -> Callable: """Mark a test function as a parametrized one.""" _code_sample_style = ( CodeSampleStyle.from_str(code_sample_style) if isinstance(code_sample_style, str) else code_sample_style ) def wrapper(func: GenericTest) -> GenericTest: if hasattr(func, PARAMETRIZE_MARKER): def wrapped_test(*_: Any, **__: Any) -> NoReturn: raise UsageError( f"You have applied `parametrize` to the `{func.__name__}` test more than once, which " "overrides the previous decorator. " "The `parametrize` decorator could be applied to the same function at most once." ) return wrapped_test HookDispatcher.add_dispatcher(func) cloned = self.clone( test_function=func, method=method, endpoint=endpoint, tag=tag, operation_id=operation_id, validate_schema=validate_schema, skip_deprecated_operations=skip_deprecated_operations, data_generation_methods=data_generation_methods, code_sample_style=_code_sample_style, # type: ignore ) setattr(func, PARAMETRIZE_MARKER, cloned) return func return wrapper
[docs] def given(self, *args: GivenInput, **kwargs: GivenInput) -> Callable: """Proxy Hypothesis strategies to ``hypothesis.given``.""" return given_proxy(*args, **kwargs)
def clone( self, *, base_url: Union[Optional[str], NotSet] = NOT_SET, test_function: Optional[GenericTest] = None, method: Optional[Filter] = NOT_SET, endpoint: Optional[Filter] = NOT_SET, tag: Optional[Filter] = NOT_SET, operation_id: Optional[Filter] = NOT_SET, app: Any = NOT_SET, hooks: Union[HookDispatcher, NotSet] = NOT_SET, auth: Union[AuthStorage, NotSet] = NOT_SET, validate_schema: Union[bool, NotSet] = NOT_SET, skip_deprecated_operations: Union[bool, NotSet] = NOT_SET, data_generation_methods: Union[DataGenerationMethodInput, NotSet] = NOT_SET, code_sample_style: Union[CodeSampleStyle, NotSet] = NOT_SET, rate_limiter: Optional[Limiter] = NOT_SET, ) -> "BaseSchema": if base_url is NOT_SET: base_url = self.base_url if method is NOT_SET: method = self.method if endpoint is NOT_SET: endpoint = self.endpoint if tag is NOT_SET: tag = self.tag if operation_id is NOT_SET: operation_id = self.operation_id if app is NOT_SET: app = self.app if validate_schema is NOT_SET: validate_schema = self.validate_schema if skip_deprecated_operations is NOT_SET: skip_deprecated_operations = self.skip_deprecated_operations if hooks is NOT_SET: hooks = self.hooks if auth is NOT_SET: auth = self.auth if data_generation_methods is NOT_SET: data_generation_methods = self.data_generation_methods if code_sample_style is NOT_SET: code_sample_style = self.code_sample_style if rate_limiter is NOT_SET: rate_limiter = self.rate_limiter return self.__class__( self.raw_schema, location=self.location, base_url=base_url, # type: ignore method=method, endpoint=endpoint, tag=tag, operation_id=operation_id, app=app, hooks=hooks, # type: ignore auth=auth, # type: ignore test_function=test_function, validate_schema=validate_schema, # type: ignore skip_deprecated_operations=skip_deprecated_operations, # type: ignore data_generation_methods=data_generation_methods, # type: ignore code_sample_style=code_sample_style, # type: ignore rate_limiter=rate_limiter, # type: ignore ) def get_local_hook_dispatcher(self) -> Optional[HookDispatcher]: """Get a HookDispatcher instance bound to the test if present.""" # It might be not present when it is used without pytest via `APIOperation.as_strategy()` if self.test_function is not None: # Might be missing it in case of `LazySchema` usage return getattr(self.test_function, "_schemathesis_hooks", None) # type: ignore return None def dispatch_hook(self, name: str, context: HookContext, *args: Any, **kwargs: Any) -> None: """Dispatch a hook via all available dispatchers.""" dispatch(name, context, *args, **kwargs) self.hooks.dispatch(name, context, *args, **kwargs) local_dispatcher = self.get_local_hook_dispatcher() if local_dispatcher is not None: local_dispatcher.dispatch(name, context, *args, **kwargs) def prepare_multipart( self, form_data: FormData, operation: APIOperation ) -> Tuple[Optional[List], Optional[Dict[str, Any]]]: """Split content of `form_data` into files & data. Forms may contain file fields, that we should send via `files` argument in `requests`. """ raise NotImplementedError def get_request_payload_content_types(self, operation: APIOperation) -> List[str]: raise NotImplementedError def make_case( self, *, case_cls: Type[C], operation: APIOperation, path_parameters: Optional[PathParameters] = None, headers: Optional[Headers] = None, cookies: Optional[Cookies] = None, query: Optional[Query] = None, body: Union[Body, NotSet] = NOT_SET, media_type: Optional[str] = None, ) -> C: raise NotImplementedError def get_case_strategy( self, operation: APIOperation, hooks: Optional[HookDispatcher] = None, auth_storage: Optional[AuthStorage] = None, data_generation_method: DataGenerationMethod = DataGenerationMethod.default(), **kwargs: Any, ) -> SearchStrategy: raise NotImplementedError
[docs] def as_state_machine(self) -> Type[APIStateMachine]: """Create a state machine class. Use it for stateful testing. """ raise NotImplementedError
def get_links(self, operation: APIOperation) -> Dict[str, Dict[str, Any]]: raise NotImplementedError def validate_response(self, operation: APIOperation, response: GenericResponse) -> None: raise NotImplementedError def prepare_schema(self, schema: Any) -> Any: raise NotImplementedError def ratelimit(self) -> ContextManager: """Limit the rate of sending generated requests.""" label = urlparse(self.base_url).netloc if self.rate_limiter is not None: return self.rate_limiter.ratelimit(label, delay=True, max_delay=0) return nullcontext() def _get_payload_schema(self, definition: Dict[str, Any], media_type: str) -> Optional[Dict[str, Any]]: raise NotImplementedError
def operations_to_dict( operations: Generator[Result[APIOperation, InvalidSchema], None, None] ) -> Dict[str, MethodsDict]: output: Dict[str, MethodsDict] = {} for result in operations: if isinstance(result, Ok): operation = result.ok() output.setdefault(operation.path, MethodsDict()) output[operation.path][operation.method] = operation return output