Skip to content

Commit a358cee

Browse files
authored
feat: annotations (#331)
1 parent 70f0c93 commit a358cee

File tree

7 files changed

+110
-5
lines changed

7 files changed

+110
-5
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/pgt_workspace/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pgt_configuration = { workspace = true }
2525
pgt_console = { workspace = true }
2626
pgt_diagnostics = { workspace = true }
2727
pgt_fs = { workspace = true, features = ["serde"] }
28+
pgt_lexer = { workspace = true }
2829
pgt_query_ext = { workspace = true }
2930
pgt_schema_cache = { workspace = true }
3031
pgt_statement_splitter = { workspace = true }

crates/pgt_workspace/src/workspace/server.rs

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ use super::{
4343
pub use statement_identifier::StatementId;
4444

4545
mod analyser;
46+
mod annotation;
4647
mod async_helper;
4748
mod change;
4849
mod db_connection;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
use std::sync::Arc;
2+
3+
use dashmap::DashMap;
4+
use pgt_lexer::{SyntaxKind, WHITESPACE_TOKENS};
5+
6+
use super::statement_identifier::StatementId;
7+
8+
#[derive(Debug, Clone, PartialEq, Eq)]
9+
pub struct StatementAnnotations {
10+
ends_with_semicolon: bool,
11+
}
12+
13+
pub struct AnnotationStore {
14+
db: DashMap<StatementId, Option<Arc<StatementAnnotations>>>,
15+
}
16+
17+
impl AnnotationStore {
18+
pub fn new() -> AnnotationStore {
19+
AnnotationStore { db: DashMap::new() }
20+
}
21+
22+
#[allow(unused)]
23+
pub fn get_annotations(
24+
&self,
25+
statement: &StatementId,
26+
content: &str,
27+
) -> Option<Arc<StatementAnnotations>> {
28+
if let Some(existing) = self.db.get(statement).map(|x| x.clone()) {
29+
return existing;
30+
}
31+
32+
// we swallow the error here because the lexing within the document would have already
33+
// thrown and we wont even get here if that happened.
34+
let annotations = pgt_lexer::lex(content).ok().map(|tokens| {
35+
let ends_with_semicolon = tokens
36+
.iter()
37+
.rev()
38+
.find(|token| !WHITESPACE_TOKENS.contains(&token.kind))
39+
.is_some_and(|token| token.kind == SyntaxKind::Ascii59);
40+
41+
Arc::new(StatementAnnotations {
42+
ends_with_semicolon,
43+
})
44+
});
45+
46+
self.db.insert(statement.clone(), None);
47+
annotations
48+
}
49+
50+
pub fn clear_statement(&self, id: &StatementId) {
51+
self.db.remove(id);
52+
53+
if let Some(child_id) = id.get_child_id() {
54+
self.db.remove(&child_id);
55+
}
56+
}
57+
}
58+
59+
#[cfg(test)]
60+
mod tests {
61+
use crate::workspace::StatementId;
62+
63+
use super::AnnotationStore;
64+
65+
#[test]
66+
fn annotates_correctly() {
67+
let store = AnnotationStore::new();
68+
69+
let test_cases = [
70+
("SELECT * FROM foo", false),
71+
("SELECT * FROM foo;", true),
72+
("SELECT * FROM foo ;", true),
73+
("SELECT * FROM foo ; ", true),
74+
("SELECT * FROM foo ;\n", true),
75+
("SELECT * FROM foo\n", false),
76+
];
77+
78+
for (idx, (content, expected)) in test_cases.iter().enumerate() {
79+
let statement_id = StatementId::Root(idx.into());
80+
81+
let annotations = store.get_annotations(&statement_id, content);
82+
83+
assert!(annotations.is_some());
84+
assert_eq!(annotations.unwrap().ends_with_semicolon, *expected);
85+
}
86+
}
87+
}

crates/pgt_workspace/src/workspace/server/change.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ mod tests {
409409
use pgt_diagnostics::Diagnostic;
410410
use pgt_text_size::TextRange;
411411

412-
use crate::workspace::{ChangeFileParams, ChangeParams, server::statement_identifier::root_id};
412+
use crate::workspace::{ChangeFileParams, ChangeParams};
413413

414414
use pgt_fs::PgTPath;
415415

@@ -886,14 +886,14 @@ mod tests {
886886
assert_eq!(
887887
changed[2],
888888
StatementChange::Added(AddedStatement {
889-
stmt: StatementId::Root(root_id(2)),
889+
stmt: StatementId::Root(2.into()),
890890
text: "select id,test from users".to_string()
891891
})
892892
);
893893
assert_eq!(
894894
changed[3],
895895
StatementChange::Added(AddedStatement {
896-
stmt: StatementId::Root(root_id(3)),
896+
stmt: StatementId::Root(3.into()),
897897
text: "select 1;".to_string()
898898
})
899899
);

crates/pgt_workspace/src/workspace/server/parsed_document.rs

+6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use pgt_text_size::{TextRange, TextSize};
88
use crate::workspace::ChangeFileParams;
99

1010
use super::{
11+
annotation::AnnotationStore,
1112
change::StatementChange,
1213
document::{Document, StatementIterator},
1314
pg_query::PgQueryStore,
@@ -24,6 +25,7 @@ pub struct ParsedDocument {
2425
ast_db: PgQueryStore,
2526
cst_db: TreeSitterStore,
2627
sql_fn_db: SQLFunctionBodyStore,
28+
annotation_db: AnnotationStore,
2729
}
2830

2931
impl ParsedDocument {
@@ -33,6 +35,7 @@ impl ParsedDocument {
3335
let cst_db = TreeSitterStore::new();
3436
let ast_db = PgQueryStore::new();
3537
let sql_fn_db = SQLFunctionBodyStore::new();
38+
let annotation_db = AnnotationStore::new();
3639

3740
doc.iter().for_each(|(stmt, _, content)| {
3841
cst_db.add_statement(&stmt, content);
@@ -44,6 +47,7 @@ impl ParsedDocument {
4447
ast_db,
4548
cst_db,
4649
sql_fn_db,
50+
annotation_db,
4751
}
4852
}
4953

@@ -69,6 +73,7 @@ impl ParsedDocument {
6973
self.cst_db.remove_statement(s);
7074
self.ast_db.clear_statement(s);
7175
self.sql_fn_db.clear_statement(s);
76+
self.annotation_db.clear_statement(s);
7277
}
7378
StatementChange::Modified(s) => {
7479
tracing::debug!(
@@ -84,6 +89,7 @@ impl ParsedDocument {
8489
self.cst_db.modify_statement(s);
8590
self.ast_db.clear_statement(&s.old_stmt);
8691
self.sql_fn_db.clear_statement(&s.old_stmt);
92+
self.annotation_db.clear_statement(&s.old_stmt);
8793
}
8894
}
8995
}

crates/pgt_workspace/src/workspace/server/statement_identifier.rs

+11-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,17 @@ pub struct RootId {
77
}
88

99
#[cfg(test)]
10-
pub fn root_id(inner: usize) -> RootId {
11-
RootId { inner }
10+
impl From<RootId> for usize {
11+
fn from(val: RootId) -> Self {
12+
val.inner
13+
}
14+
}
15+
16+
#[cfg(test)]
17+
impl From<usize> for RootId {
18+
fn from(inner: usize) -> Self {
19+
RootId { inner }
20+
}
1221
}
1322

1423
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]

0 commit comments

Comments
 (0)