#!/usr/bin/python3
# ----------------------------------------------------------------------
# File: eos-iam-mapfile.py
# Author: Manuel Reis - CERN
# ----------------------------------------------------------------------

# ************************************************************************
# * EOS - the CERN Disk Storage System                                   *
# * Copyright (C) 2021 CERN/Switzerland                                  *
# *                                                                      *
# * This program is free software: you can redistribute it and/or modify *
# * it under the terms of the GNU General Public License as published by *
# * the Free Software Foundation, either version 3 of the License, or    *
# * (at your option) any later version.                                  *
# *                                                                      *
# * This program is distributed in the hope that it will be useful,      *
# * but WITHOUT ANY WARRANTY; without even the implied warranty of       *
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the        *
# * GNU General Public License for more details.                         *
# *                                                                      *
# * You should have received a copy of the GNU General Public License    *
# * along with this program.  If not, see <http://www.gnu.org/licenses/>.*
# ************************************************************************

import os
import re
import sys
import json
import pickle
import logging
import argparse
from sys import exit
from os import getenv
from urllib import request, parse
from configparser import ConfigParser, DEFAULTSECT
from datetime import datetime, timedelta
from concurrent.futures import ThreadPoolExecutor, as_completed

LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'

class IAM_Server:
    TOKEN_ENDPOINT = '/token'
    USER_ENDPOINT ='/scim/Users'

    def __init__(self, server, client_id, client_secret, token_server = None, account = None, group_account_map = dict()):
        self.server = server
        self.client_id = client_id
        self.client_secret = client_secret
        # Assuming token server is the same as IAM's
        self.token_server = token_server or server
        self._token = None
        self.account = account
        self.group_account_map = group_account_map

    def __hash__(self):
        return hash(self.server)

    def __eq__(self, other):
        return self.server == other.server

    def __get_token(self):
        """
        Authenticates with the iam server and returns the access token.
        """
        request_data = {
            "client_id": self.client_id,
            "client_secret": self.client_secret,
            "grant_type": "client_credentials",
            "scope": "scim:read"
        }
        now = datetime.now()

        response = request.urlopen(f'https://{self.token_server}{self.TOKEN_ENDPOINT}',
                                   data=parse.urlencode(request_data).encode('utf-8'))
        response = json.loads(response.read())

        if 'access_token' not in response:
            raise BaseException("Authentication Failed")
        response['request_time'] = now
        self._token = response

    @property
    def token(self):
        """
        Property that return and renews the bearer token if expired
        """
        if self._token is None or self._token['request_time'] + timedelta(seconds=self._token['expires_in']-10) < datetime.now():
            self.__get_token()
        return self._token['access_token']

    def get_users(self, start_index = 0,count = 1, filter_function=None, **kwargs):
        """
        Queries the server to get all users belonging to the VO.
        Each batch can be up to 100 records so the requests are parallelized
        """
        # Get's a new token if expired
        header = {"Authorization": f"Bearer {self.token}"}

        users_so_far = 0
        startIndex = 0
        params = {"startIndex": startIndex, "count": count}
        params["startIndex"] = startIndex
        # Get's a new token if expired
        header["Authorization"] = f"Bearer {self.token}"
        req = request.Request(f"https://{self.server}{self.USER_ENDPOINT}?{parse.urlencode(params)}", headers=header)
        response = request.urlopen(req)
        response = json.loads(response.read())

        users = set()
        user_lst = []
        # We can use a with statement to ensure threads are cleaned up promptly
        with ThreadPoolExecutor(max_workers=8) as executor:
            # Start the load operations and mark each future with its URL
            reqs = []

            for start_index in range(0,response['totalResults'],count):
                params["startIndex"] = start_index

                # Get's a new token if expired
                header["Authorization"] = f"Bearer {self.token}"
                req = request.Request(f"https://{self.server}{self.USER_ENDPOINT}?{parse.urlencode(params)}", headers=header)
                reqs.append(executor.submit(request.urlopen, req))
                logging.debug(f"https://{self.server}{self.USER_ENDPOINT}?{parse.urlencode(params)}  with headers: {header}")

            for req in as_completed(reqs):
                try:
                    response=req.result()
                    response = json.loads(response.read())
                    if filter_function is not None:
                        users.update(filter_function(*response['Resources'], **kwargs))
                    else:
                        user_lst.extend(response['Resources'])
                except Exception as e:
                    logging.error(f'{req} generated an exception: {e}')

        if not filter_function:
            return user_lst

        return users

