diff --git a/README.md b/README.md index afb65843..30c35ac8 100644 --- a/README.md +++ b/README.md @@ -368,16 +368,16 @@ program = SyntaxTree.parse("1 + 1") puts program.construct_keys # SyntaxTree::Program[ -# statements: SyntaxTree::Statements[ -# body: [ -# SyntaxTree::Binary[ -# left: SyntaxTree::Int[value: "1"], -# operator: :+, -# right: SyntaxTree::Int[value: "1"] -# ] -# ] -# ] -# ] +# statements: SyntaxTree::Statements[ +# body: [ +# SyntaxTree::Binary[ +# left: SyntaxTree::Int[value: "1"], +# operator: :+, +# right: SyntaxTree::Int[value: "1"] +# ] +# ] +# ] +# ] ``` ## Visitor @@ -447,6 +447,28 @@ end The visitor defined above will error out unless it's only visiting a `SyntaxTree::Int` node. This is useful in a couple of ways, e.g., if you're trying to define a visitor to handle the whole tree but it's currently a work-in-progress. +### WithEnvironment + +The `WithEnvironment` module can be included in visitors to automatically keep track of local variables and arguments +defined inside each environment. A `current_environment` accessor is made availble to the request, allowing it to find +all usages and definitions of a local. + +```ruby +class MyVisitor < Visitor + include WithEnvironment + + def visit_ident(node) + # find_local will return a Local for any local variables or arguments present in the current environment or nil if + # the identifier is not a local + local = current_environment.find_local(node) + + puts local.type # print the type of the local (:variable or :argument) + puts local.definitions # print the array of locations where this local is defined + puts local.usages # print the array of locations where this local occurs + end +end +``` + ## Language server Syntax Tree additionally ships with a language server conforming to the [language server protocol](https://microsoft.github.io/language-server-protocol/). It can be invoked through the CLI by running: diff --git a/lib/syntax_tree.rb b/lib/syntax_tree.rb index 88c66369..7861ddcd 100644 --- a/lib/syntax_tree.rb +++ b/lib/syntax_tree.rb @@ -19,6 +19,8 @@ require_relative "syntax_tree/visitor/json_visitor" require_relative "syntax_tree/visitor/match_visitor" require_relative "syntax_tree/visitor/pretty_print_visitor" +require_relative "syntax_tree/visitor/environment" +require_relative "syntax_tree/visitor/with_environment" # Syntax Tree is a suite of tools built on top of the internal CRuby parser. It # provides the ability to generate a syntax tree from source, as well as the diff --git a/lib/syntax_tree/visitor/environment.rb b/lib/syntax_tree/visitor/environment.rb new file mode 100644 index 00000000..dfcf0a80 --- /dev/null +++ b/lib/syntax_tree/visitor/environment.rb @@ -0,0 +1,81 @@ +# frozen_string_literal: true + +module SyntaxTree + # The environment class is used to keep track of local variables and arguments + # inside a particular scope + class Environment + # [Array[Local]] The local variables and arguments defined in this + # environment + attr_reader :locals + + # This class tracks the occurrences of a local variable or argument + class Local + # [Symbol] The type of the local (e.g. :argument, :variable) + attr_reader :type + + # [Array[Location]] The locations of all definitions and assignments of + # this local + attr_reader :definitions + + # [Array[Location]] The locations of all usages of this local + attr_reader :usages + + # initialize: (Symbol type) -> void + def initialize(type) + @type = type + @definitions = [] + @usages = [] + end + + # add_definition: (Location location) -> void + def add_definition(location) + @definitions << location + end + + # add_usage: (Location location) -> void + def add_usage(location) + @usages << location + end + end + + # initialize: (Environment | nil parent) -> void + def initialize(parent = nil) + @locals = {} + @parent = parent + end + + # Adding a local definition will either insert a new entry in the locals + # hash or append a new definition location to an existing local. Notice that + # it's not possible to change the type of a local after it has been + # registered + # add_local_definition: (Ident | Label identifier, Symbol type) -> void + def add_local_definition(identifier, type) + name = identifier.value.delete_suffix(":") + + @locals[name] ||= Local.new(type) + @locals[name].add_definition(identifier.location) + end + + # Adding a local usage will either insert a new entry in the locals + # hash or append a new usage location to an existing local. Notice that + # it's not possible to change the type of a local after it has been + # registered + # add_local_usage: (Ident | Label identifier, Symbol type) -> void + def add_local_usage(identifier, type) + name = identifier.value.delete_suffix(":") + + @locals[name] ||= Local.new(type) + @locals[name].add_usage(identifier.location) + end + + # Try to find the local given its name in this environment or any of its + # parents + # find_local: (String name) -> Local | nil + def find_local(name) + local = @locals[name] + return local unless local.nil? + + @parent&.find_local(name) + end + end +end diff --git a/lib/syntax_tree/visitor/with_environment.rb b/lib/syntax_tree/visitor/with_environment.rb new file mode 100644 index 00000000..62e59c98 --- /dev/null +++ b/lib/syntax_tree/visitor/with_environment.rb @@ -0,0 +1,141 @@ +# frozen_string_literal: true + +module SyntaxTree + # WithEnvironment is a module intended to be included in classes inheriting + # from Visitor. The module overrides a few visit methods to automatically keep + # track of local variables and arguments defined in the current environment. + # Example usage: + # class MyVisitor < Visitor + # include WithEnvironment + # + # def visit_ident(node) + # # Check if we're visiting an identifier for an argument, a local + # variable or something else + # local = current_environment.find_local(node) + # + # if local.type == :argument + # # handle identifiers for arguments + # elsif local.type == :variable + # # handle identifiers for variables + # else + # # handle other identifiers, such as method names + # end + # end + module WithEnvironment + def current_environment + @current_environment ||= Environment.new + end + + def with_new_environment + previous_environment = @current_environment + @current_environment = Environment.new(previous_environment) + yield + ensure + @current_environment = previous_environment + end + + # Visits for nodes that create new environments, such as classes, modules + # and method definitions + def visit_class(node) + with_new_environment { super } + end + + def visit_module(node) + with_new_environment { super } + end + + def visit_method_add_block(node) + with_new_environment { super } + end + + def visit_def(node) + with_new_environment { super } + end + + def visit_defs(node) + with_new_environment { super } + end + + def visit_def_endless(node) + with_new_environment { super } + end + + # Visit for keeping track of local arguments, such as method and block + # arguments + def visit_params(node) + node.requireds.each do |param| + @current_environment.add_local_definition(param, :argument) + end + + node.posts.each do |param| + @current_environment.add_local_definition(param, :argument) + end + + node.keywords.each do |param| + @current_environment.add_local_definition(param.first, :argument) + end + + node.optionals.each do |param| + @current_environment.add_local_definition(param.first, :argument) + end + + super + end + + def visit_rest_param(node) + name = node.name + @current_environment.add_local_definition(name, :argument) if name + + super + end + + def visit_kwrest_param(node) + name = node.name + @current_environment.add_local_definition(name, :argument) if name + + super + end + + def visit_blockarg(node) + name = node.name + @current_environment.add_local_definition(name, :argument) if name + + super + end + + # Visit for keeping track of local variable definitions + def visit_var_field(node) + value = node.value + + if value.is_a?(SyntaxTree::Ident) + @current_environment.add_local_definition(value, :variable) + end + + super + end + + alias visit_pinned_var_ref visit_var_field + + # Visits for keeping track of variable and argument usages + def visit_aref_field(node) + name = node.collection.value + @current_environment.add_local_usage(name, :variable) if name + + super + end + + def visit_var_ref(node) + value = node.value + + if value.is_a?(SyntaxTree::Ident) + definition = @current_environment.find_local(value.value) + + if definition + @current_environment.add_local_usage(value, definition.type) + end + end + + super + end + end +end diff --git a/test/visitor_with_environment_test.rb b/test/visitor_with_environment_test.rb new file mode 100644 index 00000000..915b2143 --- /dev/null +++ b/test/visitor_with_environment_test.rb @@ -0,0 +1,410 @@ +# frozen_string_literal: true + +require_relative "test_helper" + +module SyntaxTree + class VisitorWithEnvironmentTest < Minitest::Test + class Collector < Visitor + include WithEnvironment + + attr_reader :variables, :arguments + + def initialize + @variables = {} + @arguments = {} + end + + def visit_ident(node) + local = current_environment.find_local(node.value) + return unless local + + value = node.value.delete_suffix(":") + + case local.type + when :argument + @arguments[value] = local + when :variable + @variables[value] = local + end + end + + def visit_label(node) + value = node.value.delete_suffix(":") + local = current_environment.find_local(value) + return unless local + + @arguments[value] = node if local.type == :argument + end + end + + def test_collecting_simple_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo + a = 1 + a + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.variables.length) + + variable = visitor.variables["a"] + assert_equal(1, variable.definitions.length) + assert_equal(1, variable.usages.length) + + assert_equal(2, variable.definitions[0].start_line) + assert_equal(3, variable.usages[0].start_line) + end + + def test_collecting_aref_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo + a = [] + a[1] + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.variables.length) + + variable = visitor.variables["a"] + assert_equal(1, variable.definitions.length) + assert_equal(1, variable.usages.length) + + assert_equal(2, variable.definitions[0].start_line) + assert_equal(3, variable.usages[0].start_line) + end + + def test_collecting_multi_assign_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo + a, b = [1, 2] + puts a + puts b + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(2, visitor.variables.length) + + variable_a = visitor.variables["a"] + assert_equal(1, variable_a.definitions.length) + assert_equal(1, variable_a.usages.length) + + assert_equal(2, variable_a.definitions[0].start_line) + assert_equal(3, variable_a.usages[0].start_line) + + variable_b = visitor.variables["b"] + assert_equal(1, variable_b.definitions.length) + assert_equal(1, variable_b.usages.length) + + assert_equal(2, variable_b.definitions[0].start_line) + assert_equal(4, variable_b.usages[0].start_line) + end + + def test_collecting_pattern_matching_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo + case [1, 2] + in Integer => a, Integer + puts a + end + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + # There are two occurrences, one on line 3 for pinning and one on line 4 + # for reference + assert_equal(1, visitor.variables.length) + + variable = visitor.variables["a"] + + # Assignment a + assert_equal(3, variable.definitions[0].start_line) + assert_equal(4, variable.usages[0].start_line) + end + + def test_collecting_pinned_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo + a = 18 + case [1, 2] + in ^a, *rest + puts a + puts rest + end + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(2, visitor.variables.length) + + variable_a = visitor.variables["a"] + assert_equal(2, variable_a.definitions.length) + assert_equal(1, variable_a.usages.length) + + assert_equal(2, variable_a.definitions[0].start_line) + assert_equal(4, variable_a.definitions[1].start_line) + assert_equal(5, variable_a.usages[0].start_line) + + variable_rest = visitor.variables["rest"] + assert_equal(1, variable_rest.definitions.length) + assert_equal(4, variable_rest.definitions[0].start_line) + + # Rest is considered a vcall by the parser instead of a var_ref + # assert_equal(1, variable_rest.usages.length) + # assert_equal(6, variable_rest.usages[0].start_line) + end + + if RUBY_VERSION >= "3.1" + def test_collecting_one_line_pattern_matching_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo + [1] => a + puts a + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.variables.length) + + variable = visitor.variables["a"] + assert_equal(1, variable.definitions.length) + assert_equal(1, variable.usages.length) + + assert_equal(2, variable.definitions[0].start_line) + assert_equal(3, variable.usages[0].start_line) + end + + def test_collecting_endless_method_arguments + tree = SyntaxTree.parse(<<~RUBY) + def foo(a) = puts a + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.arguments.length) + + argument = visitor.arguments["a"] + assert_equal(1, argument.definitions.length) + assert_equal(1, argument.usages.length) + + assert_equal(1, argument.definitions[0].start_line) + assert_equal(1, argument.usages[0].start_line) + end + end + + def test_collecting_method_arguments + tree = SyntaxTree.parse(<<~RUBY) + def foo(a) + puts a + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.arguments.length) + + argument = visitor.arguments["a"] + assert_equal(1, argument.definitions.length) + assert_equal(1, argument.usages.length) + + assert_equal(1, argument.definitions[0].start_line) + assert_equal(2, argument.usages[0].start_line) + end + + def test_collecting_singleton_method_arguments + tree = SyntaxTree.parse(<<~RUBY) + def self.foo(a) + puts a + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.arguments.length) + + argument = visitor.arguments["a"] + assert_equal(1, argument.definitions.length) + assert_equal(1, argument.usages.length) + + assert_equal(1, argument.definitions[0].start_line) + assert_equal(2, argument.usages[0].start_line) + end + + def test_collecting_method_arguments_all_types + tree = SyntaxTree.parse(<<~RUBY) + def foo(a, b = 1, *c, d, e: 1, **f, &block) + puts a + puts b + puts c + puts d + puts e + puts f + block.call + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(7, visitor.arguments.length) + + argument_a = visitor.arguments["a"] + assert_equal(1, argument_a.definitions.length) + assert_equal(1, argument_a.usages.length) + assert_equal(1, argument_a.definitions[0].start_line) + assert_equal(2, argument_a.usages[0].start_line) + + argument_b = visitor.arguments["b"] + assert_equal(1, argument_b.definitions.length) + assert_equal(1, argument_b.usages.length) + assert_equal(1, argument_b.definitions[0].start_line) + assert_equal(3, argument_b.usages[0].start_line) + + argument_c = visitor.arguments["c"] + assert_equal(1, argument_c.definitions.length) + assert_equal(1, argument_c.usages.length) + assert_equal(1, argument_c.definitions[0].start_line) + assert_equal(4, argument_c.usages[0].start_line) + + argument_d = visitor.arguments["d"] + assert_equal(1, argument_d.definitions.length) + assert_equal(1, argument_d.usages.length) + assert_equal(1, argument_d.definitions[0].start_line) + assert_equal(5, argument_d.usages[0].start_line) + + argument_e = visitor.arguments["e"] + assert_equal(1, argument_e.definitions.length) + assert_equal(1, argument_e.usages.length) + assert_equal(1, argument_e.definitions[0].start_line) + assert_equal(6, argument_e.usages[0].start_line) + + argument_f = visitor.arguments["f"] + assert_equal(1, argument_f.definitions.length) + assert_equal(1, argument_f.usages.length) + assert_equal(1, argument_f.definitions[0].start_line) + assert_equal(7, argument_f.usages[0].start_line) + + argument_block = visitor.arguments["block"] + assert_equal(1, argument_block.definitions.length) + assert_equal(1, argument_block.usages.length) + assert_equal(1, argument_block.definitions[0].start_line) + assert_equal(8, argument_block.usages[0].start_line) + end + + def test_collecting_block_arguments + tree = SyntaxTree.parse(<<~RUBY) + def foo + [].each do |i| + puts i + end + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.arguments.length) + + argument = visitor.arguments["i"] + assert_equal(1, argument.definitions.length) + assert_equal(1, argument.usages.length) + assert_equal(2, argument.definitions[0].start_line) + assert_equal(3, argument.usages[0].start_line) + end + + def test_collecting_one_line_block_arguments + tree = SyntaxTree.parse(<<~RUBY) + def foo + [].each { |i| puts i } + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.arguments.length) + + argument = visitor.arguments["i"] + assert_equal(1, argument.definitions.length) + assert_equal(1, argument.usages.length) + assert_equal(2, argument.definitions[0].start_line) + assert_equal(2, argument.usages[0].start_line) + end + + def test_collecting_shadowed_block_arguments + tree = SyntaxTree.parse(<<~RUBY) + def foo + i = "something" + + [].each do |i| + puts i + end + + i + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.arguments.length) + assert_equal(1, visitor.variables.length) + + argument = visitor.arguments["i"] + assert_equal(1, argument.definitions.length) + assert_equal(1, argument.usages.length) + assert_equal(4, argument.definitions[0].start_line) + assert_equal(5, argument.usages[0].start_line) + + variable = visitor.variables["i"] + assert_equal(1, variable.definitions.length) + assert_equal(1, variable.usages.length) + assert_equal(2, variable.definitions[0].start_line) + assert_equal(8, variable.usages[0].start_line) + end + + def test_collecting_shadowed_local_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo(a) + puts a + a = 123 + a + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + # All occurrences are considered arguments, despite overriding the + # argument value + assert_equal(1, visitor.arguments.length) + assert_equal(0, visitor.variables.length) + + argument = visitor.arguments["a"] + assert_equal(2, argument.definitions.length) + assert_equal(2, argument.usages.length) + + assert_equal(1, argument.definitions[0].start_line) + assert_equal(3, argument.definitions[1].start_line) + assert_equal(2, argument.usages[0].start_line) + assert_equal(4, argument.usages[1].start_line) + end + end +end