#!/usr/bin/python

import os
import pwd
import sys
import socket
import logging
import optparse

os.environ.setdefault('CONDOR_CONFIG', '/etc/condor-ce/condor_config')

import cherrypy
import cherrypy._cpwsgi_server
import cherrypy.wsgiserver

import htcondor
import htcondorce.web


ALIVE_HEARTBEAT = 60
def send_heartbeat():
    euid = os.geteuid()
    try:
        htcondor.send_alive(timeout=ALIVE_HEARTBEAT)
    except (RuntimeError, ValueError) as exc:
        if 'CONDOR_INHERIT' in exc or 'location ClassAd' in exc:
            htcondor.log(htcondor.LogLevel.FullDebug,
                         'WARNING: Could not find location of HTCondor-CE master daemon to send keepalive')
        else:
            htcondor.log(htcondor.LogLevel.Always,
                         'ERROR: Failed to send keepalive to the HTCondor-CE master daemon with EUID {0}'.format(euid))
    if euid != os.geteuid():
        os.seteuid(euid)


class HTCondorHandler(logging.Handler):

    def emit(self, record):
        msg = self.format(record)
        euid = os.geteuid()
        htcondor.log(htcondor.LogLevel.Always, msg)
        if euid != os.geteuid():
            os.seteuid(euid) 


def condor_ids():
    if 'CONDOR_IDS' in htcondor.param:
        info = htcondor.param['CONDOR_IDS'].split('.', 1)
        return int(info[0]), int(info[1])
    info = pwd.getpwnam('condor')
    return info.pw_uid, info.pw_gid


g_spooldir = None
def check_multice():
    global g_spooldir
    if not g_spooldir:
        return
    dircount = 0
    for path in os.listdir(g_spooldir):
        if os.path.isdir(os.path.join(g_spooldir, path)) and (path not in ['vos', 'metrics']):
            dircount += 1
            if dircount > 1:
                break
    htcondorce.web.g_is_multice = dircount > 1


def setup_logging():
    if not hasattr(htcondor, 'log'):
        return
    cherrypy.log.error_log.handlers = []
    #cherrypy.log.access_log.handlers = []
    euid = os.geteuid()
    if os.isatty(1):
        htcondor.enable_debug()
    else:
        htcondor.enable_log()
    if euid != os.geteuid():
        os.seteuid(euid)
    h = HTCondorHandler()
    h.setLevel(logging.DEBUG)
    h.setFormatter(cherrypy._cplogging.logfmt)

    def nulltime():
        return ''

    cherrypy.log.time = nulltime
    cherrypy.log.error_log.addHandler(h)
    cherrypy.log.access_log.addHandler(h)


def parse_opts(localname=None, pidfile=None):
    parser = optparse.OptionParser()
    # Unimplemented, DaemonCore short-name options
    parser.add_option("-a")
    parser.add_option("-b", action="store_true")
    parser.add_option("-c")
    parser.add_option("-d", action="store_true")
    parser.add_option("-f", action="store_true")
    parser.add_option("-k")
    parser.add_option("-l")
    parser.add_option("-q", action="store_true")
    parser.add_option("-r", type="int")
    parser.add_option("-t", action="store_true")
    parser.add_option("-v", action="store_true")
    # Implemented options
    parser.add_option("--pool", help="HTCondor-CE pool to consider.", dest="pool")
    parser.add_option("-n", "--name", help="HTCondor-CE schedd to consider.", dest="name")
    parser.add_option("--spool", help="Spool directory to use.", dest="spool")
    parser.add_option("-p", "--port", help="Port to use for webapp.", dest="port")

    opts = parser.parse_args()[0]

    if hasattr(htcondor, 'set_subsystem'):
        htcondor.set_subsystem(localname or 'CEVIEW')
    setup_logging()

    return opts


class MyWSGIGateway(cherrypy.wsgiserver.WSGIGateway_10):

    env_override = {}

    def get_environ(self):
        environ = cherrypy.wsgiserver.WSGIGateway_10.get_environ(self)
        environ.update(self.env_override)
        return environ


