AgentSkillsCN

Sqlalchemy Async

SQLAlchemy 异步编程

SKILL.md

SQLAlchemy Async Patterns

Modern SQLAlchemy 2.0 async patterns with FastAPI

PRINCIPLES

  1. Async by Default: Use AsyncSession for non-blocking I/O
  2. Session per Request: Create session in dependency, close after
  3. Type Safety: Use mapped_column and Mapped for type hints
  4. Repository Pattern: Abstract database operations in CRUD layer
  5. Eager Loading: Control N+1 queries with selectinload/joinedload

DATABASE SETUP

Async Engine Configuration

python
from sqlalchemy.ext.asyncio import (
    AsyncSession,
    async_sessionmaker,
    create_async_engine,
)
from sqlalchemy.orm import DeclarativeBase
from typing import AsyncGenerator

DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/dbname"

engine = create_async_engine(
    DATABASE_URL,
    echo=True,  # SQL logging (disable in production)
    pool_size=5,
    max_overflow=10,
    pool_pre_ping=True,  # Verify connections
)

async_session_maker = async_sessionmaker(
    engine,
    class_=AsyncSession,
    expire_on_commit=False,
)

class Base(DeclarativeBase):
    pass

# FastAPI dependency
async def get_db() -> AsyncGenerator[AsyncSession, None]:
    async with async_session_maker() as session:
        try:
            yield session
            await session.commit()
        except Exception:
            await session.rollback()
            raise

# Initialize database
async def init_db():
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)

MODEL DEFINITION

SQLAlchemy 2.0 Style

python
from sqlalchemy import ForeignKey, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from datetime import datetime

class Base(DeclarativeBase):
    pass

class TimestampMixin:
    created_at: Mapped[datetime] = mapped_column(
        default=func.now(),
        server_default=func.now(),
    )
    updated_at: Mapped[datetime] = mapped_column(
        default=func.now(),
        onupdate=func.now(),
        server_default=func.now(),
    )

class User(TimestampMixin, Base):
    __tablename__ = "users"
    
    id: Mapped[int] = mapped_column(primary_key=True)
    email: Mapped[str] = mapped_column(String(255), unique=True, index=True)
    hashed_password: Mapped[str] = mapped_column(String(255))
    full_name: Mapped[str | None] = mapped_column(String(100))
    is_active: Mapped[bool] = mapped_column(default=True)
    
    # Relationships
    posts: Mapped[list["Post"]] = relationship(back_populates="author", lazy="selectin")

class Post(TimestampMixin, Base):
    __tablename__ = "posts"
    
    id: Mapped[int] = mapped_column(primary_key=True)
    title: Mapped[str] = mapped_column(String(200))
    content: Mapped[str] = mapped_column(Text)
    author_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
    
    # Relationships
    author: Mapped["User"] = relationship(back_populates="posts")
    tags: Mapped[list["Tag"]] = relationship(
        secondary="post_tags",
        back_populates="posts",
        lazy="selectin",
    )

class Tag(Base):
    __tablename__ = "tags"
    
    id: Mapped[int] = mapped_column(primary_key=True)
    name: Mapped[str] = mapped_column(String(50), unique=True)
    
    posts: Mapped[list["Post"]] = relationship(
        secondary="post_tags",
        back_populates="tags",
    )

# Association table
from sqlalchemy import Table, Column

post_tags = Table(
    "post_tags",
    Base.metadata,
    Column("post_id", ForeignKey("posts.id"), primary_key=True),
    Column("tag_id", ForeignKey("tags.id"), primary_key=True),
)

CRUD OPERATIONS

Generic CRUD Base

python
from typing import Generic, TypeVar, Type, Sequence
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel

ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)