def name_map_filter(*users, kwargs=None):
    """
    Collect user's id to build 'Mapfile format' rules:
    https://github.com/xrootd/xrootd/tree/master/src/XrdSciTokens
    """
    logging.debug(f"This request has {len(users)}")
    ids=set()
    return set((user.get('id') for user in users if user.get('id') is not None))


def extract_grid_dn(user, pattern=None, prefer_cern=False):
    grid_dns = []
    try:
        certs = user['urn:indigo-dc:scim:schemas:IndigoUser']['certificates']
        # Is there a CERN certificate if prefered?
        if prefer_cern:
            certs = [*filter(lambda x: x.get('subjectDn',x.get('issuerDn')).endswith('DC=cern,DC=ch'), certs)]

        for cert in certs:
            # Revert subjectDn and replace , with / (making sure commas on the values aren't replaced)
            grid_dn = '/'.join(re.split(r',(?=\w+=)', cert["subjectDn"])[::-1]) # re: courtesy of Maarten Litmaath
            if pattern is None or pattern.search(grid_dn):
                grid_dns.append(grid_dn)
    except KeyError:
        logging.warning(f"User {user['id']} doesn't have certificate to extract info (skipping it)")

    return grid_dns

def dn_filter(*users, pattern=None, prefer_cern=False, **kwargs):
    """
    Collect users with DN certificates matching regex
    """
    logging.debug(f"This request has {len(users)}")
    matching_dn = set()
    for user in users:
        if not user.get('active'):
            logging.info(f"User {user['userName']}:{user['id']} is not active, skipping")
            continue
        grid_dns = extract_grid_dn(user, pattern, prefer_cern)

        for grid_dn in grid_dns:
            matching_dn.add(f'/{grid_dn}')

    logging.info(f"{len(matching_dn)} matching certificates")
    return matching_dn

def role_group_map_filter(*users, pattern=None, prefer_cern=None, group_name=None, **kwargs):
    matching_dn = set()
    if not group_name:
        return matching_dn

    for user in users:
        if not user.get('active'):
            logging.info(f"User {user['userName']}:{user['id']} is not active, skipping")
            continue

        try:
            groups = user['groups']
            for group in groups:
                if group['display'] == group_name:
                    grid_dns = extract_grid_dn(user, pattern, prefer_cern)

                    for grid_dn in grid_dns:
                        matching_dn.add(f'/{grid_dn}')
        except KeyError:
            logging.warning(f"User {user['id']} doesn't have group mapping, skipping")

    return matching_dn

def build_namemap_file(users_id, account, ifile, ofile):
    name_map=set() # serialized dictionary!
    if ifile:
        try:
            with open(ifile) as f:
                for entry in json.load(f):
                    name_map.add(pickle.dumps(entry))
        except FileNotFoundError as e:
            logging.error(f"Unable to read {ifile}, ignoring it's content...")
            exit(4)

    for id in users_id:
        name_map.add(pickle.dumps({'sub':id,'result':account}))

    if ofile:
        try:
            with open(ofile,'w') as f:
                json.dump([pickle.loads(rule) for rule in name_map],f)
        except Exception as e:
            logging.error(f'Unable to write to {ofile},raised exception {e}')
            exit(4)
    else:
        print(json.dumps([pickle.loads(rule) for rule in name_map]))

