From adce3672f89ec63cc00a5b28256aeff44ba3a1a5 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 24 Apr 2024 09:57:19 -0600 Subject: [PATCH] simplify SymbolTable --- crates/red_knot/src/symbols.rs | 127 ++++++++++++++++----------------- 1 file changed, 63 insertions(+), 64 deletions(-) diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index 15a7656364eec9..10be2e48fd18bf 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -101,41 +101,30 @@ pub(crate) struct ImportFromDefinition { level: Option, } +/// Table of all symbols in all scopes for a module. #[derive(Debug)] -pub(crate) struct Symbols { - pub(crate) table: SymbolTable, - pub(crate) defs: FxHashMap>, +pub(crate) struct SymbolTable { + scopes_by_id: IndexVec, + symbols_by_id: IndexVec, + defs: FxHashMap>, } -impl Symbols { +impl SymbolTable { pub(crate) fn from_ast(module: &ast::ModModule) -> Self { - let symbols = Symbols { - table: SymbolTable::new(), - defs: FxHashMap::default(), - }; let root_scope_id = SymbolTable::root_scope_id(); - let mut builder = SymbolsBuilder { - symbols, + let mut builder = SymbolTableBuilder { + table: SymbolTable::new(), scopes: vec![root_scope_id], }; builder.visit_body(&module.body); - builder.symbols + builder.table } -} -/// Table of all symbols in all scopes for a module. -/// Derived from module AST, but holds no references to it. -#[derive(Debug)] -pub(crate) struct SymbolTable { - scopes_by_id: IndexVec, - symbols_by_id: IndexVec, -} - -impl SymbolTable { pub(crate) fn new() -> Self { let mut table = SymbolTable { scopes_by_id: IndexVec::new(), symbols_by_id: IndexVec::new(), + defs: FxHashMap::default(), }; table.scopes_by_id.push(Scope { name: Name::new(""), @@ -218,7 +207,14 @@ impl SymbolTable { self.symbol_by_name(SymbolTable::root_scope_id(), name) } - pub(crate) fn add_symbol_to_scope(&mut self, scope_id: ScopeId, name: &str) -> SymbolId { + pub(crate) fn defs(&self, symbol_id: SymbolId) -> &[Definition] { + self.defs + .get(&symbol_id) + .map(|defs| defs.as_slice()) + .unwrap_or_default() + } + + fn add_symbol_to_scope(&mut self, scope_id: ScopeId, name: &str) -> SymbolId { let hash = SymbolTable::hash_name(name); let scope = &mut self.scopes_by_id[scope_id]; let name = Name::new(name); @@ -238,7 +234,7 @@ impl SymbolTable { } } - pub(crate) fn add_child_scope( + fn add_child_scope( &mut self, parent_scope_id: ScopeId, name: &str, @@ -331,21 +327,19 @@ where } } -struct SymbolsBuilder { - symbols: Symbols, +struct SymbolTableBuilder { + table: SymbolTable, scopes: Vec, } -impl SymbolsBuilder { +impl SymbolTableBuilder { fn add_symbol(&mut self, identifier: &str) -> SymbolId { - self.symbols - .table - .add_symbol_to_scope(self.cur_scope(), identifier) + self.table.add_symbol_to_scope(self.cur_scope(), identifier) } fn add_symbol_with_def(&mut self, identifier: &str, definition: Definition) -> SymbolId { let symbol_id = self.add_symbol(identifier); - self.symbols + self.table .defs .entry(symbol_id) .or_default() @@ -354,7 +348,7 @@ impl SymbolsBuilder { } fn push_scope(&mut self, child_of: ScopeId, name: &str, kind: ScopeKind) -> ScopeId { - let scope_id = self.symbols.table.add_child_scope(child_of, name, kind); + let scope_id = self.table.add_child_scope(child_of, name, kind); self.scopes.push(scope_id); scope_id } @@ -396,7 +390,7 @@ impl SymbolsBuilder { } } -impl PreorderVisitor<'_> for SymbolsBuilder { +impl PreorderVisitor<'_> for SymbolTableBuilder { fn visit_expr(&mut self, expr: &ast::Expr) { if let ast::Expr::Name(ast::ExprName { id, .. }) = expr { self.add_symbol(id); @@ -472,7 +466,7 @@ mod tests { use crate::parse::Parsed; use crate::symbols::ScopeKind; - use super::{SymbolId, SymbolIterator, SymbolTable, Symbols}; + use super::{SymbolId, SymbolIterator, SymbolTable}; mod from_ast { use super::*; @@ -493,25 +487,25 @@ mod tests { #[test] fn empty() { let parsed = parse(""); - let table = Symbols::from_ast(parsed.ast()).table; + let table = SymbolTable::from_ast(parsed.ast()); assert_eq!(names(table.root_symbols()).len(), 0); } #[test] fn simple() { let parsed = parse("x"); - let syms = Symbols::from_ast(parsed.ast()); - assert_eq!(names(syms.table.root_symbols()), vec!["x"]); - assert!(syms - .defs - .get(&syms.table.root_symbol_id_by_name("x").unwrap()) - .is_none()); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["x"]); + assert_eq!( + table.defs(table.root_symbol_id_by_name("x").unwrap()).len(), + 0 + ); } #[test] fn annotation_only() { let parsed = parse("x: int"); - let table = Symbols::from_ast(parsed.ast()).table; + let table = SymbolTable::from_ast(parsed.ast()); assert_eq!(names(table.root_symbols()), vec!["int", "x"]); // TODO record definition } @@ -519,10 +513,12 @@ mod tests { #[test] fn import() { let parsed = parse("import foo"); - let syms = Symbols::from_ast(parsed.ast()); - assert_eq!(names(syms.table.root_symbols()), vec!["foo"]); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["foo"]); assert_eq!( - syms.defs[&syms.table.root_symbol_id_by_name("foo").unwrap()].len(), + table + .defs(table.root_symbol_id_by_name("foo").unwrap()) + .len(), 1 ); } @@ -530,24 +526,26 @@ mod tests { #[test] fn import_sub() { let parsed = parse("import foo.bar"); - let table = Symbols::from_ast(parsed.ast()).table; + let table = SymbolTable::from_ast(parsed.ast()); assert_eq!(names(table.root_symbols()), vec!["foo"]); } #[test] fn import_as() { let parsed = parse("import foo.bar as baz"); - let table = Symbols::from_ast(parsed.ast()).table; + let table = SymbolTable::from_ast(parsed.ast()); assert_eq!(names(table.root_symbols()), vec!["baz"]); } #[test] fn import_from() { let parsed = parse("from bar import foo"); - let syms = Symbols::from_ast(parsed.ast()); - assert_eq!(names(syms.table.root_symbols()), vec!["foo"]); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["foo"]); assert_eq!( - syms.defs[&syms.table.root_symbol_id_by_name("foo").unwrap()].len(), + table + .defs(table.root_symbol_id_by_name("foo").unwrap()) + .len(), 1 ); } @@ -561,17 +559,16 @@ mod tests { y = 2 ", ); - let syms = Symbols::from_ast(parsed.ast()); - let table = &syms.table; + let table = SymbolTable::from_ast(parsed.ast()); assert_eq!(names(table.root_symbols()), vec!["C", "y"]); let scopes = table.root_child_scope_ids(); assert_eq!(scopes.len(), 1); - let c_scope = scopes[0].scope(table); + let c_scope = scopes[0].scope(&table); assert_eq!(c_scope.kind(), ScopeKind::Class); assert_eq!(c_scope.name(), "C"); assert_eq!(names(table.symbols_for_scope(scopes[0])), vec!["x"]); assert_eq!( - syms.defs[&syms.table.root_symbol_id_by_name("C").unwrap()].len(), + table.defs(table.root_symbol_id_by_name("C").unwrap()).len(), 1 ); } @@ -585,17 +582,18 @@ mod tests { y = 2 ", ); - let syms = Symbols::from_ast(parsed.ast()); - let table = &syms.table; + let table = SymbolTable::from_ast(parsed.ast()); assert_eq!(names(table.root_symbols()), vec!["func", "y"]); let scopes = table.root_child_scope_ids(); assert_eq!(scopes.len(), 1); - let func_scope = scopes[0].scope(table); + let func_scope = scopes[0].scope(&table); assert_eq!(func_scope.kind(), ScopeKind::Function); assert_eq!(func_scope.name(), "func"); assert_eq!(names(table.symbols_for_scope(scopes[0])), vec!["x"]); assert_eq!( - syms.defs[&syms.table.root_symbol_id_by_name("func").unwrap()].len(), + table + .defs(table.root_symbol_id_by_name("func").unwrap()) + .len(), 1 ); } @@ -610,13 +608,12 @@ mod tests { y = 2 ", ); - let syms = Symbols::from_ast(parsed.ast()); - let table = &syms.table; + let table = SymbolTable::from_ast(parsed.ast()); assert_eq!(names(table.root_symbols()), vec!["func"]); let scopes = table.root_child_scope_ids(); assert_eq!(scopes.len(), 2); - let func_scope_1 = scopes[0].scope(table); - let func_scope_2 = scopes[1].scope(table); + let func_scope_1 = scopes[0].scope(&table); + let func_scope_2 = scopes[1].scope(&table); assert_eq!(func_scope_1.kind(), ScopeKind::Function); assert_eq!(func_scope_1.name(), "func"); assert_eq!(func_scope_2.kind(), ScopeKind::Function); @@ -624,7 +621,9 @@ mod tests { assert_eq!(names(table.symbols_for_scope(scopes[0])), vec!["x"]); assert_eq!(names(table.symbols_for_scope(scopes[1])), vec!["y"]); assert_eq!( - syms.defs[&syms.table.root_symbol_id_by_name("func").unwrap()].len(), + table + .defs(table.root_symbol_id_by_name("func").unwrap()) + .len(), 2 ); } @@ -637,7 +636,7 @@ mod tests { x = 1 ", ); - let table = Symbols::from_ast(parsed.ast()).table; + let table = SymbolTable::from_ast(parsed.ast()); assert_eq!(names(table.root_symbols()), vec!["func"]); let scopes = table.root_child_scope_ids(); assert_eq!(scopes.len(), 1); @@ -663,7 +662,7 @@ mod tests { x = 1 ", ); - let table = Symbols::from_ast(parsed.ast()).table; + let table = SymbolTable::from_ast(parsed.ast()); assert_eq!(names(table.root_symbols()), vec!["C"]); let scopes = table.root_child_scope_ids(); assert_eq!(scopes.len(), 1);