#!/usr/bin/python3
#
# Copyright (c) CERN 2024
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import base64
import concurrent.futures
import configparser
import datetime
import json
import logging
import logging.handlers
import os
import pymysql
import requests
import socket
import time
import warnings


def get_log(log_file_path, log_program_name, log_level="INFO"):
    hostname = socket.gethostname()

    log_format = (
        "%(asctime)s.%(msecs)03d000 "
        + hostname
        + " %(levelname)s "
        + log_program_name
        + ':LVL="%(levelname)s" PID="%(process)d" TID="%(process)d" MSG="%(message)s"'
    )
    log_date_format = "%Y/%m/%d %H:%M:%S"
    log_formatter = logging.Formatter(fmt=log_format, datefmt=log_date_format)

    log_dir = os.path.dirname(log_file_path)
    if not os.path.isdir(log_dir):
        raise Exception(
            "The logging directory {} is not a directory or does not exist".format(
                log_dir
            )
        )
    if not os.access(log_dir, os.W_OK):
        raise Exception("The logging directory {} cannot be written to".format(log_dir))

    log_handler = logging.handlers.TimedRotatingFileHandler(
        filename=log_file_path, when="midnight", backupCount=30
    )
    log_handler.setFormatter(log_formatter)

    log = logging.getLogger()
    log.setLevel(log_level)
    log.addHandler(log_handler)

    return log


def get_config(path, program_name):
    if not os.path.isfile(path):
        print(f"The {program_name} configuration file does not exist: path={path}")
        exit(1)

    config = configparser.ConfigParser()
    config.read(path)

    if not config.has_option("database", "user"):
        raise Exception(
            f"The database section of the configuration file does not contain the user option: path={path}"
        )

    if not config.has_option("database", "password"):
        raise Exception(
            f"The database section of the configuration file does not contain the password option: path={path}"
        )

    if not config.has_option("database", "database"):
        raise Exception(
            f"The database section of the configuration file does not contain the database option: path={path}"
        )

    if not config.has_option("database", "host"):
        raise Exception(
            f"The database section of the configuration file does not contain the host option: path={path}"
        )

    if not config.has_option("database", "host"):
        raise Exception(
            f"The database section of the configuration file does not contain the port option: path={path}"
        )

    result = {
        "log_file": config.get(
            "main", "log_file", fallback=f"/var/log/fts3/{program_name}.log"
        ),
        "max_tokens_per_execution": config.getint(
            "main", "max_tokens_per_execution", fallback=1000
        ),
        "sec_between_executions": config.getint(
            "main", "sec_between_executions", fallback=300
        ),
        "log_individual_tokens": config.getboolean(
            "main", "log_individual_tokens", fallback=True
        ),
        "nb_get_token_threads": config.getint(
            "main", "nb_get_token_threads", fallback=10
        ),
        "db_user": config.get("database", "user"),
        "db_password": config.get("database", "password"),
        "db_database": config.get("database", "database"),
        "db_host": config.get("database", "host"),
        "db_port": config.getint("database", "port"),
    }

    return result


def select_worker_hostname(db_conn):
    start_select = time.time()

    cursor = db_conn.cursor()
    cursor.execute(
        "select"
        "  hostname "
        "from"
        "  t_hosts "
        "where"
        "  drain is not null or drain = 0 "
        "order by"
        "  hostname asc "
        "limit 1"
    )
    row = cursor.fetchone()
    hostname = row[0]

    select_sec = time.time() - start_select

    return hostname, select_sec


def get_major_db_schema_version(db_conn):
    cursor = db_conn.cursor()
    cursor.execute("select" "  max(major) " "from" "  t_schema_vers")
    rows = cursor.fetchall()
    major_schema_version = rows[0][0]

    return major_schema_version


def get_db_tables(db_conn):
    cursor = db_conn.cursor()
    cursor.execute(
        "select"
        "  table_name "
        "from"
        "  information_schema.tables "
        "where"
        "  table_schema = database()"
    )
    rows = cursor.fetchall()
    result = set()
    for row in rows:
        result.add(row[0])
    return result


def create_t_token_refresher_failed_token_db_table(db_conn):
    cursor = db_conn.cursor()
    cursor.execute(
        "create table t_token_refresher_failed_token("
        "  token_id char(16) not null,"
        "  failed_timestamp timestamp not null,"
        "  error_message varchar(1024) not null,"
        "  primary key (token_id)"
        ")"
    )