def build_gridmap_file(users_dn, account, ifile, ofile, lgridmap, grid_map = {}):

    if ifile:
        try:
            # As some entries may be encoded in latin let's escape it as unicode
            with open(ifile, "r", encoding='unicode_escape') as igridmap_file:
                for dn,acc in (l.rsplit(' ',1) for l in igridmap_file.readlines()):
                    grid_map[dn] = acc.strip()
        except FileNotFoundError as e:
            logging.error(f"Unable to read {ifile}, ignoring it's content...")
            exit(4)

    # Overwrite / append results
    for dn in users_dn:
        if dn in grid_map:
            logging.debug(f'Overwritting {dn}')
        grid_map[f'"{dn}"'] = account

    # Override with local gridmap entries if any
    if lgridmap:
        for lfile in lgridmap.split(","):
            try:
                with open(lfile.strip(), "r", encoding='unicode_escape') as lgridmap_file:
                    logging.debug(f"Processing {lfile}")
                    for l in lgridmap_file.readlines():
                        if l.strip() and not l.startswith('#'):
                            dn, acc = l.rsplit(' ',1)
                            grid_map[dn] = acc.strip()
            except FileNotFoundError as e:
                logging.error(f"Unable to read {lfile}, ignoring it's content...")


    content = '\n'.join(f'{dn} {acc}' for dn, acc in grid_map.items())

    if ofile:
        try:
            with open(ofile, "w", encoding='utf-8') as ogridmap_file:
                ogridmap_file.write(content)
        except Exception as e:
            logging.error(f'Unable to write to {ofile}, raised exception {e}')
            exit(4)
    else:
        print(content)
    return grid_map

def parse_groupmap(group_map_str):
    group_map = dict()
    kvs = group_map_str.split(",")
    for kv in kvs:
        k, v = kv.split(":")
        group_map[k.strip()] = v.strip()

    return group_map

def configure_servers(credentials, servers, targets, account):

    # First level of configuration - file
    iam_servers = list()

    # First configuration stage is to use command args
    if servers:
        if not account:
            logging.error('servers configured via cli, but no account configured, exiting!')
            exit(3)
        for server, client_id, client_secret in servers:
            logging.debug(f'Adding iam_server {server} from cli')
            iam_servers.append(IAM_Server(server, client_id, client_secret, server, account))

    # Third option is to rely on configuration file
    #[<iam server hostname>]
    #client-id = <id>
    #client-secret = <key>
    if credentials and len(iam_servers) == 0 :
        config = ConfigParser()
        files_read = config.read(credentials)
        if len(files_read) > 0:
            # Credentials file should have IAM server on the section
            if targets is not None:
                it = filter(lambda x: True if targets in x else False, config.sections())
            else:
                it = config.sections()

            for section in it:
                if section == DEFAULTSECT:
                    continue

                server = section
                client_id = config.get(section,'client-id')
                client_secret = config.get(section,'client-secret')
                # Assuming IAM server is token server if not defined
                token_server = config.get(section,'token-server', fallback=server)
                account = config.get(section,'account',fallback=None)

                group_map = dict()
                group_map_str = config.get(section, 'group_account_map', fallback=None)
                if group_map_str:
                    group_map = parse_groupmap(group_map_str)

                logging.debug(f'Adding iam_server {server} mapping to {account} from config')
                iam_servers.append(IAM_Server(server, client_id, client_secret, token_server, account, group_map))
        else:
            logging.warning("Credentials couldn't be loaded from configuration file")

    if len(iam_servers):
        return iam_servers
    else:
        logging.error('Configuration problem! Configuration file not loaded (correctly?) or arguments not passed.')
        exit(3)


def cleanup_files(fname, count):
    for i in range(count):
        f = "{}.{}".format(fname,i)
        try:
            os.remove(f)
            logging.debug(f"Cleaned up {f}")
        except FileNotFoundError:
            logging.warn(f"Skipping non existent {f}")
        except Exception as e:
            logging.error(f"Failed deletion {f}: {e}")

