Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] remove layer of indirection from SymbolTable #11132

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 63 additions & 64 deletions crates/red_knot/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,41 +101,30 @@ pub(crate) struct ImportFromDefinition {
level: Option<u32>,
}

/// Table of all symbols in all scopes for a module.
#[derive(Debug)]
pub(crate) struct Symbols {
pub(crate) table: SymbolTable,
pub(crate) defs: FxHashMap<SymbolId, Vec<Definition>>,
pub(crate) struct SymbolTable {
scopes_by_id: IndexVec<ScopeId, Scope>,
symbols_by_id: IndexVec<SymbolId, Symbol>,
defs: FxHashMap<SymbolId, Vec<Definition>>,
}

impl Symbols {
impl SymbolTable {
pub(crate) fn from_ast(module: &ast::ModModule) -> Self {
let symbols = Symbols {
table: SymbolTable::new(),
defs: FxHashMap::default(),
};
let root_scope_id = SymbolTable::root_scope_id();
let mut builder = SymbolsBuilder {
symbols,
let mut builder = SymbolTableBuilder {
table: SymbolTable::new(),
scopes: vec![root_scope_id],
};
builder.visit_body(&module.body);
builder.symbols
builder.table
}
}

/// Table of all symbols in all scopes for a module.
/// Derived from module AST, but holds no references to it.
#[derive(Debug)]
pub(crate) struct SymbolTable {
scopes_by_id: IndexVec<ScopeId, Scope>,
symbols_by_id: IndexVec<SymbolId, Symbol>,
}

impl SymbolTable {
pub(crate) fn new() -> Self {
let mut table = SymbolTable {
scopes_by_id: IndexVec::new(),
symbols_by_id: IndexVec::new(),
defs: FxHashMap::default(),
};
table.scopes_by_id.push(Scope {
name: Name::new("<module>"),
Expand Down Expand Up @@ -218,7 +207,14 @@ impl SymbolTable {
self.symbol_by_name(SymbolTable::root_scope_id(), name)
}

pub(crate) fn add_symbol_to_scope(&mut self, scope_id: ScopeId, name: &str) -> SymbolId {
pub(crate) fn defs(&self, symbol_id: SymbolId) -> &[Definition] {
self.defs
.get(&symbol_id)
.map(|defs| defs.as_slice())
.unwrap_or_default()
}

fn add_symbol_to_scope(&mut self, scope_id: ScopeId, name: &str) -> SymbolId {
let hash = SymbolTable::hash_name(name);
let scope = &mut self.scopes_by_id[scope_id];
let name = Name::new(name);
Expand All @@ -238,7 +234,7 @@ impl SymbolTable {
}
}

pub(crate) fn add_child_scope(
fn add_child_scope(
&mut self,
parent_scope_id: ScopeId,
name: &str,
Expand Down Expand Up @@ -331,21 +327,19 @@ where
}
}

struct SymbolsBuilder {
symbols: Symbols,
struct SymbolTableBuilder {
table: SymbolTable,
scopes: Vec<ScopeId>,
}

impl SymbolsBuilder {
impl SymbolTableBuilder {
fn add_symbol(&mut self, identifier: &str) -> SymbolId {
self.symbols
.table
.add_symbol_to_scope(self.cur_scope(), identifier)
self.table.add_symbol_to_scope(self.cur_scope(), identifier)
}

fn add_symbol_with_def(&mut self, identifier: &str, definition: Definition) -> SymbolId {
let symbol_id = self.add_symbol(identifier);
self.symbols
self.table
.defs
.entry(symbol_id)
.or_default()
Expand All @@ -354,7 +348,7 @@ impl SymbolsBuilder {
}

fn push_scope(&mut self, child_of: ScopeId, name: &str, kind: ScopeKind) -> ScopeId {
let scope_id = self.symbols.table.add_child_scope(child_of, name, kind);
let scope_id = self.table.add_child_scope(child_of, name, kind);
self.scopes.push(scope_id);
scope_id
}
Expand Down Expand Up @@ -396,7 +390,7 @@ impl SymbolsBuilder {
}
}

impl PreorderVisitor<'_> for SymbolsBuilder {
impl PreorderVisitor<'_> for SymbolTableBuilder {
fn visit_expr(&mut self, expr: &ast::Expr) {
if let ast::Expr::Name(ast::ExprName { id, .. }) = expr {
self.add_symbol(id);
Expand Down Expand Up @@ -472,7 +466,7 @@ mod tests {
use crate::parse::Parsed;
use crate::symbols::ScopeKind;

use super::{SymbolId, SymbolIterator, SymbolTable, Symbols};
use super::{SymbolId, SymbolIterator, SymbolTable};

mod from_ast {
use super::*;
Expand All @@ -493,61 +487,65 @@ mod tests {
#[test]
fn empty() {
let parsed = parse("");
let table = Symbols::from_ast(parsed.ast()).table;
let table = SymbolTable::from_ast(parsed.ast());
assert_eq!(names(table.root_symbols()).len(), 0);
}

#[test]
fn simple() {
let parsed = parse("x");
let syms = Symbols::from_ast(parsed.ast());
assert_eq!(names(syms.table.root_symbols()), vec!["x"]);
assert!(syms
.defs
.get(&syms.table.root_symbol_id_by_name("x").unwrap())
.is_none());
let table = SymbolTable::from_ast(parsed.ast());
assert_eq!(names(table.root_symbols()), vec!["x"]);
assert_eq!(
table.defs(table.root_symbol_id_by_name("x").unwrap()).len(),
0
);
}

#[test]
fn annotation_only() {
let parsed = parse("x: int");
let table = Symbols::from_ast(parsed.ast()).table;
let table = SymbolTable::from_ast(parsed.ast());
assert_eq!(names(table.root_symbols()), vec!["int", "x"]);
// TODO record definition
}

#[test]
fn import() {
let parsed = parse("import foo");
let syms = Symbols::from_ast(parsed.ast());
assert_eq!(names(syms.table.root_symbols()), vec!["foo"]);
let table = SymbolTable::from_ast(parsed.ast());
assert_eq!(names(table.root_symbols()), vec!["foo"]);
assert_eq!(
syms.defs[&syms.table.root_symbol_id_by_name("foo").unwrap()].len(),
table
.defs(table.root_symbol_id_by_name("foo").unwrap())
.len(),
1
);
}

#[test]
fn import_sub() {
let parsed = parse("import foo.bar");
let table = Symbols::from_ast(parsed.ast()).table;
let table = SymbolTable::from_ast(parsed.ast());
assert_eq!(names(table.root_symbols()), vec!["foo"]);
}

#[test]
fn import_as() {
let parsed = parse("import foo.bar as baz");
let table = Symbols::from_ast(parsed.ast()).table;
let table = SymbolTable::from_ast(parsed.ast());
assert_eq!(names(table.root_symbols()), vec!["baz"]);
}

#[test]
fn import_from() {
let parsed = parse("from bar import foo");
let syms = Symbols::from_ast(parsed.ast());
assert_eq!(names(syms.table.root_symbols()), vec!["foo"]);
let table = SymbolTable::from_ast(parsed.ast());
assert_eq!(names(table.root_symbols()), vec!["foo"]);
assert_eq!(
syms.defs[&syms.table.root_symbol_id_by_name("foo").unwrap()].len(),
table
.defs(table.root_symbol_id_by_name("foo").unwrap())
.len(),
1
);
}
Expand All @@ -561,17 +559,16 @@ mod tests {
y = 2
",
);
let syms = Symbols::from_ast(parsed.ast());
let table = &syms.table;
let table = SymbolTable::from_ast(parsed.ast());
assert_eq!(names(table.root_symbols()), vec!["C", "y"]);
let scopes = table.root_child_scope_ids();
assert_eq!(scopes.len(), 1);
let c_scope = scopes[0].scope(table);
let c_scope = scopes[0].scope(&table);
assert_eq!(c_scope.kind(), ScopeKind::Class);
assert_eq!(c_scope.name(), "C");
assert_eq!(names(table.symbols_for_scope(scopes[0])), vec!["x"]);
assert_eq!(
syms.defs[&syms.table.root_symbol_id_by_name("C").unwrap()].len(),
table.defs(table.root_symbol_id_by_name("C").unwrap()).len(),
1
);
}
Expand All @@ -585,17 +582,18 @@ mod tests {
y = 2
",
);
let syms = Symbols::from_ast(parsed.ast());
let table = &syms.table;
let table = SymbolTable::from_ast(parsed.ast());
assert_eq!(names(table.root_symbols()), vec!["func", "y"]);
let scopes = table.root_child_scope_ids();
assert_eq!(scopes.len(), 1);
let func_scope = scopes[0].scope(table);
let func_scope = scopes[0].scope(&table);
assert_eq!(func_scope.kind(), ScopeKind::Function);
assert_eq!(func_scope.name(), "func");
assert_eq!(names(table.symbols_for_scope(scopes[0])), vec!["x"]);
assert_eq!(
syms.defs[&syms.table.root_symbol_id_by_name("func").unwrap()].len(),
table
.defs(table.root_symbol_id_by_name("func").unwrap())
.len(),
1
);
}
Expand All @@ -610,21 +608,22 @@ mod tests {
y = 2
",
);
let syms = Symbols::from_ast(parsed.ast());
let table = &syms.table;
let table = SymbolTable::from_ast(parsed.ast());
assert_eq!(names(table.root_symbols()), vec!["func"]);
let scopes = table.root_child_scope_ids();
assert_eq!(scopes.len(), 2);
let func_scope_1 = scopes[0].scope(table);
let func_scope_2 = scopes[1].scope(table);
let func_scope_1 = scopes[0].scope(&table);
let func_scope_2 = scopes[1].scope(&table);
assert_eq!(func_scope_1.kind(), ScopeKind::Function);
assert_eq!(func_scope_1.name(), "func");
assert_eq!(func_scope_2.kind(), ScopeKind::Function);
assert_eq!(func_scope_2.name(), "func");
assert_eq!(names(table.symbols_for_scope(scopes[0])), vec!["x"]);
assert_eq!(names(table.symbols_for_scope(scopes[1])), vec!["y"]);
assert_eq!(
syms.defs[&syms.table.root_symbol_id_by_name("func").unwrap()].len(),
table
.defs(table.root_symbol_id_by_name("func").unwrap())
.len(),
2
);
}
Expand All @@ -637,7 +636,7 @@ mod tests {
x = 1
",
);
let table = Symbols::from_ast(parsed.ast()).table;
let table = SymbolTable::from_ast(parsed.ast());
assert_eq!(names(table.root_symbols()), vec!["func"]);
let scopes = table.root_child_scope_ids();
assert_eq!(scopes.len(), 1);
Expand All @@ -663,7 +662,7 @@ mod tests {
x = 1
",
);
let table = Symbols::from_ast(parsed.ast()).table;
let table = SymbolTable::from_ast(parsed.ast());
assert_eq!(names(table.root_symbols()), vec!["C"]);
let scopes = table.root_child_scope_ids();
assert_eq!(scopes.len(), 1);
Expand Down
Loading