#!/usr/bin/env python3
# Copyright (c) 2017, VMware, Inc. or its affiliates.

import os
import sys
import argparse
from functools import reduce


class ValidationException(Exception):
    def __init__(self, message):
        super().__init__(message)
        self.message = message


class CgroupValidation(object):
    @staticmethod
    def detect_cgroup_mount_point():
        # Get the cgroup mount place
        proc_mounts_path = "/proc/self/mounts"
        if os.path.exists(proc_mounts_path):
            with open(proc_mounts_path) as f:
                for line in f:
                    mount_specs = line.split()
                    if mount_specs[2] == "cgroup2":
                        return mount_specs[1]
        return ""


class CgroupValidationVersionTwo(CgroupValidation):
    def __init__(self, cgroup_parent=None):
        self.mount_point = self.detect_cgroup_mount_point()
        self.tab = {"r": os.R_OK, "w": os.W_OK, "x": os.X_OK, "f": os.F_OK}
        self.cgroup_parent = cgroup_parent if cgroup_parent else "gpdb.service"

    def validate_all(self):
        """
        Check the permissions of the toplevel gpdb cgroup dirs.

        The checks should keep in sync with
        src/backend/utils/resgroup/cgroup-ops-v2.c
        """

        if not self.mount_point:
            self.die("failed to detect cgroup v2 mount point.")

        self.validate_permission("cgroup.procs", "rw")

        self.validate_permission(self.cgroup_parent + "/", "rwx")
        self.validate_permission(self.cgroup_parent + "/cgroup.procs", "rw")

        self.validate_permission(self.cgroup_parent + "/cpu.max", "rw")
        self.validate_permission(self.cgroup_parent + "/cpu.weight", "rw")
        self.validate_permission(self.cgroup_parent + "/cpu.weight.nice", "rw")
        self.validate_permission(self.cgroup_parent + "/cpu.stat", "r")

        self.validate_permission(self.cgroup_parent + "/cpuset.cpus", "rw")
        self.validate_permission(self.cgroup_parent + "/cpuset.cpus.partition", "rw")
        self.validate_permission(self.cgroup_parent + "/cpuset.mems", "rw")
        self.validate_permission(self.cgroup_parent + "/cpuset.cpus.effective", "r")
        self.validate_permission(self.cgroup_parent + "/cpuset.mems.effective", "r")

        self.validate_permission(self.cgroup_parent + "/memory.current", "r")

        self.validate_permission(self.cgroup_parent + "/io.max", "rw")

    def die(self, msg):
        raise ValidationException("cgroup is not properly configured: {}".format(msg))

    def validate_permission(self, path, mode):
        """
        Validate permission on path.
        If path is a dir it must end with '/'.
        """
        fullpath = os.path.join(self.mount_point, path)
        pathtype = path[-1] == "/" and "directory" or "file"
        mode_bits = reduce(lambda x, y: x | y, [self.tab[x] for x in mode], 0)

        try:
            if not os.path.exists(fullpath):
                self.die("{} '{}' does not exist".
                         format(pathtype, fullpath))
            if not os.access(fullpath, mode_bits):
                self.die("{} '{}' permission denied: require permission '{}'".
                         format(pathtype, fullpath, mode))
        except IOError as e:
            self.die("can't check permission on {} '{}': {}".format(pathtype, fullpath, str(e)))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Validate cgroup v2 configuration for resource groups')
    parser.add_argument('--cgroup-parent',
                        dest='cgroup_parent',
                        default=None,
                        help='The cgroup parent directory name (gp_resource_group_cgroup_parent value)')

    args = parser.parse_args()

    try:
        CgroupValidationVersionTwo(cgroup_parent=args.cgroup_parent).validate_all()
    except ValidationException as e:
        exit(e.message)