def setup_filelogging(log_level, logfile):
    if not logfile:
        return False

    log_dir = os.path.dirname(logfile)

    if log_dir and not os.path.exists(log_dir):
        return False

    logging.basicConfig(
        filename = logfile,
        filemode = 'a',
        format = LOG_FORMAT,
        datefmt = DATE_FORMAT,
        level = log_level
    )

    return True


def setup_logging(log_level=logging.WARNING, logfile=None):
    if not setup_filelogging(log_level, logfile):
        logging.basicConfig(
            stream = sys.stdout,
            format = LOG_FORMAT,
            datefmt = DATE_FORMAT,
            level = log_level
        )

class CLIConfig:

    LOG_LEVELS = {
        'CRITICAL': logging.CRITICAL,
        'ERROR': logging.ERROR,
        'WARNING': logging.WARNING,
        'INFO': logging.INFO,
        'DEBUG': logging.DEBUG,
        'NOTSET': logging.NOTSET,
    }

    def __init__(self, conf_file):
        self.config = ConfigParser()
        self.config.read(conf_file)

        self.lgridmap = None
        self.inputfile = None
        self.outputfile = None
        self.type_of_format = "GRIDMAP"
        self.cleanup = False
        self.prefer_cern = False
        self.pattern = None

    def get_option(self, option, value, fallback = None):
        if value is not None:
            return value

        return self.config.get(DEFAULTSECT, option, fallback = fallback)

    # Underlying assumption that store-true is used, ie. args.<value> will evaluate to True if set
    def get_bool(self, option, value, fallback=False):
        if value:
            return True

        return self.config.getboolean(DEFAULTSECT, option, fallback=fallback)

    def setup(self, args):
        """
        Setup the CLI configuration, overriding any config option from CLI flags
        """
        log_file = self.get_option("log_file", args.log_file)
        log_level = self.get_option("log_level",args.debug, "WARNING")
        debug_level = self.LOG_LEVELS.get(log_level.upper(), logging.WARNING)
        setup_logging(debug_level, log_file)

        self.lgridmap = self.get_option("localgridmap", args.lgridmap)
        self.inputfile = self.get_option("inputfile", args.ifile)
        self.outputfile = self.get_option("outputfile", args.ofile)
        self.type_of_format = self.get_option("format", args.type_of_format, "GRIDMAP")
        self.cleanup = self.get_bool("cleanup", args.cleanup)
        self.prefer_cern = self.get_bool("prefer_cern", args.prefer_cern)

        pattern = self.get_option("pattern", args.pattern)
        if pattern:
            flags = 0
            if self.get_bool("case_sensitive", args.sensitive):
                flags = re.IGNORECASE

                try:
                    self.pattern = re.compile(pattern, flags)
                except Exception as e:
                    logging.critical(f'Pattern provided cannot be compiled: {pattern}')
                    exit(1)


def gen_temp_file(fname, count):
    if not fname:
        return None
    return "{}.{}".format(fname, count)

