Skip to content

Instantly share code, notes, and snippets.

@imankulov
Last active April 29, 2025 00:17
Show Gist options
  • Save imankulov/4051b7805ad737ace7d8de3d3f934d6b to your computer and use it in GitHub Desktop.
Save imankulov/4051b7805ad737ace7d8de3d3f934d6b to your computer and use it in GitHub Desktop.
Using pydantic models as SQLAlchemy JSON fields (convert beween JSON and pydantic.BaseModel subclasses)
#!/usr/bin/env ipython -i
import datetime
import json
from typing import Optional
import sqlalchemy as sa
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB
from pydantic import BaseModel, Field, parse_obj_as
from pydantic.json import pydantic_encoder
# --------------------------------------------------------------------------------------
# Define pydantic-alchemy specific types (once per application)
# --------------------------------------------------------------------------------------
class PydanticType(sa.types.TypeDecorator):
"""Pydantic type.
SAVING:
- Uses SQLAlchemy JSON type under the hood.
- Acceps the pydantic model and converts it to a dict on save.
- SQLAlchemy engine JSON-encodes the dict to a string.
RETRIEVING:
- Pulls the string from the database.
- SQLAlchemy engine JSON-decodes the string to a dict.
- Uses the dict to create a pydantic model.
"""
# If you work with PostgreSQL, you can consider using
# sqlalchemy.dialects.postgresql.JSONB instead of a
# generic sa.types.JSON
#
# Ref: https://www.postgresql.org/docs/13/datatype-json.html
impl = sa.types.JSON
def __init__(self, pydantic_type):
super().__init__()
self.pydantic_type = pydantic_type
def load_dialect_impl(self, dialect):
# Use JSONB for PostgreSQL and JSON for other databases.
if dialect.name == "postgresql":
return dialect.type_descriptor(JSONB())
else:
return dialect.type_descriptor(sa.JSON())
def process_bind_param(self, value, dialect):
return value.dict() if value else None
# If you use FasAPI, you can replace the line above with their jsonable_encoder().
# E.g.,
# from fastapi.encoders import jsonable_encoder
# return jsonable_encoder(value) if value else None
def process_result_value(self, value, dialect):
return parse_obj_as(self.pydantic_type, value) if value else None
def json_serializer(*args, **kwargs) -> str:
return json.dumps(*args, default=pydantic_encoder, **kwargs)
# --------------------------------------------------------------------------------------
# Configure SQLAlchemy engine, session and declarative base (once per application)
# The key is to define json_serializer while creating the engine.
# --------------------------------------------------------------------------------------
engine = sa.create_engine("sqlite:///:memory:", json_serializer=json_serializer)
Session = sessionmaker(bind=engine, expire_on_commit=False, future=True)
Base = declarative_base()
# --------------------------------------------------------------------------------------
# Define your Pydantic and SQLAlchemy models (as many as needed)
# --------------------------------------------------------------------------------------
class UserSettings(BaseModel):
notify_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
class User(Base):
__tablename__ = "users"
id: int = sa.Column(sa.Integer, primary_key=True)
name: str = sa.Column(sa.String, doc="User name", comment="User name")
settings: Optional[UserSettings] = sa.Column(PydanticType(UserSettings), nullable=True)
# --------------------------------------------------------------------------------------
# Create tables (once per application)
# --------------------------------------------------------------------------------------
Base.metadata.create_all(engine)
# --------------------------------------------------------------------------------------
# Usage example (we use 2.0 querying style with selects)
# Ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#querying-2-0-style
# --------------------------------------------------------------------------------------
session = Session()
user = User(name="user", settings=UserSettings())
session.add(user)
session.commit()
same_user = session.execute(sa.select(User)).scalars().first()
@pdmtt
Copy link

pdmtt commented Apr 28, 2025

To avoid a deprecation warning as of Pydantic 2.5 you will need to use the following

def process_result_value(self, value, dialect):
    return self.pydantic_type(**value) if value else None

Building on top of the comment above, you will need to use the following to avoid a deprecation warning as of Pydantic V2:

   @override
    def process_bind_param(
        self,
        value: "BaseModel | None",
        dialect: "Dialect",
    ) -> "dict[str, Any] | None":
        if value is None:
            return None

        if not isinstance(value, BaseModel):
            raise TypeError(f'Value "{value!r}" is not a pydantic model.') 

        return value.model_dump(mode="json", exclude_unset=True)

    @override
    def process_result_value(
        self,
        value: "dict[str, Any] | None", 
        dialect: "Dialect",
    ) -> "BaseModel | None":
        return self.pydantic_type(**value) if value else None

I also added type hints and I'm checking if the value really is an instance of a BaseModel subclass because Python.
Model.dict() is deprecated and you can now use Model.dump(mode="json") to avoid setting up a custom json serializer when creating the engine.

Edit: I forked the gist and updated the code to use the current syntax and to add typehints.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment