# Copyright 2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime, timedelta
import logging
import time
import six
import warnings

from cassandra.cqlengine import UnicodeMixin
from cassandra.cqlengine.functions import QueryValue
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator

log = logging.getLogger(__name__)


class StatementException(Exception):
    pass


class ValueQuoter(UnicodeMixin):

    def __init__(self, value):
        self.value = value

    def __unicode__(self):
        from cassandra.encoder import cql_quote
        if isinstance(self.value, bool):
            return 'true' if self.value else 'false'
        elif isinstance(self.value, (list, tuple)):
            return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']'
        elif isinstance(self.value, dict):
            return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}'
        elif isinstance(self.value, set):
            return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}'
        return cql_quote(self.value)

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.value == other.value
        return False


class InQuoter(ValueQuoter):

    def __unicode__(self):
        from cassandra.encoder import cql_quote
        return '(' + ', '.join([cql_quote(v) for v in self.value]) + ')'


class BaseClause(UnicodeMixin):

    def __init__(self, field, value):
        self.field = field
        self.value = value
        self.context_id = None

    def __unicode__(self):
        raise NotImplementedError

    def __hash__(self):
        return hash(self.field) ^ hash(self.value)

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.field == other.field and self.value == other.value
        return False

    def __ne__(self, other):
        return not self.__eq__(other)

    def get_context_size(self):
        """ returns the number of entries this clause will add to the query context """
        return 1

    def set_context_id(self, i):
        """ sets the value placeholder that will be used in the query """
        self.context_id = i

    def update_context(self, ctx):
        """ updates the query context with this clauses values """
        assert isinstance(ctx, dict)
        ctx[str(self.context_id)] = self.value


class WhereClause(BaseClause):
    """ a single where statement used in queries """

    def __init__(self, field, operator, value, quote_field=True):
        """

        :param field:
        :param operator:
        :param value:
        :param quote_field: hack to get the token function rendering properly
        :return:
        """
        if not isinstance(operator, BaseWhereOperator):
            raise StatementException(
                "operator must be of type {}, got {}".format(BaseWhereOperator, type(operator))
            )
        super(WhereClause, self).__init__(field, value)
        self.operator = operator
        self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value)
        self.quote_field = quote_field

    def __unicode__(self):
        field = ('"{}"' if self.quote_field else '{}').format(self.field)
        return u'{} {} {}'.format(field, self.operator, six.text_type(self.query_value))

    def __hash__(self):
        return super(WhereClause, self).__hash__() ^ hash(self.operator)

    def __eq__(self, other):
        if super(WhereClause, self).__eq__(other):
            return self.operator.__class__ == other.operator.__class__
        return False

    def get_context_size(self):
        return self.query_value.get_context_size()

    def set_context_id(self, i):
        super(WhereClause, self).set_context_id(i)
        self.query_value.set_context_id(i)

    def update_context(self, ctx):
        if isinstance(self.operator, InOperator):
            ctx[str(self.context_id)] = InQuoter(self.value)
        else:
            self.query_value.update_context(ctx)


class AssignmentClause(BaseClause):
    """ a single variable st statement """

    def __unicode__(self):
        return u'"{}" = %({})s'.format(self.field, self.context_id)

    def insert_tuple(self):
        return self.field, self.context_id


class TransactionClause(BaseClause):
    """ A single variable iff statement """

    def __unicode__(self):
        return u'"{}" = %({})s'.format(self.field, self.context_id)

    def insert_tuple(self):
        return self.field, self.context_id


class ContainerUpdateClause(AssignmentClause):

    def __init__(self, field, value, operation=None, previous=None, column=None):
        super(ContainerUpdateClause, self).__init__(field, value)
        self.previous = previous
        self._assignments = None
        self._operation = operation
        self._analyzed = False
        self._column = column

    def _to_database(self, val):
        return self._column.to_database(val) if self._column else val

    def _analyze(self):
        raise NotImplementedError

    def get_context_size(self):
        raise NotImplementedError

    def update_context(self, ctx):
        raise NotImplementedError


