#
# Copyright 2001 by Object Craft P/L, Melbourne, Australia.
#
# LICENCE - see LICENCE file distributed with this software for details.
#

import socket
import errno
import base64
import os.path

try:
    import Cookie
except ImportError:
    pass
try:
    import zlib
    have_zlib = True
except ImportError:
    have_zlib = False

from albatross.context import SessionBase
from albatross.common import *

class ServerDisconnect(ServerError):
    pass

class SessionCookieMixin:
    def _get_sesid_from_cookie(self):
        hdr = self.request.get_header('Cookie')
        if hdr:
            c = Cookie.SimpleCookie(hdr)
            try:
                return c[self.app.ses_appid()].value
            except KeyError:
                pass
        return None

    def _set_sesid_cookie(self, sesid):
        c = Cookie.SimpleCookie()
        appid = self.app.ses_appid()
        if sesid is None:
            c[appid] = ''
        else:
            c[appid] = sesid
        path = self.absolute_base_url()
        c[appid]['path'] = path
        if self.parsed_request_uri()[0] == 'https':
            c[appid]['secure'] = True
        prefix_len = len('Set-Cookie: ')
        self.set_header('Set-Cookie', str(c)[prefix_len:])

# Simple session server
class SessionServerContextMixin(SessionCookieMixin, SessionBase):

    def __init__(self):
        SessionBase.__init__(self)
        self.__sesid = None

    def sesid(self):
        return self.__sesid

    def load_session(self):
        sesid = self._get_sesid_from_cookie()
        text = None
        if sesid:
            text = self.app.get_session(sesid)
            if not text:
                self._set_sesid_cookie(None)
                raise SessionExpired('Session expired or browser does not support cookies')
        else:
            sesid = self.app.new_session()
        self.__sesid = sesid
        if text:
            text = base64.decodestring(text)
            try:
                if have_zlib:
                    text = zlib.decompress(text)
                self.decode_session(text)
            except:
                self.app.del_session(sesid)
                raise
        self._set_sesid_cookie(self.__sesid)

    def new_session(self):
        self.__sesid = self.app.new_session()
        self._set_sesid_cookie(self.__sesid)

    def save_session(self):
        if self.should_save_session() and self.__sesid is not None:
            text = self.encode_session()
            if have_zlib:
                text = zlib.compress(text)
            text = base64.encodestring(text)
            self.app.put_session(self.__sesid, text)

    def remove_session(self):
        SessionBase.remove_session(self)
        if self.__sesid is not None:
            self.app.del_session(self.__sesid)
            self.__sesid = None
            self._set_sesid_cookie(self.__sesid)


class SessionServerAppMixin:

    def __init__(self, appid, server = 'localhost', port = 34343, age = 1800):
        self.__appid = appid
        self.__server = server
        self.__port = port
        self.__sock = None
        self.__buf = ''
        self.__age = age
        try:
            self._server_connect()
        except ServerError:
            pass

    def ses_appid(self):
        return self.__appid

    def ses_age(self):
        return self.__age

    def _server_connect(self):
        if self.__sock:
            return
        while 1:
            try:
                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                sock.connect((self.__server, self.__port))
                self.__sock = sock
                return
            except socket.error, (eno, estr):
                if eno != errno.EINTR:
                    raise ServerError('could not connect to session server: %s' % estr)

    def _server_close(self):
        self.__sock = None
        self.__buf = ''

    def _server_read(self, eol='\r\n'):
        while 1:
            n = self.__buf.find(eol)
            if n >= 0:
                line = self.__buf[:n]
                self.__buf = self.__buf[n+len(eol):]
                return line
            try:
                buf = self.__sock.recv(16384)
            except socket.error, (eno, estr):
                if eno != errno.EINTR:
                    self._server_close()
                    raise ServerError('lost session server: %s' % estr)
            else:
                if not buf:
                    self._server_close()
                    raise ServerDisconnect
                self.__buf += buf

    def _server_write(self, text):
        while text:
            try:
                count = self.__sock.send(text)
                text = text[count:]
            except socket.error, (eno, estr):
                if eno in (errno.ECONNRESET, errno.EPIPE):
                    self._server_close()
                    raise ServerDisconnect
                if eno != errno.EINTR:
                    self._server_close()
                    raise ServerError('lost session server: %s' % estr)

    def _server_read_response(self):
        resp = self._server_read()
        if not resp.startswith('OK'):
            raise ServerError('Session server returned: %s' % resp)
        return resp[2:].strip()

    def get_session(self, sesid):
        while 1:
            try:
                self._server_connect()
                self._server_write('get %s %s\r\n' % (self.__appid, sesid))
                resp = self._server_read()
                if not resp.startswith('OK'):
                    return ''
                return self._server_read('\r\n\r\n')
            except ServerDisconnect:
                pass

    def new_session(self):
        while 1:
            try:
                self._server_connect()
                self._server_write('new %s %s\r\n' % (self.__appid, self.__age))
                return self._server_read_response()
            except ServerDisconnect:
                pass

    def put_session(self, sesid, text):
        while 1:
            try:
                self._server_connect()
                self._server_write('put %s %s\r\n' % (self.__appid, sesid))
                self._server_read_response()
                self._server_write(text.replace('\n', '\r\n') + '\r\n')
                return self._server_read_response()
            except ServerDisconnect:
                pass

    def del_session(self, sesid):
        while 1:
            try:
                self._server_connect()
                self._server_write('del %s %s\r\n' % (self.__appid, sesid))
                return self._server_read_response()
            except ServerDisconnect:
                pass
