diff --git a/examples/functions.e0 b/examples/functions.e0 index f1caf58..0029516 100644 --- a/examples/functions.e0 +++ b/examples/functions.e0 @@ -1,7 +1,7 @@ -fn compute() -> int { - return 42; +fn compute(x : int) -> int { + return 120 + x; } fn main() -> int { - return compute(); + return compute(3); } diff --git a/src/ast/llvm.rs b/src/ast/llvm.rs index 519026d..793a98f 100644 --- a/src/ast/llvm.rs +++ b/src/ast/llvm.rs @@ -7,12 +7,16 @@ use inkwell::{ module::Module, types::{BasicMetadataTypeEnum, BasicTypeEnum, FunctionType}, values::{ - BasicMetadataValueEnum, BasicValueEnum, FunctionValue, IntValue, - PointerValue, + BasicMetadataValueEnum, BasicValue, BasicValueEnum, FunctionValue, + IntValue, PointerValue, }, IntPredicate, }; +use crate::utils::{ + convert_type_to_metadata_type, convert_value_to_metadata_value, +}; + use super::{Decl, ElseClause, Expr, ExprKind, IfElse, Op, Stmt, Type}; impl Expr { @@ -26,17 +30,16 @@ impl Expr { let bool_ty = context.bool_type(); match &self.kind { - ExprKind::Var(name) => { - let value = match env.lookup(&name) { - Some(v) => match v.kind { - EnvValueKind::Local(l) => l, - EnvValueKind::Func(f) => f.as_global_value().as_pointer_value(), - }, - None => bail!("Unbound name {name:?}"), - }; - - Ok(builder.build_load(value, "")) - } + ExprKind::Var(name) => Ok(match env.lookup(&name) { + Some(v) => match v.kind { + EnvValueKind::Local(l) => l, + EnvValueKind::Func(f) => { + let ptr = f.as_global_value().as_pointer_value(); + builder.build_load(ptr, "") + } + }, + None => bail!("Unbound name {name:?}"), + }), ExprKind::Int(n) => { Ok(BasicValueEnum::IntValue(int_ty.const_int(*n as u64, false))) @@ -49,6 +52,16 @@ impl Expr { } match op { + Op::Plus => { + let left_val = + left.into_llvm(context, builder, env)?.into_int_value(); + let right_val = + right.into_llvm(context, builder, env)?.into_int_value(); + let result: IntValue = + builder.build_int_add(left_val, right_val, ""); + Ok(BasicValueEnum::IntValue(result)) + } + Op::LessThan => { let left_val = left.into_llvm(context, builder, env)?.into_int_value(); @@ -84,23 +97,15 @@ impl Expr { ty: func_ty, kind: EnvValueKind::Func(func_ptr), }) => { - fn DUMB_CONVERT(a: BasicValueEnum) -> BasicMetadataValueEnum { - use BasicMetadataValueEnum as B; - use BasicValueEnum as A; - match a { - A::ArrayValue(a) => B::ArrayValue(a), - A::IntValue(a) => B::IntValue(a), - A::FloatValue(a) => B::FloatValue(a), - A::PointerValue(a) => B::PointerValue(a), - A::StructValue(a) => B::StructValue(a), - A::VectorValue(a) => B::VectorValue(a), - } - } - let args_llvm = args .iter() - .map(|arg| arg.into_llvm(context, builder, env).map(DUMB_CONVERT)) + .map(|arg| { + arg + .into_llvm(context, builder, env) + .map(convert_value_to_metadata_value) + }) .collect::>>()?; + println!("ARGS_LLVM: {args_llvm:?}"); let call_site = builder.build_call(*func_ptr, args_llvm.as_slice(), ""); @@ -143,17 +148,19 @@ fn fn_type_basic<'ctx>( } } +#[derive(Debug)] enum EnvValueKind<'ctx> { Func(FunctionValue<'ctx>), - Local(PointerValue<'ctx>), + Local(BasicValueEnum<'ctx>), } +#[derive(Debug)] struct EnvValue<'a, 'ctx> { ty: &'a Type, kind: EnvValueKind<'ctx>, } -#[derive(Default)] +#[derive(Debug, Default)] struct Env<'a, 'ctx> { parent: Option<&'a Env<'a, 'ctx>>, local_type_map: HashMap>, @@ -246,7 +253,7 @@ fn convert_stmts( name.clone(), EnvValue { ty, - kind: EnvValueKind::Local(alloca), + kind: EnvValueKind::Local(alloca.as_basic_value_enum()), }, ); } @@ -286,7 +293,14 @@ pub fn convert( _ => None, }) { let return_ty = func.return_ty.into_llvm_basic_type(context); - let llvm_func_ty = fn_type_basic(return_ty, &[], false); + let args_ty = func + .args + .iter() + .map(|arg| { + convert_type_to_metadata_type(arg.ty.into_llvm_basic_type(context)) + }) + .collect::>(); + let llvm_func_ty = fn_type_basic(return_ty, &args_ty, false); let llvm_func = module.add_function(&func.name, llvm_func_ty, None); env.local_type_map.insert( @@ -309,11 +323,37 @@ pub fn convert( }) => func.clone(), _ => unreachable!(), }; + let entry_block = context.append_basic_block(llvm_func, "entry"); builder.position_at_end(entry_block); - convert_stmts(&context, &module, &llvm_func, &builder, &env, &func.stmts)?; + let mut scoped_env = Env { + parent: Some(&env), + local_type_map: HashMap::new(), + }; + + for (arg, param) in func.args.iter().zip(llvm_func.get_params().into_iter()) + { + scoped_env.local_type_map.insert( + arg.name.clone(), + EnvValue { + ty: &arg.ty, + kind: EnvValueKind::Local(param), + }, + ); + } + println!("ARGS {:?}", func.args); + println!("ENV {:?}", scoped_env); + + convert_stmts( + &context, + &module, + &llvm_func, + &builder, + &scoped_env, + &func.stmts, + )?; } Ok(module) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 1bef14d..eeae732 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -9,11 +9,18 @@ pub enum Decl { #[derive(Debug)] pub struct Func { pub name: String, + pub args: Vec, pub return_ty: Type, pub stmts: Vec>, pub ty: T, } +#[derive(Debug, Clone)] +pub struct Arg { + pub name: String, + pub ty: Type, +} + #[derive(Debug)] pub enum Stmt { Let(String, Option, Expr), @@ -52,6 +59,7 @@ pub enum ExprKind { pub enum Op { LessThan, GreaterThan, + Plus, } impl Op { diff --git a/src/ast/typed.rs b/src/ast/typed.rs index 2adf97e..b7833fc 100644 --- a/src/ast/typed.rs +++ b/src/ast/typed.rs @@ -64,6 +64,11 @@ impl Op { output: &Type_, ) { match self { + Op::Plus => { + ctx.constrain(left.clone(), Type_::Int); + ctx.constrain(right.clone(), Type_::Int); + ctx.constrain(output.clone(), Type_::Int); + } Op::LessThan | Op::GreaterThan => { ctx.constrain(left.clone(), Type_::Int); ctx.constrain(right.clone(), Type_::Int); @@ -230,10 +235,14 @@ fn annotate_expr( None => bail!("Name not found"), }; + println!("ENV: {:?}", ctx.current_env); + println!("ANNOT CALL {:?} {:?} {:?}", func_name, args, func_args_ty); + ctx.constrain(ret_ty.clone(), func_ret_ty); - for arg in args { + for (arg, expected_ty) in args.iter().zip(func_args_ty.iter()) { let arg_annot = annotate_expr(ctx, arg)?; + ctx.constrain(arg_annot.ty.clone(), expected_ty.clone()); args_annot.push(arg_annot); } @@ -281,14 +290,18 @@ fn collect_info( }; let new_stmts = annotate_stmts(&mut ctx, &func.stmts)?; - // TODO: Args - let total_ty = Type_::Func( - Vec::new(), - Box::new(Type_::from_type(func.return_ty.clone())), - ); + let args_ = func + .args + .iter() + .cloned() + .map(|arg| Type_::from_type(arg.ty)) + .collect(); + let total_ty = + Type_::Func(args_, Box::new(Type_::from_type(func.return_ty.clone()))); let func = Func { name: func.name.clone(), + args: func.args.clone(), return_ty: func.return_ty.clone(), stmts: new_stmts, ty: total_ty, @@ -373,6 +386,7 @@ fn substitute_in_expr_kind( ExprKind::BinOp(Box::new(left), op, Box::new(right)) } ExprKind::Call(func, args) => { + println!("CALL {:?}", (&func, &args)); let args = args .into_iter() .map(|arg| substitute_in_expr(assignments, arg)) @@ -447,6 +461,7 @@ fn substitute_in_func( ) -> Result> { Ok(Func { name: func.name, + args: func.args, return_ty: func.return_ty, stmts: substitute_in_stmts(assignments, func.stmts)?, ty: func.ty.convert(assignments)?, @@ -460,7 +475,8 @@ pub fn convert(ast: Vec>) -> Result>> { match decl { super::Decl::Func(func) => { let name = func.name.clone(); - let ty = Type::Func(Vec::new(), Box::new(func.return_ty.clone())); + let args_ty = func.args.iter().map(|arg| arg.ty.clone()).collect(); + let ty = Type::Func(args_ty, Box::new(func.return_ty.clone())); top_level_env .local_type_map .insert(name, Type_::from_type(ty)); @@ -474,9 +490,22 @@ pub fn convert(ast: Vec>) -> Result>> { for decl in ast.iter() { match decl { Decl::Func(func) => { + let mut scoped_env = Env { + parent: Some(Box::new(env)), + local_type_map: HashMap::new(), + }; + + for arg in func.args.iter() { + scoped_env + .local_type_map + .insert(arg.name.clone(), Type_::from_type(arg.ty.clone())); + } + println!("Processing func {:?}", func); - let (decorated_func, constraints, env2) = collect_info(env, func)?; - env = env2; + let (decorated_func, constraints, env2) = + collect_info(scoped_env, func)?; + env = *env2.parent.unwrap(); + println!("func: {:?}", decorated_func); println!("constraints: {:?}", constraints); diff --git a/src/parser.lalrpop b/src/parser.lalrpop index 25a86af..a1db41b 100644 --- a/src/parser.lalrpop +++ b/src/parser.lalrpop @@ -9,10 +9,14 @@ Decl: Decl<()> = { }; Func: Func<()> = { - "fn" "(" ")" "->" "{" "}" => - Func { name, return_ty, stmts, ty: (), }, + "fn" "(" ")" "->" "{" "}" => + Func { name, args, return_ty, stmts, ty: (), }, }; +Args: Vec = Punct<",", Arg>? => <>.unwrap_or_else(|| Vec::new()); + +Arg: Arg = ":" => Arg { name, ty }; + Stmt: Stmt<()> = { "let" "=" ";" => Stmt::Let(name, ty, expr), @@ -46,6 +50,13 @@ Expr: Expr<()> = { "(" ?> ")" => Expr { kind: ExprKind::Call(func, args.unwrap_or_else(|| vec![])), ty: () }, + #[precedence(level = "8")] + #[assoc(side = "left")] + => Expr { + kind: ExprKind::BinOp(Box::new(left), op, Box::new(right)), + ty: (), + }, + #[precedence(level = "13")] #[assoc(side = "none")] => Expr { @@ -54,6 +65,10 @@ Expr: Expr<()> = { }, }; +AddOp: Op = { + "+" => Op::Plus, +}; + CompareOp: Op = { "<" => Op::LessThan, ">" => Op::GreaterThan, diff --git a/src/utils.rs b/src/utils.rs index 0444f0a..e5e26ed 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1 +1,36 @@ // TODO: put layered environment here? + +use inkwell::{ + types::{BasicMetadataTypeEnum, BasicTypeEnum}, + values::{BasicMetadataValueEnum, BasicValueEnum}, +}; + +pub fn convert_type_to_metadata_type( + a: BasicTypeEnum, +) -> BasicMetadataTypeEnum { + use BasicMetadataTypeEnum as B; + use BasicTypeEnum as A; + match a { + A::IntType(a) => B::IntType(a), + A::ArrayType(a) => B::ArrayType(a), + A::FloatType(a) => B::FloatType(a), + A::StructType(a) => B::StructType(a), + A::VectorType(a) => B::VectorType(a), + A::PointerType(a) => B::PointerType(a), + } +} + +pub fn convert_value_to_metadata_value( + a: BasicValueEnum, +) -> BasicMetadataValueEnum { + use BasicMetadataValueEnum as B; + use BasicValueEnum as A; + match a { + A::ArrayValue(a) => B::ArrayValue(a), + A::IntValue(a) => B::IntValue(a), + A::FloatValue(a) => B::FloatValue(a), + A::PointerValue(a) => B::PointerValue(a), + A::StructValue(a) => B::StructValue(a), + A::VectorValue(a) => B::VectorValue(a), + } +}