diff --git a/crates/pgt_statement_splitter/src/lib.rs b/crates/pgt_statement_splitter/src/lib.rs
index 68f5daaf..63e68cd2 100644
--- a/crates/pgt_statement_splitter/src/lib.rs
+++ b/crates/pgt_statement_splitter/src/lib.rs
@@ -4,10 +4,10 @@
 pub mod diagnostics;
 mod parser;
 
-use parser::{Parse, Parser, source};
+use parser::{Parser, ParserResult, source};
 use pgt_lexer::diagnostics::ScanError;
 
-pub fn split(sql: &str) -> Result<Parse, Vec<ScanError>> {
+pub fn split(sql: &str) -> Result<ParserResult, Vec<ScanError>> {
     let tokens = pgt_lexer::lex(sql)?;
 
     let mut parser = Parser::new(tokens);
@@ -28,7 +28,7 @@ mod tests {
 
     struct Tester {
         input: String,
-        parse: Parse,
+        parse: ParserResult,
     }
 
     impl From<&str> for Tester {
diff --git a/crates/pgt_statement_splitter/src/parser.rs b/crates/pgt_statement_splitter/src/parser.rs
index 05de8cb0..c94fe245 100644
--- a/crates/pgt_statement_splitter/src/parser.rs
+++ b/crates/pgt_statement_splitter/src/parser.rs
@@ -13,24 +13,24 @@ use crate::diagnostics::SplitDiagnostic;
 /// Main parser that exposes the `cstree` api, and collects errors and statements
 /// It is modelled after a Pratt Parser. For a gentle introduction to Pratt Parsing, see https://matklad.github.io/2020/04/13/simple-but-powerful-pratt-parsing.html
 pub struct Parser {
-    /// The ranges of the statements
-    ranges: Vec<(usize, usize)>,
+    /// The statement ranges are defined by the indices of the start/end tokens
+    stmt_ranges: Vec<(usize, usize)>,
+
     /// The syntax errors accumulated during parsing
     errors: Vec<SplitDiagnostic>,
-    /// The start of the current statement, if any
+
     current_stmt_start: Option<usize>,
-    /// The tokens to parse
-    pub tokens: Vec<Token>,
+
+    tokens: Vec<Token>,
 
     eof_token: Token,
 
-    next_pos: usize,
+    current_pos: usize,
 }
 
-/// Result of Building
 #[derive(Debug)]
-pub struct Parse {
-    /// The ranges of the errors
+pub struct ParserResult {
+    /// The ranges of the parsed statements
     pub ranges: Vec<TextRange>,
     /// The syntax errors accumulated during parsing
     pub errors: Vec<SplitDiagnostic>,
@@ -41,40 +41,34 @@ impl Parser {
         let eof_token = Token::eof(usize::from(
             tokens
                 .last()
-                .map(|t| t.span.start())
+                .map(|t| t.span.end())
                 .unwrap_or(TextSize::from(0)),
         ));
 
-        // next_pos should be the initialised with the first valid token already
-        let mut next_pos = 0;
-        loop {
-            let token = tokens.get(next_pos).unwrap_or(&eof_token);
-
-            if is_irrelevant_token(token) {
-                next_pos += 1;
-            } else {
-                break;
-            }
+        // Place `current_pos` on the first relevant token
+        let mut current_pos = 0;
+        while is_irrelevant_token(tokens.get(current_pos).unwrap_or(&eof_token)) {
+            current_pos += 1;
         }
 
         Self {
-            ranges: Vec::new(),
+            stmt_ranges: Vec::new(),
             eof_token,
             errors: Vec::new(),
             current_stmt_start: None,
             tokens,
-            next_pos,
+            current_pos,
         }
     }
 
-    pub fn finish(self) -> Parse {
-        Parse {
+    pub fn finish(self) -> ParserResult {
+        ParserResult {
             ranges: self
-                .ranges
+                .stmt_ranges
                 .iter()
-                .map(|(start, end)| {
-                    let from = self.tokens.get(*start);
-                    let to = self.tokens.get(*end).unwrap_or(&self.eof_token);
+                .map(|(start_token_pos, end_token_pos)| {
+                    let from = self.tokens.get(*start_token_pos);
+                    let to = self.tokens.get(*end_token_pos).unwrap_or(&self.eof_token);
 
                     TextRange::new(from.unwrap().span.start(), to.span.end())
                 })
@@ -83,124 +77,87 @@ impl Parser {
         }
     }
 
-    /// Start statement
     pub fn start_stmt(&mut self) {
         assert!(
             self.current_stmt_start.is_none(),
             "cannot start statement within statement at {:?}",
             self.tokens.get(self.current_stmt_start.unwrap())
         );
-        self.current_stmt_start = Some(self.next_pos);
+        self.current_stmt_start = Some(self.current_pos);
     }
 
-    /// Close statement
     pub fn close_stmt(&mut self) {
-        assert!(self.next_pos > 0);
-
-        // go back the positions until we find the first relevant token
-        let mut end_token_pos = self.next_pos - 1;
-        loop {
-            let token = self.tokens.get(end_token_pos);
+        assert!(
+            self.current_stmt_start.is_some(),
+            "Must start statement before closing it."
+        );
 
-            if end_token_pos == 0 || token.is_none() {
-                break;
-            }
+        let start_token_pos = self.current_stmt_start.unwrap();
 
-            if !is_irrelevant_token(token.unwrap()) {
-                break;
-            }
+        assert!(
+            self.current_pos > start_token_pos,
+            "Must close the statement on a token that's later than the start token."
+        );
 
-            end_token_pos -= 1;
-        }
+        let (end_token_pos, _) = self.find_last_relevant().unwrap();
 
-        self.ranges.push((
-            self.current_stmt_start.expect("Expected active statement"),
-            end_token_pos,
-        ));
+        self.stmt_ranges.push((start_token_pos, end_token_pos));
 
         self.current_stmt_start = None;
     }
 
-    fn advance(&mut self) -> &Token {
-        let mut first_relevant_token = None;
-        loop {
-            let token = self.tokens.get(self.next_pos).unwrap_or(&self.eof_token);
-
-            // we need to continue with next_pos until the next relevant token after we already
-            // found the first one
-            if !is_irrelevant_token(token) {
-                if let Some(t) = first_relevant_token {
-                    return t;
-                }
-                first_relevant_token = Some(token);
-            }
-
-            self.next_pos += 1;
-        }
-    }
-
-    fn peek(&self) -> &Token {
-        match self.tokens.get(self.next_pos) {
+    fn current(&self) -> &Token {
+        match self.tokens.get(self.current_pos) {
             Some(token) => token,
             None => &self.eof_token,
         }
     }
 
-    /// Look ahead to the next relevant token
-    /// Returns `None` if we are already at the last relevant token
-    fn look_ahead(&self) -> Option<&Token> {
-        // we need to look ahead to the next relevant token
-        let mut look_ahead_pos = self.next_pos + 1;
-        loop {
-            let token = self.tokens.get(look_ahead_pos)?;
-
-            if !is_irrelevant_token(token) {
-                return Some(token);
-            }
+    fn advance(&mut self) -> &Token {
+        // can't reuse any `find_next_relevant` logic because of Mr. Borrow Checker
+        let (pos, token) = self
+            .tokens
+            .iter()
+            .enumerate()
+            .skip(self.current_pos + 1)
+            .find(|(_, t)| is_relevant(t))
+            .unwrap_or((self.tokens.len(), &self.eof_token));
+
+        self.current_pos = pos;
+        token
+    }
 
-            look_ahead_pos += 1;
-        }
+    fn look_ahead(&self) -> Option<&Token> {
+        self.tokens
+            .iter()
+            .skip(self.current_pos + 1)
+            .find(|t| is_relevant(t))
     }
 
     /// Returns `None` if there are no previous relevant tokens
     fn look_back(&self) -> Option<&Token> {
-        // we need to look back to the last relevant token
-        let mut look_back_pos = self.next_pos - 1;
-        loop {
-            let token = self.tokens.get(look_back_pos);
-
-            if look_back_pos == 0 || token.is_none() {
-                return None;
-            }
-
-            if !is_irrelevant_token(token.unwrap()) {
-                return token;
-            }
-
-            look_back_pos -= 1;
-        }
+        self.find_last_relevant().map(|it| it.1)
     }
 
-    /// checks if the current token is of `kind` and advances if true
-    /// returns true if the current token is of `kind`
-    pub fn eat(&mut self, kind: SyntaxKind) -> bool {
-        if self.peek().kind == kind {
+    /// Will advance if the `kind` matches the current token.
+    /// Otherwise, will add a diagnostic to the internal `errors`.
+    pub fn expect(&mut self, kind: SyntaxKind) {
+        if self.current().kind == kind {
             self.advance();
-            true
         } else {
-            false
+            self.errors.push(SplitDiagnostic::new(
+                format!("Expected {:#?}", kind),
+                self.current().span,
+            ));
         }
     }
 
-    pub fn expect(&mut self, kind: SyntaxKind) {
-        if self.eat(kind) {
-            return;
-        }
-
-        self.errors.push(SplitDiagnostic::new(
-            format!("Expected {:#?}", kind),
-            self.peek().span,
-        ));
+    fn find_last_relevant(&self) -> Option<(usize, &Token)> {
+        self.tokens
+            .iter()
+            .enumerate()
+            .take(self.current_pos)
+            .rfind(|(_, t)| is_relevant(t))
     }
 }
 
@@ -219,3 +176,57 @@ fn is_irrelevant_token(t: &Token) -> bool {
     WHITESPACE_TOKENS.contains(&t.kind)
         && (t.kind != SyntaxKind::Newline || t.text.chars().count() == 1)
 }
+
+fn is_relevant(t: &Token) -> bool {
+    !is_irrelevant_token(t)
+}
+
+#[cfg(test)]
+mod tests {
+    use pgt_lexer::SyntaxKind;
+
+    use crate::parser::Parser;
+
+    #[test]
+    fn advance_works_as_expected() {
+        let sql = r#"
+        create table users (
+            id serial primary key,
+            name text,
+            email text
+        );
+        "#;
+        let tokens = pgt_lexer::lex(sql).unwrap();
+        let total_num_tokens = tokens.len();
+
+        let mut parser = Parser::new(tokens);
+
+        let expected = vec![
+            (SyntaxKind::Create, 2),
+            (SyntaxKind::Table, 4),
+            (SyntaxKind::Ident, 6),
+            (SyntaxKind::Ascii40, 8),
+            (SyntaxKind::Ident, 11),
+            (SyntaxKind::Ident, 13),
+            (SyntaxKind::Primary, 15),
+            (SyntaxKind::Key, 17),
+            (SyntaxKind::Ascii44, 18),
+            (SyntaxKind::NameP, 21),
+            (SyntaxKind::TextP, 23),
+            (SyntaxKind::Ascii44, 24),
+            (SyntaxKind::Ident, 27),
+            (SyntaxKind::TextP, 29),
+            (SyntaxKind::Ascii41, 32),
+            (SyntaxKind::Ascii59, 33),
+        ];
+
+        for (kind, pos) in expected {
+            assert_eq!(parser.current().kind, kind);
+            assert_eq!(parser.current_pos, pos);
+            parser.advance();
+        }
+
+        assert_eq!(parser.current().kind, SyntaxKind::Eof);
+        assert_eq!(parser.current_pos, total_num_tokens);
+    }
+}
diff --git a/crates/pgt_statement_splitter/src/parser/common.rs b/crates/pgt_statement_splitter/src/parser/common.rs
index d145018d..1a355f08 100644
--- a/crates/pgt_statement_splitter/src/parser/common.rs
+++ b/crates/pgt_statement_splitter/src/parser/common.rs
@@ -9,7 +9,7 @@ use super::{
 
 pub fn source(p: &mut Parser) {
     loop {
-        match p.peek() {
+        match p.current() {
             Token {
                 kind: SyntaxKind::Eof,
                 ..
@@ -33,7 +33,7 @@ pub fn source(p: &mut Parser) {
 
 pub(crate) fn statement(p: &mut Parser) {
     p.start_stmt();
-    match p.peek().kind {
+    match p.current().kind {
         SyntaxKind::With => {
             cte(p);
         }
@@ -68,7 +68,7 @@ pub(crate) fn parenthesis(p: &mut Parser) {
     let mut depth = 1;
 
     loop {
-        match p.peek().kind {
+        match p.current().kind {
             SyntaxKind::Ascii40 => {
                 p.advance();
                 depth += 1;
@@ -91,7 +91,7 @@ pub(crate) fn case(p: &mut Parser) {
     p.expect(SyntaxKind::Case);
 
     loop {
-        match p.peek().kind {
+        match p.current().kind {
             SyntaxKind::EndP => {
                 p.advance();
                 break;
@@ -105,7 +105,7 @@ pub(crate) fn case(p: &mut Parser) {
 
 pub(crate) fn unknown(p: &mut Parser, exclude: &[SyntaxKind]) {
     loop {
-        match p.peek() {
+        match p.current() {
             Token {
                 kind: SyntaxKind::Ascii59,
                 ..
diff --git a/crates/pgt_statement_splitter/src/parser/dml.rs b/crates/pgt_statement_splitter/src/parser/dml.rs
index a45f6c40..015c50b6 100644
--- a/crates/pgt_statement_splitter/src/parser/dml.rs
+++ b/crates/pgt_statement_splitter/src/parser/dml.rs
@@ -13,7 +13,9 @@ pub(crate) fn cte(p: &mut Parser) {
         p.expect(SyntaxKind::As);
         parenthesis(p);
 
-        if !p.eat(SyntaxKind::Ascii44) {
+        if p.current().kind == SyntaxKind::Ascii44 {
+            p.advance();
+        } else {
             break;
         }
     }