def insert_failed_token(db_conn, token_id, error_message):
    start_insert = time.time()

    cursor = db_conn.cursor()
    cursor.execute(
        "insert ignore into t_token_refresher_failed_token("
        "  token_id,"
        "  failed_timestamp,"
        "  error_message"
        ") values ("
        "  %(token_id)s,"
        "  now(),"
        "  %(error_message)s"
        ")",
        {"token_id": token_id, "error_message": error_message},
    )

    insert_sec = time.time() - start_insert
    return insert_sec


def is_a_failed_token(db_conn, token_id):
    cursor = db_conn.cursor()
    cursor.execute(
        "select"
        "  token_id "
        "from"
        "  t_token_refresher_failed_token "
        "where"
        "  token_id = %(token_id)s",
        {"token_id": token_id},
    )

    rows = cursor.fetchall()

    return len(rows) > 0


def decode_token(raw):
    split_raw = raw.split(".")

    if len(split_raw) != 3:
        raise Exception(
            "Could not split token into three parts using '.' as the delimiter"
        )

    encoded_header = split_raw[0]
    encoded_payload = split_raw[1]

    for i in range(0, len(encoded_header) % 4):
        encoded_header += "="
    for i in range(0, len(encoded_payload) % 4):
        encoded_payload += "="

    token = {}
    token["raw"] = raw
    token["header"] = json.loads(base64.urlsafe_b64decode(encoded_header))
    token["payload"] = json.loads(base64.urlsafe_b64decode(encoded_payload))
    token["signature"] = split_raw[2]

    return token


def get_token_providers(db_conn):
    cursor = db_conn.cursor()
    cursor.execute(
        "select"
        "  issuer,"
        "  name,"
        "  client_id,"
        "  client_secret "
        "from"
        "  t_token_provider"
    )

    result = {}
    rows = cursor.fetchall()
    for row in rows:
        issuer = row[0]
        if not issuer.endswith("/"):
            issuer = issuer + "/"
        result[issuer] = {"name": row[1], "client_id": row[2], "client_secret": row[3]}

    return result


def select_token_rows_to_refresh(db_conn, max_tokens_to_refresh):
    start_select = time.time()

    cursor = db_conn.cursor()
    cursor.execute(
        "SELECT"
        "  token_id,"
        "  access_token,"
        "  refresh_token,"
        "  issuer "
        "FROM"
        "  t_token "
        "WHERE"
        "  retired = 0 "
        "AND"
        "  NOW() >= access_token_refresh_after "
        "ORDER BY"
        "  retired ASC,"
        "  access_token_refresh_after ASC "
        "LIMIT %(max_tokens_to_refresh)s",
        {"max_tokens_to_refresh": max_tokens_to_refresh},
    )

    result = []
    rows = cursor.fetchall()
    for row in rows:
        token_row = {
            "token_id": row[0],
            "access_token": row[1],
            "refresh_token": row[2],
            "issuer": row[3],
        }
        result.append(token_row)

    select_sec = time.time() - start_select

    return result, select_sec


def get_openid_configuration(issuer):
    url = f"{issuer}/.well-known/openid-configuration"

    response = None
    try:
        response = requests.get(url)
    except:
        pass

    if not response:
        return None
    if response.status_code != 200:
        return None

    try:
        return json.loads(response.text)
    except Exception as e:
        raise Exception(f"Failed to parse JSON response from {url}: {e}")


def get_openid_confs(issuers):
    start_get = time.time()
    openid_configurations = {}
    for issuer in issuers:
        openid_configurations[issuer] = get_openid_configuration(issuer)
    get_sec = time.time() - start_get
    return openid_configurations, get_sec


class TokenCanNeverBeRefreshed(Exception):
    pass


