| #!/usr/bin/env python3 |
| """Utilities for CI. |
| |
| This dynamically prepares a list of routines that had a source file change based on |
| git history. |
| """ |
| |
| import json |
| import os |
| import pprint |
| import re |
| import subprocess as sp |
| import sys |
| from dataclasses import dataclass |
| from functools import cache |
| from glob import glob |
| from inspect import cleandoc |
| from os import getenv |
| from pathlib import Path |
| from typing import TypedDict, Self |
| |
| USAGE = cleandoc( |
| """ |
| usage: |
| |
| ./ci/ci-util.py <COMMAND> [flags] |
| |
| COMMAND: |
| generate-matrix |
| Calculate a matrix of which functions had source change, print that as |
| a JSON object. |
| |
| locate-baseline [--download] [--extract] [--tag TAG] |
| Locate the most recent benchmark baseline available in CI and, if flags |
| specify, download and extract it. Never exits with nonzero status if |
| downloading fails. |
| |
| `--tag` can be specified to look for artifacts with a specific tag, such as |
| for a specific architecture. |
| |
| Note that `--extract` will overwrite files in `iai-home`. |
| |
| handle-bench-regressions PR_NUMBER |
| Exit with success if the pull request contains a line starting with |
| `ci: allow-regressions`, indicating that regressions in benchmarks should |
| be accepted. Otherwise, exit 1. |
| """ |
| ) |
| |
| REPO_ROOT = Path(__file__).parent.parent |
| GIT = ["git", "-C", REPO_ROOT] |
| DEFAULT_BRANCH = "master" |
| WORKFLOW_NAME = "CI" # Workflow that generates the benchmark artifacts |
| ARTIFACT_PREFIX = "baseline-icount*" |
| |
| # Don't run exhaustive tests if these files change, even if they contaiin a function |
| # definition. |
| IGNORE_FILES = [ |
| "libm/src/math/support/", |
| "libm/src/libm_helper.rs", |
| "libm/src/math/arch/intrinsics.rs", |
| ] |
| |
| # libm PR CI takes a long time and doesn't need to run unless relevant files have been |
| # changed. Anything matching this regex pattern will trigger a run. |
| TRIGGER_LIBM_CI_FILE_PAT = ".*(libm|musl).*" |
| |
| TYPES = ["f16", "f32", "f64", "f128"] |
| |
| |
| def eprint(*args, **kwargs): |
| """Print to stderr.""" |
| print(*args, file=sys.stderr, **kwargs) |
| |
| |
| @dataclass(init=False) |
| class PrCfg: |
| """Directives that we allow in the commit body to control test behavior. |
| |
| These are of the form `ci: foo`, at the start of a line. |
| """ |
| |
| # Skip regression checks (must be at the start of a line). |
| allow_regressions: bool = False |
| # Don't run extensive tests |
| skip_extensive: bool = False |
| |
| # Allow running a large number of extensive tests. If not set, this script |
| # will error out if a threshold is exceeded in order to avoid accidentally |
| # spending huge amounts of CI time. |
| allow_many_extensive: bool = False |
| |
| # Max number of extensive tests to run by default |
| MANY_EXTENSIVE_THRESHOLD: int = 20 |
| |
| # Run tests for `libm` that may otherwise be skipped due to no changed files. |
| always_test_libm: bool = False |
| |
| # String values of directive names |
| DIR_ALLOW_REGRESSIONS: str = "allow-regressions" |
| DIR_SKIP_EXTENSIVE: str = "skip-extensive" |
| DIR_ALLOW_MANY_EXTENSIVE: str = "allow-many-extensive" |
| DIR_TEST_LIBM: str = "test-libm" |
| |
| def __init__(self, body: str): |
| directives = re.finditer(r"^\s*ci:\s*(?P<dir_name>\S*)", body, re.MULTILINE) |
| for dir in directives: |
| name = dir.group("dir_name") |
| if name == self.DIR_ALLOW_REGRESSIONS: |
| self.allow_regressions = True |
| elif name == self.DIR_SKIP_EXTENSIVE: |
| self.skip_extensive = True |
| elif name == self.DIR_ALLOW_MANY_EXTENSIVE: |
| self.allow_many_extensive = True |
| elif name == self.DIR_TEST_LIBM: |
| self.always_test_libm = True |
| else: |
| eprint(f"Found unexpected directive `{name}`") |
| exit(1) |
| |
| pprint.pp(self) |
| |
| |
| @dataclass |
| class PrInfo: |
| """GitHub response for PR query""" |
| |
| body: str |
| commits: list[str] |
| created_at: str |
| number: int |
| cfg: PrCfg |
| |
| @classmethod |
| def from_env(cls) -> Self | None: |
| """Create a PR object from the PR_NUMBER environment if set, `None` otherwise.""" |
| pr_env = os.environ.get("PR_NUMBER") |
| if pr_env is not None and len(pr_env) > 0: |
| return cls.from_pr(pr_env) |
| |
| return None |
| |
| @classmethod |
| @cache # Cache so we don't print info messages multiple times |
| def from_pr(cls, pr_number: int | str) -> Self: |
| """For a given PR number, query the body and commit list.""" |
| pr_info = sp.check_output( |
| [ |
| "gh", |
| "pr", |
| "view", |
| str(pr_number), |
| "--json=number,commits,body,createdAt", |
| # Flatten the commit list to only hashes, change a key to snake naming |
| "--jq=.commits |= map(.oid) | .created_at = .createdAt | del(.createdAt)", |
| ], |
| text=True, |
| ) |
| pr_json = json.loads(pr_info) |
| eprint("PR info:", json.dumps(pr_json, indent=4)) |
| return cls(**json.loads(pr_info), cfg=PrCfg(pr_json["body"])) |
| |
| |
| class FunctionDef(TypedDict): |
| """Type for an entry in `function-definitions.json`""" |
| |
| sources: list[str] |
| type: str |
| |
| |
| class Context: |
| gh_ref: str | None |
| changed: list[Path] |
| defs: dict[str, FunctionDef] |
| |
| def __init__(self) -> None: |
| self.gh_ref = getenv("GITHUB_REF") |
| self.changed = [] |
| self._init_change_list() |
| |
| with open(REPO_ROOT.joinpath("etc/function-definitions.json")) as f: |
| defs = json.load(f) |
| |
| defs.pop("__comment", None) |
| self.defs = defs |
| |
| def _init_change_list(self): |
| """Create a list of files that have been changed. This uses GITHUB_REF if |
| available, otherwise a diff between `HEAD` and `master`. |
| """ |
| |
| # For pull requests, GitHub creates a ref `refs/pull/1234/merge` (1234 being |
| # the PR number), and sets this as `GITHUB_REF`. |
| ref = self.gh_ref |
| eprint(f"using ref `{ref}`") |
| if not self.is_pr(): |
| # If the ref is not for `merge` then we are not in PR CI |
| eprint("No diff available for ref") |
| return |
| |
| # The ref is for a dummy merge commit. We can extract the merge base by |
| # inspecting all parents (`^@`). |
| merge_sha = sp.check_output( |
| GIT + ["show-ref", "--hash", ref], text=True |
| ).strip() |
| merge_log = sp.check_output(GIT + ["log", "-1", merge_sha], text=True) |
| eprint(f"Merge:\n{merge_log}\n") |
| |
| parents = ( |
| sp.check_output(GIT + ["rev-parse", f"{merge_sha}^@"], text=True) |
| .strip() |
| .splitlines() |
| ) |
| assert len(parents) == 2, f"expected two-parent merge but got:\n{parents}" |
| base = parents[0].strip() |
| incoming = parents[1].strip() |
| |
| eprint(f"base: {base}, incoming: {incoming}") |
| textlist = sp.check_output( |
| GIT + ["diff", base, incoming, "--name-only"], text=True |
| ) |
| self.changed = [Path(p) for p in textlist.splitlines()] |
| |
| def is_pr(self) -> bool: |
| """Check if we are looking at a PR rather than a push.""" |
| return self.gh_ref is not None and "merge" in self.gh_ref |
| |
| @staticmethod |
| def _ignore_file(fname: str) -> bool: |
| return any(fname.startswith(pfx) for pfx in IGNORE_FILES) |
| |
| def changed_routines(self) -> dict[str, list[str]]: |
| """Create a list of routines for which one or more files have been updated, |
| separated by type. |
| """ |
| routines = set() |
| for name, meta in self.defs.items(): |
| # Don't update if changes to the file should be ignored |
| sources = (f for f in meta["sources"] if not self._ignore_file(f)) |
| |
| # Select changed files |
| changed = [f for f in sources if Path(f) in self.changed] |
| |
| if len(changed) > 0: |
| eprint(f"changed files for {name}: {changed}") |
| routines.add(name) |
| |
| ret: dict[str, list[str]] = {} |
| for r in sorted(routines): |
| ret.setdefault(self.defs[r]["type"], []).append(r) |
| |
| return ret |
| |
| def may_skip_libm_ci(self) -> bool: |
| """If this is a PR and no libm files were changed, allow skipping libm |
| jobs.""" |
| |
| # Always run on merge CI |
| if not self.is_pr(): |
| return False |
| |
| pr = PrInfo.from_env() |
| assert pr is not None, "Is a PR but couldn't load PrInfo" |
| |
| # Allow opting in to libm tests |
| if pr.cfg.always_test_libm: |
| return False |
| |
| # By default, run if there are any changed files matching the pattern |
| return all(not re.match(TRIGGER_LIBM_CI_FILE_PAT, str(f)) for f in self.changed) |
| |
| def emit_workflow_output(self): |
| """Create a JSON object a list items for each type's changed files, if any |
| did change, and the routines that were affected by the change. |
| """ |
| |
| skip_tests = False |
| error_on_many_tests = False |
| |
| pr = PrInfo.from_env() |
| if pr is not None: |
| skip_tests = pr.cfg.skip_extensive |
| error_on_many_tests = not pr.cfg.allow_many_extensive |
| |
| if skip_tests: |
| eprint("Skipping all extensive tests") |
| |
| changed = self.changed_routines() |
| matrix = [] |
| total_to_test = 0 |
| |
| # Figure out which extensive tests need to run |
| for ty in TYPES: |
| ty_changed = changed.get(ty, []) |
| ty_to_test = [] if skip_tests else ty_changed |
| total_to_test += len(ty_to_test) |
| |
| item = { |
| "ty": ty, |
| "changed": ",".join(ty_changed), |
| "to_test": ",".join(ty_to_test), |
| } |
| |
| matrix.append(item) |
| |
| ext_matrix = json.dumps({"extensive_matrix": matrix}, separators=(",", ":")) |
| may_skip = str(self.may_skip_libm_ci()).lower() |
| print(f"extensive_matrix={ext_matrix}") |
| print(f"may_skip_libm_ci={may_skip}") |
| eprint(f"total extensive tests: {total_to_test}") |
| |
| if error_on_many_tests and total_to_test > PrCfg.MANY_EXTENSIVE_THRESHOLD: |
| eprint( |
| f"More than {PrCfg.MANY_EXTENSIVE_THRESHOLD} tests would be run; add" |
| f" `{PrCfg.DIR_ALLOW_MANY_EXTENSIVE}` to the PR body if this is" |
| " intentional. If this is refactoring that happens to touch a lot of" |
| f" files, `{PrCfg.DIR_SKIP_EXTENSIVE}` can be used instead." |
| ) |
| exit(1) |
| |
| |
| def locate_baseline(flags: list[str]) -> None: |
| """Find the most recent baseline from CI, download it if specified. |
| |
| This returns rather than erroring, even if the `gh` commands fail. This is to avoid |
| erroring in CI if the baseline is unavailable (artifact time limit exceeded, first |
| run on the branch, etc). |
| """ |
| |
| download = False |
| extract = False |
| tag = "" |
| |
| while len(flags) > 0: |
| match flags[0]: |
| case "--download": |
| download = True |
| case "--extract": |
| extract = True |
| case "--tag": |
| tag = flags[1] |
| flags = flags[1:] |
| case _: |
| eprint(USAGE) |
| exit(1) |
| flags = flags[1:] |
| |
| if extract and not download: |
| eprint("cannot extract without downloading") |
| exit(1) |
| |
| try: |
| # Locate the most recent job to complete with success on our branch |
| latest_job = sp.check_output( |
| [ |
| "gh", |
| "run", |
| "list", |
| "--status=success", |
| f"--branch={DEFAULT_BRANCH}", |
| "--json=databaseId,url,headSha,conclusion,createdAt," |
| "status,workflowDatabaseId,workflowName", |
| # Return the first array element matching our workflow name. NB: cannot |
| # just use `--limit=1`, jq filtering happens after limiting. We also |
| # cannot just use `--workflow` because GH gets confused from |
| # different file names in history. |
| f'--jq=[.[] | select(.workflowName == "{WORKFLOW_NAME}")][0]', |
| ], |
| text=True, |
| ) |
| except sp.CalledProcessError as e: |
| eprint(f"failed to run github command: {e}") |
| return |
| |
| try: |
| latest = json.loads(latest_job) |
| eprint("latest job: ", json.dumps(latest, indent=4)) |
| except json.JSONDecodeError as e: |
| eprint(f"failed to decode json '{latest_job}', {e}") |
| return |
| |
| if not download: |
| eprint("--download not specified, returning") |
| return |
| |
| job_id = latest.get("databaseId") |
| if job_id is None: |
| eprint("skipping download step") |
| return |
| |
| artifact_glob = f"{ARTIFACT_PREFIX}{f"-{tag}" if tag else ""}*" |
| |
| sp.run( |
| ["gh", "run", "download", str(job_id), f"--pattern={artifact_glob}"], |
| check=False, |
| ) |
| |
| if not extract: |
| eprint("skipping extraction step") |
| return |
| |
| # Find the baseline with the most recent timestamp. GH downloads the files to e.g. |
| # `some-dirname/some-dirname.tar.xz`, so just glob the whole thing together. |
| candidate_baselines = glob(f"{artifact_glob}/{artifact_glob}") |
| if len(candidate_baselines) == 0: |
| eprint("no possible baseline directories found") |
| return |
| |
| candidate_baselines.sort(reverse=True) |
| baseline_archive = candidate_baselines[0] |
| eprint(f"extracting {baseline_archive}") |
| sp.run(["tar", "xJvf", baseline_archive], check=True) |
| eprint("baseline extracted successfully") |
| |
| |
| def handle_bench_regressions(args: list[str]): |
| """Exit with error unless the PR message contains an ignore directive.""" |
| |
| match args: |
| case [pr_number]: |
| pr_number = pr_number |
| case _: |
| eprint(USAGE) |
| exit(1) |
| |
| pr = PrInfo.from_pr(pr_number) |
| if pr.cfg.allow_regressions: |
| eprint("PR allows regressions") |
| return |
| |
| eprint("Regressions were found; benchmark failed") |
| exit(1) |
| |
| |
| def main(): |
| match sys.argv[1:]: |
| case ["generate-matrix"]: |
| ctx = Context() |
| ctx.emit_workflow_output() |
| case ["locate-baseline", *flags]: |
| locate_baseline(flags) |
| case ["handle-bench-regressions", *args]: |
| handle_bench_regressions(args) |
| case ["--help" | "-h"]: |
| print(USAGE) |
| exit() |
| case _: |
| eprint(USAGE) |
| exit(1) |
| |
| |
| if __name__ == "__main__": |
| main() |