Created
April 14, 2025 16:13
-
-
Save psteinroe/f68b50a03dd320146c0c7b7e29103542 to your computer and use it in GitHub Desktop.
initialise your declarative schema
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "move-schema" | |
version = "0.1.0" | |
edition = "2024" | |
[dependencies] | |
pg_query = "6.1.0" | |
globset = "0.4.16" | |
once_cell = "1.21.3" | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use once_cell::sync::Lazy; | |
use std::collections::HashSet; | |
use std::path::Path; | |
use std::{fs, io::prelude::*}; | |
use globset::{Glob, GlobSetBuilder}; | |
use pg_query::NodeEnum; | |
struct Group { | |
name: &'static str, | |
patterns: &'static [&'static str], | |
parent: Option<&'static str>, | |
} | |
const GROUPS: [Group; 27] = [ | |
Group { | |
name: "gallery", | |
patterns: &["category", "goal"], | |
parent: None, | |
}, | |
Group { | |
name: "channel", | |
patterns: &["channel*"], | |
parent: None, | |
}, | |
Group { | |
name: "common", | |
patterns: &[ | |
"timezone", | |
"locale", | |
"language", | |
"industry", | |
"country", | |
"continent", | |
], | |
parent: None, | |
}, | |
Group { | |
name: "custom_field", | |
patterns: &[ | |
"custom_field", | |
"deal_custom_field", | |
"appointment_custom_field", | |
"contact_custom_field", | |
], | |
parent: None, | |
}, | |
Group { | |
name: "appointment", | |
patterns: &["appointment*"], | |
parent: None, | |
}, | |
Group { | |
name: "deal", | |
patterns: &["deal*"], | |
parent: None, | |
}, | |
Group { | |
name: "form", | |
patterns: &["form*"], | |
parent: None, | |
}, | |
Group { | |
name: "journey", | |
patterns: &["journey*"], | |
parent: None, | |
}, | |
Group { | |
name: "rule", | |
patterns: &["rule*"], | |
parent: None, | |
}, | |
Group { | |
name: "web_widget", | |
patterns: &["web_widget*"], | |
parent: None, | |
}, | |
Group { | |
name: "review_channel", | |
patterns: &["review*"], | |
parent: None, | |
}, | |
Group { | |
name: "conversation_tag", | |
patterns: &["conversation_tag", "tag"], | |
parent: Some("conversation"), | |
}, | |
Group { | |
name: "template", | |
patterns: &["template*", "provider_template_approval"], | |
parent: None, | |
}, | |
Group { | |
name: "campaign", | |
patterns: &["campaign*"], | |
parent: None, | |
}, | |
Group { | |
name: "webhook", | |
patterns: &["webhook*"], | |
parent: None, | |
}, | |
Group { | |
name: "message", | |
patterns: &["message*"], | |
parent: Some("conversation"), | |
}, | |
Group { | |
name: "conversation_tag", | |
patterns: &["tag", "conversation_tag"], | |
parent: Some("conversation"), | |
}, | |
Group { | |
name: "sendout", | |
patterns: &["sendout*"], | |
parent: None, | |
}, | |
Group { | |
name: "marketing_channel", | |
patterns: &["marketing_channel", "marketing_subscription"], | |
parent: None, | |
}, | |
Group { | |
name: "contact_list", | |
patterns: &["contact_list*"], | |
parent: None, | |
}, | |
Group { | |
name: "contact", | |
patterns: &["contact*"], | |
parent: None, | |
}, | |
Group { | |
name: "message", | |
patterns: &["message*", "recipient"], | |
parent: Some("conversation"), | |
}, | |
Group { | |
name: "comment", | |
patterns: &["comment*"], | |
parent: Some("conversation"), | |
}, | |
Group { | |
name: "conversation", | |
patterns: &["conversation*"], | |
parent: None, | |
}, | |
Group { | |
name: "inbox", | |
patterns: &["inbox*"], | |
parent: None, | |
}, | |
Group { | |
name: "team", | |
patterns: &["team*", "inbox_access"], | |
parent: None, | |
}, | |
Group { | |
name: "employee", | |
patterns: &["employee", "favorite_template", "pinned*"], | |
parent: None, | |
}, | |
]; | |
// Build matchers once using Lazy static initialization | |
static MATCHERS: Lazy<Vec<(&'static str, globset::GlobSet)>> = Lazy::new(|| { | |
GROUPS | |
.iter() | |
.map(|group| { | |
let mut builder = GlobSetBuilder::new(); | |
for pattern in group.patterns { | |
builder.add(Glob::new(pattern).unwrap()); | |
} | |
(group.name, builder.build().unwrap()) | |
}) | |
.collect() | |
}); | |
/// Match a value to the corresponding group name based on glob patterns. | |
fn match_group(value: &str) -> Option<&'static str> { | |
for (name, matcher) in MATCHERS.iter() { | |
if matcher.is_match(value) { | |
return Some(name); | |
} | |
} | |
None | |
} | |
#[derive(Debug)] | |
struct SchemaNode { | |
name: String, | |
sql: String, | |
} | |
#[derive(Debug)] | |
struct TableNode { | |
schema: String, | |
name: String, | |
sql: String, | |
} | |
#[derive(Debug)] | |
struct FunctionNode { | |
schema: String, | |
name: String, | |
sql: String, | |
} | |
#[derive(Debug)] | |
struct EnablePolicyNode { | |
schema: String, | |
table: String, | |
sql: String, | |
} | |
#[derive(Debug)] | |
struct PolicyNode { | |
schema: String, | |
name: String, | |
table: String, | |
sql: String, | |
} | |
#[derive(Debug)] | |
struct IndexNode { | |
schema: String, | |
name: String, | |
table: String, | |
sql: String, | |
} | |
#[derive(Debug)] | |
struct ViewNode { | |
schema: String, | |
name: String, | |
sql: String, | |
} | |
#[derive(Debug)] | |
struct TriggerFunctionNode { | |
schema: String, | |
name: String, | |
sql: String, | |
} | |
#[derive(Debug)] | |
struct TriggerNode { | |
schema: String, | |
name: String, | |
table: String, | |
function: String, | |
sql: String, | |
} | |
#[derive(Debug)] | |
struct EnumNode { | |
schema: String, | |
name: String, | |
sql: String, | |
} | |
#[derive(Debug)] | |
struct CompositeTypeNode { | |
schema: String, | |
name: String, | |
sql: String, | |
} | |
#[derive(Debug)] | |
struct ForeignKeyNode { | |
constraint_name: String, | |
source_schema: String, | |
source_table: String, | |
target_schema: String, | |
target_table: String, | |
sql: String, | |
} | |
#[derive(Debug)] | |
struct AggregateNode { | |
schema: String, | |
name: String, | |
sql: String, | |
} | |
#[derive(Debug)] | |
enum Node { | |
Schema(SchemaNode), | |
Table(TableNode), | |
Function(FunctionNode), | |
EnablePolicy(EnablePolicyNode), | |
Policy(PolicyNode), | |
Index(IndexNode), | |
View(ViewNode), | |
TriggerFunction(TriggerFunctionNode), | |
Trigger(TriggerNode), | |
Comment(pg_query::protobuf::CommentStmt), | |
EnumNode(EnumNode), | |
CompositeType(CompositeTypeNode), | |
ForeignKey(ForeignKeyNode), | |
Aggregate(AggregateNode), | |
} | |
fn main() { | |
let schema_path = Path::new("../../supabase").join("schema.sql"); | |
let schema = std::fs::read_to_string(schema_path).expect("Failed to read schema file"); | |
let mut nodes: Vec<Node> = Vec::new(); | |
pg_query::split_with_parser(&schema) | |
.expect("error parsing") | |
.iter() | |
.for_each(|sql| { | |
parse(sql, &mut nodes); | |
}); | |
// structure: | |
// /schema1 | |
// index.sql | |
// /tables | |
// /functions | |
// /schema2 | |
// index.sql | |
let out_dir = Path::new("../../supabase/schema"); | |
let _ = fs::remove_dir_all(out_dir); | |
nodes.iter().for_each(|n| match n { | |
Node::Schema(s) => { | |
let setup_file_path = out_dir.join(&s.name).join("setup.sql"); | |
let index_file_path = out_dir.join("index.sql"); | |
// ensure it sends with ";" | |
let content = if s.sql.ends_with(';') { | |
s.sql.to_string() | |
} else { | |
format!("{};", s.sql) | |
}; | |
append_to_file(&setup_file_path, &content); | |
append_to_file( | |
&index_file_path, | |
format!("-- atlas:import {}/setup.sql", &s.name).as_str(), | |
); | |
append_to_file( | |
&index_file_path, | |
format!("-- atlas:import {}/index.sql", &s.name).as_str(), | |
); | |
} | |
Node::Table(t) => { | |
save_table_object( | |
&t.name, &t.name, &t.schema, &t.sql, "tables", out_dir, &GROUPS, | |
); | |
} | |
Node::TriggerFunction(tf) => { | |
let mut tables = HashSet::new(); | |
for node in nodes.iter() { | |
if let Node::Trigger(t) = node { | |
if t.function == tf.name { | |
tables.insert(t.table.clone()); | |
} | |
} | |
} | |
let triggers_tables = tables.iter().collect::<Vec<_>>(); | |
if triggers_tables.is_empty() || triggers_tables.len() > 1 { | |
save_object(&tf.name, &tf.schema, &tf.sql, "triggers", out_dir); | |
} else { | |
let trigger_table = triggers_tables.first().unwrap(); | |
save_table_object( | |
&tf.name, trigger_table, &tf.schema, &tf.sql, "triggers", out_dir, &GROUPS, | |
); | |
} | |
} | |
Node::Trigger(t) => { | |
save_table_object( | |
&t.function, | |
&t.table, | |
&t.schema, | |
&t.sql, | |
"triggers", | |
out_dir, | |
&GROUPS, | |
); | |
} | |
Node::Function(f) => { | |
save_object(&f.name, &f.schema, &f.sql, "functions", out_dir); | |
} | |
Node::EnablePolicy(p) => { | |
save_table_object( | |
"enable", &p.table, &p.schema, &p.sql, "policies", out_dir, &GROUPS, | |
); | |
} | |
Node::Policy(p) => { | |
save_table_object( | |
&p.name, &p.table, &p.schema, &p.sql, "policies", out_dir, &GROUPS, | |
); | |
} | |
Node::Index(i) => { | |
save_table_object( | |
&i.name, &i.table, &i.schema, &i.sql, "indexes", out_dir, &GROUPS, | |
); | |
} | |
Node::View(v) => { | |
save_table_object( | |
&v.name, &v.name, &v.schema, &v.sql, "views", out_dir, &GROUPS, | |
); | |
} | |
Node::EnumNode(e) => { | |
save_object(&e.name, &e.schema, &e.sql, "enums", out_dir); | |
} | |
Node::CompositeType(c) => { | |
save_object(&c.name, &c.schema, &c.sql, "types", out_dir); | |
} | |
Node::ForeignKey(fk) => { | |
save_table_object( | |
&fk.constraint_name, | |
&fk.source_table, | |
&fk.source_schema, | |
&fk.sql, | |
"foreign_keys", | |
out_dir, | |
&GROUPS, | |
); | |
} | |
Node::Aggregate(a) => { | |
save_object(&a.name, &a.schema, &a.sql, "aggregates", out_dir); | |
} | |
Node::Comment(c) => match c.objtype() { | |
pg_query::protobuf::ObjectType::Undefined => todo!(), | |
pg_query::protobuf::ObjectType::ObjectAccessMethod => todo!(), | |
pg_query::protobuf::ObjectType::ObjectAggregate => todo!(), | |
pg_query::protobuf::ObjectType::ObjectAmop => todo!(), | |
pg_query::protobuf::ObjectType::ObjectAmproc => todo!(), | |
pg_query::protobuf::ObjectType::ObjectAttribute => todo!(), | |
pg_query::protobuf::ObjectType::ObjectCast => todo!(), | |
pg_query::protobuf::ObjectType::ObjectColumn => { | |
let list = &c.object.clone().map(|n| n.node.clone()).unwrap().unwrap(); | |
if let pg_query::NodeEnum::List(l) = list { | |
let items = l | |
.items | |
.iter() | |
.map(|n| get_sval(&n.node)) | |
.collect::<Vec<_>>(); | |
if items.len() != 3 { | |
panic!("not three items in list"); | |
} | |
let schema = items.first().unwrap(); | |
let table_name = items.get(1).unwrap(); | |
let column_name = items.get(2).unwrap(); | |
// check if table or view | |
let entity_type = nodes.iter().find_map(|n| { | |
if let Node::Table(t) = n { | |
if t.name == *table_name && t.schema == *schema { | |
Some("tables") | |
} else { | |
None | |
} | |
} else if let Node::View(v) = n { | |
if v.name == *table_name && v.schema == *schema { | |
Some("views") | |
} else { | |
None | |
} | |
} else { | |
None | |
} | |
}); | |
save_table_object( | |
table_name, | |
table_name, | |
schema, | |
format!( | |
"COMMENT ON COLUMN \"{}\".\"{}\".\"{}\" IS '{}';", | |
schema, table_name, column_name, c.comment | |
) | |
.as_str(), | |
entity_type.expect("no table or view found"), | |
out_dir, | |
&GROUPS, | |
); | |
} | |
} | |
pg_query::protobuf::ObjectType::ObjectCollation => todo!(), | |
pg_query::protobuf::ObjectType::ObjectConversion => todo!(), | |
pg_query::protobuf::ObjectType::ObjectDatabase => todo!(), | |
pg_query::protobuf::ObjectType::ObjectDefault => todo!(), | |
pg_query::protobuf::ObjectType::ObjectDefacl => todo!(), | |
pg_query::protobuf::ObjectType::ObjectDomain => todo!(), | |
pg_query::protobuf::ObjectType::ObjectDomconstraint => todo!(), | |
pg_query::protobuf::ObjectType::ObjectEventTrigger => todo!(), | |
pg_query::protobuf::ObjectType::ObjectExtension => todo!(), | |
pg_query::protobuf::ObjectType::ObjectFdw => todo!(), | |
pg_query::protobuf::ObjectType::ObjectForeignServer => todo!(), | |
pg_query::protobuf::ObjectType::ObjectForeignTable => todo!(), | |
pg_query::protobuf::ObjectType::ObjectFunction => { | |
let list = &c.object.clone().map(|n| n.node.clone()).unwrap().unwrap(); | |
if let pg_query::NodeEnum::ObjectWithArgs(obj) = list { | |
let items = obj | |
.objname | |
.iter() | |
.map(|n| get_sval(&n.node)) | |
.collect::<Vec<_>>(); | |
if items.len() != 2 { | |
panic!("not two items in list"); | |
} | |
let schema = items.first().unwrap(); | |
let function_name = items.last().unwrap(); | |
save_object( | |
function_name, | |
schema, | |
format!( | |
"COMMENT ON FUNCTION \"{}\".\"{}\" IS '{}';", | |
schema, function_name, c.comment | |
) | |
.as_str(), | |
"functions", | |
out_dir, | |
); | |
} | |
} | |
pg_query::protobuf::ObjectType::ObjectIndex => todo!(), | |
pg_query::protobuf::ObjectType::ObjectLanguage => todo!(), | |
pg_query::protobuf::ObjectType::ObjectLargeobject => todo!(), | |
pg_query::protobuf::ObjectType::ObjectMatview => todo!(), | |
pg_query::protobuf::ObjectType::ObjectOpclass => todo!(), | |
pg_query::protobuf::ObjectType::ObjectOperator => todo!(), | |
pg_query::protobuf::ObjectType::ObjectOpfamily => todo!(), | |
pg_query::protobuf::ObjectType::ObjectParameterAcl => todo!(), | |
pg_query::protobuf::ObjectType::ObjectPolicy => todo!(), | |
pg_query::protobuf::ObjectType::ObjectProcedure => todo!(), | |
pg_query::protobuf::ObjectType::ObjectPublication => todo!(), | |
pg_query::protobuf::ObjectType::ObjectPublicationNamespace => todo!(), | |
pg_query::protobuf::ObjectType::ObjectPublicationRel => todo!(), | |
pg_query::protobuf::ObjectType::ObjectRole => todo!(), | |
pg_query::protobuf::ObjectType::ObjectRoutine => todo!(), | |
pg_query::protobuf::ObjectType::ObjectRule => todo!(), | |
pg_query::protobuf::ObjectType::ObjectSchema => { | |
let schema_name = get_sval(&c.object.clone().map(|n| n.node.clone()).unwrap()); | |
let setup_file_path = out_dir.join(&schema_name).join("setup.sql"); | |
append_to_file( | |
&setup_file_path, | |
format!("COMMENT ON SCHEMA \"{}\" IS '{}';", schema_name, c.comment).as_str(), | |
); | |
} | |
pg_query::protobuf::ObjectType::ObjectSequence => todo!(), | |
pg_query::protobuf::ObjectType::ObjectSubscription => todo!(), | |
pg_query::protobuf::ObjectType::ObjectStatisticExt => todo!(), | |
pg_query::protobuf::ObjectType::ObjectTabconstraint => todo!(), | |
pg_query::protobuf::ObjectType::ObjectTable => { | |
let list = &c.object.clone().map(|n| n.node.clone()).unwrap().unwrap(); | |
if let pg_query::NodeEnum::List(l) = list { | |
let items = l | |
.items | |
.iter() | |
.map(|n| get_sval(&n.node)) | |
.collect::<Vec<_>>(); | |
if items.len() != 2 { | |
panic!("not two items in list"); | |
} | |
let schema = items.first().unwrap(); | |
let table_name = items.last().unwrap(); | |
// check if table or view | |
let node = nodes.iter().find(|n| { | |
if let Node::Table(t) = n { | |
t.name == *table_name && t.schema == *schema | |
} else if let Node::View(v) = n { | |
v.name == *table_name && v.schema == *schema | |
} else { | |
false | |
} | |
}); | |
if let Some(Node::Table(_t)) = node { | |
save_table_object( | |
table_name, | |
table_name, | |
schema, | |
format!( | |
"COMMENT ON TABLE \"{}\".\"{}\" IS '{}';", | |
schema, table_name, c.comment | |
) | |
.as_str(), | |
"tables", | |
out_dir, | |
&GROUPS, | |
); | |
} else if let Some(Node::View(_v)) = node { | |
save_table_object( | |
table_name, | |
table_name, | |
schema, | |
format!( | |
"COMMENT ON VIEW \"{}\".\"{}\" IS '{}';", | |
schema, table_name, c.comment | |
) | |
.as_str(), | |
"views", | |
out_dir, | |
&GROUPS, | |
); | |
} else { | |
panic!("no table or view found for {}.{}", schema, table_name); | |
} | |
} | |
} | |
pg_query::protobuf::ObjectType::ObjectTablespace => todo!(), | |
pg_query::protobuf::ObjectType::ObjectTransform => todo!(), | |
pg_query::protobuf::ObjectType::ObjectTrigger => todo!(), | |
pg_query::protobuf::ObjectType::ObjectTsconfiguration => todo!(), | |
pg_query::protobuf::ObjectType::ObjectTsdictionary => todo!(), | |
pg_query::protobuf::ObjectType::ObjectTsparser => todo!(), | |
pg_query::protobuf::ObjectType::ObjectTstemplate => todo!(), | |
pg_query::protobuf::ObjectType::ObjectType => todo!(), | |
pg_query::protobuf::ObjectType::ObjectUserMapping => todo!(), | |
pg_query::protobuf::ObjectType::ObjectView => todo!(), | |
}, | |
}) | |
} | |
// Helper function to get the full group path for a matched group | |
fn get_group_path(group_name: &'static str, groups: &[Group]) -> Vec<&'static str> { | |
let mut path = Vec::new(); | |
let mut current = group_name; | |
// First add the matched group | |
path.push(current); | |
// Then add all parents | |
while let Some(group) = groups.iter().find(|g| g.name == current) { | |
// If it has a parent, add it and continue | |
match group.parent { | |
Some(parent) => { | |
current = parent; | |
path.push(current); | |
} | |
None => { | |
// No parent, we're done | |
break; | |
} | |
} | |
} | |
// Reverse to get parent -> child order | |
path.reverse(); | |
path | |
} | |
fn save_object(name: &str, schema: &str, c: &str, entity_type: &str, out_dir: &Path) { | |
let object_type_index = out_dir.join(schema).join("index.sql"); | |
append_to_file( | |
&object_type_index, | |
format!("-- atlas:import {}/index.sql", entity_type).as_str(), | |
); | |
// ensure it sends with ";" | |
let content = if c.ends_with(';') { | |
c.to_string() | |
} else { | |
format!("{};", c) | |
}; | |
let file_name = format!("{}.sql", name); | |
let base_dir = out_dir.join(schema).join(entity_type); | |
let main_index_path = base_dir.join("index.sql"); | |
// Save the actual entity file | |
let file_path = base_dir.join(&file_name); | |
append_to_file(&file_path, &content); | |
// Add import to the main index.sql | |
append_to_file( | |
&main_index_path, | |
format!("-- atlas:import {}", &file_name).as_str(), | |
); | |
} | |
fn save_table_object( | |
name: &str, | |
table_name: &str, | |
schema: &str, | |
c: &str, | |
entity_type: &str, | |
out_dir: &Path, | |
groups: &[Group], | |
) { | |
// write to schema/objtype/index.sql | |
let object_type_index = out_dir.join(schema).join("index.sql"); | |
append_to_file( | |
&object_type_index, | |
format!("-- atlas:import {}/index.sql", entity_type).as_str(), | |
); | |
// ensure it sends with ";" | |
let content = if c.ends_with(';') { | |
c.to_string() | |
} else { | |
format!("{}\n;", c) | |
}; | |
let file_name = format!("{}.sql", name); | |
let base_dir = out_dir.join(schema).join(entity_type); | |
let main_index_path = base_dir.join("index.sql"); | |
let group_match = match_group(table_name); | |
// Try to match the entity to a group | |
if group_match.is_some() || entity_type != "tables" { | |
let mut group_path = vec![]; | |
if let Some(group_name) = group_match { | |
// Get the full path of groups (parent -> child) | |
group_path.append(&mut get_group_path(group_name, groups)); | |
} | |
if entity_type != "tables" && entity_type != "views" { | |
group_path.push(table_name); | |
} | |
// Handle nested groups | |
let mut current_dir = base_dir; | |
let mut current_index_path = main_index_path; | |
// Process each level of the group hierarchy | |
for &group in group_path.iter() { | |
let next_dir = current_dir.join(group); | |
let next_index_path = next_dir.join("index.sql"); | |
// Create reference from parent to child index | |
append_to_file( | |
¤t_index_path, | |
format!("-- atlas:import {}/index.sql", group).as_str(), | |
); | |
// Update current paths for next iteration | |
current_dir = next_dir; | |
current_index_path = next_index_path; | |
} | |
// Save the actual entity file in the deepest group directory | |
let file_path = current_dir.join(&file_name); | |
append_to_file(&file_path, &content); | |
// Add import to the deepest group's index.sql | |
append_to_file( | |
¤t_index_path, | |
format!("-- atlas:import {}", &file_name).as_str(), | |
); | |
} else { | |
// No group matched, save directly in the base directory | |
let file_path = base_dir.join(&file_name); | |
append_to_file(&file_path, &content); | |
append_to_file( | |
&main_index_path, | |
format!("-- atlas:import {}", &file_name).as_str(), | |
); | |
} | |
} | |
fn append_to_file(path: &Path, content: &str) { | |
// Create parent directories if they don't exist | |
if let Some(parent) = path.parent() { | |
std::fs::create_dir_all(parent).expect("Failed to create parent directories"); | |
} | |
// Check if file exists and if content is already in it | |
let file_exists = path.exists(); | |
let content_exists = if file_exists { | |
match fs::read_to_string(path) { | |
Ok(existing_content) => existing_content.contains(content), | |
Err(_) => false, | |
} | |
} else { | |
false | |
}; | |
// Only append if content doesn't already exist | |
if !content_exists { | |
let mut file = std::fs::OpenOptions::new() | |
.create(true) | |
.append(true) | |
.open(path) | |
.expect("Failed to open file"); | |
writeln!(file, "{}", content).expect("Failed to write to file"); | |
} | |
} | |
fn parse(sql: &str, nodes: &mut Vec<Node>) { | |
let node = parse_sql(sql); | |
match node { | |
pg_query::NodeEnum::CreateSchemaStmt(n) => { | |
let schema_name = n.schemaname.to_string(); | |
nodes.push(Node::Schema(SchemaNode { | |
name: schema_name.clone(), | |
sql: sql.to_string(), | |
})); | |
} | |
pg_query::NodeEnum::CommentStmt(n) => { | |
nodes.push(Node::Comment(*n)); | |
} | |
pg_query::NodeEnum::CreateExtensionStmt(_) => { | |
// | |
} | |
pg_query::NodeEnum::CreateEnumStmt(n) => { | |
let names = n | |
.type_name | |
.iter() | |
.map(|n| get_sval(&n.node)) | |
.collect::<Vec<_>>(); | |
let schema = if names.len() > 1 { | |
names.first().unwrap() | |
} else { | |
&"public".to_string() | |
}; | |
let type_name = names.last().unwrap().to_string(); | |
nodes.push(Node::EnumNode(EnumNode { | |
schema: schema.to_string(), | |
name: type_name, | |
sql: sql.to_string(), | |
})); | |
} | |
pg_query::NodeEnum::DefineStmt(n) => match n.kind() { | |
pg_query::protobuf::ObjectType::ObjectAggregate => { | |
let names = n | |
.defnames | |
.iter() | |
.map(|n| get_sval(&n.node)) | |
.collect::<Vec<_>>(); | |
let schema = if names.len() > 1 { | |
names.first().unwrap() | |
} else { | |
&"public".to_string() | |
}; | |
let type_name = names.last().unwrap().to_string(); | |
nodes.push(Node::Aggregate(AggregateNode { | |
schema: schema.to_string(), | |
name: type_name, | |
sql: sql.to_string(), | |
})); | |
} | |
_ => panic!("Unsupported define statement"), | |
}, | |
pg_query::NodeEnum::CompositeTypeStmt(n) => { | |
let name = n.typevar.expect("no typevar"); | |
let schema = name.schemaname; | |
let type_name = name.relname; | |
nodes.push(Node::CompositeType(CompositeTypeNode { | |
schema: schema.to_string(), | |
name: type_name, | |
sql: sql.to_string(), | |
})); | |
} | |
pg_query::NodeEnum::ViewStmt(n) => { | |
let rel = n.view.expect("no relation"); | |
let schema = rel.schemaname; | |
let view_name = rel.relname; | |
nodes.push(Node::View(ViewNode { | |
schema: schema.clone(), | |
name: view_name, | |
sql: sql.to_string(), | |
})); | |
} | |
pg_query::NodeEnum::CreatePolicyStmt(n) => { | |
let name = n.policy_name; | |
let table = n.table.expect("no table"); | |
let schema = table.schemaname; | |
let relation_name = table.relname; | |
nodes.push(Node::Policy(PolicyNode { | |
schema: schema.clone(), | |
name, | |
table: relation_name, | |
sql: sql.to_string(), | |
})); | |
} | |
pg_query::NodeEnum::CreateStmt(n) => { | |
let rel = n.relation.expect("no relation"); | |
let schema = rel.schemaname; | |
let table_name = rel.relname; | |
nodes.push(Node::Table(TableNode { | |
schema: schema.clone(), | |
name: table_name, | |
sql: sql.to_string(), | |
})); | |
} | |
pg_query::NodeEnum::CreateTrigStmt(n) => { | |
let rel = n.relation.expect("no relation"); | |
let schema = rel.schemaname; | |
let table_name = rel.relname; | |
let func_name = n | |
.funcname | |
.iter() | |
.map(|n| get_sval(&n.node)) | |
.collect::<Vec<_>>(); | |
let function_name = func_name.last().unwrap().to_string(); | |
let trigger_name = n.trigname.clone(); | |
nodes.push(Node::Trigger(TriggerNode { | |
schema: schema.clone(), | |
name: trigger_name, | |
table: table_name, | |
function: function_name, | |
sql: sql.to_string(), | |
})); | |
} | |
pg_query::NodeEnum::CreateFunctionStmt(n) => { | |
let func_name = n | |
.funcname | |
.iter() | |
.map(|n| get_sval(&n.node)) | |
.collect::<Vec<_>>(); | |
let schema = if func_name.len() > 1 { | |
func_name.first().unwrap() | |
} else { | |
&"public".to_string() | |
}; | |
let function_name = func_name.last().unwrap().to_string(); | |
let is_trigger = n.return_type.as_ref().expect("").names.iter().any(|n| { | |
let type_name = get_sval(&n.node); | |
type_name == "trigger" | |
}); | |
if is_trigger { | |
nodes.push(Node::TriggerFunction(TriggerFunctionNode { | |
schema: schema.to_string(), | |
name: function_name, | |
sql: sql.to_string(), | |
})); | |
} else { | |
nodes.push(Node::Function(FunctionNode { | |
schema: schema.to_string(), | |
name: function_name, | |
sql: sql.to_string(), | |
})); | |
} | |
} | |
pg_query::NodeEnum::IndexStmt(n) => { | |
let rel = n.relation.expect("no relation"); | |
let schema = rel.schemaname; | |
let index_name = n.idxname; | |
let table_name = rel.relname; | |
nodes.push(Node::Index(IndexNode { | |
schema: schema.clone(), | |
name: index_name, | |
table: table_name, | |
sql: sql.to_string(), | |
})); | |
} | |
pg_query::NodeEnum::AlterTableStmt(n) => { | |
let rel = n.relation.expect("no relation"); | |
let schema = rel.schemaname; | |
let table_name = rel.relname; | |
let number_of_commands = n.cmds.len(); | |
let cmd = n.cmds.first().expect("no command").node.clone().unwrap(); | |
match &cmd { | |
pg_query::NodeEnum::AlterTableCmd(c) => match c.subtype() { | |
pg_query::protobuf::AlterTableType::AtEnableRowSecurity => { | |
nodes.push(Node::EnablePolicy(EnablePolicyNode { | |
schema, | |
table: table_name, | |
sql: sql.to_string(), | |
})); | |
} | |
pg_query::protobuf::AlterTableType::AtAddConstraint => { | |
if number_of_commands > 1 { | |
let commands = sql[sql.find("ADD CONSTRAINT").unwrap()..] | |
.split("ADD CONSTRAINT") | |
.collect::<Vec<_>>(); | |
// get from beginning to fist ADD CONSTRAINT | |
let begin = sql[sql.find("ALTER TABLE").unwrap() | |
..sql.find("ADD CONSTRAINT").unwrap()] | |
.to_string(); | |
commands.iter().for_each(|cmd| { | |
if cmd.is_empty() { | |
return; | |
} | |
let full_sql = format!( | |
"{}ADD CONSTRAINT{}", | |
begin, | |
cmd.trim_end().trim_end_matches(',') | |
); | |
parse(&full_sql, nodes); | |
}); | |
} else if let Some(pg_query::protobuf::node::Node::Constraint(c)) = | |
c.def.clone().unwrap().node.as_ref() | |
{ | |
let constraint_name = c.conname.clone(); | |
let source_schema = schema.clone(); | |
let source_table = table_name.clone(); | |
let pktable = c.pktable.as_ref().expect("no target relation"); | |
let target_schema = pktable.schemaname.clone(); | |
let target_table = pktable.relname.clone(); | |
nodes.push(Node::ForeignKey(ForeignKeyNode { | |
constraint_name, | |
source_schema, | |
source_table, | |
target_schema, | |
target_table, | |
sql: sql.to_string(), | |
})); | |
} else { | |
panic!("no definition for foreign key"); | |
} | |
} | |
_ => panic!("Unsupported subtype {:?} for \n '{}'", cmd, sql), | |
}, | |
_ => panic!("Unsupported command {:?} for \n '{}'", cmd, sql), | |
} | |
} | |
_ => panic!("Unsupported node:\n{:?} '{}'", node, sql), | |
}; | |
} | |
fn get_sval(n: &Option<pg_query::protobuf::node::Node>) -> String { | |
match n { | |
Some(pg_query::protobuf::node::Node::String(s)) => s.sval.clone(), | |
_ => panic!("Unsupported node {:?}", n), | |
} | |
} | |
pub fn parse_sql(sql: &str) -> NodeEnum { | |
pg_query::parse(sql) | |
.map(|parsed| { | |
parsed | |
.protobuf | |
.nodes() | |
.iter() | |
.find(|n| n.1 == 1) | |
.map(|n| n.0.to_enum()) | |
.expect("error parsing protobuf") | |
}) | |
.expect("error parsing") | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment