|  | #!/usr/bin/env python3 | 
|  | """Create a text file listing all public API. This can be used to ensure that all | 
|  | functions are covered by our macros. | 
|  |  | 
|  | This file additionally does tidy-esque checks that all functions are listed where | 
|  | needed, or that lists are sorted. | 
|  | """ | 
|  |  | 
|  | import difflib | 
|  | import json | 
|  | import re | 
|  | import subprocess as sp | 
|  | import sys | 
|  | from dataclasses import dataclass | 
|  | from glob import glob | 
|  | from pathlib import Path | 
|  | from typing import Any, Callable, TypeAlias | 
|  |  | 
|  | SELF_PATH = Path(__file__) | 
|  | ETC_DIR = SELF_PATH.parent | 
|  | ROOT_DIR = ETC_DIR.parent | 
|  |  | 
|  | # These files do not trigger a retest. | 
|  | IGNORED_SOURCES = ["libm/src/libm_helper.rs", "libm/src/math/support/float_traits.rs"] | 
|  |  | 
|  | IndexTy: TypeAlias = dict[str, dict[str, Any]] | 
|  | """Type of the `index` item in rustdoc's JSON output""" | 
|  |  | 
|  |  | 
|  | def eprint(*args, **kwargs): | 
|  | """Print to stderr.""" | 
|  | print(*args, file=sys.stderr, **kwargs) | 
|  |  | 
|  |  | 
|  | @dataclass | 
|  | class Crate: | 
|  | """Representation of public interfaces and function definition locations in | 
|  | `libm`. | 
|  | """ | 
|  |  | 
|  | public_functions: list[str] | 
|  | """List of all public functions.""" | 
|  | defs: dict[str, list[str]] | 
|  | """Map from `name->[source files]` to find all places that define a public | 
|  | function. We track this to know which tests need to be rerun when specific files | 
|  | get updated. | 
|  | """ | 
|  | types: dict[str, str] | 
|  | """Map from `name->type`.""" | 
|  |  | 
|  | def __init__(self) -> None: | 
|  | self.public_functions = [] | 
|  | self.defs = {} | 
|  | self.types = {} | 
|  |  | 
|  | j = self.get_rustdoc_json() | 
|  | index: IndexTy = j["index"] | 
|  | self._init_function_list(index) | 
|  | self._init_defs(index) | 
|  | self._init_types() | 
|  |  | 
|  | @staticmethod | 
|  | def get_rustdoc_json() -> dict[Any, Any]: | 
|  | """Get rustdoc's JSON output for the `libm` crate.""" | 
|  |  | 
|  | j = sp.check_output( | 
|  | [ | 
|  | "rustdoc", | 
|  | "libm/src/lib.rs", | 
|  | "--edition=2021", | 
|  | "--document-private-items", | 
|  | "--output-format=json", | 
|  | "--cfg=f16_enabled", | 
|  | "--cfg=f128_enabled", | 
|  | "-Zunstable-options", | 
|  | "-o-", | 
|  | ], | 
|  | cwd=ROOT_DIR, | 
|  | text=True, | 
|  | ) | 
|  | j = json.loads(j) | 
|  | return j | 
|  |  | 
|  | def _init_function_list(self, index: IndexTy) -> None: | 
|  | """Get a list of public functions from rustdoc JSON output. | 
|  |  | 
|  | Note that this only finds functions that are reexported in `lib.rs`, this will | 
|  | need to be adjusted if we need to account for functions that are defined there, or | 
|  | glob reexports in other locations. | 
|  | """ | 
|  | # Filter out items that are not public | 
|  | public = [i for i in index.values() if i["visibility"] == "public"] | 
|  |  | 
|  | # Collect a list of source IDs for reexported items in `lib.rs` or `mod math`. | 
|  | use = (i for i in public if "use" in i["inner"]) | 
|  | use = ( | 
|  | i | 
|  | for i in use | 
|  | if i["span"]["filename"] in ["libm/src/math/mod.rs", "libm/src/lib.rs"] | 
|  | ) | 
|  | reexported_ids = [item["inner"]["use"]["id"] for item in use] | 
|  |  | 
|  | # Collect a list of reexported items that are functions | 
|  | for id in reexported_ids: | 
|  | srcitem = index.get(str(id)) | 
|  | # External crate | 
|  | if srcitem is None: | 
|  | continue | 
|  |  | 
|  | # Skip if not a function | 
|  | if "function" not in srcitem["inner"]: | 
|  | continue | 
|  |  | 
|  | self.public_functions.append(srcitem["name"]) | 
|  | self.public_functions.sort() | 
|  |  | 
|  | def _init_defs(self, index: IndexTy) -> None: | 
|  | defs = {name: set() for name in self.public_functions} | 
|  | funcs = (i for i in index.values() if "function" in i["inner"]) | 
|  | funcs = (f for f in funcs if f["name"] in self.public_functions) | 
|  | for func in funcs: | 
|  | defs[func["name"]].add(func["span"]["filename"]) | 
|  |  | 
|  | # A lot of the `arch` module is often configured out so doesn't show up in docs. Use | 
|  | # string matching as a fallback. | 
|  | for fname in glob( | 
|  | "libm/src/math/arch/**/*.rs", root_dir=ROOT_DIR, recursive=True | 
|  | ): | 
|  | contents = (ROOT_DIR.joinpath(fname)).read_text() | 
|  |  | 
|  | for name in self.public_functions: | 
|  | if f"fn {name}" in contents: | 
|  | defs[name].add(fname) | 
|  |  | 
|  | for name, sources in defs.items(): | 
|  | base_sources = defs[base_name(name)[0]] | 
|  | for src in (s for s in base_sources if "generic" in s): | 
|  | sources.add(src) | 
|  |  | 
|  | for src in IGNORED_SOURCES: | 
|  | sources.discard(src) | 
|  |  | 
|  | # Sort the set | 
|  | self.defs = {k: sorted(v) for (k, v) in defs.items()} | 
|  |  | 
|  | def _init_types(self) -> None: | 
|  | self.types = {name: base_name(name)[1] for name in self.public_functions} | 
|  |  | 
|  | def write_function_list(self, check: bool) -> None: | 
|  | """Collect the list of public functions to a simple text file.""" | 
|  | output = "# autogenerated by update-api-list.py\n" | 
|  | for name in self.public_functions: | 
|  | output += f"{name}\n" | 
|  |  | 
|  | out_file = ETC_DIR.joinpath("function-list.txt") | 
|  |  | 
|  | if check: | 
|  | with open(out_file, "r") as f: | 
|  | current = f.read() | 
|  | diff_and_exit(current, output, "function list") | 
|  | else: | 
|  | with open(out_file, "w") as f: | 
|  | f.write(output) | 
|  |  | 
|  | def write_function_defs(self, check: bool) -> None: | 
|  | """Collect the list of information about public functions to a JSON file .""" | 
|  | comment = ( | 
|  | "Autogenerated by update-api-list.py. " | 
|  | "List of files that define a function with a given name. " | 
|  | "This file is checked in to make it obvious if refactoring breaks things" | 
|  | ) | 
|  |  | 
|  | d = {"__comment": comment} | 
|  | d |= { | 
|  | name: {"sources": self.defs[name], "type": self.types[name]} | 
|  | for name in self.public_functions | 
|  | } | 
|  |  | 
|  | out_file = ETC_DIR.joinpath("function-definitions.json") | 
|  | output = json.dumps(d, indent=4) + "\n" | 
|  |  | 
|  | if check: | 
|  | with open(out_file, "r") as f: | 
|  | current = f.read() | 
|  | diff_and_exit(current, output, "source list") | 
|  | else: | 
|  | with open(out_file, "w") as f: | 
|  | f.write(output) | 
|  |  | 
|  | def tidy_lists(self) -> None: | 
|  | """In each file, check annotations indicating blocks of code should be sorted or should | 
|  | include all public API. | 
|  | """ | 
|  |  | 
|  | flist = sp.check_output(["git", "ls-files"], cwd=ROOT_DIR, text=True) | 
|  |  | 
|  | for path in flist.splitlines(): | 
|  | fpath = ROOT_DIR.joinpath(path) | 
|  | if fpath.is_dir() or fpath == SELF_PATH: | 
|  | continue | 
|  |  | 
|  | lines = fpath.read_text().splitlines() | 
|  |  | 
|  | validate_delimited_block( | 
|  | fpath, | 
|  | lines, | 
|  | "verify-sorted-start", | 
|  | "verify-sorted-end", | 
|  | ensure_sorted, | 
|  | ) | 
|  |  | 
|  | validate_delimited_block( | 
|  | fpath, | 
|  | lines, | 
|  | "verify-apilist-start", | 
|  | "verify-apilist-end", | 
|  | lambda p, n, lines: self.ensure_contains_api(p, n, lines), | 
|  | ) | 
|  |  | 
|  | def ensure_contains_api(self, fpath: Path, line_num: int, lines: list[str]): | 
|  | """Given a list of strings, ensure that each public function we have is named | 
|  | somewhere. | 
|  | """ | 
|  | not_found = [] | 
|  | for func in self.public_functions: | 
|  | # The function name may be on its own or somewhere in a snake case string. | 
|  | pat = re.compile(rf"(\b|_){func}(\b|_)") | 
|  | found = next((line for line in lines if pat.search(line)), None) | 
|  |  | 
|  | if found is None: | 
|  | not_found.append(func) | 
|  |  | 
|  | if len(not_found) == 0: | 
|  | return | 
|  |  | 
|  | relpath = fpath.relative_to(ROOT_DIR) | 
|  | eprint(f"functions not found at {relpath}:{line_num}: {not_found}") | 
|  | exit(1) | 
|  |  | 
|  |  | 
|  | def validate_delimited_block( | 
|  | fpath: Path, | 
|  | lines: list[str], | 
|  | start: str, | 
|  | end: str, | 
|  | validate: Callable[[Path, int, list[str]], None], | 
|  | ) -> None: | 
|  | """Identify blocks of code wrapped within `start` and `end`, collect their contents | 
|  | to a list of strings, and call `validate` for each of those lists. | 
|  | """ | 
|  | relpath = fpath.relative_to(ROOT_DIR) | 
|  | block_lines = [] | 
|  | block_start_line: None | int = None | 
|  | for line_num, line in enumerate(lines): | 
|  | line_num += 1 | 
|  |  | 
|  | if start in line: | 
|  | block_start_line = line_num | 
|  | continue | 
|  |  | 
|  | if end in line: | 
|  | if block_start_line is None: | 
|  | eprint(f"`{end}` without `{start}` at {relpath}:{line_num}") | 
|  | exit(1) | 
|  |  | 
|  | validate(fpath, block_start_line, block_lines) | 
|  | block_lines = [] | 
|  | block_start_line = None | 
|  | continue | 
|  |  | 
|  | if block_start_line is not None: | 
|  | block_lines.append(line) | 
|  |  | 
|  | if block_start_line is not None: | 
|  | eprint(f"`{start}` without `{end}` at {relpath}:{block_start_line}") | 
|  | exit(1) | 
|  |  | 
|  |  | 
|  | def ensure_sorted(fpath: Path, block_start_line: int, lines: list[str]) -> None: | 
|  | """Ensure that a list of lines is sorted, otherwise print a diff and exit.""" | 
|  | relpath = fpath.relative_to(ROOT_DIR) | 
|  | diff_and_exit( | 
|  | "\n".join(lines), | 
|  | "\n".join(sorted(lines)), | 
|  | f"sorted block at {relpath}:{block_start_line}", | 
|  | ) | 
|  |  | 
|  |  | 
|  | def diff_and_exit(actual: str, expected: str, name: str): | 
|  | """If the two strings are different, print a diff between them and then exit | 
|  | with an error. | 
|  | """ | 
|  | if actual == expected: | 
|  | print(f"{name} output matches expected; success") | 
|  | return | 
|  |  | 
|  | a = [f"{line}\n" for line in actual.splitlines()] | 
|  | b = [f"{line}\n" for line in expected.splitlines()] | 
|  |  | 
|  | diff = difflib.unified_diff(a, b, "actual", "expected") | 
|  | sys.stdout.writelines(diff) | 
|  | print(f"mismatched {name}") | 
|  | exit(1) | 
|  |  | 
|  |  | 
|  | def base_name(name: str) -> tuple[str, str]: | 
|  | """Return the basename and type from a full function name. Keep in sync with Rust's | 
|  | `fn base_name`. | 
|  | """ | 
|  | known_mappings = [ | 
|  | ("erff", ("erf", "f32")), | 
|  | ("erf", ("erf", "f64")), | 
|  | ("modff", ("modf", "f32")), | 
|  | ("modf", ("modf", "f64")), | 
|  | ("lgammaf_r", ("lgamma_r", "f32")), | 
|  | ("lgamma_r", ("lgamma_r", "f64")), | 
|  | ] | 
|  |  | 
|  | found = next((base for (full, base) in known_mappings if full == name), None) | 
|  | if found is not None: | 
|  | return found | 
|  |  | 
|  | if name.endswith("f"): | 
|  | return (name.rstrip("f"), "f32") | 
|  |  | 
|  | if name.endswith("f16"): | 
|  | return (name.rstrip("f16"), "f16") | 
|  |  | 
|  | if name.endswith("f128"): | 
|  | return (name.rstrip("f128"), "f128") | 
|  |  | 
|  | return (name, "f64") | 
|  |  | 
|  |  | 
|  | def ensure_updated_list(check: bool) -> None: | 
|  | """Runner to update the function list and JSON, or check that it is already up | 
|  | to date. | 
|  | """ | 
|  | crate = Crate() | 
|  | crate.write_function_list(check) | 
|  | crate.write_function_defs(check) | 
|  |  | 
|  | crate.tidy_lists() | 
|  |  | 
|  |  | 
|  | def main(): | 
|  | """By default overwrite the file. If `--check` is passed, print a diff instead and | 
|  | error if the files are different. | 
|  | """ | 
|  | match sys.argv: | 
|  | case [_]: | 
|  | ensure_updated_list(False) | 
|  | case [_, "--check"]: | 
|  | ensure_updated_list(True) | 
|  | case _: | 
|  | print("unrecognized arguments") | 
|  | exit(1) | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | main() |