from api_dnl.objects_history import BaseModel
#import api_dnl
#import falcon_rest
from falcon_rest.db import get_db#,BaseModel
from sqlalchemy.sql import func
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy.sql.elements import Label
from sqlalchemy import func, distinct, select
from sqlalchemy.orm import lazyload
from sqlalchemy.dialects import postgresql
from falcon_rest.logger import log

LIST_DEBUG=False

def _q_str(q):
	return str(q.statement.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}))
def object_to_dict(obj, found=None, deep=False):
	from datetime import datetime
	from sqlalchemy.orm import class_mapper
	from falcon_rest.db.fields import (Numeric)
	if found is None:
		found = set()
	mapper = class_mapper(obj.__class__)
	columns = [column.key for column in mapper.columns if hasattr(obj,column.key)]
	get_key_value = lambda c: (c, getattr(obj, c).isoformat()) if isinstance(getattr(obj, c), datetime) else (c, getattr(obj, c))
	out = dict(map(get_key_value, columns))
	for k,v in out.items():
		if isinstance(v, Numeric):
			out[k] = float(v)
	if deep:
		for name, relation in mapper.relationships.items():
			if relation not in found:
				found.add(relation)
				related_obj = getattr(obj, name)
				if related_obj is not None:
					if relation.uselist:
						out[name] = [object_to_dict(child, found,deep) for child in related_obj]
					else:
						out[name] = object_to_dict(related_obj, found,deep)
	return out

def get_count(q):
	return get_count_new(q)
	disable_group_by = False
	if len(q._entities) > 1 and False:
		# currently support only one entity
		raise Exception('only one entity is supported for get_count, got: %s' % q)
	entity = q._entities[0]
	if hasattr(entity, 'column'):
		# _ColumnEntity has column attr - on case: query(Model.column)...
		col = entity.column
		if q._group_by and q._distinct:
			# which query can have both?
			raise NotImplementedError
		# if q._group_by or q._distinct:
		# 	col = distinct(col)
		# 	count_func = func.count(col)
		# if q._group_by:
			# need to disable group_by and enable distinct - we can do this because we have only 1 entity
		# 	disable_group_by = True
		if q._group_by or q._distinct:
			count_func = func.count()
		else:
			count_func = func.count(col)
	else:
		# _MapperEntity doesn't have column attr - on case: query(Model)...
		count_func = func.count()
	# if q._group_by and not disable_group_by:
	# 	count_func = count_func.over(None)
	# count_q = q.options(lazyload('*')).statement.with_only_columns([count_func]).order_by(Nonne)
	count_q = q.options(lazyload('*')).statement.with_only_columns([count_func]).order_by(None)
	# if disable_group_by:
	# 	count_q = count_q.group_by(None)
	if q._group_by:
		count_q = count_q.group_by(None)
        # return = q.session.execute(count_q).scalar()
	log.debug("DEBUG SQL ISSUE. SEEMS failing")
	ret = 0
	try:
		log.debug(f"Generated SQL for count query: {str(count_q)}")
		ret = q.session.execute(count_q).scalar()
		log.debug(f"DEBUG SQL: Count query executed, result: {ret}")
	except Exception as e:
		log.debug('get_count FAILED!')
		log.debug(e)
		pass
	return ret