def main(conf_file = None, args = None):
    """
    Configure IAM servers to be queried, update/write gridmap file format
    """

    conf = CLIConfig(conf_file)
    conf.setup(args)

    iam_servers = configure_servers(conf_file, args.server, args.targets, args.account)

    # Query IAM server
    count = 0
    temp_outfile = None
    grid_map = dict()
    ifile = conf.inputfile
    ofile = conf.outputfile
    lgridmap = conf.lgridmap
    type_of_format = conf.type_of_format
    for iam in iam_servers:
        # start with an empty user set for each IAM endpoint, we retain the
        # grid map which will maintain the correct account mapping
        if not iam.account:
            logging.warn(f"Skipping {iam.server} as no explicit mapping via args or config")
            continue

        users = set()

        if type_of_format == "GRIDMAP":
            temp_outfile = gen_temp_file(ofile, count)
            users.update(iam.get_users(count=100, filter_function=dn_filter,
                                           pattern=conf.pattern,
                                           prefer_cern=conf.prefer_cern))
            grid_map = build_gridmap_file(users, iam.account, ifile, temp_outfile, lgridmap, grid_map)
            count += 1
            for group_name, account in iam.group_account_map.items():
                temp_outfile = gen_temp_file(ofile, count)
                iam_users = iam.get_users(count=100, filter_function=role_group_map_filter,
                                              pattern=conf.pattern,
                                              prefer_cern=conf.prefer_cern,
                                              group_name = group_name)
                grid_map = build_gridmap_file(iam_users, account, ifile, temp_outfile, lgridmap, grid_map)
                count += 1

        elif type_of_format == "MAPFILE":
            users.update(iam.get_users(count=100, filter_function=name_map_filter))
            build_namemap_file(users, iam.account, ifile, outfile)





    if temp_outfile and ofile:
        logging.debug(f"Renaming {temp_outfile} to {ofile}")
        os.rename(temp_outfile, ofile)

        if conf.cleanup:
            cleanup_files(ofile, count-1)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GRID Map file generation from IAM Server', epilog='''
    Supports a config file via -c option, cli arguments override any configuration

    Takes a config file having the following syntax, cli switches override config settings, DEFAULT section is optional
    Example config:
    [DEFAULT]
    cleanup = True
    localgridmap = /etc/localgridmap.conf,/etc/localgridmap2.conf
    # Can be repeated as many times with different servers, we strictly override in the order of config files
    log_file = /var/log/eos/grid/gridmap.log
    log_level = WARNING  # Same options as verbose CLI option

    [myiamserver1]
    client-id = 1223
    client-secret = 1234
    account = acc4usermap
    [..]
    examples:
$ echo -e '[myiamserver.cern.ch]\\nclient-id = 1234567890\\nclient-secret = *******\\naccount=acc4usermap' > iam.conf
$ eos-iam-mapfile -c iam.conf''',formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('-v', '--verbose', type = str.upper, nargs='?', const="DEBUG", default=None, choices=("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"), dest = 'debug', help = 'Control log verbosity')
    parser.add_argument('-s', '--server', dest = 'server', nargs=3,action='append', help = 'IAM server to query with respective client key and secret (space separated)', metavar=('SERVER', 'CLIENT_ID','CLIENT_KEY'))
    parser.add_argument('-c', '--config', dest = 'config', help = r'Client credentials file (for API access) in the following format: `[<iam server hostname>]\nclient-id = <id>\nclient-secret = <key>`')
    parser.add_argument('-t', '--targets', dest = 'targets', help = 'Target specific IAM servers defined in the configuration file (must be used together with -c)')
    parser.add_argument('-i', '--inputfile', dest = 'ifile', default=None, help = "Path to existing gridmapfile to be updated (matching DN's will be overwritten)")
    parser.add_argument('-o', '--outfile', dest = 'ofile', default=None, help = 'Path to dump gridmapfile, will dump to stdout otherwise')
    parser.add_argument('-l', '--localgridmap', dest = 'lgridmap', default = None, help = 'Path to local gridmap file, supports comma separated (no space) list of files (Will override any previous mappings encountered in order)')
    parser.add_argument('-a', '--account', dest = 'account', help = 'Account to which the result from the match should be mapped to')
    parser.add_argument('-p', '--pattern', type=str, dest = 'pattern', default=None, help = 'Pattern to search on user certificates `subject DN` field')
    parser.add_argument('-C','--case-sensitive', dest='sensitive',action='store_const', const=0, default=re.IGNORECASE, help = 'Makes the regex pattern (-p) to be case sensitive')
    parser.add_argument('-u', '--prefer-cern-certs', dest = 'prefer_cern',action='store_true', help = 'Prefers CERN.CH certificates (if any) to map user (uniquely)')
    parser.add_argument('-f', '--format',  type = str.upper, nargs='?', const="MAPFILE", default="GRIDMAP", choices=("MAPFILE","GRIDMAP"),dest='type_of_format', help = 'Choose file format, using DN or ID (defaults to ID if used, else DN)')
    parser.add_argument('--cleanup', action='store_true')
    parser.add_argument('--log-file', type=str, help='The log file for logging')

    args = parser.parse_args()

    main(args.config, args)
