#!/usr/bin/python3
#
# Copyright (c) CERN 2024
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import base64
import configparser
import datetime
import json
import logging
import logging.handlers
import os
import pymysql
import requests
import socket
import time
import warnings


def get_log(log_file_path, log_program_name, log_level="INFO"):
    hostname = socket.gethostname()

    log_format = (
        "%(asctime)s.%(msecs)03d000 "
        + hostname
        + " %(levelname)s "
        + log_program_name
        + ':LVL="%(levelname)s" PID="%(process)d" TID="%(process)d" MSG="%(message)s"'
    )
    log_date_format = "%Y/%m/%d %H:%M:%S"
    log_formatter = logging.Formatter(fmt=log_format, datefmt=log_date_format)

    log_dir = os.path.dirname(log_file_path)
    if not os.path.isdir(log_dir):
        raise Exception(
            "The logging directory {} is not a directory or does not exist".format(
                log_dir
            )
        )
    if not os.access(log_dir, os.W_OK):
        raise Exception("The logging directory {} cannot be written to".format(log_dir))

    log_handler = logging.handlers.TimedRotatingFileHandler(
        filename=log_file_path, when="midnight", backupCount=30
    )
    log_handler.setFormatter(log_formatter)

    log = logging.getLogger()
    log.setLevel(log_level)
    log.addHandler(log_handler)

    return log


def get_config(path, program_name):
    if not os.path.isfile(path):
        print(f"The {program_name} configuration file does not exist: path={path}")
        exit(1)

    config = configparser.ConfigParser()
    config.read(path)

    if not config.has_option("database", "user"):
        raise Exception(
            f"The database section of the configuration file does not contain the user option: path={path}"
        )

    if not config.has_option("database", "password"):
        raise Exception(
            f"The database section of the configuration file does not contain the password option: path={path}"
        )

    if not config.has_option("database", "database"):
        raise Exception(
            f"The database section of the configuration file does not contain the database option: path={path}"
        )

    if not config.has_option("database", "host"):
        raise Exception(
            f"The database section of the configuration file does not contain the host option: path={path}"
        )

    if not config.has_option("database", "host"):
        raise Exception(
            f"The database section of the configuration file does not contain the port option: path={path}"
        )

    result = {
        "log_file": config.get(
            "main", "log_file", fallback=f"/var/log/fts3/{program_name}.log"
        ),
        "sec_between_executions": config.getint(
            "main", "sec_between_executions", fallback=600
        ),
        "max_tokens_to_retire": config.getint(
            "main", "max_tokens_to_retire", fallback=600000
        ),
        "db_user": config.get("database", "user"),
        "db_password": config.get("database", "password"),
        "db_database": config.get("database", "database"),
        "db_host": config.get("database", "host"),
        "db_port": config.getint("database", "port"),
    }

    return result


def select_worker_hostname(db_conn):
    start_select = time.time()

    cursor = db_conn.cursor()
    cursor.execute(
        "select"
        "  hostname "
        "from"
        "  t_hosts "
        "where"
        "  drain is not null or drain = 0 "
        "order by"
        "  hostname asc "
        "limit 1"
    )
    row = cursor.fetchone()
    hostname = row[0]

    select_sec = time.time() - start_select

    return hostname, select_sec


def get_major_db_schema_version(db_conn):
    cursor = db_conn.cursor()
    cursor.execute("select" "  max(major) " "from" "  t_schema_vers")
    rows = cursor.fetchall()
    return rows[0][0]


def get_db_tables(db_conn):
    cursor = db_conn.cursor()
    cursor.execute(
        "select"
        "  table_name "
        "from"
        "  information_schema.tables "
        "where"
        "  table_schema = database()"
    )
    rows = cursor.fetchall()
    result = set()
    for row in rows:
        result.add(row[0])
    return result


def create_t_token_refresher_failed_token_db_table(db_conn):
    cursor = db_conn.cursor()
    cursor.execute(
        "create table t_token_refresher_failed_token("
        "  token_id char(16) not null,"
        "  failed_timestamp timestamp not null,"
        "  error_message varchar(1024) not null,"
        "  primary key (token_id)"
        ")"
    )


def select_token_ids_to_retire(db_conn, max_tokens_to_retire):
    start_select = time.time()
    cursor = db_conn.cursor()
    cursor.execute(
        "SELECT"
        "  t_token.token_id "
        "FROM"
        "  t_token "
        "WHERE"
        "  t_token.retired = 0 "
        "AND"
        "  NOT EXISTS ("
        "    SELECT"
        "      t_file.src_token_id"
        "    FROM"
        "      t_file"
        "    WHERE"
        "      t_file.src_token_id = t_token.token_id"
        "    AND"
        "      t_file.file_state in ('TOKEN_PREP', 'SUBMITTED')"
        "  )"
        "AND"
        "  NOT EXISTS ("
        "    SELECT"
        "      t_file.dst_token_id"
        "    FROM"
        "      t_file"
        "    WHERE"
        "      t_file.dst_token_id = t_token.token_id"
        "    AND"
        "      t_file.file_state in ('TOKEN_PREP', 'SUBMITTED')"
        "  )"
        "ORDER BY"
        "  t_token.access_token_expiry "
        "LIMIT %(max_tokens_to_retire)s",
        {"max_tokens_to_retire": max_tokens_to_retire},
    )
    result = []
    rows = cursor.fetchall()
    for row in rows:
        result.append(row[0])
    select_sec = time.time() - start_select

    return result, select_sec


def retire_token(db_conn, token_id):
    start_retire = time.time()
    cursor = db_conn.cursor()
    nb_rows = cursor.execute(
        "UPDATE"
        "  t_token "
        "SET"
        "  retired = 1 "
        "WHERE"
        "  token_id = %(token_id)s ",
        {"token_id": token_id},
    )
    retire_sec = time.time() - start_retire

    token_retired = nb_rows != 0
    return token_retired, retire_sec


def retire_tokens(db_conn, token_ids):
    nb_retired_tokens = 0
    retire_tokens_sec = 0

    for token_id in token_ids:
        token_retired, retire_token_sec = retire_token(db_conn, token_id=token_id)
        if token_retired:
            nb_retired_tokens += 1
        retire_tokens_sec += retire_token_sec

    return nb_retired_tokens, retire_tokens_sec


def select_all_token_ids(db_conn):
    start_select = time.time()
    cursor = db_conn.cursor()
    cursor.execute(
        "SELECT" "  token_id " "FROM" "  t_token " "ORDER BY" "  access_token_expiry"
    )
    result = []
    rows = cursor.fetchall()
    for row in rows:
        result.append(row[0])
    select_sec = time.time() - start_select

    return result, select_sec


def delete_ignore_token(db_conn, token_id):
    start_delete = time.time()
    cursor = db_conn.cursor()
    nb_deleted = cursor.execute(
        "DELETE IGNORE FROM" "  t_token " "WHERE" "  token_id = %(token_id)s",
        {"token_id": token_id},
    )
    delete_sec = time.time() - start_delete

    return nb_deleted, delete_sec


def delete_orphan_tokens(db_conn, token_ids):
    total_nb_deleted = 0
    total_delete_sec = 0
    for token_id in token_ids:
        nb_deleted, delete_sec = delete_ignore_token(db_conn, token_id)
        total_nb_deleted += nb_deleted
        total_delete_sec += delete_sec

    return total_nb_deleted, total_delete_sec


def select_redundant_failed_token_ids(db_conn):
    start_select = time.time()
    cursor = db_conn.cursor()
    cursor.execute(
        "SELECT"
        "  token_id "
        "FROM"
        "  t_token_refresher_failed_token "
        "WHERE"
        "  token_id NOT IN ("
        "    SELECT token_id FROM t_token"
        "  )"
    )
    rows = cursor.fetchall()
    result = set()
    for row in rows:
        result.add(row[0])
    select_sec = time.time() - start_select

    return result, select_sec


def delete_ignore_redundant_failed_token(db_conn, token_id):
    start_delete = time.time()
    cursor = db_conn.cursor()
    nb_deleted = cursor.execute(
        "DELETE IGNORE FROM"
        "  t_token_refresher_failed_token "
        "WHERE"
        "  token_id = %(token_id)s",
        {"token_id": token_id},
    )

    delete_sec = time.time() - start_delete

    return nb_deleted, delete_sec


def delete_ignore_redundant_failed_tokens(db_conn, redundant_failed_token_ids):
    total_nb_deleted = 0
    total_delete_sec = 0
    for token_id in redundant_failed_token_ids:
        nb_deleted, delete_sec = delete_ignore_redundant_failed_token(db_conn, token_id)
        total_nb_deleted += nb_deleted
        total_delete_sec += delete_sec

    return total_nb_deleted, total_delete_sec


def perform_housekeeping(db_conn, select_worker_hostname_sec, max_tokens_to_retire):
    start_time = time.time()

    token_ids_to_retire, select_token_ids_to_retire_sec = select_token_ids_to_retire(
        db_conn, max_tokens_to_retire
    )
    nb_retired_tokens, retire_tokens_sec = retire_tokens(db_conn, token_ids_to_retire)

    all_token_ids, select_all_token_ids_to_delete_sec = select_all_token_ids(db_conn)
    nb_deleted_tokens, delete_sec = delete_orphan_tokens(
        db_conn, token_ids=all_token_ids
    )

    (
        redundant_failed_token_ids,
        select_redundant_failed_token_ids_sec,
    ) = select_redundant_failed_token_ids(db_conn)
    (
        nb_deleted_redundant_failed_tokens,
        delete_redundant_failed_tokens_sec,
    ) = delete_ignore_redundant_failed_tokens(
        db_conn, redundant_failed_token_ids=redundant_failed_token_ids
    )

    total_sec = time.time() - start_time

    start_datetime = datetime.datetime.fromtimestamp(start_time)
    start_str = start_datetime.strftime("%Y/%m/%d_%H:%M:%S.%f")

    log.info(
        f"start={start_str} "
        f"total_sec={total_sec:.3f} "
        f"select_worker_hostname_sec={select_worker_hostname_sec:.3f} "
        f"nb_tokens_to_retire={len(token_ids_to_retire)} "
        f"select_token_ids_to_retire_sec={select_token_ids_to_retire_sec:.3f} "
        f"nb_retired_tokens={nb_retired_tokens} "
        f"retire_tokens_sec={retire_tokens_sec:.3f} "
        f"select_token_ids_to_delete_sec={select_all_token_ids_to_delete_sec:.3f} "
        f"nb_deleted_tokens={nb_deleted_tokens} "
        f"delete_sec={delete_sec:.3f} "
        f"select_redundant_failed_token_ids_sec={select_redundant_failed_token_ids_sec:.3f} "
        f"nb_deleted_redundant={nb_deleted_redundant_failed_tokens} "
        f"delete_redundant_sec={delete_redundant_failed_tokens_sec:.3f}"
    )


warnings.simplefilter("ignore")

program_name = "ftstokenhousekeeperd"

parser = argparse.ArgumentParser(
    description="Performs houskeeping on token-related database tables"
)
parser.add_argument(
    "-c",
    "--config",
    default=f"/etc/fts3/{program_name}.ini",
    help="Path of the configuration file",
)

cmd_line = parser.parse_args()

config = get_config(cmd_line.config, program_name)

log = get_log(log_file_path=config["log_file"], log_program_name=program_name)

db_conn = pymysql.connect(
    user=config["db_user"],
    password=config["db_password"],
    database=config["db_database"],
    host=config["db_host"],
    port=config["db_port"],
    autocommit=True,
)

min_allowed_major_db_schema_version = 9
actual_major_db_schema_version = get_major_db_schema_version(db_conn)
if actual_major_db_schema_version < min_allowed_major_db_schema_version:
    print(
        f"Aborting: "
        "Database major version number is less than required: "
        f"min_allowed={min_allowed_major_db_schema_version} "
        f"actual={actual_major_db_schema_version}"
    )
    exit(1)

db_tables = get_db_tables(db_conn)
if "t_token" not in db_tables:
    print("The t_token table does not exist")
    exit(1)
if "t_token_provider" not in db_tables:
    print("The t_token_provider table does not exist")
    exit(1)

if "t_token_refresher_failed_token" not in db_tables:
    try:
        create_t_token_refresher_failed_token_db_table(db_conn)
    except:
        # Maybe another script was quicker at creating the table
        pass
    db_tables = get_db_tables(db_conn)
    if "t_token_refresher_failed_token" not in db_tables:
        print(
            "The t_token_refresher_failed_token table does not exists and cannot be created"
        )
        exit(1)

while True:
    execution_start = time.time()

    worker_hostname, select_worker_hostname_sec = select_worker_hostname(db_conn)
    if worker_hostname == socket.gethostname():
        perform_housekeeping(
            db_conn, select_worker_hostname_sec, config["max_tokens_to_retire"]
        )

    execution_duration = time.time() - execution_start

    secs_to_next_execution = config["sec_between_executions"] - execution_duration
    if secs_to_next_execution > 0:
        time.sleep(secs_to_next_execution)