def get_count_new(q):
	entity = q._entities[0]
	if hasattr(entity, 'entity_zero') and entity.entity_zero is not None:
		model = entity.entity_zero.class_
	else:
		mapper = getattr(entity, 'mapper', None)
		if mapper is not None and hasattr(mapper, 'class_'):
			model = mapper.class_
		else:
			model = None

	# Use the first column of the table if possible
	col = None
	if model and hasattr(model, '__table__'):
		col = list(model.__table__.columns)[0]

	# Check if query has aggregate functions (SUM, COUNT, AVG, etc.)
	has_aggregate = False
	try:
		for entity in q._entities:
			# _ColumnEntity has 'expr' or 'column' attribute
			expr = None
			if hasattr(entity, 'expr'):
				expr = entity.expr
			elif hasattr(entity, 'column'):
				expr = entity.column
			elif hasattr(entity, 'element'):
				expr = entity.element

			if expr is not None:
				# Direct check for GenericFunction
				if isinstance(expr, GenericFunction):
					has_aggregate = True
					break

				# Check if it's a Label wrapping a function
				if isinstance(expr, Label) or hasattr(expr, 'element'):
					inner = expr.element if hasattr(expr, 'element') else None
					# Use 'is not None' instead of truthiness to avoid SQLAlchemy boolean error
					if inner is not None and isinstance(inner, GenericFunction):
						has_aggregate = True
						break
	except Exception as e:
		log.debug('Error detecting aggregates in get_count_new: {}'.format(str(e)))
		pass

	# If query has GROUP BY OR has aggregates (even without GROUP BY),
	# we need to wrap it in a subquery to count correctly
	if q._group_by or has_aggregate:
		sub_q = q.order_by(None)
		count_q = q.session.query(func.count()).select_from(sub_q.cte("subquery"))
	else:
		count_func = func.count(distinct(col)) if col is not None and not q._distinct else func.count()
		count_q = q.with_entities(count_func).order_by(None)

	log.debug("DEBUG SQL ISSUE. SEEMS failing")
	ret = 0
	try:
		log.debug(f"Generated SQL for count query: {str(count_q)}")
		ret = q.session.execute(count_q).scalar()
		log.debug(f"DEBUG SQL: Count query executed, result: {ret}")
	except Exception as e:
		log.debug('get_count_new failed: {}'.format(str(e)))
		pass

	# If result is None, return 0
	if ret is None:
		return 0

	# If result is 1, double-check by running a simple query to ensure it should be 1
	if ret == 1:
		# Let's run a quick check query to validate the count
		additional_check_query = q.statement.limit(2)  # Limit to 2 results to verify presence
		additional_check_result = q.session.execute(additional_check_query).fetchall()

		# If no rows are returned, then it means the count of 1 is not valid, return 0
		if not additional_check_result:
			return 0

	return ret


