Skip to content

Instantly share code, notes, and snippets.

@psteinroe
Created April 14, 2025 16:13
Show Gist options
  • Save psteinroe/f68b50a03dd320146c0c7b7e29103542 to your computer and use it in GitHub Desktop.
Save psteinroe/f68b50a03dd320146c0c7b7e29103542 to your computer and use it in GitHub Desktop.
initialise your declarative schema
[package]
name = "move-schema"
version = "0.1.0"
edition = "2024"
[dependencies]
pg_query = "6.1.0"
globset = "0.4.16"
once_cell = "1.21.3"
use once_cell::sync::Lazy;
use std::collections::HashSet;
use std::path::Path;
use std::{fs, io::prelude::*};
use globset::{Glob, GlobSetBuilder};
use pg_query::NodeEnum;
struct Group {
name: &'static str,
patterns: &'static [&'static str],
parent: Option<&'static str>,
}
const GROUPS: [Group; 27] = [
Group {
name: "gallery",
patterns: &["category", "goal"],
parent: None,
},
Group {
name: "channel",
patterns: &["channel*"],
parent: None,
},
Group {
name: "common",
patterns: &[
"timezone",
"locale",
"language",
"industry",
"country",
"continent",
],
parent: None,
},
Group {
name: "custom_field",
patterns: &[
"custom_field",
"deal_custom_field",
"appointment_custom_field",
"contact_custom_field",
],
parent: None,
},
Group {
name: "appointment",
patterns: &["appointment*"],
parent: None,
},
Group {
name: "deal",
patterns: &["deal*"],
parent: None,
},
Group {
name: "form",
patterns: &["form*"],
parent: None,
},
Group {
name: "journey",
patterns: &["journey*"],
parent: None,
},
Group {
name: "rule",
patterns: &["rule*"],
parent: None,
},
Group {
name: "web_widget",
patterns: &["web_widget*"],
parent: None,
},
Group {
name: "review_channel",
patterns: &["review*"],
parent: None,
},
Group {
name: "conversation_tag",
patterns: &["conversation_tag", "tag"],
parent: Some("conversation"),
},
Group {
name: "template",
patterns: &["template*", "provider_template_approval"],
parent: None,
},
Group {
name: "campaign",
patterns: &["campaign*"],
parent: None,
},
Group {
name: "webhook",
patterns: &["webhook*"],
parent: None,
},
Group {
name: "message",
patterns: &["message*"],
parent: Some("conversation"),
},
Group {
name: "conversation_tag",
patterns: &["tag", "conversation_tag"],
parent: Some("conversation"),
},
Group {
name: "sendout",
patterns: &["sendout*"],
parent: None,
},
Group {
name: "marketing_channel",
patterns: &["marketing_channel", "marketing_subscription"],
parent: None,
},
Group {
name: "contact_list",
patterns: &["contact_list*"],
parent: None,
},
Group {
name: "contact",
patterns: &["contact*"],
parent: None,
},
Group {
name: "message",
patterns: &["message*", "recipient"],
parent: Some("conversation"),
},
Group {
name: "comment",
patterns: &["comment*"],
parent: Some("conversation"),
},
Group {
name: "conversation",
patterns: &["conversation*"],
parent: None,
},
Group {
name: "inbox",
patterns: &["inbox*"],
parent: None,
},
Group {
name: "team",
patterns: &["team*", "inbox_access"],
parent: None,
},
Group {
name: "employee",
patterns: &["employee", "favorite_template", "pinned*"],
parent: None,
},
];
// Build matchers once using Lazy static initialization
static MATCHERS: Lazy<Vec<(&'static str, globset::GlobSet)>> = Lazy::new(|| {
GROUPS
.iter()
.map(|group| {
let mut builder = GlobSetBuilder::new();
for pattern in group.patterns {
builder.add(Glob::new(pattern).unwrap());
}
(group.name, builder.build().unwrap())
})
.collect()
});
/// Match a value to the corresponding group name based on glob patterns.
fn match_group(value: &str) -> Option<&'static str> {
for (name, matcher) in MATCHERS.iter() {
if matcher.is_match(value) {
return Some(name);
}
}
None
}
#[derive(Debug)]
struct SchemaNode {
name: String,
sql: String,
}
#[derive(Debug)]
struct TableNode {
schema: String,
name: String,
sql: String,
}
#[derive(Debug)]
struct FunctionNode {
schema: String,
name: String,
sql: String,
}
#[derive(Debug)]
struct EnablePolicyNode {
schema: String,
table: String,
sql: String,
}
#[derive(Debug)]
struct PolicyNode {
schema: String,
name: String,
table: String,
sql: String,
}
#[derive(Debug)]
struct IndexNode {
schema: String,
name: String,
table: String,
sql: String,
}
#[derive(Debug)]
struct ViewNode {
schema: String,
name: String,
sql: String,
}
#[derive(Debug)]
struct TriggerFunctionNode {
schema: String,
name: String,
sql: String,
}
#[derive(Debug)]
struct TriggerNode {
schema: String,
name: String,
table: String,
function: String,
sql: String,
}
#[derive(Debug)]
struct EnumNode {
schema: String,
name: String,
sql: String,
}
#[derive(Debug)]
struct CompositeTypeNode {
schema: String,
name: String,
sql: String,
}
#[derive(Debug)]
struct ForeignKeyNode {
constraint_name: String,
source_schema: String,
source_table: String,
target_schema: String,
target_table: String,
sql: String,
}
#[derive(Debug)]
struct AggregateNode {
schema: String,
name: String,
sql: String,
}
#[derive(Debug)]
enum Node {
Schema(SchemaNode),
Table(TableNode),
Function(FunctionNode),
EnablePolicy(EnablePolicyNode),
Policy(PolicyNode),
Index(IndexNode),
View(ViewNode),
TriggerFunction(TriggerFunctionNode),
Trigger(TriggerNode),
Comment(pg_query::protobuf::CommentStmt),
EnumNode(EnumNode),
CompositeType(CompositeTypeNode),
ForeignKey(ForeignKeyNode),
Aggregate(AggregateNode),
}
fn main() {
let schema_path = Path::new("../../supabase").join("schema.sql");
let schema = std::fs::read_to_string(schema_path).expect("Failed to read schema file");
let mut nodes: Vec<Node> = Vec::new();
pg_query::split_with_parser(&schema)
.expect("error parsing")
.iter()
.for_each(|sql| {
parse(sql, &mut nodes);
});
// structure:
// /schema1
// index.sql
// /tables
// /functions
// /schema2
// index.sql
let out_dir = Path::new("../../supabase/schema");
let _ = fs::remove_dir_all(out_dir);
nodes.iter().for_each(|n| match n {
Node::Schema(s) => {
let setup_file_path = out_dir.join(&s.name).join("setup.sql");
let index_file_path = out_dir.join("index.sql");
// ensure it sends with ";"
let content = if s.sql.ends_with(';') {
s.sql.to_string()
} else {
format!("{};", s.sql)
};
append_to_file(&setup_file_path, &content);
append_to_file(
&index_file_path,
format!("-- atlas:import {}/setup.sql", &s.name).as_str(),
);
append_to_file(
&index_file_path,
format!("-- atlas:import {}/index.sql", &s.name).as_str(),
);
}
Node::Table(t) => {
save_table_object(
&t.name, &t.name, &t.schema, &t.sql, "tables", out_dir, &GROUPS,
);
}
Node::TriggerFunction(tf) => {
let mut tables = HashSet::new();
for node in nodes.iter() {
if let Node::Trigger(t) = node {
if t.function == tf.name {
tables.insert(t.table.clone());
}
}
}
let triggers_tables = tables.iter().collect::<Vec<_>>();
if triggers_tables.is_empty() || triggers_tables.len() > 1 {
save_object(&tf.name, &tf.schema, &tf.sql, "triggers", out_dir);
} else {
let trigger_table = triggers_tables.first().unwrap();
save_table_object(
&tf.name, trigger_table, &tf.schema, &tf.sql, "triggers", out_dir, &GROUPS,
);
}
}
Node::Trigger(t) => {
save_table_object(
&t.function,
&t.table,
&t.schema,
&t.sql,
"triggers",
out_dir,
&GROUPS,
);
}
Node::Function(f) => {
save_object(&f.name, &f.schema, &f.sql, "functions", out_dir);
}
Node::EnablePolicy(p) => {
save_table_object(
"enable", &p.table, &p.schema, &p.sql, "policies", out_dir, &GROUPS,
);
}
Node::Policy(p) => {
save_table_object(
&p.name, &p.table, &p.schema, &p.sql, "policies", out_dir, &GROUPS,
);
}
Node::Index(i) => {
save_table_object(
&i.name, &i.table, &i.schema, &i.sql, "indexes", out_dir, &GROUPS,
);
}
Node::View(v) => {
save_table_object(
&v.name, &v.name, &v.schema, &v.sql, "views", out_dir, &GROUPS,
);
}
Node::EnumNode(e) => {
save_object(&e.name, &e.schema, &e.sql, "enums", out_dir);
}
Node::CompositeType(c) => {
save_object(&c.name, &c.schema, &c.sql, "types", out_dir);
}
Node::ForeignKey(fk) => {
save_table_object(
&fk.constraint_name,
&fk.source_table,
&fk.source_schema,
&fk.sql,
"foreign_keys",
out_dir,
&GROUPS,
);
}
Node::Aggregate(a) => {
save_object(&a.name, &a.schema, &a.sql, "aggregates", out_dir);
}
Node::Comment(c) => match c.objtype() {
pg_query::protobuf::ObjectType::Undefined => todo!(),
pg_query::protobuf::ObjectType::ObjectAccessMethod => todo!(),
pg_query::protobuf::ObjectType::ObjectAggregate => todo!(),
pg_query::protobuf::ObjectType::ObjectAmop => todo!(),
pg_query::protobuf::ObjectType::ObjectAmproc => todo!(),
pg_query::protobuf::ObjectType::ObjectAttribute => todo!(),
pg_query::protobuf::ObjectType::ObjectCast => todo!(),
pg_query::protobuf::ObjectType::ObjectColumn => {
let list = &c.object.clone().map(|n| n.node.clone()).unwrap().unwrap();
if let pg_query::NodeEnum::List(l) = list {
let items = l
.items
.iter()
.map(|n| get_sval(&n.node))
.collect::<Vec<_>>();
if items.len() != 3 {
panic!("not three items in list");
}
let schema = items.first().unwrap();
let table_name = items.get(1).unwrap();
let column_name = items.get(2).unwrap();
// check if table or view
let entity_type = nodes.iter().find_map(|n| {
if let Node::Table(t) = n {
if t.name == *table_name && t.schema == *schema {
Some("tables")
} else {
None
}
} else if let Node::View(v) = n {
if v.name == *table_name && v.schema == *schema {
Some("views")
} else {
None
}
} else {
None
}
});
save_table_object(
table_name,
table_name,
schema,
format!(
"COMMENT ON COLUMN \"{}\".\"{}\".\"{}\" IS '{}';",
schema, table_name, column_name, c.comment
)
.as_str(),
entity_type.expect("no table or view found"),
out_dir,
&GROUPS,
);
}
}
pg_query::protobuf::ObjectType::ObjectCollation => todo!(),
pg_query::protobuf::ObjectType::ObjectConversion => todo!(),
pg_query::protobuf::ObjectType::ObjectDatabase => todo!(),
pg_query::protobuf::ObjectType::ObjectDefault => todo!(),
pg_query::protobuf::ObjectType::ObjectDefacl => todo!(),
pg_query::protobuf::ObjectType::ObjectDomain => todo!(),
pg_query::protobuf::ObjectType::ObjectDomconstraint => todo!(),
pg_query::protobuf::ObjectType::ObjectEventTrigger => todo!(),
pg_query::protobuf::ObjectType::ObjectExtension => todo!(),
pg_query::protobuf::ObjectType::ObjectFdw => todo!(),
pg_query::protobuf::ObjectType::ObjectForeignServer => todo!(),
pg_query::protobuf::ObjectType::ObjectForeignTable => todo!(),
pg_query::protobuf::ObjectType::ObjectFunction => {
let list = &c.object.clone().map(|n| n.node.clone()).unwrap().unwrap();
if let pg_query::NodeEnum::ObjectWithArgs(obj) = list {
let items = obj
.objname
.iter()
.map(|n| get_sval(&n.node))
.collect::<Vec<_>>();
if items.len() != 2 {
panic!("not two items in list");
}
let schema = items.first().unwrap();
let function_name = items.last().unwrap();
save_object(
function_name,
schema,
format!(
"COMMENT ON FUNCTION \"{}\".\"{}\" IS '{}';",
schema, function_name, c.comment
)
.as_str(),
"functions",
out_dir,
);
}
}
pg_query::protobuf::ObjectType::ObjectIndex => todo!(),
pg_query::protobuf::ObjectType::ObjectLanguage => todo!(),
pg_query::protobuf::ObjectType::ObjectLargeobject => todo!(),
pg_query::protobuf::ObjectType::ObjectMatview => todo!(),
pg_query::protobuf::ObjectType::ObjectOpclass => todo!(),
pg_query::protobuf::ObjectType::ObjectOperator => todo!(),
pg_query::protobuf::ObjectType::ObjectOpfamily => todo!(),
pg_query::protobuf::ObjectType::ObjectParameterAcl => todo!(),
pg_query::protobuf::ObjectType::ObjectPolicy => todo!(),
pg_query::protobuf::ObjectType::ObjectProcedure => todo!(),
pg_query::protobuf::ObjectType::ObjectPublication => todo!(),
pg_query::protobuf::ObjectType::ObjectPublicationNamespace => todo!(),
pg_query::protobuf::ObjectType::ObjectPublicationRel => todo!(),
pg_query::protobuf::ObjectType::ObjectRole => todo!(),
pg_query::protobuf::ObjectType::ObjectRoutine => todo!(),
pg_query::protobuf::ObjectType::ObjectRule => todo!(),
pg_query::protobuf::ObjectType::ObjectSchema => {
let schema_name = get_sval(&c.object.clone().map(|n| n.node.clone()).unwrap());
let setup_file_path = out_dir.join(&schema_name).join("setup.sql");
append_to_file(
&setup_file_path,
format!("COMMENT ON SCHEMA \"{}\" IS '{}';", schema_name, c.comment).as_str(),
);
}
pg_query::protobuf::ObjectType::ObjectSequence => todo!(),
pg_query::protobuf::ObjectType::ObjectSubscription => todo!(),
pg_query::protobuf::ObjectType::ObjectStatisticExt => todo!(),
pg_query::protobuf::ObjectType::ObjectTabconstraint => todo!(),
pg_query::protobuf::ObjectType::ObjectTable => {
let list = &c.object.clone().map(|n| n.node.clone()).unwrap().unwrap();
if let pg_query::NodeEnum::List(l) = list {
let items = l
.items
.iter()
.map(|n| get_sval(&n.node))
.collect::<Vec<_>>();
if items.len() != 2 {
panic!("not two items in list");
}
let schema = items.first().unwrap();
let table_name = items.last().unwrap();
// check if table or view
let node = nodes.iter().find(|n| {
if let Node::Table(t) = n {
t.name == *table_name && t.schema == *schema
} else if let Node::View(v) = n {
v.name == *table_name && v.schema == *schema
} else {
false
}
});
if let Some(Node::Table(_t)) = node {
save_table_object(
table_name,
table_name,
schema,
format!(
"COMMENT ON TABLE \"{}\".\"{}\" IS '{}';",
schema, table_name, c.comment
)
.as_str(),
"tables",
out_dir,
&GROUPS,
);
} else if let Some(Node::View(_v)) = node {
save_table_object(
table_name,
table_name,
schema,
format!(
"COMMENT ON VIEW \"{}\".\"{}\" IS '{}';",
schema, table_name, c.comment
)
.as_str(),
"views",
out_dir,
&GROUPS,
);
} else {
panic!("no table or view found for {}.{}", schema, table_name);
}
}
}
pg_query::protobuf::ObjectType::ObjectTablespace => todo!(),
pg_query::protobuf::ObjectType::ObjectTransform => todo!(),
pg_query::protobuf::ObjectType::ObjectTrigger => todo!(),
pg_query::protobuf::ObjectType::ObjectTsconfiguration => todo!(),
pg_query::protobuf::ObjectType::ObjectTsdictionary => todo!(),
pg_query::protobuf::ObjectType::ObjectTsparser => todo!(),
pg_query::protobuf::ObjectType::ObjectTstemplate => todo!(),
pg_query::protobuf::ObjectType::ObjectType => todo!(),
pg_query::protobuf::ObjectType::ObjectUserMapping => todo!(),
pg_query::protobuf::ObjectType::ObjectView => todo!(),
},
})
}
// Helper function to get the full group path for a matched group
fn get_group_path(group_name: &'static str, groups: &[Group]) -> Vec<&'static str> {
let mut path = Vec::new();
let mut current = group_name;
// First add the matched group
path.push(current);
// Then add all parents
while let Some(group) = groups.iter().find(|g| g.name == current) {
// If it has a parent, add it and continue
match group.parent {
Some(parent) => {
current = parent;
path.push(current);
}
None => {
// No parent, we're done
break;
}
}
}
// Reverse to get parent -> child order
path.reverse();
path
}
fn save_object(name: &str, schema: &str, c: &str, entity_type: &str, out_dir: &Path) {
let object_type_index = out_dir.join(schema).join("index.sql");
append_to_file(
&object_type_index,
format!("-- atlas:import {}/index.sql", entity_type).as_str(),
);
// ensure it sends with ";"
let content = if c.ends_with(';') {
c.to_string()
} else {
format!("{};", c)
};
let file_name = format!("{}.sql", name);
let base_dir = out_dir.join(schema).join(entity_type);
let main_index_path = base_dir.join("index.sql");
// Save the actual entity file
let file_path = base_dir.join(&file_name);
append_to_file(&file_path, &content);
// Add import to the main index.sql
append_to_file(
&main_index_path,
format!("-- atlas:import {}", &file_name).as_str(),
);
}
fn save_table_object(
name: &str,
table_name: &str,
schema: &str,
c: &str,
entity_type: &str,
out_dir: &Path,
groups: &[Group],
) {
// write to schema/objtype/index.sql
let object_type_index = out_dir.join(schema).join("index.sql");
append_to_file(
&object_type_index,
format!("-- atlas:import {}/index.sql", entity_type).as_str(),
);
// ensure it sends with ";"
let content = if c.ends_with(';') {
c.to_string()
} else {
format!("{}\n;", c)
};
let file_name = format!("{}.sql", name);
let base_dir = out_dir.join(schema).join(entity_type);
let main_index_path = base_dir.join("index.sql");
let group_match = match_group(table_name);
// Try to match the entity to a group
if group_match.is_some() || entity_type != "tables" {
let mut group_path = vec![];
if let Some(group_name) = group_match {
// Get the full path of groups (parent -> child)
group_path.append(&mut get_group_path(group_name, groups));
}
if entity_type != "tables" && entity_type != "views" {
group_path.push(table_name);
}
// Handle nested groups
let mut current_dir = base_dir;
let mut current_index_path = main_index_path;
// Process each level of the group hierarchy
for &group in group_path.iter() {
let next_dir = current_dir.join(group);
let next_index_path = next_dir.join("index.sql");
// Create reference from parent to child index
append_to_file(
&current_index_path,
format!("-- atlas:import {}/index.sql", group).as_str(),
);
// Update current paths for next iteration
current_dir = next_dir;
current_index_path = next_index_path;
}
// Save the actual entity file in the deepest group directory
let file_path = current_dir.join(&file_name);
append_to_file(&file_path, &content);
// Add import to the deepest group's index.sql
append_to_file(
&current_index_path,
format!("-- atlas:import {}", &file_name).as_str(),
);
} else {
// No group matched, save directly in the base directory
let file_path = base_dir.join(&file_name);
append_to_file(&file_path, &content);
append_to_file(
&main_index_path,
format!("-- atlas:import {}", &file_name).as_str(),
);
}
}
fn append_to_file(path: &Path, content: &str) {
// Create parent directories if they don't exist
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).expect("Failed to create parent directories");
}
// Check if file exists and if content is already in it
let file_exists = path.exists();
let content_exists = if file_exists {
match fs::read_to_string(path) {
Ok(existing_content) => existing_content.contains(content),
Err(_) => false,
}
} else {
false
};
// Only append if content doesn't already exist
if !content_exists {
let mut file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(path)
.expect("Failed to open file");
writeln!(file, "{}", content).expect("Failed to write to file");
}
}
fn parse(sql: &str, nodes: &mut Vec<Node>) {
let node = parse_sql(sql);
match node {
pg_query::NodeEnum::CreateSchemaStmt(n) => {
let schema_name = n.schemaname.to_string();
nodes.push(Node::Schema(SchemaNode {
name: schema_name.clone(),
sql: sql.to_string(),
}));
}
pg_query::NodeEnum::CommentStmt(n) => {
nodes.push(Node::Comment(*n));
}
pg_query::NodeEnum::CreateExtensionStmt(_) => {
//
}
pg_query::NodeEnum::CreateEnumStmt(n) => {
let names = n
.type_name
.iter()
.map(|n| get_sval(&n.node))
.collect::<Vec<_>>();
let schema = if names.len() > 1 {
names.first().unwrap()
} else {
&"public".to_string()
};
let type_name = names.last().unwrap().to_string();
nodes.push(Node::EnumNode(EnumNode {
schema: schema.to_string(),
name: type_name,
sql: sql.to_string(),
}));
}
pg_query::NodeEnum::DefineStmt(n) => match n.kind() {
pg_query::protobuf::ObjectType::ObjectAggregate => {
let names = n
.defnames
.iter()
.map(|n| get_sval(&n.node))
.collect::<Vec<_>>();
let schema = if names.len() > 1 {
names.first().unwrap()
} else {
&"public".to_string()
};
let type_name = names.last().unwrap().to_string();
nodes.push(Node::Aggregate(AggregateNode {
schema: schema.to_string(),
name: type_name,
sql: sql.to_string(),
}));
}
_ => panic!("Unsupported define statement"),
},
pg_query::NodeEnum::CompositeTypeStmt(n) => {
let name = n.typevar.expect("no typevar");
let schema = name.schemaname;
let type_name = name.relname;
nodes.push(Node::CompositeType(CompositeTypeNode {
schema: schema.to_string(),
name: type_name,
sql: sql.to_string(),
}));
}
pg_query::NodeEnum::ViewStmt(n) => {
let rel = n.view.expect("no relation");
let schema = rel.schemaname;
let view_name = rel.relname;
nodes.push(Node::View(ViewNode {
schema: schema.clone(),
name: view_name,
sql: sql.to_string(),
}));
}
pg_query::NodeEnum::CreatePolicyStmt(n) => {
let name = n.policy_name;
let table = n.table.expect("no table");
let schema = table.schemaname;
let relation_name = table.relname;
nodes.push(Node::Policy(PolicyNode {
schema: schema.clone(),
name,
table: relation_name,
sql: sql.to_string(),
}));
}
pg_query::NodeEnum::CreateStmt(n) => {
let rel = n.relation.expect("no relation");
let schema = rel.schemaname;
let table_name = rel.relname;
nodes.push(Node::Table(TableNode {
schema: schema.clone(),
name: table_name,
sql: sql.to_string(),
}));
}
pg_query::NodeEnum::CreateTrigStmt(n) => {
let rel = n.relation.expect("no relation");
let schema = rel.schemaname;
let table_name = rel.relname;
let func_name = n
.funcname
.iter()
.map(|n| get_sval(&n.node))
.collect::<Vec<_>>();
let function_name = func_name.last().unwrap().to_string();
let trigger_name = n.trigname.clone();
nodes.push(Node::Trigger(TriggerNode {
schema: schema.clone(),
name: trigger_name,
table: table_name,
function: function_name,
sql: sql.to_string(),
}));
}
pg_query::NodeEnum::CreateFunctionStmt(n) => {
let func_name = n
.funcname
.iter()
.map(|n| get_sval(&n.node))
.collect::<Vec<_>>();
let schema = if func_name.len() > 1 {
func_name.first().unwrap()
} else {
&"public".to_string()
};
let function_name = func_name.last().unwrap().to_string();
let is_trigger = n.return_type.as_ref().expect("").names.iter().any(|n| {
let type_name = get_sval(&n.node);
type_name == "trigger"
});
if is_trigger {
nodes.push(Node::TriggerFunction(TriggerFunctionNode {
schema: schema.to_string(),
name: function_name,
sql: sql.to_string(),
}));
} else {
nodes.push(Node::Function(FunctionNode {
schema: schema.to_string(),
name: function_name,
sql: sql.to_string(),
}));
}
}
pg_query::NodeEnum::IndexStmt(n) => {
let rel = n.relation.expect("no relation");
let schema = rel.schemaname;
let index_name = n.idxname;
let table_name = rel.relname;
nodes.push(Node::Index(IndexNode {
schema: schema.clone(),
name: index_name,
table: table_name,
sql: sql.to_string(),
}));
}
pg_query::NodeEnum::AlterTableStmt(n) => {
let rel = n.relation.expect("no relation");
let schema = rel.schemaname;
let table_name = rel.relname;
let number_of_commands = n.cmds.len();
let cmd = n.cmds.first().expect("no command").node.clone().unwrap();
match &cmd {
pg_query::NodeEnum::AlterTableCmd(c) => match c.subtype() {
pg_query::protobuf::AlterTableType::AtEnableRowSecurity => {
nodes.push(Node::EnablePolicy(EnablePolicyNode {
schema,
table: table_name,
sql: sql.to_string(),
}));
}
pg_query::protobuf::AlterTableType::AtAddConstraint => {
if number_of_commands > 1 {
let commands = sql[sql.find("ADD CONSTRAINT").unwrap()..]
.split("ADD CONSTRAINT")
.collect::<Vec<_>>();
// get from beginning to fist ADD CONSTRAINT
let begin = sql[sql.find("ALTER TABLE").unwrap()
..sql.find("ADD CONSTRAINT").unwrap()]
.to_string();
commands.iter().for_each(|cmd| {
if cmd.is_empty() {
return;
}
let full_sql = format!(
"{}ADD CONSTRAINT{}",
begin,
cmd.trim_end().trim_end_matches(',')
);
parse(&full_sql, nodes);
});
} else if let Some(pg_query::protobuf::node::Node::Constraint(c)) =
c.def.clone().unwrap().node.as_ref()
{
let constraint_name = c.conname.clone();
let source_schema = schema.clone();
let source_table = table_name.clone();
let pktable = c.pktable.as_ref().expect("no target relation");
let target_schema = pktable.schemaname.clone();
let target_table = pktable.relname.clone();
nodes.push(Node::ForeignKey(ForeignKeyNode {
constraint_name,
source_schema,
source_table,
target_schema,
target_table,
sql: sql.to_string(),
}));
} else {
panic!("no definition for foreign key");
}
}
_ => panic!("Unsupported subtype {:?} for \n '{}'", cmd, sql),
},
_ => panic!("Unsupported command {:?} for \n '{}'", cmd, sql),
}
}
_ => panic!("Unsupported node:\n{:?} '{}'", node, sql),
};
}
fn get_sval(n: &Option<pg_query::protobuf::node::Node>) -> String {
match n {
Some(pg_query::protobuf::node::Node::String(s)) => s.sval.clone(),
_ => panic!("Unsupported node {:?}", n),
}
}
pub fn parse_sql(sql: &str) -> NodeEnum {
pg_query::parse(sql)
.map(|parsed| {
parsed
.protobuf
.nodes()
.iter()
.find(|n| n.1 == 1)
.map(|n| n.0.to_enum())
.expect("error parsing protobuf")
})
.expect("error parsing")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment