#!/usr/bin/python3.12
#
# Copyright(c) 2006, Chris PeBenito <pebenito@gentoo.org>
# Copyright(c) 2006, Gentoo Foundation
#
# Portage query code derived from gentoolkit:
# Copyright(c) 2004, Gentoo Foundation
# Copyright(c) 2004, Karl Trygve Kalleberg <karltk@gentoo.org>
#
# Licensed under the GNU General Public License, v2
#
# $Header: /var/cvsroot/gentoo-projects/hardened/policycoreutils-extra/scripts/rlpkg,v 1.11 2009/08/02 01:20:32 pebenito Exp $

import os
import sys
import string
import getopt
import selinux
import portage
import subprocess

SCANELF = ["/usr/bin/scanelf", "-tqR", "-E ET_DYN"]
SETFILES = ["/sbin/setfiles"]
RESTORECON = ["/sbin/restorecon -f -"]

xattrfs = ["btrfs", "encfs", "ext2", "ext3", "ext4", "ext4dev", "f2fs", "gfs", "gfs2", "gpfs", "jffs2", "jfs", "lustre", "xfs", "zfs"]
settings = portage.config(clone=portage.settings)

textrel_shlib_paths = ["/lib", "/usr/lib", "/opt"]
textrel_type = "textrel_shlib_t"
textrel_ok_relabelfrom = ["lib_t", "shlib_t"]

textrel_bin_paths = ["/bin", "/sbin", "/usr/bin", "/usr/sbin"]


class Package:
    """Package descriptor. Contains convenience functions for querying the
    state of a package, its contents, name manipulation, ebuild info and
    similar."""

    def __init__(self, cpv):
        self._cpv = cpv
        self._scpv = portage.catpkgsplit(self._cpv)

        if not self._scpv:
            raise RuntimeError("invalid cpv: %s" % cpv)
        self._db = None
        self._settings = settings

    def get_cpv(self):
        """Returns full Category/Package-Version string"""
        return self._cpv

    def get_name(self):
        """Returns base name of package, no category nor version"""
        return self._scpv[1]

    def get_version(self):
        """Returns version of package, with revision number"""
        v = self._scpv[2]
        if self._scpv[3] != "r0":
            v += "-" + self._scpv[3]
        return v

    def get_category(self):
        """Returns category of package"""
        return self._scpv[0]

    def is_installed(self):
        """Returns true if this package is installed (merged)"""
        self._initdb()
        return os.path.exists(self._db.getpath())

    def get_contents(self):
        """Returns the full contents, as a dictionary, on the form
        [ '/bin/foo' : [ 'obj', '1052505381', '45ca8b8975d5094cd75bdc61e9933691' ], ... ]"""
        self._initdb()
        if self.is_installed():
            return self._db.getcontents()
        return {}

    def _initdb(self):
        """Internal helper function; loads package information from disk,
        when necessary"""
        if not self._db:
            cat = self.get_category()
            pnv = self.get_name() + "-" + self.get_version()
            self._db = portage.dblink(cat, pnv, "/", settings)


def find_installed_packages(search_key):
    """Returns a list of Package objects that matched the search key."""
    try:
        t = portage.db["/"]["vartree"].dbapi.match(search_key)
    # catch the "amgigous package" Exception
    except ValueError as e:
        if type(e[0]) == list:
            t = []
            for cp in e[0]:
                t += portage.db["/"]["vartree"].dbapi.match(cp)
        else:
            raise ValueError(e)
    return [Package(x) for x in t]


def parse_mount_options(mount):
    """parse the mount options field"""
    options = {}
    for i in mount.split(","):
        if "=" in i:
            l, r = i.split("=", 1)
            options[l] = r
        elif i == "ro":
            options["rw"] = False
        else:
            if i.startswith("no"):
                options[i[2:]] = False
            else:
                options[i] = True
    return options


def find_xattr_mounts():
    """Find mounted xattr filesystems"""
    print("Relabeling filesystem types: %s" % " ".join(xattrfs))
    mounts = open("/etc/mtab", "r")

    fs_matches = []
    for line in mounts.readlines():
        fields = line.split()
        if fields[2] in xattrfs:
            options = parse_mount_options(fields[3])
            if not options["rw"] or "context" in options or "bind" in options:
                continue
            fs_matches.append(fields[1])

    mounts.close()
    return fs_matches


def full_relabel(reset, verbose):
    """Relabel all xattr filesystems"""
    mountpoints = find_xattr_mounts()
    if len(mountpoints) == 0:
        print("No filesystems to relabel!  Are the filesystems mounted read-write?")
        return 1

    cmdline = SETFILES
    if reset:
        cmdline.append("-F")
    if verbose:
        cmdline.append("-vv")
    cmdline.append(selinux.selinux_file_context_path())
    cmdline += mountpoints

    print("Running %s" % " ".join(cmdline))
    return subprocess.call(cmdline, close_fds=True, shell=False)