def get_fresh_access_token(token_providers, openid_configurations, token_row):
    token_id = token_row["token_id"]
    old_access_token = decode_token(token_row["access_token"])
    old_refresh_token_raw = token_row["refresh_token"]

    result = {}
    result["token_id"] = token_id
    result["old_access_token"] = old_access_token

    token_issuer = token_row["issuer"]
    token_provider = token_providers[token_issuer]
    client_id = token_provider["client_id"]
    client_secret = token_provider["client_secret"]
    openid_configuration = openid_configurations[token_issuer]
    if not openid_configuration:
        result[
            "non_recoverable_error"
        ] = f"Token issuer has no OpenID configuration: token_issuer={token_issuer}"
        return result
    token_endpoint = openid_configuration["token_endpoint"]

    headers = {"Content-Type": "application/x-www-form-urlencoded"}
    data = {"grant_type": "refresh_token", "refresh_token": old_refresh_token_raw}
    if isinstance(old_access_token["payload"]["aud"], str):
        data["audience"] = old_access_token["payload"]["aud"]
    elif isinstance(old_access_token["payload"]["aud"], list):
        data["audience"] = " ".join(old_access_token["payload"]["aud"])
    else:
        raise Exception(
            "get_fresh_access_token failed: aud of old access-token is neither a string nor a list"
        )
    data["scope"] = old_access_token["payload"]["scope"]

    response = None
    post_sec = 0
    try:
        start_post = time.time()
        session = requests.Session()
        response = session.post(
            token_endpoint, headers=headers, data=data, auth=(client_id, client_secret)
        )
        result["post_sec"] = time.time() - start_post
    except Exception as e:
        result["post_sec"] = time.time() - start_post
        result[
            "recoverable_error"
        ] = f"Failed to get fresh access token: requests.post raised an exception: {e}"
        log.warn(result["recoverable_error"])
        return result

    if response is None:
        result[
            "recoverable_error"
        ] = "Failed to get fresh access token: response has unexpected None value"
        return result

    if response.status_code == 401:
        result["non_recoverable_error"] = (
            "Failed to get fresh access token: "
            "Token can never be refreshed: "
            "Failed to get fresh access token: "
            f"status_code={response.status_code}"
        )
        return result

    if response.status_code != 200:
        result[
            "recoverable_error"
        ] = f"Failed to get fresh access token: status_code={response.status_code}"
        return result

    json_response = None
    try:
        json_response = json.loads(response.text)
    except Exception as e:
        result[
            "recoverable_error"
        ] = f"Failed to get fresh access token: Failed to parse response.text as JSON: {e}"
        return result

    if "access_token" not in json_response:
        result[
            "recoverable_error"
        ] = "Failed to get fresh access token: No access token in the response"
        return result

    if "refresh_token" not in json_response:
        result[
            "recoverable_error"
        ] = "Failed to get fresh access token: No refresh token in the response"
        return result

    result["new_access_token"] = decode_token(json_response["access_token"])
    result["new_refresh_token_raw"] = json_response["refresh_token"]
    result["same_refresh_token"] = (
        old_refresh_token_raw == json_response["refresh_token"]
    )

    return result


def get_fresh_access_tokens(
    refresher_pool, token_providers, openid_configurations, token_rows
):
    nbFailed = 0
    future_results = []

    for token_row in token_rows:
        future_result = refresher_pool.submit(
            get_fresh_access_token, token_providers, openid_configurations, token_row
        )
        future_results.append(future_result)

    # Wait for and collect all of the results
    results = []
    for future_result in future_results:
        results.append(future_result.result())

    return results


def update_token_row(db_conn, token_id, new_access_token, new_refresh_token_raw):
    # Refresh a token half way through its lifetime
    new_access_token_lifetime_sec = (
        new_access_token["payload"]["exp"] - new_access_token["payload"]["nbf"]
        if new_access_token["payload"]["exp"] > new_access_token["payload"]["nbf"]
        else 0
    )
    new_access_token_refresh_after = new_access_token["payload"]["nbf"] + new_access_token_lifetime_sec * 0.5

    start_update = time.time()
    cursor = db_conn.cursor()
    result = cursor.execute(
        "UPDATE"
        "  t_token "
        "SET"
        "  access_token = %(access_token)s,"
        "  access_token_not_before = from_unixtime(%(access_token_not_before)s),"
        "  access_token_expiry = from_unixtime(%(access_token_expiry)s),"
        "  access_token_refresh_after = from_unixtime(%(access_token_refresh_after)s),"
        "  refresh_token = %(refresh_token)s "
        "WHERE"
        "  token_id = %(token_id)s",
        {
            "access_token": new_access_token["raw"],
            "access_token_not_before": new_access_token["payload"]["nbf"],
            "access_token_expiry": new_access_token["payload"]["exp"],
            "access_token_refresh_after": new_access_token_refresh_after,
            "refresh_token": new_refresh_token_raw,
            "token_id": token_id,
        },
    )
    update_sec = time.time() - start_update

    if result != 1:
        raise Exception(
            f"Failed to update token row: expected_nb_effected_rows=1 actual={result}"
        )

    return update_sec


def process_fresh_access_token(config, log, db_conn, fresh_access_token_result, stats):
    token_id = fresh_access_token_result["token_id"]

    if "recoverable_error" in fresh_access_token_result:
        stats["nb_recoverable_error"] += 1
        return

    if "non_recoverable_error" in fresh_access_token_result:
        stats["nb_non_recoverable_error"] += 1
        stats["insert_failed_sec"] += insert_failed_token(
            db_conn=db_conn,
            token_id=token_id,
            error_message=fresh_access_token_result["non_recoverable_error"],
        )
        return

    stats["nb_get_access_token"] += 1
    stats["get_access_token_sec"] += fresh_access_token_result["post_sec"]
    if fresh_access_token_result["same_refresh_token"]:
        stats["nb_same_refresh_tokens"] += 1

    old_access_token = fresh_access_token_result["old_access_token"]
    new_access_token = fresh_access_token_result["new_access_token"]
    new_refresh_token_raw = fresh_access_token_result["new_refresh_token_raw"]

    old_access_token_lifetime = (
        old_access_token["payload"]["exp"] - old_access_token["payload"]["nbf"]
    )
    new_access_token_lifetime = (
        new_access_token["payload"]["exp"] - new_access_token["payload"]["nbf"]
    )

    if config["log_individual_tokens"]:
        log.info(
            "Got fresh access-token: "
            f"token_id={token_id} "
            f"iss={new_access_token['payload']['iss']} "
            f"scope={new_access_token['payload']['scope']} "
            f"aud={new_access_token['payload']['aud']} "
            f"iat={new_access_token['payload']['iat']} "
            f"nbf={new_access_token['payload']['nbf']} "
            f"exp={new_access_token['payload']['exp']} "
            f"lifetime_sec={new_access_token_lifetime}"
        )

    if abs(old_access_token_lifetime - new_access_token_lifetime) > 1:
        stats["nb_diff_lifetime"] += 1
        log.warning(
            "Lifetime of old and new acces-tokens differ: "
            f"old_lifetime_sec={old_access_token_lifetime} "
            f"new_lifetime_sec={new_access_token_lifetime}"
        )

    try:
        stats["update_access_token_sec"] += update_token_row(
            db_conn=db_conn,
            token_id=token_id,
            new_access_token=new_access_token,
            new_refresh_token_raw=new_refresh_token_raw,
        )
        stats["nb_updated"] += 1
    except:
        stats["nb_failed_update"] += 1


def process_fresh_access_tokens(config, log, db_conn, fresh_access_token_results):
    stats = {
        "nb_get_access_token": 0,
        "get_access_token_sec": 0,
        "single_thread_get_access_token_hz": 0,
        "update_access_token_sec": 0,
        "nb_diff_lifetime": 0,
        "nb_recoverable_error": 0,
        "nb_non_recoverable_error": 0,
        "nb_updated": 0,
        "nb_failed_update": 0,
        "insert_failed_sec": 0,
        "nb_same_refresh_tokens": 0,
    }

    for fresh_access_token_result in fresh_access_token_results:
        process_fresh_access_token(
            config, log, db_conn, fresh_access_token_result, stats
        )

    if stats["get_access_token_sec"] > 0:
        stats["single_thread_get_access_token_hz"] = (
            stats["nb_get_access_token"] / stats["get_access_token_sec"]
        )

    return stats


