use std::collections::HashMap;
use crate::{ast::{ASTType, AST}, ast_from_ast, collect_lines, error::{ErrorKind, KabelError}, extension::Extension, out_of_scope, out_of_scope_var};
pub struct Resolver {
text: Vec<String>,
symbol_table: Vec<HashMap<String, (Symbol, usize)>>, // (Symbol, reference to locals)
pub locals: Vec<usize>, // reference to stack
pub errors: Vec<KabelError>,
}
impl Resolver {
pub fn new(text: String) -> Self {
Self {
text: text.lines().collect::<Vec<&str>>().iter().map(|s| s.to_string()).collect(),
symbol_table: vec![HashMap::new()],
locals: Vec::new(),
errors: Vec::new(),
}
}
pub fn visit(&mut self, ast: AST) -> AST {
use ASTType::*;
match ast.ast_type {
Program(asts) => {
let mut program = Vec::new();
for ast in asts {
let ast = self.visit(ast.clone());
program.push(ast)
}
AST {
ast_type: ASTType::Program(program),
extensions: Vec::new(),
start_line: 0,
end_line: 0,
start_column: 0,
end_column: 0,
}
}
Function(name, args, block) => {
self.symbol_table.last_mut().unwrap().insert(name.name.clone(), (Symbol::Function(args.len()),0));
self.symbol_table.push(HashMap::new());
for arg in args.clone() {
self.symbol_table.last_mut().unwrap().insert(arg.name, (Symbol::Var,0));
}
let block = self.visit(*block);
self.symbol_table.pop();
ast_from_ast!(Function(name, args, Box::new(block)), ast, ast)
}
Return(expr) => {
if let Some(expr) = *expr {
let expr = self.visit(expr);
return ast_from_ast!(Return(Box::new(Some(expr))), ast, ast);
}
ast_from_ast!(Return(Box::new(None)), ast, ast)
}
Loop(block) => {
let block = self.visit(*block);
ast_from_ast!(Loop(Box::new(block)), ast, ast)
}
While(condition, block) => {
let condition = self.visit(*condition);
let block = self.visit(*block);
ast_from_ast!(While(Box::new(condition), Box::new(block)), ast, ast)
}
For(expr1, expr2, expr3, block) => {
let mut n_expr1 = None;
let mut n_expr2 = None;
let mut n_expr3 = None;
if let Some(expr) = *expr1 {
n_expr1 = Some(self.visit(expr));
}
if let Some(expr) = *expr2 {
n_expr2 = Some(self.visit(expr));
}
if let Some(expr) = *expr3 {
n_expr3 = Some(self.visit(expr));
}
let block = self.visit(*block);
ast_from_ast!(For(Box::new(n_expr1), Box::new(n_expr2), Box::new(n_expr3), Box::new(block)), ast, ast)
}
If(condition, block, else_expr) => {
let condition = self.visit(*condition);
let block = self.visit(*block);
let mut n_else_expr = None;
if let Some(else_expr) = *else_expr {
n_else_expr = Some(self.visit(else_expr));
}
ast_from_ast!(If(Box::new(condition), Box::new(block), Box::new(n_else_expr)), ast, ast)
}
Block(stmts) => {
self.symbol_table.push(HashMap::new());
let mut n_stmts = Vec::new();
for stmt in stmts {
n_stmts.push(self.visit(stmt));
}
self.symbol_table.pop();
ast_from_ast!(Block(n_stmts), ast, ast)
}
Decl(name, expr) => {
let expr = self.visit(*expr);
self.locals.push(0);
self.symbol_table.last_mut().unwrap().insert(name.name.clone(), (Symbol::Var, self.locals.len()-1));
AST {
ast_type: Decl(name, Box::new(expr)),
extensions: vec![Extension::Resolution(self.locals.len()-1)],
start_line: ast.start_line,
end_line: ast.end_line,
start_column: ast.start_column,
end_column: ast.end_column,
}
}
Expr(expr) => {
let expr = self.visit(*expr);
ast_from_ast!(Expr(Box::new(expr)), ast, ast)
}
// REMOVE LATER
Print(expr) => {
let expr = self.visit(*expr);
ast_from_ast!(Print(Box::new(expr)), ast, ast)
}
Assign(name, expr) => {
let expr = self.visit(*expr);
if !self.symbol_table.last().unwrap().contains_key(&name.name) {
self.errors.push(out_of_scope_var!(self, "Variable \"{}\" not in scope", name, ast));
}
ast_from_ast!(Assign(name, Box::new(expr)), ast, ast)
}
Ternary(condition, true_expr, false_expr) => {
let condition = self.visit(*condition);
let true_expr = self.visit(*true_expr);
let false_expr = self.visit(*false_expr);
ast_from_ast!(Ternary(Box::new(condition), Box::new(true_expr), Box::new(false_expr)), ast, ast)
}
Subscript(array, index) => {
let array = self.visit(*array);
let index = self.visit(*index);
ast_from_ast!(Subscript(Box::new(array), Box::new(index)), ast, ast)
}
Binary(left, oper, right) => {
let left = self.visit(*left);
let right = self.visit(*right);
ast_from_ast!(Binary(Box::new(left), oper, Box::new(right)), ast, ast)
}
Unary(oper, right) => {
let right = self.visit(*right);
ast_from_ast!(Unary(oper, Box::new(right)), ast, ast)
}
Lit(ref lit) => {
let lit = lit.clone();
match lit {
crate::ast::Lit::Ident(ref name) => {
let resolution = self.resolve_var(name);
if !resolution.0 {
self.errors.push(out_of_scope!(self, "Variable \"{}\" not in scope", name, ast))
} else {
return AST {
ast_type: Lit(lit),
extensions: vec![Extension::Resolution(resolution.1)],
start_line: ast.start_line,
end_line: ast.end_line,
start_column: ast.start_column,
end_column: ast.end_column,
};
}
}
_ => {}
}
ast_from_ast!(Lit(lit), ast, ast)
}
Call(ident, args) => {
if let Err(e) = self.resolve_function(&ident.name, args.len()) {
match e {
(ErrorKind::OutOfScope, _, _) => self.errors.push(out_of_scope!(self, "Function \"{}\" not in scope", ident.name, ast)),
(ErrorKind::IncorrectArity, Some(f_arity), Some(arity)) => {
self.errors.push(
KabelError::new(
ErrorKind::IncorrectArity,
format!("Function {} has {} argument, provided {}", ident.name, f_arity, arity),
ast.start_line,
ast.start_column,
collect_lines!(self.text[ast.start_line-1..ast.end_line-1]),
)
);
}
_ => { panic!("Returned invalid ErrorKind from resolve_function") },
}
}
let mut n_args = Vec::new();
for arg in args {
n_args.push(self.visit(arg));
}
ast_from_ast!(Call(ident, n_args), ast, ast)
}
/*Member(left, right) => {
self.visit_member(*left, *right);
}*/
_ => { panic!("not implemented") } // not implemented
}
}
// TODO: make visit_member not throw out of scope errors
/*pub fn visit_member(&mut self, left: AST, right: AST) {
self.visit(left);
self.visit(right);
}*/
fn resolve_var(&self, name: &String) -> (bool, usize) {
for scope in self.symbol_table.iter().rev() {
if let Some((Symbol::Var, place)) = scope.get(name) {
return (true, *place);
}
}
(false, 0)
}
fn resolve_function(&mut self, name: &String, arity: usize) -> Result<(), (ErrorKind, Option<usize>, Option<usize>)>{
for scope in self.symbol_table.iter().rev() {
if let Some((Symbol::Function(f_arity), _place)) = scope.get(name) {
if *f_arity == arity {
return Ok(());
} else {
return Err((ErrorKind::IncorrectArity, Some(*f_arity), Some(arity)));
}
}
}
Err((ErrorKind::OutOfScope, None, None))
}
}
pub enum Symbol {
Var,
Function(usize),
}