class SetUpdateClause(ContainerUpdateClause):
    """ updates a set collection """

    def __init__(self, field, value, operation=None, previous=None, column=None):
        super(SetUpdateClause, self).__init__(field, value, operation, previous, column=column)
        self._additions = None
        self._removals = None

    def __unicode__(self):
        qs = []
        ctx_id = self.context_id
        if (self.previous is None and
                self._assignments is None and
                self._additions is None and
                self._removals is None):
            qs += ['"{}" = %({})s'.format(self.field, ctx_id)]
        if self._assignments is not None:
            qs += ['"{}" = %({})s'.format(self.field, ctx_id)]
            ctx_id += 1
        if self._additions is not None:
            qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)]
            ctx_id += 1
        if self._removals is not None:
            qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)]

        return ', '.join(qs)

    def _analyze(self):
        """ works out the updates to be performed """
        if self.value is None or self.value == self.previous:
            pass
        elif self._operation == "add":
            self._additions = self.value
        elif self._operation == "remove":
            self._removals = self.value
        elif self.previous is None:
            self._assignments = self.value
        else:
            # partial update time
            self._additions = (self.value - self.previous) or None
            self._removals = (self.previous - self.value) or None
        self._analyzed = True

    def get_context_size(self):
        if not self._analyzed:
            self._analyze()
        if (self.previous is None and
                not self._assignments and
                self._additions is None and
                self._removals is None):
            return 1
        return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals))

    def update_context(self, ctx):
        if not self._analyzed:
            self._analyze()
        ctx_id = self.context_id
        if (self.previous is None and
                self._assignments is None and
                self._additions is None and
                self._removals is None):
            ctx[str(ctx_id)] = self._to_database({})
        if self._assignments is not None:
            ctx[str(ctx_id)] = self._to_database(self._assignments)
            ctx_id += 1
        if self._additions is not None:
            ctx[str(ctx_id)] = self._to_database(self._additions)
            ctx_id += 1
        if self._removals is not None:
            ctx[str(ctx_id)] = self._to_database(self._removals)


class ListUpdateClause(ContainerUpdateClause):
    """ updates a list collection """

    def __init__(self, field, value, operation=None, previous=None, column=None):
        super(ListUpdateClause, self).__init__(field, value, operation, previous, column=column)
        self._append = None
        self._prepend = None

    def __unicode__(self):
        if not self._analyzed:
            self._analyze()
        qs = []
        ctx_id = self.context_id
        if self._assignments is not None:
            qs += ['"{}" = %({})s'.format(self.field, ctx_id)]
            ctx_id += 1

        if self._prepend is not None:
            qs += ['"{0}" = %({1})s + "{0}"'.format(self.field, ctx_id)]
            ctx_id += 1

        if self._append is not None:
            qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)]

        return ', '.join(qs)

    def get_context_size(self):
        if not self._analyzed:
            self._analyze()
        return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend))

    def update_context(self, ctx):
        if not self._analyzed:
            self._analyze()
        ctx_id = self.context_id
        if self._assignments is not None:
            ctx[str(ctx_id)] = self._to_database(self._assignments)
            ctx_id += 1
        if self._prepend is not None:
            msg = "Previous versions of cqlengine implicitly reversed prepended lists to account for CASSANDRA-8733. " \
                  "THIS VERSION DOES NOT. This warning will be removed in a future release."
            warnings.warn(msg)
            log.warning(msg)

            ctx[str(ctx_id)] = self._to_database(self._prepend)
            ctx_id += 1
        if self._append is not None:
            ctx[str(ctx_id)] = self._to_database(self._append)

    def _analyze(self):
        """ works out the updates to be performed """
        if self.value is None or self.value == self.previous:
            pass

        elif self._operation == "append":
            self._append = self.value

        elif self._operation == "prepend":
            self._prepend = self.value

        elif self.previous is None:
            self._assignments = self.value

        elif len(self.value) < len(self.previous):
            # if elements have been removed,
            # rewrite the whole list
            self._assignments = self.value

        elif len(self.previous) == 0:
            # if we're updating from an empty
            # list, do a complete insert
            self._assignments = self.value
        else:

            # the max start idx we want to compare
            search_space = len(self.value) - max(0, len(self.previous) - 1)

            # the size of the sub lists we want to look at
            search_size = len(self.previous)

            for i in range(search_space):
                # slice boundary
                j = i + search_size
                sub = self.value[i:j]
                idx_cmp = lambda idx: self.previous[idx] == sub[idx]
                if idx_cmp(0) and idx_cmp(-1) and self.previous == sub:
                    self._prepend = self.value[:i] or None
                    self._append = self.value[j:] or None
                    break

            # if both append and prepend are still None after looking
            # at both lists, an insert statement will be created
            if self._prepend is self._append is None:
                self._assignments = self.value

        self._analyzed = True