class DnlApiBaseModel(BaseModel):
	__abstract__ = True
	#db_label = 'icx'

	def __repr__(self):
		# return 'obj'
		parts = []
		for column_name in self.get_fields():
			part = '{} = {}'.format(column_name, getattr(self, column_name))
			if len(part) > 1024:
				part = part[:1024] + ' [...long field]'
			parts.append(part)

		return '<{}: {}>'.format(self.__class__.__name__, ', '.join(parts))

	def copy(self):
		from sqlalchemy import inspect
		# SQLAlchemy related data class?
		if not isinstance(self, DnlApiBaseModel):
			raise TypeError('The given parameter with type {} is not ' \
							'mapped by SQLAlchemy.'.format(type(self)))

		mapper = inspect(type(self))
		newObj = type(self)()

		for name, col in mapper.columns.items():
			# no PrimaryKey not Unique
			if not col.primary_key: #and not col.unique:
				setattr(newObj, name, getattr(self, name))

		return newObj

	def as_dict(self):
		return object_to_dict(self)

	@classmethod
	def session(cls):
		return get_db(cls.db_label).session

	@classmethod
	def get_objects_list(cls, query=None, filtering=None, paging=None,
						 ordering=None,query_only = False):
		from sqlalchemy import alias, cast, String
		log.debug("qs = query or cls.query()")
		qs = query or cls.query()
		log.debug("GOT QS")
		if filtering:
			log.debug("STARTING FILTERING")
			for k, v in filtering.items():
				if type(v)==type('') and ('*' in v or '_' in v):  # pragma: no cover
					try:
						v = v.upper().replace('*', '%')
						qs = qs.filter(func.upper(cast(getattr(cls, k), String)).like(v))
					except Exception as e:
						qs = qs.filter(getattr(cls, k) == v)
				else:
					qs = qs.filter(getattr(cls, k) == v)
			log.debug("FINISHED FILTERING")

		if ordering:  # pragma: no cover
			log.debug("ORDERING")
			if ',' in ordering['by']:
				ordl = ordering['by'].split(',')
				dirl = ordering['dir'].split(',')
				ob = []
				i=0
				for ord in ordl:
					order_by = getattr(cls, ord)
					if dirl[i] == 'desc':
						order_by = order_by.desc().nullslast()
					i += 1
					ob.append(order_by)
				qs = qs.order_by(*tuple(ob))
			else:
				order_by = getattr(cls, ordering['by'])
				if ordering['dir'] == 'desc':
					order_by = order_by.desc().nullslast()
				qs = qs.order_by(order_by)
			log.debug("FINISHED ORDERING")
		try:
			log.debug("GETTING COUNT")
			cnt = get_count(qs)
		except Exception as e:  # TODO: Fix this somehow
			log.debug(f"ERROR - {e}")
			cls.session().rollback()
			cnt = paging['till'] - paging['from'] if paging and 'from' in paging and 'till' in paging else 10#len(qs.all())
		if paging:
			log.debug("STARTING PAGING")
			if 'from' in paging:
				qs = qs.offset(paging['from'])
			if 'till' in paging:
				qs = qs.limit(paging['till'] - paging.get('from', 0))
			log.debug("FINISHED PAGING")
		if LIST_DEBUG:
			try:
				log.debug(_q_str(qs))
			except:
				log.debug(str(qs))
		if query_only:
			return qs, cnt
		else:
			return qs.all(), cnt

	@classmethod
	def get_objects_query_list(cls, query=None, filtering=None, paging=None,
						 ordering=None):
		from sqlalchemy import alias,cast,String
		qs = query or cls.query()
		if hasattr(cls,'_tab'):
			tab = cls._tab
		else:
			tab = alias(qs,cls.__tablename__)
		if filtering:
			for k, v in filtering.items():

				if type(v)==type('') and '*' in v:  # pragma: no cover
					v = v.upper()
					qs = qs.filter(func.upper(cast(getattr(tab.c, k),String)).like(v.replace('*', '%')))
				else:
					qs = qs.filter(getattr(tab.c, k) == v)

		if ordering:  # pragma: no cover
			order_by = getattr(tab.c, ordering['by'])
			if ordering['dir'] == 'desc':
				order_by = order_by.desc().nullslast()

			qs = qs.order_by(order_by)

		try:
			cnt = get_count(qs)#cnt = qs.count()
		except Exception as e:  # TODO: AssertionError Fix this somehow
			cnt = len(qs.all())

		if paging:
			if 'from' in paging:
				qs = qs.offset(paging['from'])
			if 'till' in paging:
				qs = qs.limit(paging['till'] - paging.get('from', 0))
		try:
			log.debug(_q_str(qs))
		except:
			log.debug(str(qs))
		return qs.all(), cnt

	@classmethod
	def get_objects_query(cls, query=None, filtering=None, paging=None,
	                           ordering=None):
		from sqlalchemy import alias, cast, String
		qs = query or cls.query()
		if hasattr(cls, '_tab'):
			tab = cls._tab
		else:
			tab = alias(qs, cls.__tablename__)
		if filtering:
			for k, v in filtering.items():

				if type(v) == type('') and '*' in v:  # pragma: no cover
					v = v.upper()
					qs = qs.filter(func.upper(cast(getattr(tab.c, k), String)).like(v.replace('*', '%')))
				else:
					if hasattr(tab.c, k):
						qs = qs.filter(getattr(tab.c, k) == v)
					# elif hasattr(tab.c, cls.__tablename__+'_'+k):
					# 	qs = qs.filter(getattr(tab.c, cls.__tablename__+'_'+k) == v)
					else:
						raise Exception('wrong filter {}={}'.format(k,v))

		if ordering:  # pragma: no cover
			order_by = getattr(tab.c, ordering['by'])
			if ordering['dir'] == 'desc':
				order_by = order_by.desc().nullslast()

			qs = qs.order_by(order_by)

		try:
			cnt = get_count(qs)  # cnt = qs.count()
		except Exception as e:  # TODO: AssertionError Fix this somehow
			cnt = None# len(qs.all())

		if paging:
			if 'from' in paging:
				qs = qs.offset(paging['from'])
			if 'till' in paging:
				qs = qs.limit(paging['till'] - paging.get('from', 0))
		try:
			log.debug(_q_str(qs))
		except:
			log.debug(str(qs))
		return qs, cnt


	def add_history_record_on_create_or_update(self, **kwargs):
		if 'user' not in kwargs:  # pragma: no cover
			raise Exception(
				'No user passed to object save method while use_history == True and save_history=True passed'
			)
		sess = get_db(self.db_label).session

		entity_name = self.get_name()
		entity_pk = str(self.get_object_primary_key_value())[:64]

		revision_number = self.get_next_revision_number()

		action = self.get_action_create_or_update(revision_number, entity_name, entity_pk)

		module = kwargs.get('module', None)
		object_name = None
		for k in ('alias', 'name', '_name', 'trunk_name'):
			if k in kwargs:
				object_name = kwargs[k]
				break
		if not object_name:
			for k in ('alias', 'name', '_name', 'trunk_name'):
				if k in self.as_dict():
					object_name = self.as_dict()[k]
					break
		if not object_name and ('resource_id' in self.as_dict() or 'resource_id' in kwargs):
			from api_dnl.model import Resource
			if 'resource_id' in self.as_dict():
				resource_id = self.as_dict().get('resource_id')
			else:
				resource_id = kwargs.get('resource_id')
			resource = Resource.get(resource_id)
			if resource:
				object_name = resource.name or resource.alias

		revision_obj_id = self.object_revision_module.ObjectRevisionModel(
			user_id=kwargs['user'].get_id(),
			entity_name=entity_name,
			entity_pk=entity_pk,
			action=action,
			revision_number=revision_number,
			module=module,
			object_name=object_name
		).save()

		data = []
		if action == 'create' or not self.previous_data:
			for field_name, value in self.get_current_data().items():
				data.append(dict(
					object_revision_id=revision_obj_id,
					field_name=field_name,
					old_value=None,
					new_value=value
				))
		else:
			for field_name, values in self.get_diff().items():
				data.append(dict(
					object_revision_id=revision_obj_id,
					field_name=field_name,
					old_value=values[0],
					new_value=values[1]
				))

		from api_dnl.tasks import do_save_history
		try:
			do_save_history.delay(data)
		except:
			do_save_history(data)