def relabel_textrel_shlib(verbose):
    print("Scanning for shared libraries with text relocations...")

    tl = subprocess.Popen(SCANELF + textrel_shlib_paths, stdout=subprocess.PIPE, close_fds=True, bufsize=1)
    tl.wait()

    notok = 0
    textrel_libs = 0
    for line in tl.stdout.readline():
        filename = line.split()[1]
        textrel_libs += 1

        (ret, context) = selinux.getfilecon(filename)
        if ret < 0:
            print("Error getting context of %s" % filename)
            continue

        ctx = string.split(context, ":")

        if len(ctx) < 3:
            print('Debug: getfilecon on "%s" returned a context of "%s" which split incorrectly (%s).' % (filename, context, ctx))
            continue

        if ctx[2] in textrel_ok_relabelfrom:
            if verbose:
                print("Relabeling %s to %s." % (filename, textrel_type))
            ctx[2] = textrel_type
            if selinux.setfilecon(filename, ":".join(ctx)) < 0:
                print("Failed to relabel %s." % filename)
        elif ctx[2] == textrel_type:
            if verbose:
                print('Skipping %s because it is already %s.' % (filename, textrel_type))
        else:
            print('Not relabeling %s because it is %s.' % (filename, ctx[2]))
            notok += 1

    print("%d libraries with text relocations, %d not relabeled." % (textrel_libs, notok))

    if notok > 0:
        print("\nSome files were not relabeled!  This is not necessarily bad,")
        print("but may indicate a labeling problem, since what is detected as")
        print("a library is not already labeled with a library type.")
        print("If you just relabeled the entire filesystem, please report")
        print("this in the #gentoo-hardened IRC channel, the")
        print("gentoo-hardened mail list, or Gentoo bugzilla.\n")

    print("Scanning for PIE binaries with text relocations...")

    tb = subprocess.Popen(SCANELF + textrel_bin_paths, stdout=subprocess.PIPE, close_fds=True, bufsize=1)
    tb.wait()

    textrel_bins = 0
    for line in tb.stdout.readline():
        print("PIE executable %s has text relocations!" % line.split()[1])
        textrel_bins += 1

    print("%d binaries with text relocations detected." % (textrel_bins))

    if textrel_bins > 0:
        print("\nPIE binaries with text relocations have been detected!")
        print("This is not supported by stock policy.  The best solution")
        print("is to fix the text relocations.  Please consult hardened")
        print("compiler developers in the #gentoo-hardened IRC channel,")
        print("the gentoo-hardened mail list, or Gentoo bugzilla for")
        print("more help.")

    return notok + textrel_bins


def relabel_packages(packages, reset, verbose):
    """Relabel specified packages"""
    # build package list
    pkglist = []
    for pkgname in packages:
        pkglist += find_installed_packages(pkgname)

    if len(pkglist) == 0:
        print("No packages found to relabel.")
        sys.exit(1)

    # set up the base command
    cmdline = RESTORECON
    if reset:
        cmdline.append("-F")
    if verbose:
        cmdline.append("-vv")

    # do the relabeling
    childin = os.popen(" ".join(cmdline), 'w')

    for i in pkglist:
        print("Relabeling: %s" % i.get_cpv())
        for j in i.get_contents().keys():
            childin.write(j + '\n')

    rc = childin.close()

    # for some reason, a successful completion has
    # a return code of None, otherwise its numeric.
    if rc is None:
        rc = 0
    else:
        print("Error relabeling: %d" % rc)

    return rc


def usage(message=""):
    pgmname = os.path.basename(sys.argv[0])

    print("Usage: %s [OPTIONS] {<pkg1> [<pkg2> ...]}" % pgmname)
    print("")
    print("  -a, --all      Relabel the entire filesystem instead of individual packages.")
    print("  -r, --reset    Force reset of context if the file's selinux identity is")
    print("                     different or the file's type is customizable.")
    print("  -t, --textrels Scan for libraries with text relocations and relabel them.")
    print("                     Implied by -a.")
    print("  -v, --verbose  Enable more verbose output.")
    print("  -h, --help     Display this help and exit")
    print("")
    print("Packages can be specified with a portage package specification, for example,")
    print('"policycoreutils" or ">=sys-apps/policycoreutils-1.30".')
    print("")

    if message != "":
        print('%s: %s' % (pgmname, message))
        sys.exit(1)
    else:
        sys.exit(0)


def main():
    reset = False
    relabel = False
    textrels = False
    verbose = False

    if len(sys.argv) < 2:
        usage("At least one argument required.")

    try:
        opts, packages = getopt.getopt(sys.argv[1:], "ahrtv", ["all", "help", "reset", "textrel", "verbose"])
    except getopt.GetoptError as error:
        usage(error.msg)

    for o, a in opts:
        if o in ("-a", "--all"):
            relabel = True
            textrels = True
        if o in ("-h", "--help"):
            usage()
        if o in ("-r", "--reset"):
            reset = True
        if o in ("-t", "--textrel"):
            textrels = True
        if o in ("-v", "--verbose"):
            verbose = True

    rc = 0
    if relabel:
        rc += full_relabel(reset, verbose)

    if textrels:
        rc += relabel_textrel_shlib(verbose)

    if not relabel and not textrels:
        if len(packages) == 0:
            usage("No packages specified.")

        rc += relabel_packages(packages, reset, verbose)

    sys.exit(rc)

if __name__ == "__main__":
    main()
