Skip to content

Commit 531aac3

Browse files
Decouple rule finding logic from LSP crate as requested
Co-authored-by: HerringtonDarkholme <[email protected]>
1 parent 683f20e commit 531aac3

File tree

4 files changed

+136
-90
lines changed

4 files changed

+136
-90
lines changed

crates/cli/src/config.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,21 @@ impl ProjectConfig {
8989
let global_rules = find_util_rules(self)?;
9090
read_directory_yaml(self, global_rules, rule_overwrite)
9191
}
92+
93+
/// Create a rule finding closure that can be used by LSP or other consumers
94+
/// This allows decoupling the rule finding logic from the specific consumers
95+
pub fn make_rule_finder<L>(self) -> impl Fn() -> anyhow::Result<RuleCollection<L>> + Send + Sync + 'static
96+
where
97+
L: ast_grep_core::Language + serde::de::DeserializeOwned + Clone + std::cmp::Eq + Send + Sync + 'static,
98+
{
99+
move || {
100+
let global_rules = find_util_rules_generic::<L>(&self.project_dir, &self.util_dirs)?;
101+
let configs = read_directory_yaml_generic::<L>(&self.project_dir, &self.rule_dirs, global_rules)?;
102+
let collection = RuleCollection::try_new(configs).context(EC::GlobPattern)?;
103+
Ok(collection)
104+
}
105+
}
106+
92107
/// returns a Result of Result.
93108
/// The inner Result is for configuration not found, or ProjectNotExist
94109
/// The outer Result is for definitely wrong config.
@@ -160,6 +175,39 @@ fn find_util_rules(config: &ProjectConfig) -> Result<GlobalRules> {
160175
Ok(ret)
161176
}
162177