class MapUpdateClause(ContainerUpdateClause):
    """ updates a map collection """

    def __init__(self, field, value, operation=None, previous=None, column=None):
        super(MapUpdateClause, self).__init__(field, value, operation, previous, column=column)
        self._updates = None

    def _analyze(self):
        if self._operation == "update":
            self._updates = self.value.keys()
        else:
            if self.previous is None:
                self._updates = sorted([k for k, v in self.value.items()])
            else:
                self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None
        self._analyzed = True

    def get_context_size(self):
        if not self._analyzed:
            self._analyze()
        if self.previous is None and not self._updates:
            return 1
        return len(self._updates or []) * 2

    def update_context(self, ctx):
        if not self._analyzed:
            self._analyze()
        ctx_id = self.context_id
        if self.previous is None and not self._updates:
            ctx[str(ctx_id)] = {}
        else:
            for key in self._updates or []:
                val = self.value.get(key)
                ctx[str(ctx_id)] = self._column.key_col.to_database(key) if self._column else key
                ctx[str(ctx_id + 1)] = self._column.value_col.to_database(val) if self._column else val
                ctx_id += 2

    def __unicode__(self):
        if not self._analyzed:
            self._analyze()
        qs = []

        ctx_id = self.context_id
        if self.previous is None and not self._updates:
            qs += ['"{}" = %({})s'.format(self.field, ctx_id)]
        else:
            for _ in self._updates or []:
                qs += ['"{}"[%({})s] = %({})s'.format(self.field, ctx_id, ctx_id + 1)]
                ctx_id += 2

        return ', '.join(qs)


class CounterUpdateClause(ContainerUpdateClause):

    def __init__(self, field, value, previous=None, column=None):
        super(CounterUpdateClause, self).__init__(field, value, previous=previous, column=column)
        self.previous = self.previous or 0

    def get_context_size(self):
        return 1

    def update_context(self, ctx):
        ctx[str(self.context_id)] = self._to_database(abs(self.value - self.previous))

    def __unicode__(self):
        delta = self.value - self.previous
        sign = '-' if delta < 0 else '+'
        return '"{0}" = "{0}" {1} %({2})s'.format(self.field, sign, self.context_id)


class BaseDeleteClause(BaseClause):
    pass


class FieldDeleteClause(BaseDeleteClause):
    """ deletes a field from a row """

    def __init__(self, field):
        super(FieldDeleteClause, self).__init__(field, None)

    def __unicode__(self):
        return '"{}"'.format(self.field)

    def update_context(self, ctx):
        pass

    def get_context_size(self):
        return 0


class MapDeleteClause(BaseDeleteClause):
    """ removes keys from a map """

    def __init__(self, field, value, previous=None):
        super(MapDeleteClause, self).__init__(field, value)
        self.value = self.value or {}
        self.previous = previous or {}
        self._analyzed = False
        self._removals = None

    def _analyze(self):
        self._removals = sorted([k for k in self.previous if k not in self.value])
        self._analyzed = True

    def update_context(self, ctx):
        if not self._analyzed:
            self._analyze()
        for idx, key in enumerate(self._removals):
            ctx[str(self.context_id + idx)] = key

    def get_context_size(self):
        if not self._analyzed:
            self._analyze()
        return len(self._removals)

    def __unicode__(self):
        if not self._analyzed:
            self._analyze()
        return ', '.join(['"{}"[%({})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))])


class BaseCQLStatement(UnicodeMixin):
    """ The base cql statement class """

    def __init__(self, table, consistency=None, timestamp=None, where=None):
        super(BaseCQLStatement, self).__init__()
        self.table = table
        self.consistency = consistency
        self.context_id = 0
        self.context_counter = self.context_id
        self.timestamp = timestamp

        self.where_clauses = []
        for clause in where or []:
            self.add_where_clause(clause)

    def add_where_clause(self, clause):
        """
        adds a where clause to this statement
        :param clause: the clause to add
        :type clause: WhereClause
        """
        if not isinstance(clause, WhereClause):
            raise StatementException("only instances of WhereClause can be added to statements")
        clause.set_context_id(self.context_counter)
        self.context_counter += clause.get_context_size()
        self.where_clauses.append(clause)

    def get_context(self):
        """
        returns the context dict for this statement
        :rtype: dict
        """
        ctx = {}
        for clause in self.where_clauses or []:
            clause.update_context(ctx)
        return ctx

    def get_context_size(self):
        return len(self.get_context())

    def update_context_id(self, i):
        self.context_id = i
        self.context_counter = self.context_id
        for clause in self.where_clauses:
            clause.set_context_id(self.context_counter)
            self.context_counter += clause.get_context_size()

    @property
    def timestamp_normalized(self):
        """
        we're expecting self.timestamp to be either a long, int, a datetime, or a timedelta
        :return:
        """
        if not self.timestamp:
            return None

        if isinstance(self.timestamp, six.integer_types):
            return self.timestamp

        if isinstance(self.timestamp, timedelta):
            tmp = datetime.now() + self.timestamp
        else:
            tmp = self.timestamp

        return int(time.mktime(tmp.timetuple()) * 1e+6 + tmp.microsecond)

    def __unicode__(self):
        raise NotImplementedError

    def __repr__(self):
        return self.__unicode__()

    @property
    def _where(self):
        return 'WHERE {}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses]))


class SelectStatement(BaseCQLStatement):
    """ a cql select statement """

    def __init__(self,
                 table,
                 fields=None,
                 count=False,
                 consistency=None,
                 where=None,
                 order_by=None,
                 limit=None,
                 allow_filtering=False):

        """
        :param where
        :type where list of cqlengine.statements.WhereClause
        """
        super(SelectStatement, self).__init__(
            table,
            consistency=consistency,
            where=where
        )

        self.fields = [fields] if isinstance(fields, six.string_types) else (fields or [])
        self.count = count
        self.order_by = [order_by] if isinstance(order_by, six.string_types) else order_by
        self.limit = limit
        self.allow_filtering = allow_filtering

    def __unicode__(self):
        qs = ['SELECT']
        if self.count:
            qs += ['COUNT(*)']
        else:
            qs += [', '.join(['"{}"'.format(f) for f in self.fields]) if self.fields else '*']
        qs += ['FROM', self.table]

        if self.where_clauses:
            qs += [self._where]

        if self.order_by and not self.count:
            qs += ['ORDER BY {}'.format(', '.join(six.text_type(o) for o in self.order_by))]

        if self.limit:
            qs += ['LIMIT {}'.format(self.limit)]

        if self.allow_filtering:
            qs += ['ALLOW FILTERING']

        return ' '.join(qs)


class AssignmentStatement(BaseCQLStatement):
    """ value assignment statements """

    def __init__(self,
                 table,
                 assignments=None,
                 consistency=None,
                 where=None,
                 ttl=None,
                 timestamp=None):
        super(AssignmentStatement, self).__init__(
            table,
            consistency=consistency,
            where=where,
        )
        self.ttl = ttl
        self.timestamp = timestamp

        # add assignments
        self.assignments = []
        for assignment in assignments or []:
            self.add_assignment_clause(assignment)

    def update_context_id(self, i):
        super(AssignmentStatement, self).update_context_id(i)
        for assignment in self.assignments:
            assignment.set_context_id(self.context_counter)
            self.context_counter += assignment.get_context_size()

    def add_assignment_clause(self, clause):
        """
        adds an assignment clause to this statement
        :param clause: the clause to add
        :type clause: AssignmentClause
        """
        if not isinstance(clause, AssignmentClause):
            raise StatementException("only instances of AssignmentClause can be added to statements")
        clause.set_context_id(self.context_counter)
        self.context_counter += clause.get_context_size()
        self.assignments.append(clause)

    @property
    def is_empty(self):
        return len(self.assignments) == 0

    def get_context(self):
        ctx = super(AssignmentStatement, self).get_context()
        for clause in self.assignments:
            clause.update_context(ctx)
        return ctx


class InsertStatement(AssignmentStatement):
    """ an cql insert select statement """

    def __init__(self,
                 table,
                 assignments=None,
                 consistency=None,
                 where=None,
                 ttl=None,
                 timestamp=None,
                 if_not_exists=False):
        super(InsertStatement, self).__init__(table,
                                              assignments=assignments,
                                              consistency=consistency,
                                              where=where,
                                              ttl=ttl,
                                              timestamp=timestamp)

        self.if_not_exists = if_not_exists

    def add_where_clause(self, clause):
        raise StatementException("Cannot add where clauses to insert statements")

    def __unicode__(self):
        qs = ['INSERT INTO {}'.format(self.table)]

        # get column names and context placeholders
        fields = [a.insert_tuple() for a in self.assignments]
        columns, values = zip(*fields)

        qs += ["({})".format(', '.join(['"{}"'.format(c) for c in columns]))]
        qs += ['VALUES']
        qs += ["({})".format(', '.join(['%({})s'.format(v) for v in values]))]

        if self.if_not_exists:
            qs += ["IF NOT EXISTS"]

        if self.ttl:
            qs += ["USING TTL {}".format(self.ttl)]

        if self.timestamp:
            qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)]

        return ' '.join(qs)


