#
# Copyright 2006 by Object Craft P/L, Melbourne, Australia.
#
# LICENCE - see LICENCE file distributed with this software for details.
#

# XXX TODO
# Handle SIGPIPE (client abort) and SIGUSR1 (graceful)?
# Respond to FCGI_GET_VALUES  (can't test)
# should we be calling poll on every stdout or stderr write?
# Handle FCGI_ABORT_REQUEST

# SPEC - http://www.fastcgi.com/devkit/doc/fcgi-spec.html
# Apache mod_fcgi - http://www.fastcgi.com/mod_fastcgi/docs/mod_fastcgi.html
# App - http://localhost/cgi-bin/fcgitest/test.py

import sys
import os
import cgi
import errno
import socket
import select
import struct
from cStringIO import StringIO

from albatross import cgiapp

class FCGIError(Exception): pass
class NotFCGI(FCGIError): pass

# I think the FastCGI spec made a mistake attempting to multiplex
# stderr over the socket: if you have a problem while starting up,
# error output goes into a void, making FastCGI very hard debug. Apache
# allows us to continue to write to stderr - if your web server
# doesn't, then change this variable to True
multiplex_stderr = False

# FastCGI Constants
# http://www.fastcgi.com/devkit/doc/fcgi-spec.html#S8
FCGI_LISTENSOCK_FILENO  = 0
FCGI_HEADER_LEN         = 8
FCGI_VERSION            = 1
FCGI_MAX_DATA           = (1<<16)-1       # 16 bit length field
# FCGI message types:
FCGI_BEGIN_REQUEST      = 1
FCGI_ABORT_REQUEST      = 2
FCGI_END_REQUEST        = 3
FCGI_PARAMS             = 4
FCGI_STDIN              = 5
FCGI_STDOUT             = 6
FCGI_STDERR             = 7
FCGI_DATA               = 8
FCGI_GET_VALUES         = 9
FCGI_GET_VALUES_RESULT  = 10
FCGI_UNKNOWN_TYPE       = 11
FCGI_MAXTYPE            = FCGI_UNKNOWN_TYPE
# BEGIN_REQUEST flags
FCGI_KEEP_CONN          = 1
# BEGIN_REQUEST roles
FCGI_RESPONDER          = 1
FCGI_AUTHORIZER         = 2
FCGI_FILTER             = 3
# END_REQUEST status
FCGI_REQUEST_COMPLETE   = 0
FCGI_CANT_MPX_CONN      = 1
FCGI_OVERLOADED         = 2
FCGI_UNKNOWN_ROLE       = 3

# This is not used, although it's retained to aid future debugging - some
# fastcgi environments do not have a usable fd 2 (stderr).
def redir_stderr():
    fd = os.open('/tmp/debug', os.O_WRONLY|os.O_APPEND|os.O_CREAT, 0666)
    if fd != sys.stderr.fileno():
        os.dup2(fd, sys.stderr.fileno())
        os.close(fd)

# 
listen_sock = -1
is_fcgi = True                          # Until we know otherwise
# If running via TCP, contains a list of acceptable source IP address.
web_server_addrs = None

def _is_listening_socket(sock):
    # The FCGI socket is bound and listening, but not connected - if we're
    # accidently started as CGI, then we get a connected socket or a pipe
    # instead.
    try:
        sock.getpeername()
    except socket.error, (eno, errmsg):
        return eno == errno.ENOTCONN
    else:
        return False

def fcgi_init():
    global listen_sock, web_server_addrs, is_fcgi
    if not is_fcgi:
        raise NotFCGI
    try:
        web_server_addrs = os.environ['FCGI_WEB_SERVER_ADDRS'].split(',')
    except KeyError:
        pass
    null = os.open('/dev/null', os.O_RDWR)
    try:
        # socket.fromfd dups the socket (sigh)
        try:
            listen_sock = socket.fromfd(FCGI_LISTENSOCK_FILENO, 
                                        socket.AF_UNIX, socket.SOCK_STREAM)
        except socket.error:
            raise NotFCGI
        if not _is_listening_socket(listen_sock):
            listen_sock.close()
            listen_sock = -1
            is_fcgi = False
            raise NotFCGI
        # the fcgi spec says we're to close stdin, stdout and stderr, but 3rd
        # party code often expects fd 0, 1 and 2 to be open, so we point 0 and
        # 1 (stdin and stdout) at /dev/null.  Apache's mod_fcgi deviates from
        # the spec by forwarding writes to fd 2 (stderr) to the web server
        # error log - this is handy, so we leave fd 2 alone.
        os.dup2(null, 0)
        os.dup2(null, 1)
    finally:
        os.close(null)