178+
/// Generic version of find_util_rules that works with any language type
179+
fn find_util_rules_generic<L>(
180+
project_dir: &Path,
181+
util_dirs: &Option<Vec<PathBuf>>,
182+
) -> Result<GlobalRules>
183+
where
184+
L: ast_grep_core::Language + serde::de::DeserializeOwned + Clone + std::cmp::Eq + Send + Sync + 'static,
185+
{
186+
let Some(mut walker) = build_util_walker(project_dir, util_dirs) else {
187+
return Ok(GlobalRules::default());
188+
};
189+
let mut utils = vec![];
190+
let walker = walker.types(config_file_type()).build();
191+
for dir in walker {
192+
let config_file = dir.with_context(|| EC::WalkRuleDir(PathBuf::new()))?;
193+
// file_type is None only if it is stdin, safe to panic here
194+
if !config_file
195+
.file_type()
196+
.expect("file type should be available for non-stdin")
197+
.is_file()
198+
{
199+
continue;
200+
}
201+
let path = config_file.path();
202+
let file = read_to_string(path)?;
203+
let new_configs = from_str(&file)?;
204+
utils.push(new_configs);
205+
}
206+
207+
let ret = DeserializeEnv::<L>::parse_global_utils(utils).context(EC::InvalidGlobalUtils)?;
208+
Ok(ret)
209+
}
210+
163211
fn read_directory_yaml(
164212
config: &ProjectConfig,
165213
global_rules: GlobalRules,
@@ -204,6 +252,39 @@ fn read_directory_yaml(
204252
Ok((collection, trace))
205253
}
206254

255+
/// Generic version of read_directory_yaml that works with any language type
256+
fn read_directory_yaml_generic<L>(
257+
project_dir: &Path,
258+
rule_dirs: &[PathBuf],
259+
global_rules: GlobalRules,
260+
) -> Result<Vec<RuleConfig<L>>>
261+
where
262+
L: ast_grep_core::Language + serde::de::DeserializeOwned + Clone + std::cmp::Eq + Send + Sync + 'static,
263+
{
264+
let mut configs = vec![];
265+
for dir in rule_dirs {
266+
let dir_path = project_dir.join(dir);
267+
let walker = WalkBuilder::new(&dir_path)
268+
.types(config_file_type())
269+
.build();
270+
for dir in walker {
271+
let config_file = dir.with_context(|| EC::WalkRuleDir(dir_path.clone()))?;
272+
// file_type is None only if it is stdin, safe to panic here
273+
if !config_file
274+
.file_type()
275+
.expect("file type should be available for non-stdin")
276+
.is_file()
277+
{
278+
continue;
279+
}
280+
let path = config_file.path();
281+
let new_configs = read_rule_file_generic::<L>(path, Some(&global_rules))?;
282+
configs.extend(new_configs);
283+
}
284+
}
285+
Ok(configs)
286+
}
287+
207288
pub fn with_rule_stats(
208289
configs: Vec<RuleConfig<SgLang>>,
209290
) -> Result<(RuleCollection<SgLang>, RuleTrace)> {
@@ -231,6 +312,23 @@ pub fn read_rule_file(
231312
parsed.with_context(|| EC::ParseRule(path.to_path_buf()))
232313
}
233314

315+
/// Generic version of read_rule_file that works with any language type
316+
pub fn read_rule_file_generic<L>(
317+
path: &Path,
318+
global_rules: Option<&GlobalRules>,
319+
) -> Result<Vec<RuleConfig<L>>>
320+
where
321+
L: ast_grep_core::Language + serde::de::DeserializeOwned + Clone + std::cmp::Eq + Send + Sync + 'static,
322+
{
323+
let yaml = read_to_string(path).with_context(|| EC::ReadRule(path.to_path_buf()))?;
324+
let parsed = if let Some(globals) = global_rules {
325+
from_yaml_string(&yaml, globals)
326+
} else {
327+
from_yaml_string(&yaml, &Default::default())
328+
};
329+
parsed.with_context(|| EC::ParseRule(path.to_path_buf()))
330+
}
331+
234332
const CONFIG_FILE: &str = "sgconfig.yml";
235333

236334
/// return None if config file does not exist

crates/cli/src/lsp.rs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,14 @@ async fn run_language_server_impl(_arg: LspArg, project: Result<ProjectConfig>)
1313
let project_config = project?;
1414
let stdin = tokio::io::stdin();
1515
let stdout = tokio::io::stdout();
16-
let config_result = project_config.find_rules(Default::default());
17-
let config_result_std: std::result::Result<_, String> = config_result
18-
.map_err(|e| {
19-
// convert anyhow::Error to String with chain of causes
20-
e.chain()
21-
.map(|e| e.to_string())
22-
.collect::<Vec<_>>()
23-
.join(". ")
24-
})
25-
.map(|r| r.0);
26-
let config_base = project_config.project_dir;
16+
17+
let config_base = project_config.project_dir.clone();
18+
19+
// Create a rule finder closure that uses the CLI logic
20+
let rule_finder = project_config.make_rule_finder::<crate::lang::SgLang>();
21+
2722
let (service, socket) =
28-
LspService::build(|client| Backend::new(client, config_base, config_result_std)).finish();
23+
LspService::build(|client| Backend::new(client, config_base, rule_finder)).finish();
2924
Server::new(stdin, stdout, socket).serve(service).await;
3025
Ok(())
3126
}

crates/lsp/src/lib.rs

Lines changed: 14 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ pub struct Backend<L: LSPLang> {
4343
errors: Arc<RwLock<Option<String>>>,
4444
// interner for rule ids to note, to avoid duplication
4545
interner: DashMap<String, Arc<String>>,
46+
// rule finding closure to reload rules
47+
rule_finder: Box<dyn Fn() -> anyhow::Result<RuleCollection<L>> + Send + Sync>,
4648
}
4749

4850
const FALLBACK_CODE_ACTION_PROVIDER: Option<CodeActionProviderCapability> =
@@ -247,14 +249,17 @@ fn pos_tuple_to_range((line, character, end_line, end_character): (u32, u32, u32
247249
}
248250

249251
impl<L: LSPLang> Backend<L> {
250-
pub fn new(
252+
pub fn new<F>(
251253
client: Client,
252254
base: PathBuf,
253-
rules: std::result::Result<RuleCollection<L>, String>,
254-
) -> Self {
255-
let (rules, errors) = match rules {
255+
rule_finder: F,
256+
) -> Self
257+
where
258+
F: Fn() -> anyhow::Result<RuleCollection<L>> + Send + Sync + 'static,
259+
{
260+
let (rules, errors) = match rule_finder() {
256261
Ok(r) => (r, None),
257-
Err(e) => (RuleCollection::default(), Some(e)),
262+
Err(e) => (RuleCollection::default(), Some(e.to_string())),
258263
};
259264
Self {
260265
client,
@@ -263,6 +268,7 @@ impl<L: LSPLang> Backend<L> {
263268
map: DashMap::new(),
264269
errors: Arc::new(RwLock::new(errors)),
265270
interner: DashMap::new(),
271+
rule_finder: Box::new(rule_finder),
266272
}
267273
}
268274

@@ -657,8 +663,8 @@ impl<L: LSPLang> Backend<L> {
657663
.log_message(MessageType::INFO, "Starting rule reload...")
658664
.await;
659665

660-
// Try to reload rules from the file system
661-
let result = self.load_rules_from_filesystem().await;
666+
// Use the rule finder closure to reload rules
667+
let result = (self.rule_finder)();
662668

663669
match result {
664670
Ok(new_rules) => {
@@ -676,7 +682,7 @@ impl<L: LSPLang> Backend<L> {
676682

677683
self
678684
.client
679-
.log_message(MessageType::INFO, "Rules reloaded successfully from filesystem")
685+
.log_message(MessageType::INFO, "Rules reloaded successfully using CLI logic")
680686
.await;
681687
}
682688
Err(e) => {
@@ -702,64 +708,6 @@ impl<L: LSPLang> Backend<L> {
702708
Ok(())
703709
}
704710

705-
/// Load rules from the filesystem - simplified version of CLI config loading
706-
async fn load_rules_from_filesystem(&self) -> anyhow::Result<RuleCollection<L>> {
707-
use ast_grep_config::{from_yaml_string, GlobalRules};
708-
use ast_grep_language::config_file_type;
709-
use ignore::WalkBuilder;
710-
use std::fs::read_to_string;
711-
712-
// Look for sgconfig.yml in the base directory
713-
let config_path = self.base.join("sgconfig.yml");
714-
715-
let rule_dirs = if config_path.exists() {
716-
let config_content = read_to_string(&config_path)?;
717-
let config: serde_yaml::Value = serde_yaml::from_str(&config_content)?;
718-
719-
if let Some(rule_dirs) = config.get("ruleDirs").and_then(|v| v.as_sequence()) {
720-
rule_dirs
721-
.iter()
722-
.filter_map(|v| v.as_str())
723-
.map(|s| self.base.join(s))
724-
.collect::<Vec<_>>()
725-
} else {
726-
vec![self.base.join("rules")] // Default rules directory
727-
}
728-
} else {
729-
vec![self.base.join("rules")] // Default rules directory
730-
};
731-
732-
// Read all rule files
733-
let mut configs = Vec::new();
734-
let global_rules = GlobalRules::default(); // Simplified - no util rules for now
735-
736-
for rule_dir in rule_dirs {
737-
if !rule_dir.exists() {
738-
continue;
739-
}
740-
741-
let walker = WalkBuilder::new(&rule_dir)
742-
.types(config_file_type())
743-
.build();
744-
745-
for entry in walker {
746-
let entry = entry?;
747-
if !entry.file_type().unwrap().is_file() {
748-
continue;
749-
}
750-
751-
let path = entry.path();
752-
let yaml = read_to_string(path)?;
753-
let parsed = from_yaml_string(&yaml, &global_rules)?;
754-
configs.extend(parsed);
755-
}
756-
}
757-
758-
// Create the rule collection
759-
let collection = RuleCollection::try_new(configs)?;
760-
Ok(collection)
761-
}
762-
763711
/// Republish diagnostics for all currently open files
764712
async fn republish_all_diagnostics(&self) {
765713
// Get all currently open file URIs

crates/lsp/tests/basic.rs

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,13 @@ fn req_resp_should_work() {
5656
}
5757

5858
pub fn create_lsp() -> (DuplexStream, DuplexStream) {
59-
let globals = GlobalRules::default();
60-
let config: RuleConfig<SupportLang> = from_yaml_string(
61-
r"
59+
let base = Path::new("./").to_path_buf();
60+
61+
// Create a rule finder closure that builds the rule collection from scratch
62+
let rule_finder = move || {
63+
let globals = GlobalRules::default();
64+
let config: RuleConfig<SupportLang> = from_yaml_string(
65+
r"
6266
id: no-console-rule
6367
message: No console.log
6468
severity: warning
@@ -69,16 +73,17 @@ note: no console.log
6973
fix: |
7074
alert($$$A)
7175
",
72-
&globals,
73-
)
74-
.unwrap()
75-
.pop()
76-
.unwrap();
77-
let base = Path::new("./").to_path_buf();
78-
let rc: RuleCollection<SupportLang> = RuleCollection::try_new(vec![config]).unwrap();
79-
let rc_result: std::result::Result<_, String> = Ok(rc);
76+
&globals,
77+
)
78+
.unwrap()
79+
.pop()
80+
.unwrap();
81+
let rc: RuleCollection<SupportLang> = RuleCollection::try_new(vec![config]).unwrap();
82+
Ok(rc)
83+
};
84+
8085
let (service, socket) =
81-
LspService::build(|client| Backend::new(client, base, rc_result)).finish();
86+
LspService::build(|client| Backend::new(client, base, rule_finder)).finish();
8287
let (req_client, req_server) = duplex(1024);
8388
let (resp_server, resp_client) = duplex(1024);
8489

0 commit comments

Comments
 (0)