#!/usr/bin/python3
# Copyright 2024 Helmut Grohne <helmut@subdivi.de>
# SPDX-License-Identifier: GPL-3

import argparse
import itertools
import os
import pathlib
import re
import subprocess
import sys
import typing


def positive_integer(decstr: str) -> int:
    """Parse a positive, integral number from a string and raise a ValueError
    otherwise.

    >>> positive_integer(-1)
    Traceback (most recent call last):
        ...
    ValueError: integer must be positive
    >>> positive_integer(0)
    Traceback (most recent call last):
        ...
    ValueError: integer must be positive
    >>> positive_integer(1)
    1
    """
    value = int(decstr)  # raises ValueError
    if value < 1:
        raise ValueError("integer must be positive")
    return value


def parse_size(expression: str) -> int:
    """Parse an expression representing a data size with an optional unit
    suffix into an integer. Raises a ValueError on failure.

    >>> parse_size("5")
    5
    >>> parse_size("4KB")
    4096
    >>> parse_size("0.9g")
    966367641
    >>> parse_size("-1")
    Traceback (most recent call last):
        ...
    ValueError: number must be positive
    """
    expression = expression.lower()
    if expression.endswith("b"):
        expression = expression[:-1]
    suffixes = {
        "k": 2**10,
        "m": 2**20,
        "g": 2**30,
        "t": 2**40,
        "e": 2**50,
    }
    factor = 1
    if expression[-1:] in suffixes:
        factor = suffixes[expression[-1]]
        expression = expression[:-1]
    fval = float(expression)  # propagate ValueError
    value = int(fval * factor)  # propagate ValueError
    if value < 1:
        raise ValueError("number must be positive")
    return value


def guess_python() -> int | None:
    """Estimate the number of processors using Python's standard library
    functions.

    >>> guess_python() > 0
    True
    """
    process_cpu_count = getattr(os, "process_cpu_count", None)
    if process_cpu_count is not None:
        count = process_cpu_count()
        assert isinstance(count, int)  # make mypy happy
        return count
    # Fallback for Python < 3.13.
    return len(os.sched_getaffinity(0))


def guess_nproc() -> int:
    """Estimate number of processors using coreutils' nproc.

    >>> guess_nproc() > 0
    True
    """
    return positive_integer(
        subprocess.check_output(["nproc"], encoding="ascii")
    )


def guess_cores() -> int | None:
    """Estimate the number of cores (not SMT threads).  This is done by
    counting the number of distinct "core id" values found in sysfs.
    """
    cpus = set(
        int(fn.read_text())
        for fn in pathlib.Path("/sys/devices/system/cpu").glob(
            "cpu*/topology/core_id"
        )
    )
    if not cpus:
        return None
    return len(cpus)


def guess_deb_build_parallel(
    environ: typing.Mapping[str, str] = os.environ
) -> int | None:
    """Parse a possible parallel= assignment in a DEB_BUILD_OPTIONS environment
    variable.

    >>> guess_deb_build_parallel({})
    >>> guess_deb_build_parallel({"DEB_BUILD_OPTIONS": "nocheck parallel=3"})
    3
    """
    try:
        options = environ["DEB_BUILD_OPTIONS"]
    except KeyError:
        return None
    for option in options.split():
        if option.startswith("parallel="):
            option = option.removeprefix("parallel=")
            try:
                return positive_integer(option)
            except ValueError:
                pass
    return None


def guess_from_environment(
    variable: str, environ: typing.Mapping[str, str] = os.environ
) -> int | None:
    """Read a number from an environment variable.

    >>> guess_from_environment("CPUS", {"CPUS": 4})
    4
    >>> guess_from_environment("CPUS", {"other": 3})
    """
    try:
        return positive_integer(environ[variable])
    except (KeyError, ValueError):
        return None


def guess_memavailable() -> int:
    """Estimate the available memory from /proc/meminfo in bytes."""
    with open("/proc/meminfo", encoding="ascii") as fh:
        for line in fh:
            if line.startswith("MemAvailable:"):
                line = line.removeprefix("MemAvailable:").strip()
                return 1024 * positive_integer(line.removesuffix("kB"))
    raise RuntimeError("no MemAvailable line found in /proc/meminfo")


