from __future__ import annotations
import io
import json
import pathlib
import re
from typing import IO, TYPE_CHECKING, Any, Callable, cast
from urllib.parse import urljoin
from ... import experimental, fixups
from ...code_samples import CodeSampleStyle
from ...constants import DEFAULT_RESPONSE_TIMEOUT, NOT_SET, WAIT_FOR_SCHEMA_INTERVAL
from ...exceptions import SchemaError, SchemaErrorType
from ...filters import filter_set_from_components
from ...generation import (
DEFAULT_DATA_GENERATION_METHODS,
DataGenerationMethod,
DataGenerationMethodInput,
GenerationConfig,
)
from ...hooks import HookContext, dispatch
from ...internal.deprecation import warn_filtration_arguments
from ...internal.output import OutputConfig
from ...internal.validation import require_relative_url
from ...loaders import load_schema_from_url, load_yaml
from ...throttling import build_limiter
from ...transports.content_types import is_json_media_type, is_yaml_media_type
from ...transports.headers import setup_default_headers
from ...types import Filter, NotSet, PathLike, Specification
from . import definitions, validation
if TYPE_CHECKING:
import jsonschema
from pyrate_limiter import Limiter
from ...lazy import LazySchema
from ...transports.responses import GenericResponse
from .schemas import BaseOpenAPISchema
def _is_json_response(response: GenericResponse) -> bool:
"""Guess if the response contains JSON."""
content_type = response.headers.get("Content-Type")
if content_type is not None:
return is_json_media_type(content_type)
return False
def _has_suffix(path: PathLike, suffix: str) -> bool:
if isinstance(path, str):
return path.endswith(suffix)
return path.suffix == suffix
def _is_json_path(path: PathLike) -> bool:
return _has_suffix(path, ".json")
def _is_yaml_response(response: GenericResponse) -> bool:
"""Guess if the response contains YAML."""
content_type = response.headers.get("Content-Type")
if content_type is not None:
return is_yaml_media_type(content_type)
return False
def _is_yaml_path(path: PathLike) -> bool:
return _has_suffix(path, ".yaml") or _has_suffix(path, ".yml")
[docs]def from_path(
path: PathLike,
*,
app: Any = None,
base_url: str | None = None,
method: Filter | None = None,
endpoint: Filter | None = None,
tag: Filter | None = None,
operation_id: Filter | None = None,
skip_deprecated_operations: bool | None = None,
validate_schema: bool = False,
force_schema_version: str | None = None,
data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
generation_config: GenerationConfig | None = None,
output_config: OutputConfig | None = None,
code_sample_style: str = CodeSampleStyle.default().name,
rate_limit: str | None = None,
encoding: str = "utf8",
sanitize_output: bool = True,
) -> BaseOpenAPISchema:
"""Load Open API schema via a file from an OS path.
:param path: A path to the schema file.
:param encoding: The name of the encoding used to decode the file.
"""
with open(path, encoding=encoding) as fd:
return from_file(
fd,
app=app,
base_url=base_url,
method=method,
endpoint=endpoint,
tag=tag,
operation_id=operation_id,
skip_deprecated_operations=skip_deprecated_operations,
validate_schema=validate_schema,
force_schema_version=force_schema_version,
data_generation_methods=data_generation_methods,
generation_config=generation_config,
output_config=output_config,
code_sample_style=code_sample_style,
location=pathlib.Path(path).absolute().as_uri(),
rate_limit=rate_limit,
sanitize_output=sanitize_output,
__expects_json=_is_json_path(path),
__expects_yaml=_is_yaml_path(path),
)
[docs]def from_uri(
uri: str,
*,
app: Any = None,
base_url: str | None = None,
port: int | None = None,
method: Filter | None = None,
endpoint: Filter | None = None,
tag: Filter | None = None,
operation_id: Filter | None = None,
skip_deprecated_operations: bool | None = None,
validate_schema: bool = False,
force_schema_version: str | None = None,
data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
generation_config: GenerationConfig | None = None,
output_config: OutputConfig | None = None,
code_sample_style: str = CodeSampleStyle.default().name,
wait_for_schema: float | None = None,
rate_limit: str | None = None,
sanitize_output: bool = True,
**kwargs: Any,
) -> BaseOpenAPISchema:
"""Load Open API schema from the network.
:param str uri: Schema URL.
"""
import backoff
import requests
setup_default_headers(kwargs)
if port:
from yarl import URL
uri = str(URL(uri).with_port(port))
if not base_url:
base_url = uri
if wait_for_schema is not None:
@backoff.on_exception( # type: ignore
backoff.constant,
requests.exceptions.ConnectionError,
max_time=wait_for_schema,
interval=WAIT_FOR_SCHEMA_INTERVAL,
)
def _load_schema(_uri: str, **_kwargs: Any) -> requests.Response:
return requests.get(_uri, **_kwargs)
else:
_load_schema = requests.get
kwargs.setdefault("timeout", DEFAULT_RESPONSE_TIMEOUT / 1000)
response = load_schema_from_url(lambda: _load_schema(uri, **kwargs))
return from_file(
response.text,
app=app,
base_url=base_url,
method=method,
endpoint=endpoint,
tag=tag,
operation_id=operation_id,
skip_deprecated_operations=skip_deprecated_operations,
validate_schema=validate_schema,
force_schema_version=force_schema_version,
data_generation_methods=data_generation_methods,
generation_config=generation_config,
output_config=output_config,
code_sample_style=code_sample_style,
location=uri,
rate_limit=rate_limit,
sanitize_output=sanitize_output,
__expects_json=_is_json_response(response),
__expects_yaml=_is_yaml_response(response),
)
SCHEMA_INVALID_ERROR = "The provided API schema does not appear to be a valid OpenAPI schema"
SCHEMA_LOADING_ERROR = "Received unsupported content while expecting a JSON or YAML payload for Open API"
SCHEMA_SYNTAX_ERROR = "API schema does not appear syntactically valid"
def _load_yaml(data: str, include_details_on_error: bool = False) -> dict[str, Any]:
import yaml
try:
return load_yaml(data)
except yaml.YAMLError as exc:
if include_details_on_error:
type_ = SchemaErrorType.SYNTAX_ERROR
message = SCHEMA_SYNTAX_ERROR
extras = [entry for entry in str(exc).splitlines() if entry]
else:
type_ = SchemaErrorType.UNEXPECTED_CONTENT_TYPE
message = SCHEMA_LOADING_ERROR
extras = []
raise SchemaError(type_, message, extras=extras) from exc
[docs]def from_file(
file: IO[str] | str,
*,
app: Any = None,
base_url: str | None = None,
method: Filter | None = None,
endpoint: Filter | None = None,
tag: Filter | None = None,
operation_id: Filter | None = None,
skip_deprecated_operations: bool | None = None,
validate_schema: bool = False,
force_schema_version: str | None = None,
data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
generation_config: GenerationConfig | None = None,
output_config: OutputConfig | None = None,
code_sample_style: str = CodeSampleStyle.default().name,
location: str | None = None,
rate_limit: str | None = None,
sanitize_output: bool = True,
__expects_json: bool = False,
__expects_yaml: bool = False,
**kwargs: Any, # needed in the runner to have compatible API across all loaders
) -> BaseOpenAPISchema:
"""Load Open API schema from a file descriptor, string or bytes.
:param file: Could be a file descriptor, string or bytes.
"""
if hasattr(file, "read"):
data = file.read() # type: ignore
else:
data = file
if __expects_json:
try:
raw = json.loads(data)
except json.JSONDecodeError as exc:
# Fallback to a slower YAML loader. This way we'll still load schemas from responses with
# invalid `Content-Type` headers or YAML files that have the `.json` extension.
# This is a rare case, and it will be slower but trying JSON first improves a more common use case
try:
raw = _load_yaml(data)
except SchemaError:
raise SchemaError(
SchemaErrorType.SYNTAX_ERROR,
SCHEMA_SYNTAX_ERROR,
extras=[entry for entry in str(exc).splitlines() if entry],
) from exc
else:
raw = _load_yaml(data, include_details_on_error=__expects_yaml)
return from_dict(
raw,
app=app,
base_url=base_url,
method=method,
endpoint=endpoint,
tag=tag,
operation_id=operation_id,
skip_deprecated_operations=skip_deprecated_operations,
validate_schema=validate_schema,
force_schema_version=force_schema_version,
data_generation_methods=data_generation_methods,
generation_config=generation_config,
output_config=output_config,
code_sample_style=code_sample_style,
location=location,
rate_limit=rate_limit,
sanitize_output=sanitize_output,
)
def _is_fast_api(app: Any) -> bool:
for cls in app.__class__.__mro__:
if f"{cls.__module__}.{cls.__qualname__}" == "fastapi.applications.FastAPI":
return True
return False
[docs]def from_dict(
raw_schema: dict[str, Any],
*,
app: Any = None,
base_url: str | None = None,
method: Filter | None = None,
endpoint: Filter | None = None,
tag: Filter | None = None,
operation_id: Filter | None = None,
skip_deprecated_operations: bool | None = None,
validate_schema: bool = False,
force_schema_version: str | None = None,
data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
generation_config: GenerationConfig | None = None,
output_config: OutputConfig | None = None,
code_sample_style: str = CodeSampleStyle.default().name,
location: str | None = None,
rate_limit: str | None = None,
sanitize_output: bool = True,
) -> BaseOpenAPISchema:
"""Load Open API schema from a Python dictionary.
:param dict raw_schema: A schema to load.
"""
from ... import transports
from .schemas import OpenApi30, SwaggerV20
if not isinstance(raw_schema, dict):
raise SchemaError(SchemaErrorType.OPEN_API_INVALID_SCHEMA, SCHEMA_INVALID_ERROR)
_code_sample_style = CodeSampleStyle.from_str(code_sample_style)
hook_context = HookContext()
is_openapi_31 = raw_schema.get("openapi", "").startswith("3.1")
is_fast_api_fixup_installed = fixups.is_installed("fast_api")
if is_fast_api_fixup_installed and is_openapi_31:
fixups.fast_api.uninstall()
elif _is_fast_api(app):
fixups.fast_api.adjust_schema(raw_schema)
dispatch("before_load_schema", hook_context, raw_schema)
rate_limiter: Limiter | None = None
if rate_limit is not None:
rate_limiter = build_limiter(rate_limit)
for name in ("method", "endpoint", "tag", "operation_id", "skip_deprecated_operations"):
value = locals()[name]
if value is not None:
warn_filtration_arguments(name)
filter_set = filter_set_from_components(
include=True,
method=method,
endpoint=endpoint,
tag=tag,
operation_id=operation_id,
skip_deprecated_operations=skip_deprecated_operations,
)
def init_openapi_2() -> SwaggerV20:
_maybe_validate_schema(raw_schema, definitions.SWAGGER_20_VALIDATOR, validate_schema)
instance = SwaggerV20(
raw_schema,
specification=Specification.OPENAPI,
app=app,
base_url=base_url,
filter_set=filter_set,
validate_schema=validate_schema,
data_generation_methods=DataGenerationMethod.ensure_list(data_generation_methods),
generation_config=generation_config or GenerationConfig(),
output_config=output_config or OutputConfig(),
code_sample_style=_code_sample_style,
location=location,
rate_limiter=rate_limiter,
sanitize_output=sanitize_output,
transport=transports.get(app),
)
dispatch("after_load_schema", hook_context, instance)
return instance
def init_openapi_3(forced: bool) -> OpenApi30:
version = raw_schema["openapi"]
if (
not (is_openapi_31 and experimental.OPEN_API_3_1.is_enabled)
and not forced
and not OPENAPI_30_VERSION_RE.match(version)
):
if is_openapi_31:
raise SchemaError(
SchemaErrorType.OPEN_API_EXPERIMENTAL_VERSION,
f"The provided schema uses Open API {version}, which is currently not fully supported.",
)
raise SchemaError(
SchemaErrorType.OPEN_API_UNSUPPORTED_VERSION,
f"The provided schema uses Open API {version}, which is currently not supported.",
)
if is_openapi_31:
validator = definitions.OPENAPI_31_VALIDATOR
else:
validator = definitions.OPENAPI_30_VALIDATOR
_maybe_validate_schema(raw_schema, validator, validate_schema)
instance = OpenApi30(
raw_schema,
specification=Specification.OPENAPI,
app=app,
base_url=base_url,
filter_set=filter_set,
validate_schema=validate_schema,
data_generation_methods=DataGenerationMethod.ensure_list(data_generation_methods),
generation_config=generation_config or GenerationConfig(),
output_config=output_config or OutputConfig(),
code_sample_style=_code_sample_style,
location=location,
rate_limiter=rate_limiter,
sanitize_output=sanitize_output,
transport=transports.get(app),
)
dispatch("after_load_schema", hook_context, instance)
return instance
if force_schema_version == "20":
return init_openapi_2()
if force_schema_version == "30":
return init_openapi_3(forced=True)
if "swagger" in raw_schema:
return init_openapi_2()
if "openapi" in raw_schema:
return init_openapi_3(forced=False)
raise SchemaError(
SchemaErrorType.OPEN_API_UNSPECIFIED_VERSION,
"Unable to determine the Open API version as it's not specified in the document.",
)
OPENAPI_30_VERSION_RE = re.compile(r"^3\.0\.\d(-.+)?$")
# It is a common case when API schemas are stored in the YAML format and HTTP status codes are numbers
# The Open API spec requires HTTP status codes as strings
DOC_ENTRY = "https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.3.md#patterned-fields-1"
NUMERIC_STATUS_CODES_MESSAGE = f"""Numeric HTTP status codes detected in your YAML schema.
According to the Open API specification, status codes must be strings, not numbers.
For more details, check the Open API documentation: {DOC_ENTRY}
Please, stringify the following status codes:"""
NON_STRING_OBJECT_KEY_MESSAGE = (
"The Open API specification requires all keys in the schema to be strings. "
"You have some keys that are not strings."
)
def _format_status_codes(status_codes: list[tuple[int, list[str | int]]]) -> str:
buffer = io.StringIO()
for status_code, path in status_codes:
buffer.write(f" - {status_code} at schema['paths']")
for chunk in path:
buffer.write(f"[{chunk!r}]")
buffer.write("['responses']\n")
return buffer.getvalue().rstrip()
def _maybe_validate_schema(
instance: dict[str, Any], validator: jsonschema.validators.Draft4Validator, validate_schema: bool
) -> None:
from jsonschema import ValidationError
if validate_schema:
try:
validator.validate(instance)
except TypeError as exc:
if validation.is_pattern_error(exc):
status_codes = validation.find_numeric_http_status_codes(instance)
if status_codes:
message = _format_status_codes(status_codes)
raise SchemaError(
SchemaErrorType.YAML_NUMERIC_STATUS_CODES, f"{NUMERIC_STATUS_CODES_MESSAGE}\n{message}"
) from exc
# Some other pattern error
raise SchemaError(SchemaErrorType.YAML_NON_STRING_KEYS, NON_STRING_OBJECT_KEY_MESSAGE) from exc
raise SchemaError(SchemaErrorType.UNCLASSIFIED, "Unknown error") from exc
except ValidationError as exc:
raise SchemaError(
SchemaErrorType.OPEN_API_INVALID_SCHEMA,
SCHEMA_INVALID_ERROR,
extras=[entry for entry in str(exc).splitlines() if entry],
) from exc
[docs]def from_pytest_fixture(
fixture_name: str,
*,
app: Any = NOT_SET,
base_url: str | None | NotSet = NOT_SET,
method: Filter | None = NOT_SET,
endpoint: Filter | None = NOT_SET,
tag: Filter | None = NOT_SET,
operation_id: Filter | None = NOT_SET,
skip_deprecated_operations: bool | None = None,
validate_schema: bool = False,
data_generation_methods: DataGenerationMethodInput | NotSet = NOT_SET,
generation_config: GenerationConfig | NotSet = NOT_SET,
output_config: OutputConfig | NotSet = NOT_SET,
code_sample_style: str = CodeSampleStyle.default().name,
rate_limit: str | None = None,
sanitize_output: bool = True,
) -> LazySchema:
"""Load schema from a ``pytest`` fixture.
It is useful if you don't want to make network requests during module loading. With this loader you can defer it
to a fixture.
Note, the fixture should return a ``BaseSchema`` instance loaded with another loader.
:param str fixture_name: The name of a fixture to load.
"""
from ...lazy import LazySchema
_code_sample_style = CodeSampleStyle.from_str(code_sample_style)
_data_generation_methods: DataGenerationMethodInput | NotSet
if data_generation_methods is not NOT_SET:
data_generation_methods = cast(DataGenerationMethodInput, data_generation_methods)
_data_generation_methods = DataGenerationMethod.ensure_list(data_generation_methods)
else:
_data_generation_methods = data_generation_methods
rate_limiter: Limiter | None = None
if rate_limit is not None:
rate_limiter = build_limiter(rate_limit)
for name in ("method", "endpoint", "tag", "operation_id", "skip_deprecated_operations"):
value = locals()[name]
if value is not None:
warn_filtration_arguments(name)
filter_set = filter_set_from_components(
include=True,
method=method,
endpoint=endpoint,
tag=tag,
operation_id=operation_id,
skip_deprecated_operations=skip_deprecated_operations,
)
return LazySchema(
fixture_name,
app=app,
base_url=base_url,
filter_set=filter_set,
validate_schema=validate_schema,
data_generation_methods=_data_generation_methods,
generation_config=generation_config,
output_config=output_config,
code_sample_style=_code_sample_style,
rate_limiter=rate_limiter,
sanitize_output=sanitize_output,
)
[docs]def from_wsgi(
schema_path: str,
app: Any,
*,
base_url: str | None = None,
method: Filter | None = None,
endpoint: Filter | None = None,
tag: Filter | None = None,
operation_id: Filter | None = None,
skip_deprecated_operations: bool | None = None,
validate_schema: bool = False,
force_schema_version: str | None = None,
data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
generation_config: GenerationConfig | None = None,
output_config: OutputConfig | None = None,
code_sample_style: str = CodeSampleStyle.default().name,
rate_limit: str | None = None,
sanitize_output: bool = True,
**kwargs: Any,
) -> BaseOpenAPISchema:
"""Load Open API schema from a WSGI app.
:param str schema_path: An in-app relative URL to the schema.
:param app: A WSGI app instance.
"""
from werkzeug.test import Client
from ...transports.responses import WSGIResponse
require_relative_url(schema_path)
setup_default_headers(kwargs)
client = Client(app, WSGIResponse)
response = load_schema_from_url(lambda: client.get(schema_path, **kwargs))
return from_file(
response.data,
app=app,
base_url=base_url,
method=method,
endpoint=endpoint,
tag=tag,
operation_id=operation_id,
skip_deprecated_operations=skip_deprecated_operations,
validate_schema=validate_schema,
force_schema_version=force_schema_version,
data_generation_methods=data_generation_methods,
generation_config=generation_config,
output_config=output_config,
code_sample_style=code_sample_style,
location=schema_path,
rate_limit=rate_limit,
sanitize_output=sanitize_output,
__expects_json=_is_json_response(response),
)
def get_loader_for_app(app: Any) -> Callable:
from ...transports.asgi import is_asgi_app
if is_asgi_app(app):
return from_asgi
if app.__class__.__module__.startswith("aiohttp."):
return from_aiohttp
return from_wsgi
[docs]def from_aiohttp(
schema_path: str,
app: Any,
*,
base_url: str | None = None,
method: Filter | None = None,
endpoint: Filter | None = None,
tag: Filter | None = None,
operation_id: Filter | None = None,
skip_deprecated_operations: bool | None = None,
validate_schema: bool = False,
force_schema_version: str | None = None,
data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
generation_config: GenerationConfig | None = None,
output_config: OutputConfig | None = None,
code_sample_style: str = CodeSampleStyle.default().name,
rate_limit: str | None = None,
sanitize_output: bool = True,
**kwargs: Any,
) -> BaseOpenAPISchema:
"""Load Open API schema from an AioHTTP app.
:param str schema_path: An in-app relative URL to the schema.
:param app: An AioHTTP app instance.
"""
from ...extra._aiohttp import run_server
port = run_server(app)
app_url = f"http://127.0.0.1:{port}/"
url = urljoin(app_url, schema_path)
return from_uri(
url,
base_url=base_url,
method=method,
endpoint=endpoint,
tag=tag,
operation_id=operation_id,
skip_deprecated_operations=skip_deprecated_operations,
validate_schema=validate_schema,
force_schema_version=force_schema_version,
data_generation_methods=data_generation_methods,
generation_config=generation_config,
output_config=output_config,
code_sample_style=code_sample_style,
rate_limit=rate_limit,
sanitize_output=sanitize_output,
**kwargs,
)
[docs]def from_asgi(
schema_path: str,
app: Any,
*,
base_url: str | None = None,
method: Filter | None = None,
endpoint: Filter | None = None,
tag: Filter | None = None,
operation_id: Filter | None = None,
skip_deprecated_operations: bool | None = None,
validate_schema: bool = False,
force_schema_version: str | None = None,
data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
generation_config: GenerationConfig | None = None,
output_config: OutputConfig | None = None,
code_sample_style: str = CodeSampleStyle.default().name,
rate_limit: str | None = None,
sanitize_output: bool = True,
**kwargs: Any,
) -> BaseOpenAPISchema:
"""Load Open API schema from an ASGI app.
:param str schema_path: An in-app relative URL to the schema.
:param app: An ASGI app instance.
"""
from starlette_testclient import TestClient as ASGIClient
require_relative_url(schema_path)
setup_default_headers(kwargs)
client = ASGIClient(app)
response = load_schema_from_url(lambda: client.get(schema_path, **kwargs))
return from_file(
response.text,
app=app,
base_url=base_url,
method=method,
endpoint=endpoint,
tag=tag,
operation_id=operation_id,
skip_deprecated_operations=skip_deprecated_operations,
validate_schema=validate_schema,
force_schema_version=force_schema_version,
data_generation_methods=data_generation_methods,
generation_config=generation_config,
output_config=output_config,
code_sample_style=code_sample_style,
location=schema_path,
rate_limit=rate_limit,
sanitize_output=sanitize_output,
__expects_json=_is_json_response(response),
)