use std::collections::HashMap; use crate::{ast::{ASTType, AST}, ast_from_ast, collect_lines, error::{ErrorKind, KabelError}, extension::Extension, out_of_scope, out_of_scope_var}; pub struct Resolver { text: Vec, symbol_table: Vec>, // (Symbol, reference to locals) pub locals: Vec, // scope pub scope: usize, pub errors: Vec, } impl Resolver { pub fn new(text: String) -> Self { Self { text: text.lines().collect::>().iter().map(|s| s.to_string()).collect(), symbol_table: vec![HashMap::new()], locals: Vec::new(), scope: 0, errors: Vec::new(), } } pub fn visit(&mut self, ast: AST) -> AST { use ASTType::*; match ast.ast_type { Program(asts) => { let mut program = Vec::new(); for ast in asts { let ast = self.visit(ast.clone()); program.push(ast) } AST { ast_type: ASTType::Program(program), extensions: Vec::new(), start_line: 0, end_line: 0, start_column: 0, end_column: 0, } } Function(name, args, block) => { self.symbol_table.last_mut().unwrap().insert(name.name.clone(), (Symbol::Function(args.len()),0)); self.symbol_table.push(HashMap::new()); for arg in args.clone() { self.symbol_table.last_mut().unwrap().insert(arg.name, (Symbol::Var,0)); } let block = self.visit(*block); self.symbol_table.pop(); ast_from_ast!(Function(name, args, Box::new(block)), ast, ast) } Return(expr) => { if let Some(expr) = *expr { let expr = self.visit(expr); return ast_from_ast!(Return(Box::new(Some(expr))), ast, ast); } ast_from_ast!(Return(Box::new(None)), ast, ast) } Loop(block) => { let block = self.visit(*block); ast_from_ast!(Loop(Box::new(block)), ast, ast) } While(condition, block) => { let condition = self.visit(*condition); let block = self.visit(*block); ast_from_ast!(While(Box::new(condition), Box::new(block)), ast, ast) } Break => { ast_from_ast!(Break, ast, ast) } Continue => { ast_from_ast!(Continue, ast, ast) } For(expr1, expr2, expr3, block) => { let mut n_expr1 = None; let mut n_expr2 = None; let mut n_expr3 = None; if let Some(expr) = *expr1 { n_expr1 = Some(self.visit(expr)); } if let Some(expr) = *expr2 { n_expr2 = Some(self.visit(expr)); } if let Some(expr) = *expr3 { n_expr3 = Some(self.visit(expr)); } let block = self.visit(*block); ast_from_ast!(For(Box::new(n_expr1), Box::new(n_expr2), Box::new(n_expr3), Box::new(block)), ast, ast) } If(condition, block, else_expr) => { let condition = self.visit(*condition); let block = self.visit(*block); let mut n_else_expr = None; if let Some(else_expr) = *else_expr { n_else_expr = Some(self.visit(else_expr)); } ast_from_ast!(If(Box::new(condition), Box::new(block), Box::new(n_else_expr)), ast, ast) } Block(stmts) => { self.symbol_table.push(HashMap::new()); self.scope += 1; let mut n_stmts = Vec::new(); for stmt in stmts { n_stmts.push(self.visit(stmt)); } /*for (index, scope) in self.locals.clone().iter().enumerate() { if self.scope == *scope { self.locals.remove(index); } }*/ while let Some(scope) = self.locals.last() { if self.scope == *scope { self.locals.pop(); } else { break; } } self.scope -= 1; self.symbol_table.pop(); ast_from_ast!(Block(n_stmts), ast, ast) } Decl(name, expr) => { let expr = self.visit(*expr); self.locals.push(self.scope); self.symbol_table.last_mut().unwrap().insert(name.name.clone(), (Symbol::Var, self.locals.len()-1)); AST { ast_type: Decl(name, Box::new(expr)), extensions: vec![Extension::Resolution(self.scope, self.locals.len()-1)], start_line: ast.start_line, end_line: ast.end_line, start_column: ast.start_column, end_column: ast.end_column, } } Expr(expr) => { let expr = self.visit(*expr); ast_from_ast!(Expr(Box::new(expr)), ast, ast) } // REMOVE LATER Print(expr) => { let expr = self.visit(*expr); ast_from_ast!(Print(Box::new(expr)), ast, ast) } Assign(name, expr) => { let expr = self.visit(*expr); let resolution = self.resolve_var(&name.name); if !resolution.0 { self.errors.push(out_of_scope_var!(self, "Variable \"{}\" not in scope", name, ast)); } AST { ast_type: Assign(name, Box::new(expr)), extensions: vec![Extension::Resolution(self.scope, resolution.1)], start_line: ast.start_line, end_line: ast.end_line, start_column: ast.start_column, end_column: ast.end_column, } } Ternary(condition, true_expr, false_expr) => { let condition = self.visit(*condition); let true_expr = self.visit(*true_expr); let false_expr = self.visit(*false_expr); ast_from_ast!(Ternary(Box::new(condition), Box::new(true_expr), Box::new(false_expr)), ast, ast) } Subscript(array, index) => { let array = self.visit(*array); let index = self.visit(*index); ast_from_ast!(Subscript(Box::new(array), Box::new(index)), ast, ast) } Binary(left, oper, right) => { let left = self.visit(*left); let right = self.visit(*right); ast_from_ast!(Binary(Box::new(left), oper, Box::new(right)), ast, ast) } Unary(oper, right) => { let right = self.visit(*right); ast_from_ast!(Unary(oper, Box::new(right)), ast, ast) } Lit(ref lit) => { let lit = lit.clone(); match lit { crate::ast::Lit::Ident(ref name) => { let resolution = self.resolve_var(name); if !resolution.0 { self.errors.push(out_of_scope!(self, "Variable \"{}\" not in scope", name, ast)) } else { return AST { ast_type: Lit(lit), extensions: vec![Extension::Resolution(self.scope, resolution.1)], start_line: ast.start_line, end_line: ast.end_line, start_column: ast.start_column, end_column: ast.end_column, }; } } _ => {} } ast_from_ast!(Lit(lit), ast, ast) } Call(ident, args) => { if let Err(e) = self.resolve_function(&ident.name, args.len()) { match e { (ErrorKind::OutOfScope, _, _) => self.errors.push(out_of_scope!(self, "Function \"{}\" not in scope", ident.name, ast)), (ErrorKind::IncorrectArity, Some(f_arity), Some(arity)) => { self.errors.push( KabelError::new( ErrorKind::IncorrectArity, format!("Function {} has {} argument, provided {}", ident.name, f_arity, arity), ast.start_line, ast.start_column, collect_lines!(self.text[ast.start_line-1..ast.end_line-1]), ) ); } _ => { panic!("Returned invalid ErrorKind from resolve_function") }, } } let mut n_args = Vec::new(); for arg in args { n_args.push(self.visit(arg)); } ast_from_ast!(Call(ident, n_args), ast, ast) } /*Member(left, right) => { self.visit_member(*left, *right); }*/ _ => { panic!("not implemented") } // not implemented } } // TODO: make visit_member not throw out of scope errors /*pub fn visit_member(&mut self, left: AST, right: AST) { self.visit(left); self.visit(right); }*/ fn resolve_var(&self, name: &String) -> (bool, usize) { for scope in self.symbol_table.iter().rev() { if let Some((Symbol::Var, place)) = scope.get(name) { return (true, *place); } } (false, 0) } fn resolve_function(&mut self, name: &String, arity: usize) -> Result<(), (ErrorKind, Option, Option)>{ for scope in self.symbol_table.iter().rev() { if let Some((Symbol::Function(f_arity), _place)) = scope.get(name) { if *f_arity == arity { return Ok(()); } else { return Err((ErrorKind::IncorrectArity, Some(*f_arity), Some(arity))); } } } Err((ErrorKind::OutOfScope, None, None)) } } #[derive(Debug, Clone, Copy)] pub enum Symbol { Var, Function(usize), }