diff --git a/Cargo.lock b/Cargo.lock index 79ec52f0..4a1f6ea5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2765,6 +2765,7 @@ dependencies = [ "pgt_console", "pgt_diagnostics", "pgt_fs", + "pgt_lexer", "pgt_query_ext", "pgt_schema_cache", "pgt_statement_splitter", diff --git a/crates/pgt_workspace/Cargo.toml b/crates/pgt_workspace/Cargo.toml index 7df42b19..5f598b2d 100644 --- a/crates/pgt_workspace/Cargo.toml +++ b/crates/pgt_workspace/Cargo.toml @@ -25,6 +25,7 @@ pgt_configuration = { workspace = true } pgt_console = { workspace = true } pgt_diagnostics = { workspace = true } pgt_fs = { workspace = true, features = ["serde"] } +pgt_lexer = { workspace = true } pgt_query_ext = { workspace = true } pgt_schema_cache = { workspace = true } pgt_statement_splitter = { workspace = true } diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 27f5e8be..5e33bc27 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -43,6 +43,7 @@ use super::{ pub use statement_identifier::StatementId; mod analyser; +mod annotation; mod async_helper; mod change; mod db_connection; diff --git a/crates/pgt_workspace/src/workspace/server/annotation.rs b/crates/pgt_workspace/src/workspace/server/annotation.rs new file mode 100644 index 00000000..321dd3ac --- /dev/null +++ b/crates/pgt_workspace/src/workspace/server/annotation.rs @@ -0,0 +1,87 @@ +use std::sync::Arc; + +use dashmap::DashMap; +use pgt_lexer::{SyntaxKind, WHITESPACE_TOKENS}; + +use super::statement_identifier::StatementId; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StatementAnnotations { + ends_with_semicolon: bool, +} + +pub struct AnnotationStore { + db: DashMap>>, +} + +impl AnnotationStore { + pub fn new() -> AnnotationStore { + AnnotationStore { db: DashMap::new() } + } + + #[allow(unused)] + pub fn get_annotations( + &self, + statement: &StatementId, + content: &str, + ) -> Option> { + if let Some(existing) = self.db.get(statement).map(|x| x.clone()) { + return existing; + } + + // we swallow the error here because the lexing within the document would have already + // thrown and we wont even get here if that happened. + let annotations = pgt_lexer::lex(content).ok().map(|tokens| { + let ends_with_semicolon = tokens + .iter() + .rev() + .find(|token| !WHITESPACE_TOKENS.contains(&token.kind)) + .is_some_and(|token| token.kind == SyntaxKind::Ascii59); + + Arc::new(StatementAnnotations { + ends_with_semicolon, + }) + }); + + self.db.insert(statement.clone(), None); + annotations + } + + pub fn clear_statement(&self, id: &StatementId) { + self.db.remove(id); + + if let Some(child_id) = id.get_child_id() { + self.db.remove(&child_id); + } + } +} + +#[cfg(test)] +mod tests { + use crate::workspace::StatementId; + + use super::AnnotationStore; + + #[test] + fn annotates_correctly() { + let store = AnnotationStore::new(); + + let test_cases = [ + ("SELECT * FROM foo", false), + ("SELECT * FROM foo;", true), + ("SELECT * FROM foo ;", true), + ("SELECT * FROM foo ; ", true), + ("SELECT * FROM foo ;\n", true), + ("SELECT * FROM foo\n", false), + ]; + + for (idx, (content, expected)) in test_cases.iter().enumerate() { + let statement_id = StatementId::Root(idx.into()); + + let annotations = store.get_annotations(&statement_id, content); + + assert!(annotations.is_some()); + assert_eq!(annotations.unwrap().ends_with_semicolon, *expected); + } + } +} diff --git a/crates/pgt_workspace/src/workspace/server/change.rs b/crates/pgt_workspace/src/workspace/server/change.rs index afe0eb64..69c68189 100644 --- a/crates/pgt_workspace/src/workspace/server/change.rs +++ b/crates/pgt_workspace/src/workspace/server/change.rs @@ -409,7 +409,7 @@ mod tests { use pgt_diagnostics::Diagnostic; use pgt_text_size::TextRange; - use crate::workspace::{ChangeFileParams, ChangeParams, server::statement_identifier::root_id}; + use crate::workspace::{ChangeFileParams, ChangeParams}; use pgt_fs::PgTPath; @@ -886,14 +886,14 @@ mod tests { assert_eq!( changed[2], StatementChange::Added(AddedStatement { - stmt: StatementId::Root(root_id(2)), + stmt: StatementId::Root(2.into()), text: "select id,test from users".to_string() }) ); assert_eq!( changed[3], StatementChange::Added(AddedStatement { - stmt: StatementId::Root(root_id(3)), + stmt: StatementId::Root(3.into()), text: "select 1;".to_string() }) ); diff --git a/crates/pgt_workspace/src/workspace/server/parsed_document.rs b/crates/pgt_workspace/src/workspace/server/parsed_document.rs index a110fb1f..2b64d24a 100644 --- a/crates/pgt_workspace/src/workspace/server/parsed_document.rs +++ b/crates/pgt_workspace/src/workspace/server/parsed_document.rs @@ -8,6 +8,7 @@ use pgt_text_size::{TextRange, TextSize}; use crate::workspace::ChangeFileParams; use super::{ + annotation::AnnotationStore, change::StatementChange, document::{Document, StatementIterator}, pg_query::PgQueryStore, @@ -24,6 +25,7 @@ pub struct ParsedDocument { ast_db: PgQueryStore, cst_db: TreeSitterStore, sql_fn_db: SQLFunctionBodyStore, + annotation_db: AnnotationStore, } impl ParsedDocument { @@ -33,6 +35,7 @@ impl ParsedDocument { let cst_db = TreeSitterStore::new(); let ast_db = PgQueryStore::new(); let sql_fn_db = SQLFunctionBodyStore::new(); + let annotation_db = AnnotationStore::new(); doc.iter().for_each(|(stmt, _, content)| { cst_db.add_statement(&stmt, content); @@ -44,6 +47,7 @@ impl ParsedDocument { ast_db, cst_db, sql_fn_db, + annotation_db, } } @@ -69,6 +73,7 @@ impl ParsedDocument { self.cst_db.remove_statement(s); self.ast_db.clear_statement(s); self.sql_fn_db.clear_statement(s); + self.annotation_db.clear_statement(s); } StatementChange::Modified(s) => { tracing::debug!( @@ -84,6 +89,7 @@ impl ParsedDocument { self.cst_db.modify_statement(s); self.ast_db.clear_statement(&s.old_stmt); self.sql_fn_db.clear_statement(&s.old_stmt); + self.annotation_db.clear_statement(&s.old_stmt); } } } diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs index 0739fb2f..8c02814d 100644 --- a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -7,8 +7,17 @@ pub struct RootId { } #[cfg(test)] -pub fn root_id(inner: usize) -> RootId { - RootId { inner } +impl From for usize { + fn from(val: RootId) -> Self { + val.inner + } +} + +#[cfg(test)] +impl From for RootId { + fn from(inner: usize) -> Self { + RootId { inner } + } } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]