diff --git a/crates/pgt_statement_splitter/src/lib.rs b/crates/pgt_statement_splitter/src/lib.rs index 9e92d3af..68f5daaf 100644 --- a/crates/pgt_statement_splitter/src/lib.rs +++ b/crates/pgt_statement_splitter/src/lib.rs @@ -142,6 +142,30 @@ mod tests { .expect_statements(vec!["insert into tbl (id) select 1", "select 3"]); } + #[test] + fn with_check() { + Tester::from("create policy employee_insert on journey_execution for insert to authenticated with check ((select private.organisation_id()) = organisation_id);") + .expect_statements(vec!["create policy employee_insert on journey_execution for insert to authenticated with check ((select private.organisation_id()) = organisation_id);"]); + } + + #[test] + fn nested_parenthesis() { + Tester::from( + "create table if not exists journey_node_execution ( + id uuid default gen_random_uuid() not null primary key, + + constraint uq_node_exec unique (journey_execution_id, journey_node_id) +);", + ) + .expect_statements(vec![ + "create table if not exists journey_node_execution ( + id uuid default gen_random_uuid() not null primary key, + + constraint uq_node_exec unique (journey_execution_id, journey_node_id) +);", + ]); + } + #[test] fn with_cte() { Tester::from("with test as (select 1 as id) select * from test;") diff --git a/crates/pgt_statement_splitter/src/parser/common.rs b/crates/pgt_statement_splitter/src/parser/common.rs index ec5f93a6..af3dc6cc 100644 --- a/crates/pgt_statement_splitter/src/parser/common.rs +++ b/crates/pgt_statement_splitter/src/parser/common.rs @@ -65,11 +65,20 @@ pub(crate) fn statement(p: &mut Parser) { pub(crate) fn parenthesis(p: &mut Parser) { p.expect(SyntaxKind::Ascii40); + let mut depth = 1; + loop { match p.peek().kind { + SyntaxKind::Ascii40 => { + p.advance(); + depth += 1; + } SyntaxKind::Ascii41 | SyntaxKind::Eof => { p.advance(); - break; + depth -= 1; + if depth == 0 { + break; + } } _ => { p.advance(); @@ -174,6 +183,8 @@ pub(crate) fn unknown(p: &mut Parser, exclude: &[SyntaxKind]) { if [ // WITH ORDINALITY should not start a new statement SyntaxKind::Ordinality, + // WITH CHECK should not start a new statement + SyntaxKind::Check, ] .iter() .all(|x| Some(x) != next.as_ref())