def refresh_tokens(refresher_pool, config, log, db_conn, select_worker_hostname_sec):
    start_time = time.time()

    token_providers = get_token_providers(db_conn)
    token_issuers = token_providers.keys()
    openid_configurations, get_openid_confs_sec = get_openid_confs(token_issuers)
    token_rows_to_refresh, select_tokens_to_refresh_sec = select_token_rows_to_refresh(
        db_conn, max_tokens_to_refresh=config["max_tokens_per_execution"]
    )

    start_get_fresh_access_tokens = time.time()
    fresh_access_token_results = get_fresh_access_tokens(
        refresher_pool, token_providers, openid_configurations, token_rows_to_refresh
    )
    get_fresh_access_tokens_sec = time.time() - start_get_fresh_access_tokens
    multi_thread_get_access_token_hz = (
        len(token_rows_to_refresh) / get_fresh_access_tokens_sec
        if get_fresh_access_tokens_sec
        else 0
    )

    process_fresh_access_tokens_stats = process_fresh_access_tokens(
        config=config,
        log=log,
        db_conn=db_conn,
        fresh_access_token_results=fresh_access_token_results,
    )

    total_sec = time.time() - start_time

    start_datetime = datetime.datetime.fromtimestamp(start_time)
    start_str = start_datetime.strftime("%Y/%m/%d_%H:%M:%S.%f")

    log.info(
        f"start={start_str} "
        f"total_sec={total_sec:.3f} "
        f"nb_get_token_threads={config['nb_get_token_threads']} "
        f"select_worker_hostname_sec={select_worker_hostname_sec:.3f} "
        f"get_openid_confs_sec={get_openid_confs_sec:.3f} "
        f"nb_tokens_to_refresh={len(token_rows_to_refresh)} "
        f"select_tokens_to_refresh_sec={select_tokens_to_refresh_sec:.3f} "
        f"nb_get_access_token={process_fresh_access_tokens_stats['nb_get_access_token']} "
        f"get_access_token_sec={process_fresh_access_tokens_stats['get_access_token_sec']:.3f} "
        f"single_thread_get_access_token_hz={process_fresh_access_tokens_stats['single_thread_get_access_token_hz']:.3f} "
        f"multi_thread_get_access_token_hz={multi_thread_get_access_token_hz:.3f} "
        f"update_access_token_sec={process_fresh_access_tokens_stats['update_access_token_sec']:.3f} "
        f"nb_diff_lifetime={process_fresh_access_tokens_stats['nb_diff_lifetime']} "
        f"nb_recoverable_error={process_fresh_access_tokens_stats['nb_recoverable_error']} "
        f"nb_non_recoverable_error={process_fresh_access_tokens_stats['nb_non_recoverable_error']} "
        f"nb_updated={process_fresh_access_tokens_stats['nb_updated']} "
        f"nb_failed_update={process_fresh_access_tokens_stats['nb_failed_update']} "
        f"insert_failed_sec={process_fresh_access_tokens_stats['insert_failed_sec']:.3f} "
        f"nb_same_refresh_tokens={process_fresh_access_tokens_stats['nb_same_refresh_tokens']}"
    )


warnings.simplefilter("ignore")

program_name = "ftstokenrefresherd"

parser = argparse.ArgumentParser(
    description="Refreshes access tokens in the FTS database"
)
parser.add_argument(
    "-c",
    "--config",
    default=f"/etc/fts3/{program_name}.ini",
    help="Path of the configuration file",
)

cmd_line = parser.parse_args()

config = get_config(cmd_line.config, program_name)

log = get_log(log_file_path=config["log_file"], log_program_name=program_name)

db_conn = pymysql.connect(
    user=config["db_user"],
    password=config["db_password"],
    database=config["db_database"],
    host=config["db_host"],
    port=config["db_port"],
    autocommit=True,
)

min_allowed_major_db_schema_version = 9
actual_major_db_schema_version = get_major_db_schema_version(db_conn)
if actual_major_db_schema_version < min_allowed_major_db_schema_version:
    print(
        f"Aborting: "
        "Database major version number is less than required: "
        f"min_allowed={min_allowed_major_db_schema_version} "
        f"actual={actual_major_db_schema_version}"
    )
    exit(1)

db_tables = get_db_tables(db_conn)
if "t_token" not in db_tables:
    print("The t_token table does not exist")
    exit(1)
if "t_token_provider" not in db_tables:
    print("The t_token_provider table does not exist")
    exit(1)

if "t_token_refresher_failed_token" not in db_tables:
    try:
        create_t_token_refresher_failed_token_db_table(db_conn)
    except:
        # Maybe another script was quicker at creating the table
        pass
    db_tables = get_db_tables(db_conn)
    if "t_token_refresher_failed_token" not in db_tables:
        print(
            "The t_token_refresher_failed_token table does not exists and cannot be created"
        )
        exit(1)

refresher_pool = concurrent.futures.ThreadPoolExecutor(
    max_workers=config["nb_get_token_threads"]
)

while True:
    execution_start = time.time()

    worker_hostname, select_worker_hostname_sec = select_worker_hostname(db_conn)
    if worker_hostname == socket.gethostname():
        refresh_tokens(refresher_pool, config, log, db_conn, select_worker_hostname_sec)

    execution_duration = time.time() - execution_start

    secs_to_next_execution = config["sec_between_executions"] - execution_duration
    if secs_to_next_execution > 0:
        time.sleep(secs_to_next_execution)
