From c51de77269da8ee1a54902edadc874ebd71f8060 Mon Sep 17 00:00:00 2001 From: Michael Zhang Date: Wed, 6 Apr 2022 06:20:47 -0500 Subject: [PATCH] Giant overhaul, still doesn't work tho --- .gitignore | 2 + examples/conditions.e0 | 11 ++ src/ast/llvm.rs | 149 +++++++++++--- src/ast/mod.rs | 86 +++++--- src/ast/typed.rs | 437 +++++++++++++++++++++++++++++++++++++++++ src/main.rs | 11 +- src/parser.lalrpop | 46 ++++- src/utils.rs | 1 + 8 files changed, 679 insertions(+), 64 deletions(-) create mode 100644 examples/conditions.e0 create mode 100644 src/ast/typed.rs create mode 100644 src/utils.rs diff --git a/.gitignore b/.gitignore index 2359dec..c2dfea8 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ result* *.bc *.ll + +a.out diff --git a/examples/conditions.e0 b/examples/conditions.e0 new file mode 100644 index 0000000..1e0d918 --- /dev/null +++ b/examples/conditions.e0 @@ -0,0 +1,11 @@ +fn main() -> int { + let x = 5; + + if x < 3 { + return 3; + } else if x > 10 { + return 10; + } else { + return x; + } +} diff --git a/src/ast/llvm.rs b/src/ast/llvm.rs index f53c3ba..6f7f409 100644 --- a/src/ast/llvm.rs +++ b/src/ast/llvm.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, marker::PhantomData}; +use std::collections::HashMap; use anyhow::Result; use inkwell::{ @@ -6,31 +6,72 @@ use inkwell::{ context::Context, module::Module, types::{BasicMetadataTypeEnum, BasicTypeEnum, FunctionType}, - values::{AnyValueEnum, BasicValueEnum, PointerValue}, + values::{BasicValueEnum, FunctionValue, IntValue, PointerValue}, + IntPredicate, }; -use super::{Decl, Expr, Stmt, Type}; +use super::{Decl, ElseClause, Expr, ExprKind, IfElse, Op, Stmt, Type}; -impl Expr { +impl Expr { fn into_llvm<'ctx>( &self, context: &'ctx Context, - builder: &Builder, + builder: &'ctx Builder, env: &Env<'_, 'ctx>, ) -> Result> { - match self { - Expr::Var(name) => { + let int_ty = context.i64_type(); + let bool_ty = context.bool_type(); + + match &self.kind { + ExprKind::Var(name) => { let (_, value) = match env.lookup(&name) { Some(v) => v, None => bail!("Unbound name {name:?}"), }; - Ok(BasicValueEnum::PointerValue(*value)) + Ok(builder.build_load(*value, "")) } - Expr::Int(n) => Ok(BasicValueEnum::IntValue( - context.i64_type().const_int(*n as u64, false), - )), + ExprKind::Int(n) => { + Ok(BasicValueEnum::IntValue(int_ty.const_int(*n as u64, false))) + } + + ExprKind::BinOp(left, op, right) => { + if !op.check_types(&left.ty, &right.ty) { + // TODO: detailed error message + bail!("Invalid types on operation."); + } + + match op { + Op::LessThan => { + let left_val = + left.into_llvm(context, builder, env)?.into_int_value(); + let right_val = + right.into_llvm(context, builder, env)?.into_int_value(); + let result: IntValue = builder.build_int_compare( + IntPredicate::SLT, + left_val, + right_val, + "", + ); + Ok(BasicValueEnum::IntValue(result)) + } + + Op::GreaterThan => { + let left_val = + left.into_llvm(context, builder, env)?.into_int_value(); + let right_val = + right.into_llvm(context, builder, env)?.into_int_value(); + let result: IntValue = builder.build_int_compare( + IntPredicate::SGT, + left_val, + right_val, + "", + ); + Ok(BasicValueEnum::IntValue(result)) + } + } + } } } } @@ -42,6 +83,8 @@ impl Type { ) -> BasicTypeEnum<'ctx> { match self { Type::Int => BasicTypeEnum::IntType(context.i64_type()), + Type::Bool => BasicTypeEnum::IntType(context.bool_type()), + _ => panic!("Tried to convert a function type into a LLVM basic type"), } } } @@ -65,7 +108,7 @@ type EnvValue<'a, 'ctx> = (&'a Type, PointerValue<'ctx>); #[derive(Default)] struct Env<'a, 'ctx> { - parent: Option>>, + parent: Option<&'a Env<'a, 'ctx>>, local_type_map: HashMap>, } @@ -81,23 +124,72 @@ impl<'a, 'ctx> Env<'a, 'ctx> { } } -fn convert_lexical_block( +fn convert_if_else( context: &Context, + module: &Module, builder: &Builder, - parent_env: Env, - stmts: impl AsRef<[Stmt]>, + function: &FunctionValue, + if_else: &IfElse, + env: &Env, +) -> Result<()> { + let success_branch = context.append_basic_block(*function, "success"); + let fail_branch = context.append_basic_block(*function, "fail"); + let exit_block = context.append_basic_block(*function, "exit"); + + let cond = if_else + .cond + .into_llvm(context, builder, env)? + .into_int_value(); + builder.build_conditional_branch(cond, success_branch, fail_branch); + + // build success branch + builder.position_at_end(success_branch); + convert_stmts(context, module, function, builder, env, &if_else.body)?; + // builder.build_unconditional_branch(exit_branch); + + // build fail branch + builder.position_at_end(fail_branch); + match &if_else.else_clause { + Some(ElseClause::If(if_else2)) => { + println!("SUB-IF CLAUSE {if_else2:?}"); + convert_if_else(context, module, builder, function, if_else2, env)?; + } + Some(ElseClause::Body(body)) => { + println!("STMTS {body:?}"); + convert_stmts(context, module, function, builder, env, body)?; + } + None => { + println!("NOTHING"); + } + } + builder.build_unconditional_branch(exit_block); + + builder.position_at_end(exit_block); + + Ok(()) +} + +fn convert_stmts( + context: &Context, + module: &Module, + function: &FunctionValue, + builder: &Builder, + parent_env: &Env, + stmts: impl AsRef<[Stmt]>, ) -> Result<()> { let stmts = stmts.as_ref(); let mut scope_env = Env { - parent: Some(Box::new(parent_env)), + parent: Some(parent_env), ..Default::default() }; for stmt in stmts.iter() { match stmt { - Stmt::Let(name, ty, expr) => { + Stmt::Let(name, _, expr) => { + let ty = &expr.ty; let llvm_ty = ty.into_llvm_basic_type(context); - let alloca = builder.build_alloca(llvm_ty, name); + // Empty variable name gets LLVM to generate a unique name + let alloca = builder.build_alloca(llvm_ty, ""); let expr_val = expr.into_llvm(context, builder, &scope_env)?; builder.build_store(alloca, expr_val); @@ -114,17 +206,30 @@ fn convert_lexical_block( } println!("Emitted return."); } + + Stmt::IfElse(if_else) => { + convert_if_else( + context, module, builder, function, if_else, &scope_env, + )?; + } } } Ok(()) } -pub fn convert(context: &mut Context, program: Vec) -> Result { - let module = context.create_module("program"); +pub fn convert( + file_name: String, + context: &Context, + program: Vec>, +) -> Result { + let module = context.create_module(&file_name); let builder = context.create_builder(); - for func in program.iter().filter_map(Decl::unwrap_func) { + for func in program.iter().filter_map(|decl| match decl { + Decl::Func(v) => Some(v), + _ => None, + }) { let return_ty = func.return_ty.into_llvm_basic_type(context); let llvm_func_ty = fn_type_basic(return_ty, &[], false); @@ -134,7 +239,7 @@ pub fn convert(context: &mut Context, program: Vec) -> Result { builder.position_at_end(entry_block); let env = Env::default(); - convert_lexical_block(context, &builder, env, &func.stmts)?; + convert_stmts(&context, &module, &llvm_func, &builder, &env, &func.stmts)?; } Ok(module) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 053ae4a..15260bd 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,39 +1,69 @@ pub mod llvm; +pub mod typed; #[derive(Debug)] -pub enum Decl { - Func(Func), +pub enum Decl { + Func(Func), } -impl Decl { - pub fn unwrap_func(&self) -> Option<&Func> { - match self { - Decl::Func(func) => Some(func), - _ => None, - } +#[derive(Debug)] +pub struct Func { + pub name: String, + pub return_ty: Type, + pub stmts: Vec>, +} + +#[derive(Debug)] +pub enum Stmt { + Let(String, Option, Expr), + Return(Option>), + IfElse(IfElse), +} + +#[derive(Debug)] +pub struct IfElse { + pub cond: Expr, + pub body: Vec>, + pub else_clause: Option>, +} + +#[derive(Debug)] +pub enum ElseClause { + Body(Vec>), + If(Box>), +} + +#[derive(Debug)] +pub struct Expr { + pub kind: ExprKind, + pub ty: T, +} + +#[derive(Debug)] +pub enum ExprKind { + Int(i64), + Var(String), + BinOp(Box>, Op, Box>), +} + +#[derive(Copy, Clone, Debug)] +pub enum Op { + LessThan, + GreaterThan, +} + +impl Op { + pub fn check_types(&self, left_ty: &Type, right_ty: &Type) -> bool { + // TODO: since only binops work on integers right now, just check that the two sides are both + // integers. this will have to change once && gets added. + matches!(left_ty, Type::Int) && matches!(right_ty, Type::Int) } } -#[derive(Debug)] -pub struct Func { - pub name: String, - pub return_ty: Type, - pub stmts: Vec, -} - -#[derive(Debug)] -pub enum Stmt { - Let(String, Type, Expr), - Return(Option), -} - -#[derive(Debug)] -pub enum Expr { - Int(i64), - Var(String), -} - -#[derive(Debug)] +#[derive(Clone, Debug, Hash, Eq, PartialEq)] pub enum Type { Int, + Bool, + + Func(Vec, Box), } diff --git a/src/ast/typed.rs b/src/ast/typed.rs new file mode 100644 index 0000000..fff96ab --- /dev/null +++ b/src/ast/typed.rs @@ -0,0 +1,437 @@ +use std::collections::{HashMap, HashSet}; + +use anyhow::Result; + +use crate::ast::Expr; + +use super::{Decl, ElseClause, ExprKind, Func, IfElse, Op, Stmt, Type}; + +#[derive(Clone, Debug, Hash, Eq, PartialEq)] +pub enum Type_ { + Var(usize), + + Int, + Bool, + + Func(Vec, Box), +} + +impl Type_ { + fn from_type(ty: Type) -> Self { + match ty { + Type::Int => Type_::Int, + Type::Bool => Type_::Bool, + + Type::Func(args, ret) => { + let args = args.into_iter().map(|arg| Type_::from_type(arg)).collect(); + let ret = Type_::from_type(*ret); + Type_::Func(args, Box::new(ret)) + } + } + } + + fn convert(&self, assignments: &Assignments) -> Result { + Ok(match self { + Type_::Var(n) => match assignments.get(&n) { + Some(v) => v.convert(assignments)?, + None => bail!("Unsolved constraint variable {n}"), + }, + + Type_::Int => Type::Int, + Type_::Bool => Type::Bool, + + Type_::Func(args, ret) => { + let args = args + .into_iter() + .map(|arg| arg.convert(assignments)) + .collect::>()?; + let ret = ret.convert(assignments)?; + Type::Func(args, Box::new(ret)) + } + }) + } +} + +impl Op { + fn constraints( + &self, + ctx: &mut AnnotationContext, + left: &Type_, + right: &Type_, + output: &Type_, + ) { + match self { + Op::LessThan | Op::GreaterThan => { + ctx.constrain(left.clone(), Type_::Int); + ctx.constrain(right.clone(), Type_::Int); + ctx.constrain(output.clone(), Type_::Bool); + } + } + } +} + +#[derive(Debug, Hash, Eq, PartialEq)] +struct Constraint(Type_, Type_); + +#[derive(Default)] +struct Env { + parent: Option>, + local_type_map: HashMap, +} + +impl Env { + pub fn lookup(&self, name: impl AsRef) -> Option<&Type_> { + match self.local_type_map.get(name.as_ref()) { + Some(v) => Some(v), + None => match &self.parent { + Some(p) => p.lookup(name), + None => None, + }, + } + } +} + +struct AnnotationContext<'a> { + counter: usize, + constraints: &'a mut HashSet, + current_env: Option, +} + +impl<'a> AnnotationContext<'a> { + pub fn type_var(&mut self) -> Type_ { + Type_::Var(self.gen_int()) + } + + pub fn gen_int(&mut self) -> usize { + let id = self.counter; + self.counter += 1; + id + } + + pub fn constrain(&mut self, left: Type_, right: Type_) { + if left == right { + // No op, return now + return; + } + self.constraints.insert(Constraint(left, right)); + } + + pub fn lookup(&self, name: impl AsRef) -> Option<&Type_> { + self.current_env.as_ref().unwrap().lookup(name) + } + + pub fn define_var(&mut self, name: impl AsRef, ty: Type_) { + self + .current_env + .as_mut() + .unwrap() + .local_type_map + .insert(name.as_ref().to_string(), ty); + } + + pub fn push_scope(&mut self) { + self.current_env = Some(Env { + parent: Some(Box::new(self.current_env.take().unwrap())), + local_type_map: Default::default(), + }); + } + + pub fn pop_scope(&mut self) { + self.current_env = + Some(*self.current_env.take().unwrap().parent.take().unwrap()); + } +} + +fn annotate_stmts( + ctx: &mut AnnotationContext, + stmts: impl AsRef<[Stmt<()>]>, +) -> Vec> { + let stmts = stmts.as_ref(); + let mut new_stmts = Vec::new(); + ctx.push_scope(); + + for stmt in stmts.iter() { + match stmt { + Stmt::Return(ret_val) => { + let new_stmt = + Stmt::Return(ret_val.as_ref().map(|expr| annotate_expr(ctx, expr))); + new_stmts.push(new_stmt); + } + + Stmt::Let(name, ty, body) => { + let new_stmt = + Stmt::Let(name.clone(), ty.clone(), annotate_expr(ctx, body)); + let ty = match ty { + Some(v) => Type_::from_type(v.clone()), + None => ctx.type_var(), + }; + ctx.define_var(name, ty); + new_stmts.push(new_stmt); + } + + Stmt::IfElse(if_else) => { + let new_stmt = Stmt::IfElse(annotate_if_else(ctx, &if_else)); + new_stmts.push(new_stmt); + } + } + } + + ctx.pop_scope(); + new_stmts +} + +fn annotate_expr(ctx: &mut AnnotationContext, expr: &Expr<()>) -> Expr { + match &expr.kind { + ExprKind::Int(n) => Expr { + kind: ExprKind::Int(*n), + ty: Type_::Int, + }, + + ExprKind::Var(name) => { + let ty = match ctx.lookup(name) { + Some(v) => v.clone(), + None => ctx.type_var(), + }; + Expr { + kind: ExprKind::Var(name.clone()), + ty, + } + } + + ExprKind::BinOp(left, op, right) => { + let left = annotate_expr(ctx, left); + let right = annotate_expr(ctx, right); + let output = ctx.type_var(); + + op.constraints(ctx, &left.ty, &right.ty, &output); + + Expr { + kind: ExprKind::BinOp(Box::new(left), *op, Box::new(right)), + ty: output, + } + } + } +} + +fn annotate_if_else( + ctx: &mut AnnotationContext, + if_else: &IfElse<()>, +) -> IfElse { + let converted_cond = annotate_expr(ctx, &if_else.cond); + let converted_body = annotate_stmts(ctx, &if_else.body); + + let else_clause = match &if_else.else_clause { + Some(ElseClause::If(if_else2)) => { + Some(ElseClause::If(Box::new(annotate_if_else(ctx, &if_else2)))) + } + Some(ElseClause::Body(stmts)) => { + Some(ElseClause::Body(annotate_stmts(ctx, &stmts))) + } + None => None, + }; + + IfElse { + cond: converted_cond, + body: converted_body, + else_clause, + } +} + +fn collect_info(func: &Func<()>) -> (Func, HashSet) { + let mut constraints = HashSet::new(); + let mut ctx = AnnotationContext { + counter: 0, + constraints: &mut constraints, + current_env: Some(Env::default()), + }; + let new_stmts = annotate_stmts(&mut ctx, &func.stmts); + + let func = Func { + name: func.name.clone(), + return_ty: func.return_ty.clone(), + stmts: new_stmts, + }; + + (func, constraints) +} + +type Assignments = HashMap; + +fn substitute_types(assignments: &Assignments, ty: &Type_) -> Type_ { + match ty { + Type_::Var(n) => match assignments.get(&n) { + Some(ty2) => ty2.clone(), + None => ty.clone(), + }, + Type_::Func(args, ret) => { + let args = args + .into_iter() + .map(|arg| substitute_types(assignments, arg)) + .collect(); + let ret = substitute_types(assignments, &*ret); + Type_::Func(args, Box::new(ret)) + } + Type_::Int | Type_::Bool => ty.clone(), + } +} + +fn unify_constraints(constraints: &HashSet) -> Result { + let mut assignments = HashMap::new(); + + for Constraint(left, right) in constraints { + let left = substitute_types(&assignments, left); + let right = substitute_types(&assignments, right); + unify_single(&mut assignments, left, right)?; + } + + Ok(assignments) +} + +fn unify_single( + assignments: &mut Assignments, + left: Type_, + right: Type_, +) -> Result<()> { + match (left, right) { + (Type_::Int, Type_::Int) | (Type_::Bool, Type_::Bool) => {} + + (Type_::Var(n), o) | (o, Type_::Var(n)) => { + assignments.insert(n, o); + } + + (Type_::Func(left_args, left_ret), Type_::Func(right_args, right_ret)) => { + let mut new_constraints = HashSet::new(); + for (left_arg, right_arg) in + left_args.into_iter().zip(right_args.into_iter()) + { + new_constraints.insert(Constraint(left_arg, right_arg)); + } + new_constraints.insert(Constraint(*left_ret, *right_ret)); + assignments.extend(unify_constraints(&new_constraints)?); + } + + (left, right) => bail!("Mismatching types {left:?} vs. {right:?}"), + }; + + Ok(()) +} + +fn substitute_in_expr_kind( + assignments: &Assignments, + expr_kind: ExprKind, +) -> Result> { + Ok(match expr_kind { + ExprKind::Int(n) => ExprKind::Int(n), + ExprKind::Var(name) => ExprKind::Var(name), + ExprKind::BinOp(left, op, right) => { + let left = substitute_in_expr(assignments, *left)?; + let right = substitute_in_expr(assignments, *right)?; + ExprKind::BinOp(Box::new(left), op, Box::new(right)) + } + }) +} + +fn substitute_in_expr( + assignments: &Assignments, + expr: Expr, +) -> Result> { + Ok(Expr { + kind: substitute_in_expr_kind(assignments, expr.kind)?, + ty: expr.ty.convert(assignments)?, + }) +} + +fn substitute_in_if_else( + assignments: &Assignments, + if_else: IfElse, +) -> Result> { + let cond = substitute_in_expr(assignments, if_else.cond)?; + let body = substitute_in_stmts(assignments, if_else.body)?; + + let else_clause = match if_else.else_clause { + Some(ElseClause::If(if_else2)) => Some(ElseClause::If(Box::new( + substitute_in_if_else(assignments, *if_else2)?, + ))), + Some(ElseClause::Body(body)) => { + Some(ElseClause::Body(substitute_in_stmts(assignments, body)?)) + } + None => None, + }; + + Ok(IfElse { + cond, + body, + else_clause, + }) +} + +fn substitute_in_stmts( + assignments: &Assignments, + stmts: Vec>, +) -> Result>> { + stmts + .into_iter() + .map(|stmt| { + Ok(match stmt { + Stmt::Let(name, ty, body) => { + Stmt::Let(name, ty, substitute_in_expr(assignments, body)?) + } + + Stmt::Return(ret_val) => Stmt::Return(match ret_val { + Some(v) => Some(substitute_in_expr(assignments, v)?), + None => None, + }), + + Stmt::IfElse(if_else) => { + Stmt::IfElse(substitute_in_if_else(assignments, if_else)?) + } + }) + }) + .collect() +} + +fn substitute_in_func( + assignments: &Assignments, + func: Func, +) -> Result> { + Ok(Func { + name: func.name, + return_ty: func.return_ty, + stmts: substitute_in_stmts(assignments, func.stmts)?, + }) +} + +pub fn convert(ast: Vec>) -> Result>> { + // First pass, gather all of the type signatures in the top level + let mut top_level_signatures = HashMap::new(); + for decl in ast.iter() { + match decl { + super::Decl::Func(func) => { + let name = func.name.clone(); + let ty = Type::Func(Vec::new(), Box::new(func.return_ty.clone())); + top_level_signatures.insert(name, ty); + } + } + } + + // Now, type-check each function separately + let mut new_decl = Vec::new(); + for decl in ast.iter() { + match decl { + Decl::Func(func) => { + let (decorated_func, constraints) = collect_info(func); + println!("func: {:?}", decorated_func); + println!("constraints: {:?}", constraints); + + let assignments = unify_constraints(&constraints)?; + println!("assignments: {:?}", assignments); + + let typed_func = substitute_in_func(&assignments, decorated_func)?; + println!("typed: {:?}", typed_func); + new_decl.push(Decl::Func(typed_func)); + } + } + } + + Ok(new_decl) +} diff --git a/src/main.rs b/src/main.rs index 2a6467e..0d68e27 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ extern crate anyhow; lalrpop_mod!(parser); mod ast; +mod utils; use std::fs::{self, File}; use std::path::PathBuf; @@ -32,14 +33,16 @@ struct Opt { fn main() -> Result<()> { let opts = Opt::parse(); - let contents = fs::read_to_string(opts.path)?; + let contents = fs::read_to_string(&opts.path)?; let parser = ProgramParser::new(); let ast = parser.parse(&contents).unwrap(); - println!("AST: {ast:?}"); - let mut context = Context::create(); - let module = ast::llvm::convert(&mut context, ast)?; + let typed_ast = ast::typed::convert(ast)?; + + let context = Context::create(); + let module = + ast::llvm::convert(opts.path.display().to_string(), &context, typed_ast)?; { let file = File::create(&opts.out_path)?; diff --git a/src/parser.lalrpop b/src/parser.lalrpop index 2803230..8d43793 100644 --- a/src/parser.lalrpop +++ b/src/parser.lalrpop @@ -2,31 +2,57 @@ use crate::ast::*; grammar; -pub Program: Vec = Decl* => <>; +pub Program: Vec> = Decl* => <>; -Decl: Decl = { +Decl: Decl<()> = { Func => Decl::Func(<>), }; -Func: Func = { +Func: Func<()> = { "fn" "(" ")" "->" "{" "}" => Func { name, return_ty, stmts }, }; -Stmt: Stmt = { - "let" ":" "=" ";" => +Stmt: Stmt<()> = { + "let" "=" ";" => Stmt::Let(name, ty, expr), "return" ";" => Stmt::Return(expr), + IfElse => Stmt::IfElse(<>), }; -Expr: Expr = { - Expr1 => <>, +ColonType: Type = ":" => ty; + +IfElse: IfElse<()> = + "if" "{" "}" => + IfElse { cond, body, else_clause }; + +Else: ElseClause<()> = "else" => else_clause; +Else_: ElseClause<()> = { + IfElse => ElseClause::If(Box::new(<>)), + "{" "}" => ElseClause::Body(body), }; -Expr1: Expr = { - r"[0-9]+" => Expr::Int(<>.parse::().unwrap()), - Ident => Expr::Var(<>), +Expr: Expr<()> = { + #[precedence(level = "0")] "(" ")" => expr, + + #[precedence(level = "0")] + r"[0-9]+" => Expr { kind: ExprKind::Int(<>.parse::().unwrap()), ty: () }, + + #[precedence(level = "0")] + Ident => Expr { kind: ExprKind::Var(<>), ty: () }, + + #[precedence(level = "13")] + #[assoc(side = "none")] + => Expr { + kind: ExprKind::BinOp(Box::new(left), op, Box::new(right)), + ty: (), + }, +}; + +CompareOp: Op = { + "<" => Op::LessThan, + ">" => Op::GreaterThan, }; Type: Type = { diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..0444f0a --- /dev/null +++ b/src/utils.rs @@ -0,0 +1 @@ +// TODO: put layered environment here?