class WSGILogging(object):

    def __init__(self, app):
        self.app = app

    FORMAT = '{host} "{request}" {status} {size} "{referer}" "{agent}"'

    def __call__(self, environ, start_response):
        status_codes = []
        content_lengths = []
        def custom_start_response(status, response_headers, exc_info=None):
            status_codes.append(int(status.partition(' ')[0]))
            for name, value in response_headers:
                if name.lower() == 'content-length':
                    content_lengths.append(int(value))
                    break
            return start_response(status, response_headers, exc_info)
        retval = self.app(environ, custom_start_response)
        content_length = content_lengths[0] if content_lengths else len(b''.join(retval))

        msg = { \
            'host': environ.get('REMOTE_ADDR', ''),
            'request': "%s %s %s" % ( \
              environ.get('REQUEST_METHOD', ''),
              environ.get('PATH_INFO', ''),
              environ.get('SERVER_PROTOCOL', '')
            ),
            'size': content_length,
            'status': status_codes[0],
            'referer': environ.get('HTTP_REFERER', ''),
            'agent': environ.get('HTTP_USER_AGENT', ''),
        }
        msg = self.FORMAT.format(**msg)
        htcondor.log(htcondor.LogLevel.Always, msg)
        return retval


def main():
    dc_long_opts = {}
    for long_opt in ['-local-name', '-pidfile']:
        try:
            long_opt_index = sys.argv.index(long_opt)
            dc_long_opts[long_opt.replace('-', '')] = sys.argv.pop(long_opt_index + 1)
            sys.argv.pop(long_opt_index)
        except ValueError:
            pass

    opts = parse_opts(**dc_long_opts)

    spooldir = htcondor.param.get("HTCONDORCE_VIEW_SPOOL")
    if opts.spool:
        spooldir = opts.spool
    if not spooldir:
        if not os.path.exists("tmp"):
            os.mkdir("tmp")
        spooldir = "tmp"
    global g_spooldir
    g_spooldir = spooldir
    check_multice()

    if opts.port:
        port = int(opts.port)
    else:
        port = int(htcondor.param.get('HTCONDORCE_VIEW_PORT', 8080))
    MyWSGIGateway.env_override['htcondorce.spool'] = spooldir
    if os.path.exists('templates'):
        MyWSGIGateway.env_override['htcondorce.templates'] = 'templates'
    if opts.pool:
        MyWSGIGateway.env_override['htcondorce.pool'] = opts.pool
    if opts.name:
        MyWSGIGateway.env_override['htcondorce.name'] = opts.name

    app = WSGILogging(htcondorce.web.application)

    cherrypy.tree.graft(app, "/")
    cherrypy.server.unsubscribe()

    # Do the IPv6 / v4 dance:
    #   - If bindv6only is enabled, then we must connect to v4 and v6 separately.
    #   - Otherwise, only connect via v6
    bindonly = '/proc/sys/net/ipv6/bindv6only'
    addrs = socket.getaddrinfo(socket.getfqdn(), 9618, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, 0)
    families = [i[0] for i in addrs]
    bind_ipv6_only = False
    if socket.AF_INET6 in families:
        bind_ipv6_only = os.path.exists(bindonly) and open(bindonly).read() == "1"
        server6 = cherrypy._cpserver.Server()
        server6.socket_host = "::"
        server6.socket_port = port
        server6.thread_pool = 5
        wsgi_server = cherrypy._cpwsgi_server.CPWSGIServer(server6)
        wsgi_server.gateway = MyWSGIGateway
        server6.instance = wsgi_server
        server6.subscribe()

    if (socket.AF_INET in families) and ( \
             (bind_ipv6_only) or (socket.AF_INET6 not in families) \
            ):
        server4 = cherrypy._cpserver.Server()
        server4.socket_host = "0.0.0.0"
        server4.socket_port = port
        server4.thread_pool = 5

        wsgi_server = cherrypy._cpwsgi_server.CPWSGIServer(server4)
        wsgi_server.gateway = MyWSGIGateway
        server4.instance = wsgi_server
        server4.subscribe()

    if os.geteuid() == 0:
        uid, gid = condor_ids()
        cherrypy.process.plugins.DropPrivileges(cherrypy.engine, uid=uid, gid=gid).subscribe()

    cherrypy.engine.start()

    if not os.isatty(1) and hasattr(htcondor, 'send_alive'):
        wd = cherrypy.process.plugins.BackgroundTask(max(ALIVE_HEARTBEAT/3-1, 1), send_heartbeat)
        wd.start()

    # Periodically update to make sure we're still a multi-ce (or not)
    wd = cherrypy.process.plugins.BackgroundTask(5, check_multice)
    wd.start()

    cherrypy.engine.block()


if __name__ == '__main__':
    main()