from api_dnl.settings import DB_CONN_STRING
def init_db(conn_string = DB_CONN_STRING,db_label='default',declarative_base=DnlApiBaseModel):
    from falcon_rest.db import initialize_db,_DB,_databases
    from api_dnl.settings import DB_CONN_STRING,DB_POOLSIZE
    import sqlalchemy
    from sqlalchemy import event
    from sqlalchemy.orm import sessionmaker,scoped_session
    #initialize_db(DB_CONN_STRING, DnlApiBaseModel)
    #conn_string = DB_CONN_STRING
    log.info('DB [{}]: Initializing DB, label {}'.format(db_label, db_label))
    print('DB [{}]: Initializing DB, label {}'.format(db_label, db_label))
    db = _DB(declarative_base, db_label)
    db.tr_engine = sqlalchemy.create_engine(conn_string, pool_size=int(DB_POOLSIZE))
	#
    # @event.listens_for(db.tr_engine, 'begin')
    # def receive_begin(conn):
	#     conn.execute('SET TRANSACTION READ ONLY')

    db.tr_session = scoped_session(sessionmaker(bind=db.tr_engine))
    #db.engine = db.tr_engine.execution_options(isolation_level="AUTOCOMMIT")
    db.engine =  sqlalchemy.create_engine(conn_string, pool_size=int(DB_POOLSIZE), isolation_level="AUTOCOMMIT")
    db.session = scoped_session(sessionmaker(bind=db.engine))
    _databases[db_label] = db
    
    return _databases[db_label]