class UpdateStatement(AssignmentStatement):
    """ an cql update select statement """

    def __init__(self,
                 table,
                 assignments=None,
                 consistency=None,
                 where=None,
                 ttl=None,
                 timestamp=None,
                 transactions=None):
        super(UpdateStatement, self). __init__(table,
                                               assignments=assignments,
                                               consistency=consistency,
                                               where=where,
                                               ttl=ttl,
                                               timestamp=timestamp)

        # Add iff statements
        self.transactions = []
        for transaction in transactions or []:
            self.add_transaction_clause(transaction)

    def __unicode__(self):
        qs = ['UPDATE', self.table]

        using_options = []

        if self.ttl:
            using_options += ["TTL {}".format(self.ttl)]

        if self.timestamp:
            using_options += ["TIMESTAMP {}".format(self.timestamp_normalized)]

        if using_options:
            qs += ["USING {}".format(" AND ".join(using_options))]

        qs += ['SET']
        qs += [', '.join([six.text_type(c) for c in self.assignments])]

        if self.where_clauses:
            qs += [self._where]

        if len(self.transactions) > 0:
            qs += [self._get_transactions()]

        return ' '.join(qs)

    def add_transaction_clause(self, clause):
        """
        Adds a iff clause to this statement

        :param clause: The clause that will be added to the iff statement
        :type clause: TransactionClause
        """
        if not isinstance(clause, TransactionClause):
            raise StatementException('only instances of AssignmentClause can be added to statements')
        clause.set_context_id(self.context_counter)
        self.context_counter += clause.get_context_size()
        self.transactions.append(clause)

    def get_context(self):
        ctx = super(UpdateStatement, self).get_context()
        for clause in self.transactions or []:
            clause.update_context(ctx)
        return ctx

    def _get_transactions(self):
        return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))

    def update_context_id(self, i):
        super(UpdateStatement, self).update_context_id(i)
        for transaction in self.transactions:
            transaction.set_context_id(self.context_counter)
            self.context_counter += transaction.get_context_size()


class DeleteStatement(BaseCQLStatement):
    """ a cql delete statement """

    def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None):
        super(DeleteStatement, self).__init__(
            table,
            consistency=consistency,
            where=where,
            timestamp=timestamp
        )
        self.fields = []
        if isinstance(fields, six.string_types):
            fields = [fields]
        for field in fields or []:
            self.add_field(field)

    def update_context_id(self, i):
        super(DeleteStatement, self).update_context_id(i)
        for field in self.fields:
            field.set_context_id(self.context_counter)
            self.context_counter += field.get_context_size()

    def get_context(self):
        ctx = super(DeleteStatement, self).get_context()
        for field in self.fields:
            field.update_context(ctx)
        return ctx

    def add_field(self, field):
        if isinstance(field, six.string_types):
            field = FieldDeleteClause(field)
        if not isinstance(field, BaseClause):
            raise StatementException("only instances of AssignmentClause can be added to statements")
        field.set_context_id(self.context_counter)
        self.context_counter += field.get_context_size()
        self.fields.append(field)

    def __unicode__(self):
        qs = ['DELETE']
        if self.fields:
            qs += [', '.join(['{}'.format(f) for f in self.fields])]
        qs += ['FROM', self.table]

        delete_option = []

        if self.timestamp:
            delete_option += ["TIMESTAMP {}".format(self.timestamp_normalized)]

        if delete_option:
            qs += [" USING {} ".format(" AND ".join(delete_option))]

        if self.where_clauses:
            qs += [self._where]

        return ' '.join(qs)
