Skip to content

Instantly share code, notes, and snippets.

@matutter
Created January 26, 2020 05:30

Revisions

  1. matutter created this gist Jan 26, 2020.
    68 changes: 68 additions & 0 deletions poly_delete.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,68 @@
    from sqlalchemy import create_engine
    from sqlalchemy.ext.declarative import declarative_base
    from sqlalchemy import Column, String, Integer, ForeignKey
    from sqlalchemy.orm import relationship
    from sqlalchemy.orm import sessionmaker

    engine = create_engine('sqlite:///:memory:', echo=False)
    session = sessionmaker(bind=engine)()
    Base = declarative_base()

    class Queue(Base):
    __tablename__ = 'queue'
    id = Column(Integer, primary_key=True)
    tasks = relationship("Task", cascade="all, delete-orphan")
    def __repr__(self):
    return f'{type(self).__name__}({self.id})'

    class Task(Base):
    __tablename__ = 'task'
    id = Column(Integer, primary_key=True)
    type = Column(String, server_default='base')

    queue_id = Column(Integer, ForeignKey('queue.id'))

    __mapper_args__ = {
    'polymorphic_identity': __tablename__,
    'polymorphic_on': type
    }
    def __repr__(self):
    return f'{type(self).__name__}({self.id}, {self.type})'

    class OtherTask(Task):
    __tablename__ = 'otask'
    id = Column(Integer, ForeignKey('task.id'), primary_key=True)
    level = Column(Integer)

    __mapper_args__ = {
    'polymorphic_identity': __tablename__
    }

    class NextTask(Task):
    __tablename__ = 'nexttask'
    id = Column(Integer, ForeignKey('task.id'), primary_key=True)
    level = Column(Integer)

    __mapper_args__ = {
    'polymorphic_identity': __tablename__
    }

    Base.metadata.create_all(engine)

    queue = Queue(id=0)
    queue.tasks.append(Task(id=0))
    queue.tasks.append(OtherTask(id=1, level=9))
    queue.tasks.append(NextTask(id=2, level=4))

    session.add(queue)
    session.commit()

    for q in session.query(Queue):
    print(q)
    for task in q.tasks:
    print(task)

    print(session.query(Task).all())
    session.delete(queue)
    session.commit()
    print(session.query(Task).all())