Source code for schemathesis.auths

"""Support for custom API authentication mechanisms."""
import threading
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, List, Optional, Type, TypeVar, Union

import requests.auth
from typing_extensions import Protocol, runtime_checkable

from .exceptions import UsageError
from .filters import FilterSet, FilterValue, MatcherFunc, attach_filter_chain
from .types import GenericTest

if TYPE_CHECKING:
    from .models import APIOperation, Case

DEFAULT_REFRESH_INTERVAL = 300
AUTH_STORAGE_ATTRIBUTE_NAME = "_schemathesis_auth"
Auth = TypeVar("Auth")


[docs]@dataclass class AuthContext: """Holds state relevant for the authentication process. :ivar APIOperation operation: API operation that is currently being processed. :ivar app: Optional Python application if the WSGI / ASGI integration is used. """ operation: "APIOperation" app: Optional[Any]
[docs]@runtime_checkable class AuthProvider(Protocol): """Get authentication data for an API and set it on the generated test cases."""
[docs] def get(self, context: AuthContext) -> Optional[Auth]: """Get the authentication data. :param AuthContext context: Holds state relevant for the authentication process. :return: Any authentication data you find useful for your use case. For example, it could be an access token. """
[docs] def set(self, case: "Case", data: Auth, context: AuthContext) -> None: """Set authentication data on a generated test case. :param Optional[Auth] data: Authentication data you got from the ``get`` method. :param Case case: Generated test case. :param AuthContext context: Holds state relevant for the authentication process. """
@dataclass class CacheEntry(Generic[Auth]): """Cached auth data.""" data: Auth expires: float @dataclass class RequestsAuth(Generic[Auth]): """Provider that sets auth data via `requests` auth instance.""" auth: requests.auth.AuthBase def get(self, _: AuthContext) -> Optional[Auth]: return self.auth # type: ignore[return-value] def set(self, case: "Case", _: Auth, __: AuthContext) -> None: case._auth = self.auth @dataclass class CachingAuthProvider(Generic[Auth]): """Caches the underlying auth provider.""" provider: AuthProvider refresh_interval: int = DEFAULT_REFRESH_INTERVAL cache_entry: Optional[CacheEntry[Auth]] = None # The timer exists here to simplify testing timer: Callable[[], float] = time.monotonic _refresh_lock: threading.Lock = field(default_factory=threading.Lock) def get(self, context: AuthContext) -> Optional[Auth]: """Get cached auth value.""" if self.cache_entry is None or self.timer() >= self.cache_entry.expires: with self._refresh_lock: if not (self.cache_entry is None or self.timer() >= self.cache_entry.expires): # Another thread updated the cache return self.cache_entry.data # We know that optional auth is possible only inside a higher-level wrapper data: Auth = self.provider.get(context) # type: ignore[assignment] self.cache_entry = CacheEntry(data=data, expires=self.timer() + self.refresh_interval) return data return self.cache_entry.data def set(self, case: "Case", data: Auth, context: AuthContext) -> None: """Set auth data on the `Case` instance. This implementation delegates this to the actual provider. """ self.provider.set(case, data, context) class FilterableRegisterAuth(Protocol): """Protocol that adds filters to the return value of `register`.""" def __call__(self, provider_class: Type[AuthProvider]) -> Type[AuthProvider]: pass def apply_to( self, func: Optional[MatcherFunc] = None, *, name: Optional[FilterValue] = None, name_regex: Optional[str] = None, method: Optional[FilterValue] = None, method_regex: Optional[str] = None, path: Optional[FilterValue] = None, path_regex: Optional[str] = None, ) -> "FilterableRegisterAuth": pass def skip_for( self, func: Optional[MatcherFunc] = None, *, name: Optional[FilterValue] = None, name_regex: Optional[str] = None, method: Optional[FilterValue] = None, method_regex: Optional[str] = None, path: Optional[FilterValue] = None, path_regex: Optional[str] = None, ) -> "FilterableRegisterAuth": pass class FilterableApplyAuth(Protocol): """Protocol that adds filters to the return value of `apply`.""" def __call__(self, test: GenericTest) -> GenericTest: pass def apply_to( self, func: Optional[MatcherFunc] = None, *, name: Optional[FilterValue] = None, name_regex: Optional[str] = None, method: Optional[FilterValue] = None, method_regex: Optional[str] = None, path: Optional[FilterValue] = None, path_regex: Optional[str] = None, ) -> "FilterableApplyAuth": pass def skip_for( self, func: Optional[MatcherFunc] = None, *, name: Optional[FilterValue] = None, name_regex: Optional[str] = None, method: Optional[FilterValue] = None, method_regex: Optional[str] = None, path: Optional[FilterValue] = None, path_regex: Optional[str] = None, ) -> "FilterableApplyAuth": pass class FilterableRequestsAuth(Protocol): """Protocol that adds filters to the return value of `set_from_requests`.""" def apply_to( self, func: Optional[MatcherFunc] = None, *, name: Optional[FilterValue] = None, name_regex: Optional[str] = None, method: Optional[FilterValue] = None, method_regex: Optional[str] = None, path: Optional[FilterValue] = None, path_regex: Optional[str] = None, ) -> "FilterableRequestsAuth": pass def skip_for( self, func: Optional[MatcherFunc] = None, *, name: Optional[FilterValue] = None, name_regex: Optional[str] = None, method: Optional[FilterValue] = None, method_regex: Optional[str] = None, path: Optional[FilterValue] = None, path_regex: Optional[str] = None, ) -> "FilterableRequestsAuth": pass @dataclass class SelectiveAuthProvider(Generic[Auth]): """Applies auth depending on the configured filters.""" provider: AuthProvider filter_set: FilterSet def get(self, context: AuthContext) -> Optional[Auth]: if self.filter_set.match(context): return self.provider.get(context) return None def set(self, case: "Case", data: Auth, context: AuthContext) -> None: self.provider.set(case, data, context) @dataclass class AuthStorage(Generic[Auth]): """Store and manage API authentication.""" providers: List[AuthProvider] = field(default_factory=list) @property def is_defined(self) -> bool: """Whether there is an auth provider set.""" return bool(self.providers) def __call__( self, provider_class: Optional[Type[AuthProvider]] = None, *, refresh_interval: Optional[int] = DEFAULT_REFRESH_INTERVAL, ) -> Union[FilterableRegisterAuth, FilterableApplyAuth]: if provider_class is not None: return self.apply(provider_class, refresh_interval=refresh_interval) return self.register(refresh_interval=refresh_interval) def set_from_requests(self, auth: requests.auth.AuthBase) -> FilterableRequestsAuth: """Use `requests` auth instance as an auth provider.""" filter_set = FilterSet() self.providers.append(SelectiveAuthProvider(provider=RequestsAuth(auth), filter_set=filter_set)) class _FilterableRequestsAuth: pass attach_filter_chain(_FilterableRequestsAuth, "apply_to", filter_set.include) attach_filter_chain(_FilterableRequestsAuth, "skip_for", filter_set.exclude) return _FilterableRequestsAuth # type: ignore[return-value] def _set_provider( self, *, provider_class: Type[AuthProvider], refresh_interval: Optional[int] = DEFAULT_REFRESH_INTERVAL, filter_set: FilterSet, ) -> None: if not issubclass(provider_class, AuthProvider): raise TypeError( f"`{provider_class.__name__}` is not a valid auth provider. " f"Check `schemathesis.auths.AuthProvider` documentation for examples." ) provider: AuthProvider # Apply caching if desired if refresh_interval is not None: provider = CachingAuthProvider(provider_class(), refresh_interval=refresh_interval) else: provider = provider_class() # Store filters if any if not filter_set.is_empty(): provider = SelectiveAuthProvider(provider, filter_set) self.providers.append(provider) def register(self, *, refresh_interval: Optional[int] = DEFAULT_REFRESH_INTERVAL) -> FilterableRegisterAuth: """Register a new auth provider. .. code-block:: python @schemathesis.auth() class TokenAuth: def get(self, context): # This is a real endpoint, try it out! response = requests.post( "https://example.schemathesis.io/api/token/", json={"username": "demo", "password": "test"}, ) data = response.json() return data["access_token"] def set(self, case, data, context): # Modify `case` the way you need case.headers = {"Authorization": f"Bearer {data}"} """ filter_set = FilterSet() def wrapper(provider_class: Type[AuthProvider]) -> Type[AuthProvider]: self._set_provider(provider_class=provider_class, refresh_interval=refresh_interval, filter_set=filter_set) return provider_class attach_filter_chain(wrapper, "apply_to", filter_set.include) attach_filter_chain(wrapper, "skip_for", filter_set.exclude) return wrapper # type: ignore[return-value] def unregister(self) -> None: """Unregister the currently registered auth provider. No-op if there is no auth provider registered. """ self.providers = [] def apply( self, provider_class: Type[AuthProvider], *, refresh_interval: Optional[int] = DEFAULT_REFRESH_INTERVAL ) -> FilterableApplyAuth: """Register auth provider only on one test function. :param Type[AuthProvider] provider_class: Authentication provider class. :param Optional[int] refresh_interval: Cache duration in seconds. .. code-block:: python class Auth: ... @schema.auth(Auth) @schema.parametrize() def test_api(case): ... """ filter_set = FilterSet() def wrapper(test: GenericTest) -> GenericTest: auth_storage = self.add_auth_storage(test) auth_storage._set_provider( provider_class=provider_class, refresh_interval=refresh_interval, filter_set=filter_set ) return test attach_filter_chain(wrapper, "apply_to", filter_set.include) attach_filter_chain(wrapper, "skip_for", filter_set.exclude) return wrapper # type: ignore[return-value] @classmethod def add_auth_storage(cls, test: GenericTest) -> "AuthStorage": """Attach a new auth storage instance to the test if it is not already present.""" if not hasattr(test, AUTH_STORAGE_ATTRIBUTE_NAME): setattr(test, AUTH_STORAGE_ATTRIBUTE_NAME, cls()) else: raise UsageError(f"`{test.__name__}` has already been decorated with `apply`.") return getattr(test, AUTH_STORAGE_ATTRIBUTE_NAME) def set(self, case: "Case", context: AuthContext) -> None: """Set authentication data on a generated test case.""" if not self.is_defined: raise UsageError("No auth provider is defined.") for provider in self.providers: data: Optional[Auth] = provider.get(context) if data is not None: provider.set(case, data, context) break def set_on_case(case: "Case", context: AuthContext, auth_storage: Optional[AuthStorage]) -> None: """Set authentication data on this case. If there is no auth defined, then this function is no-op. """ if auth_storage is not None: auth_storage.set(case, context) elif case.operation.schema.auth.is_defined: case.operation.schema.auth.set(case, context) elif GLOBAL_AUTH_STORAGE.is_defined: GLOBAL_AUTH_STORAGE.set(case, context) def get_auth_storage_from_test(test: GenericTest) -> Optional[AuthStorage]: """Extract the currently attached auth storage from a test function.""" return getattr(test, AUTH_STORAGE_ATTRIBUTE_NAME, None) # Global auth API GLOBAL_AUTH_STORAGE: AuthStorage = AuthStorage() register = GLOBAL_AUTH_STORAGE.register unregister = GLOBAL_AUTH_STORAGE.unregister