def guess_cgroup_memory() -> int | None:
    """Return the smallest "memory.high" or "memory.max" limit of the current
    cgroup or any parent of it if any.
    """
    guess: int | None = None
    mygroup = pathlib.PurePath(
        pathlib.Path("/proc/self/cgroup")
        .read_text(encoding=sys.getfilesystemencoding())
        .strip()
        .split(":", 2)[2]
    ).relative_to("/")
    sfc = pathlib.Path("/sys/fs/cgroup")
    for group in itertools.chain((mygroup,), mygroup.parents):
        for entry in ("memory.max", "memory.high"):
            try:
                value = positive_integer(
                    (sfc / group / entry).read_text(encoding="ascii")
                )
            except (FileNotFoundError, ValueError):
                pass
            else:
                if guess is None:
                    guess = value
                else:
                    guess = min(guess, value)
    return guess


def parse_required_memory(expression: str) -> list[int]:
    """Parse comma-separated list of memory expressions. Empty expressions copy
    the previous entry.

    >>> parse_required_memory("1k,9,,1")
    [1024, 9, 9, 1]
    """
    values: list[int] = []
    for memexpr in expression.split(","):
        if not memexpr:
            if values:
                values.append(values[-1])
            else:
                raise ValueError("initial memory expression cannot be empty")
        else:
            values.append(parse_size(memexpr))
    return values


def guess_memory_concurrency(memory: int, usage: list[int]) -> int:
    """Estimate the maximum number of cores that can be used given the
    available memory and a sequence of per-core memory consumption.

    >>> guess_memory_concurrency(4, [1])
    4
    >>> guess_memory_concurrency(10, [5, 4, 3])
    2
    >>> guess_memory_concurrency(2, [3])
    1
    """
    concurrency = 0
    for use in usage[:-1]:
        if use > memory:
            break
        memory -= use
        concurrency += 1
    else:
        concurrency += memory // usage[-1]
    return max(1, concurrency)


def clamp(
    value: int, lower: int | None = None, upper: int | None = None
) -> int:
    """Return an adjusted value that does not exceed the lower or upper limits
    if any.

    >>> clamp(5, upper=4)
    4
    >>> clamp(9, 2)
    9
    """
    if upper is not None and upper < value:
        value = upper
    if lower is not None and lower > value:
        value = lower
    return value


def get_argparser() -> argparse.ArgumentParser:
    """Construct the ArgumentParser for the CLI."""
    parser = argparse.ArgumentParser(
        description="""

Guess a suitable concurrency level given constraints from the runtime
environment. Environment variables such as DEB_BUILD_OPTIONS, RPM_BUILD_NCPUS
and CMAKE_BUILD_PARALLEL_LEVEL are consulted for initial guesses falling back
to querying the operating system in order to guess a suitable concurrency
level. A user may further restrict the emitted concurrency by specifying a
minimum or maximum and by requiring sufficient memory to support a level of
concurreny to be available.

""".strip(),
    )
    parser.add_argument(
        "--detect",
        action="store",
        default="nproc",
        metavar="METHOD",
        help="supply a processor count or select a detection method "
        "(nproc, python or cores)",
    )
    parser.add_argument(
        "--max",
        action="store",
        type=positive_integer,
        default=None,
        metavar="N",
        help="limit the number of detected cores to a given maximum",
    )
    parser.add_argument(
        "--min",
        action="store",
        type=positive_integer,
        default=None,
        metavar="N",
        help="limit the number of detected cores to a given minimum",
    )
    parser.add_argument(
        "--require-mem",
        action="store",
        type=parse_required_memory,
        default=[],
        metavar="MEMLIST",
        help="specify per-core required memory as a comma separated list",
    )
    return parser


def main() -> None:
    """Command line interface entry point."""
    parser = get_argparser()
    args = parser.parse_args()
    if args.min is not None and args.max is not None and args.min > args.max:
        parser.error("--min value larger than --max value")

    guess = None
    detectfunc = None
    for detector in args.detect.split(","):
        try:
            guess = positive_integer(detector)
        except ValueError:
            try:
                detectfunc = {
                    "nproc": guess_nproc,
                    "python": guess_python,
                    "cores": guess_cores,
                }[detector]
            except KeyError:
                parser.error("invalid argument to --detect")
    if guess is None:
        assert detectfunc is not None
        guess = (
            guess_from_environment("CMAKE_BUILD_PARALLEL_LEVEL")
            or guess_deb_build_parallel()
            or guess_from_environment("RPM_BUILD_NCPUS")
            or detectfunc()
        )
    if guess is None:
        print("failed to guess processor count", file=sys.stderr)
        sys.exit(1)
    if args.require_mem and guess > 1:
        memory = clamp(guess_memavailable(), upper=guess_cgroup_memory())
        guess = clamp(
            guess, upper=guess_memory_concurrency(memory, args.require_mem)
        )
    guess = clamp(guess, args.min, args.max)
    print(guess)


if __name__ == "__main__":
    main()
