[+] add get_or_create for sqlalchemy

This commit is contained in:
Siarhei Siniak 2025-07-09 11:51:20 +03:00
parent 731b9d384a
commit 13e2bff324
2 changed files with 50 additions and 1 deletions

@ -8,7 +8,7 @@ async def ticker_store_multiple(
tickers: list[Ticker],
) -> None:
async with session() as active_session:
async with active_session.begin():
async with active_session.begin() as transaction:
active_session.add_all(
tickers,
)

@ -0,0 +1,49 @@
from typing import (TypeVar, Optional, Any, cast,)
from sqlalchemy.ext.asyncio import AsyncSessionTransaction, AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.exc import NoResultFound, IntegrityError
M = TypeVar('M', bound='DeclarativeBase')
async def get_or_create(
session: AsyncSession,
model: type[M],
create_method: Optional[str] = None,
create_method_kwargs: Optional[dict[str, Any]] = None,
**kwargs: Any
) -> tuple[M, bool]:
async def select_row() -> M:
res = await session.execute(
select(model).where(
*[
getattr(model, k) == v
for k, v in kwargs.items()
]
)
)
row = res.one()
assert isinstance(row, model)
return row
try:
res = await select_row()
return res, False
except NoResultFound:
if create_method_kwargs:
kwargs.update(create_method_kwargs)
if not create_method:
created = model(**kwargs)
else:
created = getattr(model, create_method)(**kwargs)
try:
session.add(created)
await session.flush()
return created, True
except IntegrityError:
await session.rollback()
return await select_row(), False