from psycopg import AsyncConnection import config import models conninfo = 'host=%(host)s port=%(port)s user=%(user)s password=%(password)s dbname=%(dbname)s' % { 'host': config.PostgreSQL.host, 'port': config.PostgreSQL.port, 'user': config.PostgreSQL.user, 'password': config.PostgreSQL.password, 'dbname': config.PostgreSQL.dbname, } class NotFoundError(Exception): pass class Users: @staticmethod async def insert_or_update_user( telegram_user_id: int, first_name: str, last_name: str = None, username: str = None, ) -> models.User: async with await AsyncConnection.connect(conninfo) as connection: async with connection.cursor() as cursor: sql = ''' insert into users ( telegram_user_id, first_name, last_name, username ) values ( %(telegram_user_id)s, %(first_name)s, %(last_name)s, %(username)s ) on conflict (telegram_user_id) do update set first_name = excluded.first_name, last_name = excluded.last_name, username = excluded.username returning users.id, users.telegram_user_id, users.first_name, users.last_name, users.username, users.role; ''' await cursor.execute( sql, { 'telegram_user_id': telegram_user_id, 'first_name': first_name, 'last_name': last_name, 'username': username, }, ) user_id, telegram_user_id, first_name, last_name, username, role, = await cursor.fetchone() return models.User( id=user_id, telegram_id=telegram_user_id, first_name=first_name, last_name=last_name, username=username, role=role, ) class PollSchemas: @staticmethod async def get_poll_schema_by_name( name: str, ) -> models.PollSchema: async with await AsyncConnection.connect(conninfo) as connection: async with connection.cursor() as cursor: sql = ''' select poll_schemas.id, poll_schemas.name, poll_schemas.question from poll_schemas where poll_schemas.name = %(name)s; ''' await cursor.execute( sql, { 'name': name, }, ) try: poll_schema_id, name, question = await cursor.fetchone() except TypeError: raise NotFoundError() return models.PollSchema( id=poll_schema_id, name=name, question=question, ) class PollOptions: @staticmethod async def get_poll_options( poll_schema: models.PollSchema, ordinals: list[int] = None, ) -> list[models.PollOption]: async with await AsyncConnection.connect(conninfo) as connection: async with connection.cursor() as cursor: sql = ''' select poll_options.id, poll_options.name, poll_options.ordinal from poll_options where poll_options.poll_schema_id = %(poll_schema_id)s and poll_options.ordinal is not null; ''' if ordinals is None else ''' select poll_options.id, poll_options.name, poll_options.ordinal from poll_options where poll_options.poll_schema_id = %(poll_schema_id)s and poll_options.ordinal = any(%(ordinals)s); ''' await cursor.execute( sql, { 'poll_schema_id': poll_schema.id, 'ordinals': ordinals, }, ) records = await cursor.fetchall() return [ models.PollOption( id=poll_option_id, poll_schema=poll_schema, name=name, ordinal=ordinal, ) for poll_option_id, name, ordinal in records ] class Polls: @staticmethod async def get_polls() -> list[models.Poll]: async with await AsyncConnection.connect(conninfo) as connection: async with connection.cursor() as cursor: sql = ''' select polls.id, polls.telegram_message_id, polls.telegram_poll_id, poll_schemas.id, poll_schemas.name, poll_schemas.question, polls.created_at, polls.is_complete from polls inner join poll_schemas on polls.poll_schema_id = poll_schemas.id; ''' await cursor.execute(sql) records = await cursor.fetchall() return [ models.Poll( id=poll_id, telegram_message_id=telegram_message_id, telegram_poll_id=telegram_poll_id, poll_schema=models.PollSchema( id=poll_schema_id, name=name, question=question, ), created_at=created_at, is_complete=is_complete, ) for poll_id, telegram_message_id, telegram_poll_id, poll_schema_id, name, question, created_at, is_complete in records ] @staticmethod async def insert_poll( telegram_message_id: int, telegram_poll_id: str, poll_schema: models.PollSchema, ) -> models.Poll: async with await AsyncConnection.connect(conninfo) as connection: async with connection.cursor() as cursor: sql = ''' insert into polls ( telegram_message_id, telegram_poll_id, poll_schema_id ) values ( %(telegram_message_id)s, %(telegram_poll_id)s, %(poll_schema_id)s ) returning polls.id, polls.created_at, polls.is_complete; ''' await cursor.execute( sql, { 'telegram_message_id': telegram_message_id, 'telegram_poll_id': telegram_poll_id, 'poll_schema_id': poll_schema.id, }, ) try: poll_id, created_at, is_complete = await cursor.fetchone() except TypeError: raise NotFoundError() return models.Poll( id=poll_id, telegram_message_id=telegram_message_id, telegram_poll_id=telegram_poll_id, poll_schema=poll_schema, created_at=created_at, is_complete=is_complete, ) class PollAnswers: @staticmethod async def insert_or_update_poll_answer( poll: models.Poll, user: models.User, poll_options: list[models.PollOption], ) -> list[models.PollAnswer]: async with await AsyncConnection.connect(conninfo) as connection: async with connection.cursor() as cursor: sql = ''' insert into poll_answers ( poll_id, user_id, poll_option_id ) values ( %(poll_id)s, %(user_id)s, %(poll_option_id)s ) on conflict (poll_id, user_id, poll_option_id) do nothing returning poll_answers.id; ''' await cursor.executemany( sql, [ { 'poll_id': poll.id, 'user_id': user.id, 'poll_option_id': poll_option.id, } for poll_option in poll_options ], returning=True, ) records = await cursor.fetchall() return [ models.PollAnswer( id=poll_answer_id, poll=poll, user=user, poll_option=poll_options, ) for poll_answer_id, in records ]