def is_fcgi():
    if listen_sock < 0:
        try:
            fcgi_init()
        except NotFCGI:
            return False
    return True


class FCGIFileOut:
    def __init__(self, protocol, f_type):
        self.__protocol = protocol
        self.__type = f_type

    def write(self, data):
        while data:
            self.__protocol.send(self.__type, data[:FCGI_MAX_DATA])
            data = data[FCGI_MAX_DATA:]
            self.__protocol.poll()
    
    def flush(self):
        self.__protocol.flush()


class FCGIProtocol:
    hdrfmt = '>BBHHBB'
    assert struct.calcsize(hdrfmt) == FCGI_HEADER_LEN

    def __init__(self):
        if listen_sock < 0:
            fcgi_init()
        self.sock, addr = listen_sock.accept()
        self.sock.setblocking(0)
        self.stdin = []
        self.params = {}
        self.recv_buf = ''
        self.send_buf = ''
        self.in_fds = [self.sock]
        self.out_fds = []
        self.current_requestId = 0
        self.server_request_complete = False
        if web_server_addrs and addr not in web_server_addrs:
            raise FCGIError('Received request from bad IP address: %s '
                            '(expected %s)' % 
                            (addr, ', or '.join(web_server_addrs)))
        sys.stdout = FCGIFileOut(self, FCGI_STDOUT)
        if multiplex_stderr:
            sys.stderr = FCGIFileOut(self, FCGI_STDERR)
        while not self.server_request_complete:
            self.poll()

    def getFieldStorage(self):
        return cgi.FieldStorage(fp=StringIO(''.join(self.stdin)),
                                environ=self.params, keep_blank_values=1)

    def flush(self):
        while self.send_buf:
            self.poll()
        
    def close(self):
        self.recv_buf = ''
        self.send_buf = ''
        self.in_fds = []
        self.out_fds = []
        if self.sock:
            self.sock.close()
            self.sock = None

    def end(self):
        self.send_end_request(FCGI_REQUEST_COMPLETE, 0)
        self.flush()
        self.close()

    def send(self, f_type, content, f_requestId=None):
        if self.sock:
            if f_requestId is None:
                f_requestId = self.current_requestId
            f_length = len(content)
            assert f_length <= FCGI_MAX_DATA
            hdr = struct.pack(self.hdrfmt, FCGI_VERSION, f_type, f_requestId,
                              f_length, 0, 0)
            self.send_buf = ''.join((self.send_buf, hdr, content))
            if not self.out_fds:
                self.out_fds = [self.sock]

    def send_unknown_type(self, f_type):
        self.send(FCGI_UNKNOWN_TYPE, struct.pack('>B7x', f_type), 0)

    def send_end_request(self, f_status, app_status, f_requestId=None):
        data = struct.pack('>LB3x', f_status, app_status)
        self.send(FCGI_END_REQUEST, data, f_requestId)

    def process_incoming(self):
        # Digest any complete packets that have been received
        while len(self.recv_buf) >= FCGI_HEADER_LEN:
            f_version, f_type, f_requestId, f_contentLength, \
                f_paddingLength, f_reserved = \
                    struct.unpack(self.hdrfmt, self.recv_buf[:FCGI_HEADER_LEN])
            if f_version != FCGI_VERSION:
                raise FCGIError('Received message version %d, expect %d' %
                                (f_version, FCGI_VERSION))
            content_end = FCGI_HEADER_LEN + f_contentLength
            padding_end = content_end + f_paddingLength
            if len(self.recv_buf) < padding_end:
                break
            content = self.recv_buf[FCGI_HEADER_LEN:content_end]
            self.recv_buf = self.recv_buf[padding_end:]
            try:
                type_handler = self.type_dispatch[f_type]
            except IndexError:
                type_handler = None
            if type_handler is None:
                self.send_unknown_type(f_type)
            else:
                if f_type == FCGI_BEGIN_REQUEST:
                    if self.current_requestId:
                        self.send_end_request(FCGI_CANT_MPX_CONN, 0,
                                              f_requestId=f_requestId)
                        continue
                    else:
                        self.current_requestId = f_requestId
                elif f_requestId and f_requestId != self.current_requestId:
                    raise FCGIError('unsupported concurrent request received')
                type_handler(self, content)

    # Poll the fcgi socket, process any incoming messages, and send any
    # outgoing data.
    def poll(self):
        if self.sock:
            r, w, e = select.select(self.in_fds, self.out_fds, [])
            # Read any new data
            if r:
                try:
                    buf = self.sock.recv(16384)
                except socket.error, (eno, estr):
                    if eno != errno.ECONNRESET:
                        raise
                    self.close()
                else:
                    if not buf:
                        self.close()
                    else:
                        self.recv_buf += buf
            self.process_incoming()
            # Send as much data as we can
            if w:
                if self.send_buf:
                    try:
                        l = self.sock.send(self.send_buf)
                    except socket.error, (eno, estr):
                        if eno != errno.ECONNRESET:
                            raise
                        self.close()
                    else:
                        self.send_buf = self.send_buf[l:]
                if not self.send_buf:
                    self.out_fds = []

    def nvpair(self, offs, buf):
        # FCGI encodes name-value pairs as two lengths, then two strings
        # of the specified lengths. The lengths are encoded either as a single
        # byte less than 127, or, if the high bit is set, as a 4 byte big
        # endian.
        name_len = ord(buf[offs])
        if name_len & 0x80:
            name_len = struct.unpack('>L', buf[offs:offs+4])[0] & 0x7fffffff;
            offs += 4
        else:
            offs += 1
        value_len = ord(buf[offs])
        if value_len & 0x80:
            value_len = struct.unpack('>L', buf[offs:offs+4])[0] & 0x7fffffff;
            offs += 4
        else:
            offs += 1
        return offs+name_len+value_len, \
            buf[offs:offs+name_len], \
            buf[offs+name_len:offs+name_len+value_len]

    type_dispatch = [None] * (FCGI_MAXTYPE+1)

    def fcgi_begin_request(self, content):
        role, flags = struct.unpack('>HB', content[:3])
        if role != FCGI_RESPONDER:
            self.send_end_request(FCGI_UNKNOWN_ROLE,0)
    type_dispatch[FCGI_BEGIN_REQUEST] = fcgi_begin_request

    def fcgi_abort_request(self, content):
        pass
    type_dispatch[FCGI_ABORT_REQUEST] = fcgi_abort_request

    def fcgi_end_request(self, content):
        pass
    type_dispatch[FCGI_END_REQUEST] = fcgi_end_request

    def fcgi_params(self, content):
        offs = 0
        while offs < len(content):
            offs, name, value = self.nvpair(offs, content)
            self.params[name] = value
    type_dispatch[FCGI_PARAMS] = fcgi_params

    def fcgi_stdin(self, content):
        if content:
            self.stdin.append(content)
        else:
            self.server_request_complete = True
    type_dispatch[FCGI_STDIN] = fcgi_stdin

    def fcgi_get_values(self, content):
        raise NotImplementedError
    type_dispatch[FCGI_GET_VALUES] = fcgi_get_values


if is_fcgi():
    class Request(cgiapp.Request):
        def __init__(self, fields = None):
            self.__protocol = FCGIProtocol()
            if fields is None:
                fields = self.__protocol.getFieldStorage()
            cgiapp.Request.__init__(self, fields)

        def get_param(self, name, default=None):
            return self.__protocol.params.get(name, default)
        
        def return_code(self):
            self.__protocol.end()

    def running():
        return True
else:
    Request = cgiapp.Request
    is_running = True
    def running():
        global is_running
        try:
            return is_running
        finally:
            is_running = False
