SQLAlchemy Async Patterns
Modern SQLAlchemy 2.0 async patterns with FastAPI
PRINCIPLES
- •Async by Default: Use AsyncSession for non-blocking I/O
- •Session per Request: Create session in dependency, close after
- •Type Safety: Use mapped_column and Mapped for type hints
- •Repository Pattern: Abstract database operations in CRUD layer
- •Eager Loading: Control N+1 queries with selectinload/joinedload
DATABASE SETUP
Async Engine Configuration
python
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import DeclarativeBase
from typing import AsyncGenerator
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/dbname"
engine = create_async_engine(
DATABASE_URL,
echo=True, # SQL logging (disable in production)
pool_size=5,
max_overflow=10,
pool_pre_ping=True, # Verify connections
)
async_session_maker = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
class Base(DeclarativeBase):
pass
# FastAPI dependency
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
# Initialize database
async def init_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
MODEL DEFINITION
SQLAlchemy 2.0 Style
python
from sqlalchemy import ForeignKey, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from datetime import datetime
class Base(DeclarativeBase):
pass
class TimestampMixin:
created_at: Mapped[datetime] = mapped_column(
default=func.now(),
server_default=func.now(),
)
updated_at: Mapped[datetime] = mapped_column(
default=func.now(),
onupdate=func.now(),
server_default=func.now(),
)
class User(TimestampMixin, Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True)
email: Mapped[str] = mapped_column(String(255), unique=True, index=True)
hashed_password: Mapped[str] = mapped_column(String(255))
full_name: Mapped[str | None] = mapped_column(String(100))
is_active: Mapped[bool] = mapped_column(default=True)
# Relationships
posts: Mapped[list["Post"]] = relationship(back_populates="author", lazy="selectin")
class Post(TimestampMixin, Base):
__tablename__ = "posts"
id: Mapped[int] = mapped_column(primary_key=True)
title: Mapped[str] = mapped_column(String(200))
content: Mapped[str] = mapped_column(Text)
author_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
# Relationships
author: Mapped["User"] = relationship(back_populates="posts")
tags: Mapped[list["Tag"]] = relationship(
secondary="post_tags",
back_populates="posts",
lazy="selectin",
)
class Tag(Base):
__tablename__ = "tags"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(50), unique=True)
posts: Mapped[list["Post"]] = relationship(
secondary="post_tags",
back_populates="tags",
)
# Association table
from sqlalchemy import Table, Column
post_tags = Table(
"post_tags",
Base.metadata,
Column("post_id", ForeignKey("posts.id"), primary_key=True),
Column("tag_id", ForeignKey("tags.id"), primary_key=True),
)
CRUD OPERATIONS
Generic CRUD Base
python
from typing import Generic, TypeVar, Type, Sequence
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
self.model = model
async def get(self, db: AsyncSession, id: int) -> ModelType | None:
result = await db.execute(
select(self.model).where(self.model.id == id)
)
return result.scalar_one_or_none()
async def get_multi(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
) -> Sequence[ModelType]:
result = await db.execute(
select(self.model).offset(skip).limit(limit)
)
return result.scalars().all()
async def create(
self,
db: AsyncSession,
*,
obj_in: CreateSchemaType,
) -> ModelType:
obj_data = obj_in.model_dump()
db_obj = self.model(**obj_data)
db.add(db_obj)
await db.flush()
await db.refresh(db_obj)
return db_obj
async def update(
self,
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: UpdateSchemaType,
) -> ModelType:
update_data = obj_in.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_obj, field, value)
await db.flush()
await db.refresh(db_obj)
return db_obj
async def delete(self, db: AsyncSession, *, id: int) -> bool:
obj = await self.get(db, id)
if obj:
await db.delete(obj)
return True
return False
# Usage
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
async def get_by_email(
self,
db: AsyncSession,
email: str,
) -> User | None:
result = await db.execute(
select(User).where(User.email == email)
)
return result.scalar_one_or_none()
user_crud = CRUDUser(User)
QUERY PATTERNS
Efficient Queries
python
from sqlalchemy import select, and_, or_, func
from sqlalchemy.orm import selectinload, joinedload
# Eager loading to avoid N+1
async def get_user_with_posts(db: AsyncSession, user_id: int) -> User | None:
result = await db.execute(
select(User)
.options(selectinload(User.posts))
.where(User.id == user_id)
)
return result.scalar_one_or_none()
# Complex filters
async def search_posts(
db: AsyncSession,
*,
keyword: str | None = None,
author_id: int | None = None,
tag_names: list[str] | None = None,
) -> Sequence[Post]:
query = select(Post).options(
selectinload(Post.author),
selectinload(Post.tags),
)
filters = []
if keyword:
filters.append(
or_(
Post.title.ilike(f"%{keyword}%"),
Post.content.ilike(f"%{keyword}%"),
)
)
if author_id:
filters.append(Post.author_id == author_id)
if tag_names:
query = query.join(Post.tags).where(Tag.name.in_(tag_names))
if filters:
query = query.where(and_(*filters))
result = await db.execute(query)
return result.scalars().unique().all()
# Aggregation
async def get_post_stats(db: AsyncSession) -> dict:
result = await db.execute(
select(
func.count(Post.id).label("total"),
func.avg(func.length(Post.content)).label("avg_length"),
)
)
row = result.one()
return {"total": row.total, "avg_length": row.avg_length}
ALEMBIC MIGRATIONS
Migration Setup
python
# alembic/env.py
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import async_engine_from_config
from alembic import context
import asyncio
from app.db.base import Base
from app.core.config import settings
config = context.config
config.set_main_option("sqlalchemy.url", str(settings.database_url))
target_metadata = Base.metadata
def run_migrations_offline() -> None:
context.configure(
url=str(settings.database_url),
target_metadata=target_metadata,
literal_binds=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection):
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
connectable = async_engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
asyncio.run(run_async_migrations())
ANTI-PATTERNS
❌ AVOID:
python
# N+1 queries
for user in users:
print(user.posts) # Each access = new query
# Sync session in async context
Session() # Wrong - use AsyncSession
# No session management
await db.execute(query)
# Missing commit/rollback
✅ PREFER:
python
# Eager loading
users = await db.execute(
select(User).options(selectinload(User.posts))
)
# Proper async session
async with async_session_maker() as session:
await session.execute(query)
await session.commit()