#!/usr/bin/env python3

import os
from list_crates import list_crates
from collections import Counter

# This contains annotations the script think are missing,
# but actually they don't need to be there
additional_provided = {}

# This contains annotations the script detected and think shouldn't be there,
# but actually they should
additional_required = {}


# Not interested in the low-level interfaces we provide only for fuzzing
additional_provided["equix"] = [
    (
        "{BucketArray, BucketArrayMemory, BucketArrayPair, Count, Uninit}",
        'feature = "bucket-array"',
    ),
]

# PreferredRuntime has a somewhat more complex rule for existing
additional_provided["tor-rtcompat"] = [
    ("PreferredRuntime", 'feature = "native-tls"'),
    ("PreferredRuntime", 'feature = "native-tls"'),
    ("PreferredRuntime", 'all(feature = "rustls", not(feature = "native-tls"))'),
    ("PreferredRuntime", 'all(feature = "rustls", not(feature = "native-tls"))'),
    (
        "NativeTlsProvider",
        'all(feature = "native-tls", any(feature = "tokio", feature = "async-std"))',
    ),
    (
        "RustlsProvider",
        'all(feature = "rustls", any(feature = "tokio", feature = "async-std"))',
    ),
]
# "unix::SocketAddr" is present unconditionally,
# though it has different definitions.
additional_provided["tor-general-addr"] = [
    ("SocketAddr", "unix"),
]


# We're not very interested in the testing feature
additional_provided["tor-guardmgr"] = [
    ("TestConfig", 'any(test, feature = "testing")'),
]

# Sha1 is present both ways
additional_provided["tor-llcrypto"] = [
    ("Sha1", 'feature = "with-openssl"'),
    ("Sha1", 'not(feature = "with-openssl")'),
]

additional_required["tor-llcrypto"] = [
    ("aes", "all()"),
    ("aes", "all()"),
]

additional_required["tor-hsservice"] = [
    ("restricted_discovery", "all()"),
]

# This is an * include; expended wildcard must be in additional_required
additional_provided["tor-proto"] = [
    ("*", 'feature = "testing"'),
]

additional_required["tor-proto"] = [
    ("CtrlMsg", 'feature = "testing"'),
    ("CreateResponse", 'feature = "testing"'),
    # I have no idea, but empirically this stops the CI complaining -Diziet
    ("Conversation", 'feature = "send-control-msg"'),
]

# This is detected two times, but only on cfg_attr(docsrs) is enough
additional_provided["tor-netdoc"] = [
    ("Nickname", 'feature = "dangerous-expose-struct-fields"'),
    ("NsConsensusRouterStatus", 'feature = "ns_consensus"'),
]


def extract_feature_pub_use(path):
    START_CFG = "#[cfg("
    END_CFG = ")]"
    PUB_USE = "pub use "
    res = []
    cfg = None
    with open(path, "r") as file:
        for line in file.readlines():
            if line.find(PUB_USE) != -1 and cfg:
                # last line was a #[cfg(..)] line and this is a pub use
                pubuse_pos = line.find(PUB_USE)
                # ignore comments
                if "//" in line[:pubuse_pos]:
                    continue
                # extract ident
                #
                # (BUG: this still doesn't handle `pub use {A,B,C} very well.)
                start = line.rfind(":")
                if start == -1:
                    start = line.rfind(" ")
                ident = line[start + 1 :]
                if (pos := ident.find(";")) != -1:
                    ident = ident[:pos]
                res.append((ident, cfg))
                cfg = None
                continue

            # check if we are on a #[cfg(..)] line, if so, remember it
            start_cfg = line.find(START_CFG)
            end_cfg = line.find(END_CFG)
            if start_cfg == -1 or end_cfg == -1:
                cfg = None
            else:
                start_cfg += len(START_CFG)
                cfg = line[start_cfg:end_cfg]
    return res


def extract_cfg_attr(path):
    START_CFG = "#[cfg_attr(docsrs, doc(cfg("
    END_CFG = ")))]"
    res = []
    cfg = None
    with open(path, "r") as file:
        for line in file.readlines():
            pos = max(
                [line.find(kw + " ") for kw in ["struct", "enum", "mod", "trait"]]
            )
            if pos != -1 and cfg:
                # last line was a cfg and this is a declaration
                subline = line[pos:]
                subline = subline[subline.find(" ") + 1 :]
                end = min(
                    subline.find(pat) for pat in " (<;" if subline.find(pat) != -1
                )
                ident = subline[:end]
                res.append((ident, cfg))
                cfg = None
                continue

            # check if we are on a #[cfg_attr(docsrs, doc(cfg(..)))] line, if so, remember it
            start_cfg = line.find(START_CFG)
            end_cfg = line.find(END_CFG)
            if start_cfg != -1 and end_cfg != -1:
                start_cfg += len(START_CFG)
                cfg = line[start_cfg:end_cfg]
                # don't reset when it's a cfg_attr followed by some other #[something]
            elif "#[" not in line:
                cfg = None
    return res


def for_each_rs(path, fn):
    res = []
    for dir_, _, files in os.walk(path):
        for file in files:
            if not file.endswith(".rs"):
                continue
            res += fn(os.path.join(dir_, file))
    return res


def main():
    for crate_info in list_crates():
        crate = crate_info.name
        print(f"processing {crate}")
        crate_path = f"crates/{crate}/src"
        required = for_each_rs(
            crate_path, extract_feature_pub_use
        ) + additional_required.get(crate, [])
        provided = for_each_rs(crate_path, extract_cfg_attr) + additional_provided.get(
            crate, []
        )
        req = Counter(required)
        prov = Counter(provided)
        ok = True
        for elem in (req - prov).elements():
            ok = False
            print(f"feature but no cfg_attr(docsrs): {elem}")
        for elem in (prov - req).elements():
            ok = False
            print(f"cfg_attr(docsrs) but no feature: {elem}")
        if not ok:
            print("Found error, exiting")
            exit(1)


if __name__ == "__main__":
    main()