class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
    def __init__(self, model: Type[ModelType]):
        self.model = model
    
    async def get(self, db: AsyncSession, id: int) -> ModelType | None:
        result = await db.execute(
            select(self.model).where(self.model.id == id)
        )
        return result.scalar_one_or_none()
    
    async def get_multi(
        self,
        db: AsyncSession,
        *,
        skip: int = 0,
        limit: int = 100,
    ) -> Sequence[ModelType]:
        result = await db.execute(
            select(self.model).offset(skip).limit(limit)
        )
        return result.scalars().all()
    
    async def create(
        self,
        db: AsyncSession,
        *,
        obj_in: CreateSchemaType,
    ) -> ModelType:
        obj_data = obj_in.model_dump()
        db_obj = self.model(**obj_data)
        db.add(db_obj)
        await db.flush()
        await db.refresh(db_obj)
        return db_obj
    
    async def update(
        self,
        db: AsyncSession,
        *,
        db_obj: ModelType,
        obj_in: UpdateSchemaType,
    ) -> ModelType:
        update_data = obj_in.model_dump(exclude_unset=True)
        for field, value in update_data.items():
            setattr(db_obj, field, value)
        await db.flush()
        await db.refresh(db_obj)
        return db_obj
    
    async def delete(self, db: AsyncSession, *, id: int) -> bool:
        obj = await self.get(db, id)
        if obj:
            await db.delete(obj)
            return True
        return False

# Usage
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
    async def get_by_email(
        self,
        db: AsyncSession,
        email: str,
    ) -> User | None:
        result = await db.execute(
            select(User).where(User.email == email)
        )
        return result.scalar_one_or_none()

user_crud = CRUDUser(User)

QUERY PATTERNS

Efficient Queries

python
from sqlalchemy import select, and_, or_, func
from sqlalchemy.orm import selectinload, joinedload

# Eager loading to avoid N+1
async def get_user_with_posts(db: AsyncSession, user_id: int) -> User | None:
    result = await db.execute(
        select(User)
        .options(selectinload(User.posts))
        .where(User.id == user_id)
    )
    return result.scalar_one_or_none()

# Complex filters
async def search_posts(
    db: AsyncSession,
    *,
    keyword: str | None = None,
    author_id: int | None = None,
    tag_names: list[str] | None = None,
) -> Sequence[Post]:
    query = select(Post).options(
        selectinload(Post.author),
        selectinload(Post.tags),
    )
    
    filters = []
    if keyword:
        filters.append(
            or_(
                Post.title.ilike(f"%{keyword}%"),
                Post.content.ilike(f"%{keyword}%"),
            )
        )
    if author_id:
        filters.append(Post.author_id == author_id)
    if tag_names:
        query = query.join(Post.tags).where(Tag.name.in_(tag_names))
    
    if filters:
        query = query.where(and_(*filters))
    
    result = await db.execute(query)
    return result.scalars().unique().all()

# Aggregation
async def get_post_stats(db: AsyncSession) -> dict:
    result = await db.execute(
        select(
            func.count(Post.id).label("total"),
            func.avg(func.length(Post.content)).label("avg_length"),
        )
    )
    row = result.one()
    return {"total": row.total, "avg_length": row.avg_length}

ALEMBIC MIGRATIONS

Migration Setup

python
# alembic/env.py
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import async_engine_from_config
from alembic import context
import asyncio

from app.db.base import Base
from app.core.config import settings

config = context.config
config.set_main_option("sqlalchemy.url", str(settings.database_url))

target_metadata = Base.metadata

def run_migrations_offline() -> None:
    context.configure(
        url=str(settings.database_url),
        target_metadata=target_metadata,
        literal_binds=True,
    )
    with context.begin_transaction():
        context.run_migrations()

def do_run_migrations(connection):
    context.configure(connection=connection, target_metadata=target_metadata)
    with context.begin_transaction():
        context.run_migrations()

async def run_async_migrations() -> None:
    connectable = async_engine_from_config(
        config.get_section(config.config_ini_section),
        prefix="sqlalchemy.",
        poolclass=pool.NullPool,
    )
    async with connectable.connect() as connection:
        await connection.run_sync(do_run_migrations)
    await connectable.dispose()

def run_migrations_online() -> None:
    asyncio.run(run_async_migrations())

ANTI-PATTERNS

AVOID:

python
# N+1 queries
for user in users:
    print(user.posts)  # Each access = new query

# Sync session in async context
Session()  # Wrong - use AsyncSession

# No session management
await db.execute(query)
# Missing commit/rollback

PREFER:

python
# Eager loading
users = await db.execute(
    select(User).options(selectinload(User.posts))
)

# Proper async session
async with async_session_maker() as session:
    await session.execute(query)
    await session.commit()