diff --git a/cli/src/commands/eval.rs b/cli/src/commands/eval.rs index 3891104..f8af8ca 100644 --- a/cli/src/commands/eval.rs +++ b/cli/src/commands/eval.rs @@ -57,6 +57,14 @@ pub fn eval_all( if asts.is_empty() { return false; } + log::debug!( + "Expressions ({}):\n\t{}", + asts.len(), + asts.iter() + .map(|ast| ast.print_expr()) + .collect::>() + .join("\n\t") + ); 'exprs: for (i, ast) in asts.iter().enumerate() { let checked_ast = match checker.check_expr(ast) { Ok(ast) => ast, diff --git a/cli/src/commands/repl.rs b/cli/src/commands/repl.rs index 48910fc..f233d6f 100644 --- a/cli/src/commands/repl.rs +++ b/cli/src/commands/repl.rs @@ -1,14 +1,14 @@ -use core::str; -use std::io::{Read, Write}; - +use crate::{commands::eval::eval_all, logger::init_logger_str, CLI_VERSION}; use clap::{ArgMatches, Command}; use colorful::Colorful; +use core::str; use lento_core::{ interpreter::env::global_env, lexer::lexer::InputSource, parser::parser, stdlib::init::stdlib, type_checker::checker::TypeChecker, }; +use std::io::{Read, Write}; -use crate::{commands::eval::eval_all, logger::init_logger_str, CLI_VERSION}; +const PROMPT: &str = "►"; pub fn handle_command_repl(args: &ArgMatches, _arg_parser: &mut Command) { // Set the Ctrl-C handler to exit the program @@ -51,7 +51,7 @@ pub fn handle_command_repl(args: &ArgMatches, _arg_parser: &mut Command) { let mut env = global_env(); std.init_environment(&mut env); loop { - print!("> "); + print!("{} ", PROMPT); std::io::stdout().flush().unwrap(); let found_expr = eval_all( &mut parser, diff --git a/cli/src/error.rs b/cli/src/error.rs index b73dda6..1e90099 100644 --- a/cli/src/error.rs +++ b/cli/src/error.rs @@ -54,14 +54,6 @@ pub fn print_error_report(kind: &str, base: BaseError, content: &str, source: &I report = report.with_help(hint); } - if base.info.end.eof { - report = report.with_label( - Label::new((source.name(), base.info.end.index..base.info.end.index)) - .with_message(format!("end of {}", source.human_readable())) - .with_color(ariadne::Color::Yellow), - ); - } - report .finish() .print((source.name(), Source::from(content))) diff --git a/core/src/compiler/backends/cranelift/tests.rs b/core/src/compiler/backends/cranelift/tests.rs index db7c658..745ceea 100644 --- a/core/src/compiler/backends/cranelift/tests.rs +++ b/core/src/compiler/backends/cranelift/tests.rs @@ -1,3 +1,4 @@ +/* #[cfg(test)] mod tests { use std::io::Write; @@ -103,3 +104,4 @@ mod tests { assert!(!program_output.is_empty()); } } +*/ diff --git a/core/src/interpreter/eval.rs b/core/src/interpreter/eval.rs index 6487e25..e244c1e 100644 --- a/core/src/interpreter/eval.rs +++ b/core/src/interpreter/eval.rs @@ -2,11 +2,16 @@ use std::borrow::Borrow; use crate::{ interpreter::value::NativeFunction, + parser::pattern::BindPattern, type_checker::{ - checked_ast::{CheckedAst, CheckedBindPattern}, + checked_ast::CheckedAst, types::{GetType, Type}, }, - util::{error::LineInfo, str::Str}, + util::{ + error::{BaseErrorExt, LineInfo}, + failable::Failable, + str::Str, + }, }; use super::{ @@ -44,16 +49,11 @@ pub fn eval_expr(ast: &CheckedAst, env: &mut Environment) -> InterpretResult { CheckedAst::Identifier { name, .. } => match env.lookup_identifier(name) { (Some(v), _) => v.clone(), (_, Some(f)) => Value::Function(Box::new(f.clone())), - (_, _) => unreachable!("This should have been checked by the type checker"), + (_, _) => unreachable!("Undefined identifier: {}", name), }, CheckedAst::Assignment { target, expr, .. } => { - let info = target.info(); - let target = match target.to_owned() { - CheckedBindPattern::Variable { name, .. } => name, - _ => unreachable!("This should have been checked by the type checker"), - }; let value = eval_expr(expr, env)?; - env.add_value(Str::String(target), value.clone(), info)?; + eval_assignment(target, &value, env)?; value } CheckedAst::List { exprs, ty, .. } => Value::List( @@ -86,7 +86,7 @@ pub fn eval_expr(ast: &CheckedAst, env: &mut Environment) -> InterpretResult { } } } - CheckedAst::FunctionDef { + CheckedAst::Lambda { param, body, return_type, @@ -109,7 +109,7 @@ pub fn eval_expr(ast: &CheckedAst, env: &mut Environment) -> InterpretResult { if !matches!(ast, CheckedAst::Literal { .. }) { log::trace!( "Eval: {} -> {}", - ast.print_sexpr(), + ast.print_expr(), result.pretty_print_color() ); } @@ -160,7 +160,10 @@ fn eval_call( let arg = eval_expr(arg, env)?; let mut closure = env.new_child(Str::Str("")); // Bind the argument to the parameter of the function variation - closure.add_value(Str::String(param.name.clone()), arg.clone(), info)?; + let BindPattern::Variable { name, info } = ¶m.pattern else { + unreachable!("Other bind patterns are not supported yet"); + }; + closure.add_value(Str::String(name.clone()), arg.clone(), info)?; eval_expr(body, &mut closure) } Function::Native { .. } => { @@ -216,3 +219,69 @@ fn eval_tuple(exprs: &[CheckedAst], env: &mut Environment) -> InterpretResult { Ok(Value::Tuple(values, Type::Tuple(types))) } + +/// Evaluate an assignment expression to a bind pattern +fn eval_assignment( + pattern: &BindPattern, + value: &Value, + env: &mut Environment, +) -> Failable { + match pattern { + // BindPattern::Function { name, params, .. } => { + // let mut closure = env.new_child(Str::Str("")); + // closure.add_value(name.clone(), value.clone(), pattern.info())?; + // for param in params { + // if let BindPattern::Variable { name, .. } = param { + // closure.add_value(name.clone(), value.clone(), pattern.info())?; + // } + // } + // } + BindPattern::Variable { name, .. } => { + env.add_value(Str::String(name.clone()), value.clone(), pattern.info())? + } + BindPattern::Tuple { elements, .. } => { + let Value::Tuple(values, _) = value else { + unreachable!("This should have been checked by the type checker"); + }; + for (pattern, value) in elements.iter().zip(values) { + eval_assignment(pattern, value, env)?; + } + } + BindPattern::Record { fields, .. } => { + let Value::Record(values, _) = value else { + unreachable!("This should have been checked by the type checker"); + }; + for (key, pattern) in fields { + let Some((_, value)) = values.iter().find(|(k, _)| k == key) else { + unreachable!("This should have been checked by the type checker"); + }; + eval_assignment(pattern, value, env)?; + } + } + BindPattern::List { elements, .. } => { + let Value::List(values, _) = value else { + unreachable!("This should have been checked by the type checker"); + }; + for (element, value) in elements.iter().zip(values) { + eval_assignment(element, value, env)?; + } + } + BindPattern::Wildcard => {} + BindPattern::Literal { value: lit, .. } => { + if *value != lit.as_value() { + return Err(RuntimeError::new( + format!( + "Literal pattern match failed: expected {}, found {}", + lit.as_value().pretty_print(), + value.pretty_print() + ), + pattern.info().clone(), + )); + } + } + BindPattern::Rest { .. } => { + // Handle rest pattern if needed + } + } + Ok(()) +} diff --git a/core/src/interpreter/number.rs b/core/src/interpreter/number.rs index d1442ec..4c32daf 100644 --- a/core/src/interpreter/number.rs +++ b/core/src/interpreter/number.rs @@ -69,7 +69,7 @@ trait NumberMethods { } /// An unsigned integer number type that can be represented in 1, 8, 16, 32, 64, 128 bits or arbitrary precision. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum UnsignedInteger { UInt1(u8), // Bit UInt8(u8), // Byte @@ -525,7 +525,7 @@ impl ArithmeticOperations for UnsignedInteger { } /// A signed integer number type that can be represented in 8, 16, 32, 64, 128 bits or arbitrary precision. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum SignedInteger { Int8(i8), Int16(i16), diff --git a/core/src/interpreter/tests.rs b/core/src/interpreter/tests.rs index ae9d0bb..6bdeda7 100644 --- a/core/src/interpreter/tests.rs +++ b/core/src/interpreter/tests.rs @@ -7,7 +7,10 @@ mod tests { number::{FloatingPoint, Number, UnsignedInteger}, value::Value, }, - parser::parser::{self, ParseResult, ParseResults}, + parser::{ + parser::{self, ParseResult, ParseResults}, + pattern::BindPattern, + }, stdlib::init::{stdlib, Initializer}, type_checker::{ checked_ast::{CheckedAst, CheckedParam}, @@ -56,21 +59,18 @@ mod tests { info: LineInfo::default(), }), arg: Box::new(rhs), - return_type: std_types::NUM(), + ret_ty: std_types::NUM(), info: LineInfo::default(), }), arg: Box::new(lhs), - return_type: std_types::NUM(), + ret_ty: std_types::NUM(), info: LineInfo::default(), } } fn fn_unit() -> CheckedAst { - CheckedAst::function_def( - CheckedParam { - name: "ignore".to_string(), - ty: std_types::UNIT, - }, + CheckedAst::lambda( + CheckedParam::from_str("ignore", std_types::UNIT), CheckedAst::Block { exprs: vec![], ty: std_types::UNIT, @@ -117,7 +117,7 @@ mod tests { info: LineInfo::default(), }, ], - expr_types: Type::Tuple(vec![std_types::UINT8; 3]), + ty: Type::Tuple(vec![std_types::UINT8; 3]), info: LineInfo::default(), }; let result = eval_expr(&ast, &mut std_env()); @@ -165,7 +165,7 @@ mod tests { value: Value::Unit, info: LineInfo::default(), }), - return_type: std_types::UNIT, + ret_ty: std_types::UNIT, info: LineInfo::default(), }; let result = eval_expr(&ast, &mut global_env()); @@ -175,25 +175,10 @@ mod tests { assert_eq!(result, Value::Unit); } - #[test] - fn invalid_function() { - let ast = CheckedAst::FunctionCall { - expr: Box::new(fn_unit()), - arg: Box::new(CheckedAst::Literal { - value: make_u8(1), - info: LineInfo::default(), - }), - return_type: std_types::UNIT, - info: LineInfo::default(), - }; - let result = eval_expr(&ast, &mut global_env()); - assert!(result.is_err()); - } - #[test] fn assignment() { let ast = CheckedAst::Assignment { - target: crate::type_checker::checked_ast::CheckedBindPattern::Variable { + target: BindPattern::Variable { name: "x".into(), info: LineInfo::default(), }, @@ -244,7 +229,7 @@ mod tests { y = 2; z = x + y; "#, - None, + Some(&stdlib()), ) .expect("Failed to parse module"); let mut checker = TypeChecker::default(); @@ -262,11 +247,11 @@ mod tests { fn function_decl_paren_explicit_args_and_ret() { let module = parse_str_all( r#" - add(x: u8, y: u8, z: u8) -> u8 { + u8 add(u8 x, u8 y, u8 z) { x + y + z } "#, - None, + Some(&stdlib()), ) .expect("Failed to parse module"); let mut checker = TypeChecker::default(); @@ -283,11 +268,11 @@ mod tests { fn function_decl_explicit_args_and_ret() { let module = parse_str_all( r#" - add x: u8 y: u8 z: u8 -> u8 { + u8 add u8 x, u8 y, u8 z { x + y + z } "#, - None, + Some(&stdlib()), ) .expect("Failed to parse module"); let mut checker = TypeChecker::default(); @@ -304,11 +289,11 @@ mod tests { fn function_decl_explicit_args() { let module = parse_str_all( r#" - add x: u8 y: u8 z: u8 { + add u8 x, u8 y, u8 z { x + y + z } "#, - None, + Some(&stdlib()), ) .expect("Failed to parse module"); let mut checker = TypeChecker::default(); @@ -329,7 +314,7 @@ mod tests { x + y + z } "#, - None, + Some(&stdlib()), ) .expect("Failed to parse module"); let mut checker = TypeChecker::default(); @@ -350,7 +335,7 @@ mod tests { x + y + z } "#, - None, + Some(&stdlib()), ) .expect("Failed to parse module"); let mut checker = TypeChecker::default(); @@ -367,11 +352,11 @@ mod tests { fn function_decl_implicit_random_parens() { let module = parse_str_all( r#" - add x y (z) a (b) (c) -> u8 { + u8 add x y (z) a (b) (c) { x + y + z + a + b + c } "#, - None, + Some(&stdlib()), ) .expect("Failed to parse module"); let mut checker = TypeChecker::default(); @@ -388,9 +373,9 @@ mod tests { fn function_decl_paren_explicit_signature_oneline() { let module = parse_str_all( r#" - add(x: u8, y: u8, z: u8) -> u8 = x + y + z; + u8 add(u8 x, u8 y, u8 z) = x + y + z; "#, - None, + Some(&stdlib()), ) .expect("Failed to parse module"); let mut checker = TypeChecker::default(); @@ -407,9 +392,9 @@ mod tests { fn function_decl_explicit_signature_oneline() { let module = parse_str_all( r#" - add x: u8 y: u8 z: u8 -> u8 = x + y + z; + u8 add u8 x, u8 y, u8 z = x + y + z; "#, - None, + Some(&stdlib()), ) .expect("Failed to parse module"); let mut checker = TypeChecker::default(); diff --git a/core/src/interpreter/value.rs b/core/src/interpreter/value.rs index a3c194c..8357163 100644 --- a/core/src/interpreter/value.rs +++ b/core/src/interpreter/value.rs @@ -17,17 +17,17 @@ use super::{env::Environment, eval::InterpretResult, number::Number}; /// ```ignore /// record = { "key": 1, 2: 3.0, 'c': "value", 4.0: 'd' } /// ``` -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum RecordKey { String(String), - Number(Number), + // Number(Number), // TODO: Support numbers as keys } impl Display for RecordKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { RecordKey::String(s) => write!(f, "{}", s), - RecordKey::Number(n) => write!(f, "{}", n), + // RecordKey::Number(n) => write!(f, "{}", n), } } } diff --git a/core/src/lexer/lexer.rs b/core/src/lexer/lexer.rs index 4b2c74d..4597476 100644 --- a/core/src/lexer/lexer.rs +++ b/core/src/lexer/lexer.rs @@ -92,6 +92,30 @@ impl NumInfo { } } +//--------------------------------------------------------------------------------------// +// Lexer Factory Functions // +//--------------------------------------------------------------------------------------// + +pub fn from_str(input: &str) -> Lexer> { + Lexer::new(BytesReader::from(input)) +} + +pub fn from_string(input: String) -> Lexer> { + Lexer::new(Cursor::new(input)) +} + +pub fn from_file(file: File) -> Lexer> { + Lexer::new(BufReader::new(file)) +} + +pub fn from_stream(reader: R) -> Lexer { + Lexer::new_stream(reader) +} + +pub fn from_stdin() -> Lexer { + from_stream(StdinReader::default()) +} + //--------------------------------------------------------------------------------------// // Lexer // //--------------------------------------------------------------------------------------// @@ -118,7 +142,7 @@ where content: Vec, index: usize, line_info: LineInfo, - pub operators: HashSet, + operators: HashSet, peeked_tokens: Vec, // Queue of peeked tokens (FIFO) /// The buffer size of the lexer. \ /// This is the size of the buffer used to read from the source code. @@ -199,6 +223,10 @@ impl Lexer { self.content } + pub fn add_operator(&mut self, op: String) { + self.operators.insert(op); + } + pub fn current_index(&self) -> usize { self.index } @@ -462,7 +490,6 @@ impl Lexer { ']' => TokenKind::RightBracket, ';' => TokenKind::SemiColon, ':' => TokenKind::Colon, - ',' => TokenKind::Comma, '/' if self.peek_char(0) == Some('/') => return self.read_comment(), _ => return self.read_operator(c), }); @@ -1168,7 +1195,6 @@ impl Lexer { ops = new_ops; self.next_char(); // Eat the peeked character } - log::trace!("longest_match: {}", longest_match); let Some(op) = self.operators.get(&longest_match) else { return Err(LexerError::new( format!("Unknown operator {}", longest_match), @@ -1178,23 +1204,3 @@ impl Lexer { self.new_token_info(TokenKind::Op(op.clone())) } } - -pub fn from_str(input: &str) -> Lexer> { - Lexer::new(BytesReader::from(input)) -} - -pub fn from_string(input: String) -> Lexer> { - Lexer::new(Cursor::new(input)) -} - -pub fn from_file(file: File) -> Lexer> { - Lexer::new(BufReader::new(file)) -} - -pub fn from_stream(reader: R) -> Lexer { - Lexer::new_stream(reader) -} - -pub fn from_stdin() -> Lexer { - from_stream(StdinReader::default()) -} diff --git a/core/src/lexer/tests.rs b/core/src/lexer/tests.rs index ddaaf38..285e14a 100644 --- a/core/src/lexer/tests.rs +++ b/core/src/lexer/tests.rs @@ -1,6 +1,7 @@ #[cfg(test)] mod tests { - use std::io::{Read, Seek}; + + use std::io::Read; use crate::{ interpreter::number::{FloatingPoint, Number, UnsignedInteger}, @@ -8,39 +9,72 @@ mod tests { lexer::{from_str, Lexer}, token::TokenKind, }, + parser::op::intrinsic_operators, stdlib::init::stdlib, }; - fn assert_next_token_eq(lexer: &mut Lexer, token: TokenKind) { - assert_eq!(lexer.next_token().unwrap().token, token); + fn init(lexer: &mut Lexer) { + intrinsic_operators() + .iter() + .for_each(|op| lexer.add_operator(op.symbol.clone())); + stdlib().init_lexer(lexer); } #[test] fn function() { let mut lexer = from_str("add a b = a + b"); - stdlib().init_lexer(&mut lexer); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("add".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("a".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("b".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::Op("=".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("a".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::Op("+".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("b".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + init(&mut lexer); + + let token = TokenKind::Identifier("add".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Identifier("a".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Identifier("b".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Op("=".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Identifier("a".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Op("+".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Identifier("b".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn assign() { - let mut lexer = from_str("x = 1;"); - stdlib().init_lexer(&mut lexer); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("x".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::Op("=".to_string())); - assert_next_token_eq( - &mut lexer, - TokenKind::Number(Number::UnsignedInteger(UnsignedInteger::UInt1(1))), - ); - assert_next_token_eq(&mut lexer, TokenKind::SemiColon); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + let mut lexer = from_str("x = 1 + 2;"); + init(&mut lexer); + + let token = TokenKind::Identifier("x".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Op("=".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Number(Number::UnsignedInteger(UnsignedInteger::UInt1(1))); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Op("+".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Number(Number::UnsignedInteger(UnsignedInteger::UInt8(2))); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::SemiColon; + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] @@ -49,173 +83,267 @@ mod tests { let equals = "==".to_string(); let assignment = "=".to_string(); let strict_equals = "===".to_string(); - lexer.operators.insert(equals.clone()); - lexer.operators.insert(strict_equals.clone()); - lexer.operators.insert(assignment.clone()); + lexer.add_operator(equals.clone()); + lexer.add_operator(strict_equals.clone()); + lexer.add_operator(assignment.clone()); // == - assert_next_token_eq(&mut lexer, TokenKind::Op(equals)); + + let token = TokenKind::Op(equals); + assert_eq!(lexer.next_token().unwrap().token, token); // = - assert_next_token_eq(&mut lexer, TokenKind::Op(assignment.clone())); + + let token = TokenKind::Op(assignment.clone()); + assert_eq!(lexer.next_token().unwrap().token, token); // === - assert_next_token_eq(&mut lexer, TokenKind::Op(strict_equals.clone())); + + let token = TokenKind::Op(strict_equals.clone()); + assert_eq!(lexer.next_token().unwrap().token, token); // ==== - assert_next_token_eq(&mut lexer, TokenKind::Op(strict_equals)); - assert_next_token_eq(&mut lexer, TokenKind::Op(assignment)); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + let token = TokenKind::Op(strict_equals); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Op(assignment); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn string() { let mut lexer = from_str(r#""Hello, World!""#); - stdlib().init_lexer(&mut lexer); - assert_next_token_eq(&mut lexer, TokenKind::String("Hello, World!".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + init(&mut lexer); + + let token = TokenKind::String("Hello, World!".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn string_escape() { let mut lexer = from_str(r#""Hello, \"World\"!""#); - stdlib().init_lexer(&mut lexer); - assert_next_token_eq( - &mut lexer, - TokenKind::String("Hello, \\\"World\\\"!".to_string()), - ); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + init(&mut lexer); + + let token = TokenKind::String("Hello, \\\"World\\\"!".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn char() { let mut lexer = from_str(r#"'a'"#); - stdlib().init_lexer(&mut lexer); - assert_next_token_eq(&mut lexer, TokenKind::Char('a')); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + init(&mut lexer); + + let token = TokenKind::Char('a'); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn char_escape() { let mut lexer = from_str(r#"'\\'"#); - stdlib().init_lexer(&mut lexer); - assert_next_token_eq(&mut lexer, TokenKind::Char('\\')); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + init(&mut lexer); + + let token = TokenKind::Char('\\'); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn float() { let mut lexer = from_str("123.456"); - assert_next_token_eq( - &mut lexer, - TokenKind::Number(Number::FloatingPoint(FloatingPoint::Float32(123.456))), - ); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + + let token = TokenKind::Number(Number::FloatingPoint(FloatingPoint::Float32(123.456))); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn integer() { let mut lexer = from_str("123"); - assert_next_token_eq( - &mut lexer, - TokenKind::Number(Number::UnsignedInteger(UnsignedInteger::UInt8(123))), - ); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + + let token = TokenKind::Number(Number::UnsignedInteger(UnsignedInteger::UInt8(123))); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn identifier() { let mut lexer = from_str("abc_123"); stdlib().init_lexer(&mut lexer); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("abc_123".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + + let token = TokenKind::Identifier("abc_123".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn keywords() { let mut lexer = from_str("true false"); - stdlib().init_lexer(&mut lexer); - assert_next_token_eq(&mut lexer, TokenKind::Boolean(true)); - assert_next_token_eq(&mut lexer, TokenKind::Boolean(false)); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + init(&mut lexer); + + let token = TokenKind::Boolean(true); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Boolean(false); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn comment() { let mut lexer = from_str("// This is a comment"); - stdlib().init_lexer(&mut lexer); - assert_next_token_eq( - &mut lexer, - TokenKind::Comment(" This is a comment".to_string()), - ); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + + let token = TokenKind::Comment(" This is a comment".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn commas() { let mut lexer = from_str("a, b, c"); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("a".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::Comma); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("b".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::Comma); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("c".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + init(&mut lexer); + + let token = TokenKind::Identifier("a".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Op(",".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Identifier("b".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Op(",".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Identifier("c".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn colon() { let mut lexer = from_str("a: b"); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("a".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::Colon); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("b".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + + let token = TokenKind::Identifier("a".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Colon; + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Identifier("b".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn parens() { let mut lexer = from_str("(a)"); - assert_next_token_eq( - &mut lexer, - TokenKind::LeftParen { - is_function_call: false, - }, - ); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("a".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::RightParen); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + + let token = TokenKind::LeftParen { + is_function_call: false, + }; + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Identifier("a".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::RightParen; + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn braces() { let mut lexer = from_str("{a}"); - assert_next_token_eq(&mut lexer, TokenKind::LeftBrace); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("a".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::RightBrace); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + + let token = TokenKind::LeftBrace; + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Identifier("a".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::RightBrace; + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn brackets() { let mut lexer = from_str("[a]"); - assert_next_token_eq(&mut lexer, TokenKind::LeftBracket); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("a".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::RightBracket); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + + let token = TokenKind::LeftBracket; + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Identifier("a".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::RightBracket; + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn operators() { let mut lexer = from_str("a + b"); - lexer.operators.insert("+".to_string()); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("a".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::Op("+".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("b".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + lexer.add_operator("+".to_string()); + + let token = TokenKind::Identifier("a".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Op("+".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Identifier("b".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } #[test] fn semicolon() { let mut lexer = from_str("a; b"); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("a".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::SemiColon); - assert_next_token_eq(&mut lexer, TokenKind::Identifier("b".to_string())); - assert_next_token_eq(&mut lexer, TokenKind::EndOfFile); + + let token = TokenKind::Identifier("a".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::SemiColon; + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::Identifier("b".to_string()); + assert_eq!(lexer.next_token().unwrap().token, token); + + let token = TokenKind::EndOfFile; + assert_eq!(lexer.next_token().unwrap().token, token); } } diff --git a/core/src/lexer/token.rs b/core/src/lexer/token.rs index ae6142a..1cc1530 100644 --- a/core/src/lexer/token.rs +++ b/core/src/lexer/token.rs @@ -10,7 +10,6 @@ pub enum TokenKind { Newline, SemiColon, Colon, - Comma, // Literals Identifier(String), Number(Number), @@ -54,7 +53,6 @@ impl TokenKind { matches!( self, TokenKind::EndOfFile - | TokenKind::Newline | TokenKind::SemiColon | TokenKind::RightParen | TokenKind::RightBrace @@ -100,7 +98,6 @@ impl Display for TokenKind { Self::Newline => write!(f, "newline"), Self::SemiColon => write!(f, ";"), Self::Colon => write!(f, ":"), - Self::Comma => write!(f, ","), Self::Identifier(s) => write!(f, "{}", s), Self::Number(s) => write!(f, "{}", s), Self::String(s) => write!(f, "\"{}\"", s), diff --git a/core/src/parser/ast.rs b/core/src/parser/ast.rs index 7d84dae..7dbc49b 100644 --- a/core/src/parser/ast.rs +++ b/core/src/parser/ast.rs @@ -1,52 +1,121 @@ +use std::fmt::Debug; + use crate::{ interpreter::value::{RecordKey, Value}, util::error::LineInfo, }; -use super::op::OperatorInfo; +use super::{op::OpInfo, pattern::BindPattern}; #[derive(Debug, Clone)] pub struct ParamAst { - pub name: String, pub ty: Option, - pub info: LineInfo, + pub pattern: BindPattern, } impl PartialEq for ParamAst { fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.ty == other.ty + self.pattern == other.pattern && self.ty == other.ty } } -#[derive(Debug, Clone)] +#[derive(Clone)] pub enum TypeAst { Identifier { - /// The name of the type. name: String, - /// The line information for the type. info: LineInfo, }, Constructor { expr: Box, - arg: Box, + params: Vec, + info: LineInfo, + }, + Record { + fields: Vec<(RecordKey, TypeAst)>, info: LineInfo, }, } -impl TypeAst { - pub fn print_sexpr(&self) -> String { +impl Debug for TypeAst { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - TypeAst::Identifier { name, .. } => name.clone(), - TypeAst::Constructor { expr, arg, .. } => { - format!("({} {})", expr.print_sexpr(), arg.print_sexpr()) + Self::Identifier { name, .. } => { + f.debug_struct("Identifier").field("name", name).finish() + } + Self::Constructor { expr, params, .. } => f + .debug_struct("Constructor") + .field("expr", expr) + .field("params", params) + .finish(), + Self::Record { fields, .. } => { + f.debug_struct("Record").field("fields", fields).finish() } } } +} +impl TypeAst { pub fn info(&self) -> &LineInfo { match self { TypeAst::Identifier { info, .. } => info, TypeAst::Constructor { info, .. } => info, + TypeAst::Record { info, .. } => info, + } + } + + pub fn print_expr(&self) -> String { + match self { + TypeAst::Identifier { name, .. } => name.clone(), + TypeAst::Constructor { + expr, params: args, .. + } => { + format!( + "{}({})", + expr.print_expr(), + args.iter() + .map(|a| a.print_expr()) + .collect::>() + .join(", ") + ) + } + TypeAst::Record { fields, .. } => { + format!( + "{{ {} }}", + fields + .iter() + .map(|(k, v)| format!("{}: {}", k, v.print_expr())) + .collect::>() + .join(", ") + ) + } + } + } + + pub fn pretty_print(&self) -> String { + match self { + TypeAst::Identifier { name, .. } => name.clone(), + TypeAst::Constructor { + expr, params: args, .. + } => { + format!( + "{}({})", + expr.pretty_print(), + args.iter() + .map(|a| a.pretty_print()) + .collect::>() + .join(", ") + ) + } + TypeAst::Record { fields, .. } => { + format!( + "{{ {} }}", + fields + .iter() + .map(|(k, v)| format!("{}: {}", k, v.pretty_print())) + .collect::>() + .join(", ") + ) + } } } } @@ -57,23 +126,23 @@ impl PartialEq for TypeAst { (Self::Identifier { name: l0, .. }, Self::Identifier { name: r0, .. }) => l0 == r0, ( Self::Constructor { - expr: l_expr, - arg: l_arg, + expr: l0, + params: l1, info: _, }, Self::Constructor { - expr: r_expr, - arg: r_arg, + expr: r0, + params: r1, info: _, }, - ) => l_expr == r_expr && l_arg == r_arg, + ) => l0 == r0 && l1 == r1, _ => false, } } } /// **Expressions** in the program source code. -#[derive(Debug, Clone)] +#[derive(Clone)] pub enum Ast { /// A literal is a constant value that is directly represented in the source code. Literal { value: Value, info: LineInfo }, @@ -91,8 +160,8 @@ pub enum Ast { fields: Vec<(RecordKey, Ast)>, info: LineInfo, }, - /// A field access expression is a reference to a field in a record. - FieldAccess { + /// A member field access expression is a reference to a field in a record. + MemderAccess { /// The record expression to access the field from. expr: Box, /// The field key to access. @@ -106,13 +175,13 @@ pub enum Ast { /// Any type annotation for the target expression. annotation: Option, /// The target expression to assign to. - target: Box, + target: BindPattern, /// The source expression to assign to the target. expr: Box, info: LineInfo, }, - /// A function definition is a named function with a list of parameters and a body expression. - FunctionDef { + /// A lambda expression is an anonymous function that can be passed as a value. + Lambda { param: ParamAst, body: Box, return_type: Option, @@ -124,22 +193,16 @@ pub enum Ast { arg: Box, info: LineInfo, }, - /// An accumulate expression is an operation with multiple operands. - Accumulate { - op_info: OperatorInfo, - exprs: Vec, - info: LineInfo, - }, /// A binary expression is an operation with two operands. Binary { lhs: Box, - op_info: OperatorInfo, + op_info: OpInfo, rhs: Box, info: LineInfo, }, /// A unary expression is an operation with one operand. Unary { - op_info: OperatorInfo, + op_info: OpInfo, expr: Box, info: LineInfo, }, @@ -147,19 +210,90 @@ pub enum Ast { Block { exprs: Vec, info: LineInfo }, } +impl Debug for Ast { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Literal { value, .. } => f.debug_struct("Literal").field("value", value).finish(), + Self::LiteralType { expr } => { + f.debug_struct("LiteralType").field("expr", expr).finish() + } + Self::Tuple { exprs, .. } => f.debug_struct("Tuple").field("exprs", exprs).finish(), + Self::List { exprs, .. } => f.debug_struct("List").field("exprs", exprs).finish(), + Self::Record { fields, .. } => { + f.debug_struct("Record").field("fields", fields).finish() + } + Self::MemderAccess { expr, field, .. } => f + .debug_struct("MemderAccess") + .field("expr", expr) + .field("field", field) + .finish(), + Self::Identifier { name, .. } => { + f.debug_struct("Identifier").field("name", name).finish() + } + Self::Assignment { + annotation, + target, + expr, + .. + } => f + .debug_struct("Assignment") + .field("annotation", annotation) + .field("target", target) + .field("expr", expr) + .finish(), + Self::Lambda { + param, + body, + return_type, + .. + } => f + .debug_struct("Lambda") + .field("param", param) + .field("body", body) + .field("return_type", return_type) + .finish(), + Self::FunctionCall { expr, arg, .. } => f + .debug_struct("FunctionCall") + .field("expr", expr) + .field("arg", arg) + .finish(), + Self::Binary { + lhs, op_info, rhs, .. + } => f + .debug_struct("Binary") + .field("lhs", lhs) + .field("op_info", op_info) + .field("rhs", rhs) + .finish(), + Self::Unary { op_info, expr, .. } => f + .debug_struct("Unary") + .field("op_info", op_info) + .field("expr", expr) + .finish(), + Self::Block { exprs, .. } => f.debug_struct("Block").field("exprs", exprs).finish(), + } + } +} + impl Ast { + pub fn unit(info: LineInfo) -> Self { + Ast::Tuple { + exprs: vec![], + info, + } + } + pub fn info(&self) -> &LineInfo { match self { Ast::Literal { info, .. } => info, Ast::Tuple { info, .. } => info, Ast::List { info, .. } => info, Ast::Record { info, .. } => info, - Ast::FieldAccess { info, .. } => info, + Ast::MemderAccess { info, .. } => info, Ast::LiteralType { expr, .. } => expr.info(), Ast::Identifier { info, .. } => info, Ast::FunctionCall { info, .. } => info, - Ast::FunctionDef { info, .. } => info, - Ast::Accumulate { info, .. } => info, + Ast::Lambda { info, .. } => info, Ast::Binary { info, .. } => info, Ast::Unary { info, .. } => info, Ast::Assignment { info, .. } => info, @@ -174,17 +308,17 @@ impl Ast { } } - pub fn print_sexpr(&self) -> String { + pub fn print_expr(&self) -> String { match self { Ast::Literal { value, .. } => value.pretty_print(), - Ast::LiteralType { expr, .. } => expr.print_sexpr(), + Ast::LiteralType { expr, .. } => expr.print_expr(), Ast::Tuple { exprs: elements, .. } => format!( "({})", elements .iter() - .map(|e| e.print_sexpr()) + .map(|e| e.print_expr()) .collect::>() .join(", ") ), @@ -194,7 +328,7 @@ impl Ast { "[{}]", elements .iter() - .map(|e| e.print_sexpr()) + .map(|e| e.print_expr()) .collect::>() .join(", ") ), @@ -202,60 +336,46 @@ impl Ast { "{{ {} }}", fields .iter() - .map(|(k, v)| format!("{}: {}", k, v.print_sexpr())) + .map(|(k, v)| format!("{}: {}", k, v.print_expr())) .collect::>() .join(", ") ), - Ast::FieldAccess { expr, field, .. } => format!("({}.{})", expr.print_sexpr(), field), + Ast::MemderAccess { expr, field, .. } => format!("({}.{})", expr.print_expr(), field), Ast::Identifier { name, .. } => name.clone(), Ast::FunctionCall { expr, arg, info: _ } => { - format!("{}({})", expr.print_sexpr(), arg.print_sexpr()) + format!("({} {})", expr.print_expr(), arg.print_expr()) } - Ast::FunctionDef { + Ast::Lambda { param: params, body, .. } => { if let Some(ty) = ¶ms.ty { format!( - "({} {} -> {})", - params.name, - ty.print_sexpr(), - body.print_sexpr() + "({} {} => {})", + ty.print_expr(), + params.pattern.print_expr(), + body.print_expr() ) } else { - format!("(unknown {} -> {})", params.name, body.print_sexpr()) + format!("({} => {})", params.pattern.print_expr(), body.print_expr()) } } - Ast::Accumulate { - op_info: op, - exprs: operands, - .. - } => { - format!( - "({} {})", - op.symbol.clone(), - operands - .iter() - .map(|e| e.print_sexpr()) - .collect::>() - .join(", ") - ) - } + Ast::Binary { lhs, op_info, rhs, .. } => format!( "({} {} {})", + lhs.print_expr(), op_info.symbol.clone(), - lhs.print_sexpr(), - rhs.print_sexpr() + rhs.print_expr() ), Ast::Unary { op_info: op, expr: operand, .. } => { - format!("({} {})", op.symbol.clone(), operand.print_sexpr()) + format!("({} {})", op.symbol.clone(), operand.print_expr()) } Ast::Assignment { annotation: ty, @@ -265,22 +385,22 @@ impl Ast { } => { if let Some(ty) = ty { format!( - "(= {} {} {})", - ty.print_sexpr(), - lhs.print_sexpr(), - rhs.print_sexpr() + "({} {} = {})", + ty.print_expr(), + lhs.print_expr(), + rhs.print_expr() ) } else { - format!("(= {} {})", lhs.print_sexpr(), rhs.print_sexpr()) + format!("({} = {})", lhs.print_expr(), rhs.print_expr()) } } Ast::Block { exprs, .. } => format!( - "({})", + "{{ {} }}", exprs .iter() - .map(|e| e.print_sexpr()) + .map(|e| e.print_expr()) .collect::>() - .join(" ") + .join("; ") ), } } @@ -305,31 +425,19 @@ impl PartialEq for Ast { }, ) => l0 == r0 && l1 == r1, ( - Self::FunctionDef { + Self::Lambda { param: l_param, body: l_body, return_type: l_return_type, .. }, - Self::FunctionDef { + Self::Lambda { param: r_param, body: r_body, return_type: r_return_type, .. }, ) => l_param == r_param && l_body == r_body && l_return_type == r_return_type, - ( - Self::Accumulate { - op_info: l0, - exprs: l1, - .. - }, - Self::Accumulate { - op_info: r0, - exprs: r1, - .. - }, - ) => l0 == r0 && l1 == r1, ( Self::Binary { rhs: rhs1, diff --git a/core/src/parser/error.rs b/core/src/parser/error.rs index bdc04fa..a2bf317 100644 --- a/core/src/parser/error.rs +++ b/core/src/parser/error.rs @@ -38,12 +38,8 @@ impl BaseErrorExt for ParseError { /// Errors for when an operator is inserted into the parser /// operator table. #[derive(Debug, PartialEq)] -pub enum ParseOperatorError { - PositionForSymbolExists, - /// Any operator cannot override an existing static operator. - CannotOverrideStaticOperator, - /// When adding a static operator, no other operator with the same symbol can exist. - NonStaticOperatorExists, +pub enum ParserOpError { + AlreadyExists, } /// Errors for when a type is inserted into the parser diff --git a/core/src/parser/mod.rs b/core/src/parser/mod.rs index 0620f23..c151952 100644 --- a/core/src/parser/mod.rs +++ b/core/src/parser/mod.rs @@ -1,5 +1,6 @@ pub mod ast; -pub mod parser; pub mod error; pub mod op; +pub mod parser; +pub mod pattern; mod tests; diff --git a/core/src/parser/op.rs b/core/src/parser/op.rs index 175094f..c804915 100644 --- a/core/src/parser/op.rs +++ b/core/src/parser/op.rs @@ -12,73 +12,64 @@ use super::parser::ParseResult; // Execution Agnostic Data // //--------------------------------------------------------------------------------------// +/// The position of the operator in the expression. +/// - Prefix: `-x` +/// - Infix: `x + y` +/// - Postfix: `x!` #[derive(Clone, Debug, PartialEq, Hash, Eq)] -pub enum OperatorPosition { +pub enum OpPos { Prefix, // Unary operator Infix, // Binary operator Postfix, // Unary operator - /// If the binary operator should accumulate the arguments - /// when there are more than two arguments in the expression - /// - /// ## Example - /// The comma operator (`,`) accumulates the arguments into a tuple - /// ```ignore - /// a, b, c // Accumulate into a tuple - /// ``` - /// Gives a tuple `(a, b, c)`, not `(a, (b, c))` - /// - /// The assignment operator (`=`) does not accumulate the arguments - /// ```ignore - /// a = b // Does not accumulate - /// ``` - /// Gives a single assignment `a = b` - InfixAccumulate, } -impl OperatorPosition { +impl OpPos { pub fn is_prefix(&self) -> bool { - matches!(self, OperatorPosition::Prefix) + matches!(self, OpPos::Prefix) } pub fn is_infix(&self) -> bool { - matches!( - self, - OperatorPosition::Infix | OperatorPosition::InfixAccumulate - ) + matches!(self, OpPos::Infix) } pub fn is_postfix(&self) -> bool { - matches!(self, OperatorPosition::Postfix) - } - - pub fn is_accumulate(&self) -> bool { - matches!(self, OperatorPosition::InfixAccumulate) + matches!(self, OpPos::Postfix) } } +/// Associativity of the operator. #[derive(Clone, Debug, PartialEq, Hash, Eq)] -pub enum OperatorAssociativity { +pub enum OpAssoc { Left, Right, } -pub type OperatorPrecedence = u16; - -pub mod default_operator_precedence { - use super::OperatorPrecedence; - - pub const ASSIGNMENT: OperatorPrecedence = 100; - pub const CONDITIONAL: OperatorPrecedence = 200; - pub const LOGICAL_OR: OperatorPrecedence = 300; - pub const LOGICAL_AND: OperatorPrecedence = 400; - pub const EQUALITY: OperatorPrecedence = 500; - pub const TUPLE: OperatorPrecedence = 600; - pub const ADDITIVE: OperatorPrecedence = 700; - pub const MULTIPLICATIVE: OperatorPrecedence = 800; - pub const EXPONENTIAL: OperatorPrecedence = 900; - pub const PREFIX: OperatorPrecedence = 1000; - pub const POSTFIX: OperatorPrecedence = 1100; - pub const MEMBER_ACCESS: OperatorPrecedence = 1200; +/// The precedence of the operator. +pub type OpPrec = u16; + +/// Default precedence for operators used to define the order of operations. +/// Higher precedence operators are evaluated first. +pub mod prec { + use super::OpPrec; + + pub const SEMICOLON_PREC: OpPrec = 200; + pub const ASSIGNMENT_PREC: OpPrec = 300; + pub const COMMA_PREC: OpPrec = 400; + pub const CONDITIONAL_PREC: OpPrec = 500; + pub const LOGICAL_OR_PREC: OpPrec = 600; + pub const LOGICAL_AND_PREC: OpPrec = 700; + pub const EQUALITY_PREC: OpPrec = 800; + pub const TUPLE_PREC: OpPrec = 900; + pub const ADDITIVE_PREC: OpPrec = 1000; + pub const MULTIPLICATIVE_PREC: OpPrec = 1100; + pub const EXPONENTIAL_PREC: OpPrec = 1200; + pub const PREFIX_PREC: OpPrec = 1300; + pub const POSTFIX_PREC: OpPrec = 1400; + pub const MEMBER_ACCESS_PREC: OpPrec = 1500; + + /// Function application precedence. + /// Stronger than any other default operator. + pub const FUNCTION_APP_PREC: OpPrec = 2000; } //--------------------------------------------------------------------------------------// @@ -86,20 +77,19 @@ pub mod default_operator_precedence { //--------------------------------------------------------------------------------------// #[derive(Clone, Debug)] -pub enum StaticOperatorAst { +pub enum StaticOpAst { Prefix(Ast), Infix(Ast, Ast), Postfix(Ast), - Accumulate(Vec), } #[derive(Clone, Debug)] -pub struct OperatorSignature { +pub struct OpSignature { pub params: Vec, pub ret: Type, } -impl OperatorSignature { +impl OpSignature { pub fn new(params: Vec, ret: Type) -> Self { Self { params, ret } } @@ -140,21 +130,15 @@ impl OperatorSignature { //--------------------------------------------------------------------------------------// #[derive(Clone, Debug, PartialEq, Hash, Eq)] -pub struct OperatorInfo { - /// Descriptive name of the operator. - /// Used for: - /// - handler function definition - /// - error messages - /// - introspection. - pub name: String, +pub struct OpInfo { /// The symbol of the operator pub symbol: String, /// The position of the operator - pub position: OperatorPosition, + pub position: OpPos, /// The precedence of the operator - pub precedence: OperatorPrecedence, + pub precedence: OpPrec, /// The associativity of the operator - pub associativity: OperatorAssociativity, + pub associativity: OpAssoc, /// If the operator allows trailing arguments /// /// ## Note @@ -172,71 +156,63 @@ pub struct OperatorInfo { } #[derive(Clone, Debug)] -pub struct RuntimeOperatorHandler { +pub struct RuntimeOpHandler { pub function_name: String, - pub signature: OperatorSignature, + pub signature: OpSignature, } -pub type ParseOperatorHandler = fn(StaticOperatorAst) -> Ast; - #[derive(Clone, Debug)] -pub struct StaticOperatorHandler { - pub signature: OperatorSignature, - pub handler: fn(StaticOperatorAst) -> ParseResult, +pub struct StaticOpHandler { + pub signature: OpSignature, + pub handler: fn(StaticOpAst) -> ParseResult, } #[derive(Clone, Debug)] -pub enum OperatorHandler { +pub enum OpHandler { /// Runtime operators (functions) - Runtime(RuntimeOperatorHandler), - /// Parse-time operators (macros) - Parse(ParseOperatorHandler), + Runtime(RuntimeOpHandler), /// The compile-time handler for the operator /// (macros or syntax extensions/sugar) /// 1. The signature of the operator. This is used for type checking and inference on the operator in expressions. /// 2. The native handler function for the operator called at compile-time - Static(StaticOperatorHandler), + Static(StaticOpHandler), } #[derive(Clone, Debug)] pub struct Operator { /// Basic information about the operator /// required for parsing and type checking. - pub info: OperatorInfo, + pub info: OpInfo, /// The handler for the operator - pub handler: OperatorHandler, + pub handler: OpHandler, } impl Operator { - pub fn signature(&self) -> OperatorSignature { + pub fn signature(&self) -> OpSignature { match &self.handler { - OperatorHandler::Runtime(RuntimeOperatorHandler { signature, .. }) => signature.clone(), - OperatorHandler::Parse(_) => { - panic!("Parse operator does not have a signature") - } - OperatorHandler::Static(StaticOperatorHandler { signature, .. }) => signature.clone(), + OpHandler::Runtime(RuntimeOpHandler { signature, .. }) => signature.clone(), + OpHandler::Static(StaticOpHandler { signature, .. }) => signature.clone(), } } pub fn new_runtime( function_name: String, symbol: String, - position: OperatorPosition, - precedence: OperatorPrecedence, - associativity: OperatorAssociativity, + position: OpPos, + precedence: OpPrec, + associativity: OpAssoc, allow_trailing: bool, - signature: OperatorSignature, + signature: OpSignature, ) -> Self { Self { - info: OperatorInfo { - name: function_name.clone(), + info: OpInfo { symbol, position, precedence, associativity, allow_trailing, }, - handler: OperatorHandler::Runtime(RuntimeOperatorHandler { + handler: OpHandler::Runtime(RuntimeOpHandler { function_name, signature, }), @@ -245,47 +221,23 @@ impl Operator { #[allow(clippy::too_many_arguments)] pub fn new_static( - name: String, - symbol: String, - position: OperatorPosition, - precedence: OperatorPrecedence, - associativity: OperatorAssociativity, - allow_trailing: bool, - signature: OperatorSignature, - handler: fn(StaticOperatorAst) -> ParseResult, - ) -> Self { - Self { - info: OperatorInfo { - name, - symbol, - position, - precedence, - associativity, - allow_trailing, - }, - handler: OperatorHandler::Static(StaticOperatorHandler { signature, handler }), - } - } - - pub fn new_parse( - name: String, symbol: String, - position: OperatorPosition, - precedence: OperatorPrecedence, - associativity: OperatorAssociativity, + position: OpPos, + precedence: OpPrec, + associativity: OpAssoc, allow_trailing: bool, - handler: ParseOperatorHandler, + signature: OpSignature, + handler: fn(StaticOpAst) -> ParseResult, ) -> Self { Self { - info: OperatorInfo { - name, + info: OpInfo { symbol, position, precedence, associativity, allow_trailing, }, - handler: OperatorHandler::Parse(handler), + handler: OpHandler::Static(StaticOpHandler { signature, handler }), } } } diff --git a/core/src/parser/parser.rs b/core/src/parser/parser.rs index 87b851a..6b7c47d 100644 --- a/core/src/parser/parser.rs +++ b/core/src/parser/parser.rs @@ -10,11 +10,11 @@ use colorful::Colorful; use crate::{ interpreter::value::{RecordKey, Value}, lexer::{ - lexer::{self, LexResult}, + lexer, readers::{bytes_reader::BytesReader, stdin::StdinReader}, token::{TokenInfo, TokenKind}, }, - parser::ast::{ParamAst, TypeAst}, + parser::op::{prec, ASSIGNMENT_SYM, MEMBER_ACCESS_SYM}, util::{ error::{BaseErrorExt, LineInfo}, failable::Failable, @@ -25,11 +25,12 @@ use crate::lexer::lexer::Lexer; use super::{ ast::Ast, - error::{ParseError, ParseOperatorError}, + error::{ParseError, ParserOpError}, op::{ - Operator, OperatorAssociativity, OperatorHandler, OperatorInfo, OperatorPosition, - OperatorPrecedence, StaticOperatorAst, + prec::{COMMA_PREC, FUNCTION_APP_PREC}, + OpAssoc, OpInfo, OpPos, OpPrec, COMMA_SYM, }, + specialize, }; /// Token predicates for parsing @@ -40,19 +41,80 @@ mod pred { matches!(t, TokenKind::EndOfFile) } - /// Check if the token is an ignored token. + /// Check if the token is an ignore token. /// These include: /// - `Newline` /// - `Comment` - pub fn ignored(t: &TokenKind) -> bool { + pub fn ignore(t: &TokenKind) -> bool { matches!(t, TokenKind::Comment(_) | TokenKind::Newline) } } +//--------------------------------------------------------------------------------------// +// Parser Factory Functions // +//--------------------------------------------------------------------------------------// + +pub fn from_file(file: File) -> Parser> { + Parser::new(lexer::from_file(file)) +} + +pub fn from_string(source: String) -> Parser> { + Parser::new(lexer::from_string(source)) +} + +pub fn from_str(source: &str) -> Parser> { + Parser::new(lexer::from_str(source)) +} + +pub fn from_stdin() -> Parser { + Parser::new(lexer::from_stdin()) +} + +pub fn from_stream(reader: R) -> Parser { + Parser::new(lexer::from_stream(reader)) +} + //--------------------------------------------------------------------------------------// // Parser // //--------------------------------------------------------------------------------------// +pub(super) const COMMA_SYM: &str = ","; +pub(super) const ASSIGNMENT_SYM: &str = "="; +pub(super) const MEMBER_ACCESS_SYM: &str = "."; + +/// Default operators used in the language grammar and required for parsing. \ +/// These operators are defined in the parser and are required to produce valid ASTs. \ +/// The binary operators are replaced with `Ast` nodes by `syntax_sugar::specialize` after parsing a `parse_top_expr` expression. +/// - `semicolon`: `;` - Used to separate statements becomes an `Ast::Block` node. +/// - `comma`: `,` - Used to separate expressions in tuples and lists becomes an `Ast::Tuple` or `Ast::List` node. +/// - `assignment`: `=` - Used to assign values to variables becomes an `Ast::Assignment` node. +/// - `member access`: `.` - Used to access members of records becomes an `Ast::MemberAccess` node. +pub fn intrinsic_operators() -> Vec { + vec![ + OpInfo { + symbol: COMMA_SYM.to_string(), + position: OpPos::Infix, + precedence: prec::COMMA_PREC, + associativity: OpAssoc::Left, + allow_trailing: true, + }, + OpInfo { + symbol: ASSIGNMENT_SYM.to_string(), + position: OpPos::Infix, + precedence: prec::ASSIGNMENT_PREC, + associativity: OpAssoc::Right, + allow_trailing: false, + }, + OpInfo { + symbol: MEMBER_ACCESS_SYM.to_string(), + position: OpPos::Infix, + precedence: prec::MEMBER_ACCESS_PREC, + associativity: OpAssoc::Left, + allow_trailing: false, + }, + ] +} + /// A parse results is a list of AST nodes or a parse error. pub type ParseResults = Result, ParseError>; @@ -73,8 +135,7 @@ where /// - They have different signatures /// - They have different positions /// - The symbol is a built-in operator that is overloadable - operators: HashMap>, - types: HashSet, + operators: HashMap>, } impl Parser { @@ -82,8 +143,8 @@ impl Parser { Self { lexer, operators: HashMap::new(), - types: HashSet::new(), } + .init_default_operators() } pub fn get_lexer(&mut self) -> &mut Lexer { @@ -102,47 +163,41 @@ impl Parser { self.lexer.move_content() } - /// Add a new type to the parser. - pub fn add_type(&mut self, name: String) { - self.types.insert(name); - } - - /// Get a type by its name. - pub fn is_type(&self, name: &str) -> bool { - self.types.contains(name) + /// Initialize the parser with default operators. + fn init_default_operators(mut self) -> Self { + intrinsic_operators().into_iter().for_each(|op| { + self.define_op(op) + .expect("Failed to define default operator") + }); + self } /// Define an operator in the parser. /// If the operator already exists with the same signature, - pub fn define_op(&mut self, op: Operator) -> Failable { - if let Some(existing) = self.get_op(&op.info.symbol) { - if existing.iter().any(|e| e.info.position == op.info.position) { - // TODO: Compare signatures instead of positions - return Err(ParseOperatorError::PositionForSymbolExists); + pub fn define_op(&mut self, op: OpInfo) -> Failable { + if let Some(existing) = self.get_op(&op.symbol) { + if existing.iter().any(|e| e.position == op.position) { + return Err(ParserOpError::AlreadyExists); } } - self.lexer.operators.insert(op.info.symbol.clone()); + self.lexer.add_operator(op.symbol.clone()); self.operators - .entry(op.info.symbol.clone()) + .entry(op.symbol.clone()) .or_default() .push(op); Ok(()) } - pub fn get_op(&self, symbol: &str) -> Option<&Vec> { + pub fn get_op(&self, symbol: &str) -> Option<&Vec> { self.operators.get(symbol) } - pub fn find_operator( - &self, - symbol: &str, - pred: impl Fn(&OperatorInfo) -> bool, - ) -> Option<&Operator> { + pub fn find_operator(&self, symbol: &str, pred: impl Fn(&OpInfo) -> bool) -> Option<&OpInfo> { self.get_op(symbol) - .and_then(|ops| ops.iter().find(|op| pred(&op.info))) + .and_then(|ops| ops.iter().find(|op| pred(op))) } - pub fn find_operator_pos(&self, symbol: &str, pos: OperatorPosition) -> Option<&Operator> { + pub fn find_operator_pos(&self, symbol: &str, pos: OpPos) -> Option<&OpInfo> { self.find_operator(symbol, |op| op.position == pos) } @@ -164,45 +219,41 @@ impl Parser { Ok(ast) } - /// Parse a global AST from the stream of tokens. - /// A global AST is a list of **all** top-level AST nodes (expressions). - pub fn parse_all(&mut self) -> ParseResults { - let mut asts = Vec::new(); - loop { - if let Ok(t) = self.lexer.peek_token_not(pred::ignored, 0) { - if pred::eof(&t.token) { - break; - } - } - match self.parse_top_expr() { - Ok(expr) => asts.push(expr), - Err(e) => return Err(e), - } + fn parse_expected( + &mut self, + condition: impl FnOnce(&TokenKind) -> bool, + symbol: &'static str, + ) -> Result { + match self.lexer.expect_next_token_not(pred::ignore) { + Ok(t) if condition(&t.token) => Ok(t), + Ok(t) => Err(ParseError::new( + format!( + "Expected {} but found {}", + symbol.yellow(), + t.token.to_string().light_red() + ), + t.info.clone(), + ) + .with_label(format!("This should be a {}", symbol), t.info)), + Err(err) => Err(ParseError::new( + format!("Expected {}", symbol.yellow(),), + err.info().clone(), + ) + .with_label(err.message().to_owned(), err.info().clone())), } - Ok(asts) } - /// Parse **a single** expression from the stream of tokens. - /// Returns an AST node or an error. - /// If the first token is an EOF, then the parser will return an empty unit expression. - /// - /// # Note - /// The parser will not necessarily consume all tokens from the stream. - /// It will **ONLY** consume a whole complete expression. - /// There may be remaining tokens in the stream after the expression is parsed. - pub fn parse_one(&mut self) -> ParseResult { - // Check if the next token is an EOF, then return an empty unit top-level expression - if let Ok(t) = self.lexer.peek_token_not(pred::ignored, 0) { - if pred::eof(&t.token) { - return Ok(Ast::Literal { - value: Value::Unit, - info: t.info, - }); - } - } - self.parse_top_expr() + fn parse_expected_eq( + &mut self, + expected_token: TokenKind, + symbol: &'static str, + ) -> Result { + self.parse_expected(|t| t == &expected_token, symbol) } + // ======================================== EXPRESSION PARSING ======================================== // + + /// Parse a literal `Value` from the lexer. fn parse_literal(&mut self, token: &TokenKind, info: LineInfo) -> ParseResult { Ok(Ast::Literal { value: match token { @@ -225,278 +276,62 @@ impl Parser { }) } - /// Parse a parenthesized function call. - /// This function is called when the parser encounters an identifier followed by a left parenthesis. - /// The parser will then attempt to parse the function call. - /// ```lento - /// func(a, b, c) - /// ``` - /// If the next token is an assignment operator `=`, then the parser will attempt to parse a function definition. - /// ```lento - /// func(a, b, c) = expr + /// Parse a tuple from the lexer. + /// + /// ## Examples + /// ```ignore + /// () + /// (1, 2) + /// (1, 2, 3) /// ``` - fn parse_paren_call(&mut self, id: String, info: LineInfo) -> ParseResult { - log::trace!("Parsing parenthesized function call: {}", id); - let mut args = Vec::new(); - while let Ok(end) = self.lexer.peek_token(0) { - if end.token == TokenKind::RightParen { - break; - } - args.push(self.parse_top_expr()?); - if let Ok(nt) = self.lexer.peek_token(0) { - if nt.token == TokenKind::Comma { - self.lexer.next_token().unwrap(); - continue; - } else if nt.token == TokenKind::RightParen { - break; - } else if let TokenKind::Identifier(param_name) = nt.token { - // Found (..., ty id) in an argument list. - // Check if `ty` is an identifier - if args.iter().all(|arg| matches!(arg, Ast::Identifier { .. })) { - let Ast::Identifier { - name: param_type, - info: param_info, - } = args.pop().unwrap() - else { - unreachable!("All arguments should be identifiers"); - }; - // Found a type identifier in a function call. - // Try to parse the rest as a function definition if all args are identifiers. - self.lexer.next_token().unwrap(); // Consume the `param_name` identifier - // Remove the last argument - // Expect the next token to be a comma or right paren - if let Ok(nt) = self.lexer.peek_token(0) { - if nt.token == TokenKind::Comma { - self.lexer.next_token().unwrap(); - } else if nt.token == TokenKind::RightParen { - // Continue parsing the function definition - // The ) will be consumed by the `parse_func_def` function - } else { - return Err(ParseError::new( - format!( - "Expected {} or {}, but found {}", - ",".yellow(), - ")".yellow(), - nt.token.to_string().light_red() - ), - info, - ) - .with_label( - "This should be either a comma or a right parenthesis" - .to_string(), - nt.info, - )); - } - } - return self.parse_func_def( - None, - id, - info, - args.iter() - .map(|arg| match arg { - Ast::Identifier { name: id, info } => ParamAst { - name: id.clone(), - ty: None, - info: info.clone(), - }, - _ => unreachable!(), - }) - // Also add the current parameter - .chain([ParamAst { - name: param_name, - ty: Some(TypeAst::Identifier { - name: param_type, - info: param_info.clone(), - }), - info: param_info, - }]) - .collect(), - ); - } - } - } - if let Ok(nt) = self.lexer.peek_token(0) { - return Err(ParseError::new( - format!( - "Expected {} or {}, but found {}", - ",".yellow(), - ")".yellow(), - nt.token.to_string().light_red() - ), - info, - ) - .with_label( - "This should be either a comma or a right parenthesis".to_string(), - nt.info, - )); - } else { - return Err(ParseError::new( - "Unexpected end of program".to_string(), - LineInfo::eof(info.start, self.lexer.current_index()), - )); + fn parse_tuple(&mut self) -> ParseResult { + // Check if the next token is a right parenthesis `)`, then return an empty tuple + if let Ok(t) = self.lexer.peek_token(0) { + if t.token == TokenKind::RightParen { + self.lexer.next_token().unwrap(); + return Ok(Ast::unit(t.info)); } } - self.parse_expected(TokenKind::RightParen, ")")?; - - if let Some(func_def) = self.try_parse_func_def_no_types(&id, &info, &args) { - return func_def; - } - - log::trace!( - "Parsed function call: {}({})", - id, - args.iter() - .map(Ast::print_sexpr) - .collect::>() - .join(", ") - ); - Ok(syntax_sugar::roll_function_call(id, args, info)) - // Ok(Ast::Call(id, args)) + log::trace!("Parsing elements..."); + let tuple = self.parse_top_expr()?; + log::trace!("Parsed tuple elements: {:?}", tuple); + self.parse_expected_eq(TokenKind::RightParen, ")")?; + Ok(tuple) } - /// If we encountered an expression like: - /// ```lento - /// func(a, b, c) - /// ``` - /// Check if the subsequent token is an assignment operator `=`. - /// If it is, then we have a function definition of the form: - /// ```lento - /// func(a, b, c) = expr - /// ``` - fn try_parse_func_def_no_types( - &mut self, - func_name: &str, - info: &LineInfo, - args: &[Ast], - ) -> Option { - if let Ok(ref nt) = self.lexer.peek_token(0) { - if let TokenKind::Op(ref op) = nt.token { - if op == "=" { - // If all arguments are identifiers, then this is a function definition - if args.iter().all(|arg| matches!(arg, Ast::Identifier { .. })) { - log::trace!( - "Parsed function definition: {}({:?}) -> {:?}", - func_name, - args, - nt - ); - self.lexer.next_token().unwrap(); - let body = match self.parse_top_expr() { - Ok(body) => body, - Err(err) => { - log::warn!("Failed to parse function body: {}", err.message()); - return Some(Err(err)); - } - }; - let params = args - .iter() - .map(|arg| match arg { - Ast::Identifier { name: id, info } => ParamAst { - name: id.clone(), - ty: None, - info: info.clone(), - }, - _ => unreachable!(), - }) - .collect::>(); - // Roll functions into single param definitions - let function = syntax_sugar::roll_function_definition(params, body); - let assign_info = info.join(function.info()); - return Some(Ok(Ast::Assignment { - annotation: None, - target: Box::new(Ast::Identifier { - name: func_name.to_string(), - info: info.clone(), - }), - expr: Box::new(function), - info: assign_info, - })); - } + fn parse_record_or_block(&mut self, start_info: LineInfo) -> ParseResult { + // Try to parse as record + if let Some(res) = self.parse_record_fields(&start_info) { + res + } else { + // Parse as block + let mut exprs = Vec::new(); + while let Ok(end) = self.lexer.peek_token(0) { + if end.token == TokenKind::RightBrace { + break; } + exprs.push(self.parse_top_expr()?); } + let last = self.parse_expected_eq(TokenKind::RightBrace, "}")?; + Ok(Ast::Block { + exprs, + info: start_info.join(&last.info), + }) } - None } - /// If we encountered an expression like: - /// ```lento - /// int func( - /// ``` - /// or - /// ```lento - /// func(int a - /// ``` - /// Then we have a function definition of the form: - /// ```lento - /// int func(...) = expr - /// ``` - /// or - /// ```lento - /// func(int a, ...) = expr - /// ``` - fn parse_func_def( - &mut self, - ret_type: Option, - name: String, - info: LineInfo, - parsed_params: Vec, - ) -> ParseResult { - log::trace!("Parsing function definition: {} -> {:?}", name, ret_type); - let mut params = parsed_params; + fn parse_list(&mut self, start_info: LineInfo) -> ParseResult { + let mut exprs = Vec::new(); while let Ok(end) = self.lexer.peek_token(0) { - if end.token == TokenKind::RightParen { + if end.token == TokenKind::RightBracket { break; } - let nt = self - .lexer - .expect_next_token_not(pred::ignored) - .map_err(|err| { - ParseError::new("Expected type expression".to_string(), err.info().clone()) - })?; - let ty = match self.try_parse_type(&nt) { - Some(t) => t, - None => { - return Err(ParseError::new( - format!( - "Expected type expression, but found {}", - end.token.to_string().light_red() - ), - info, - )); - } - }; - let (param_name, param_info) = match self.lexer.next_token() { - Ok(t) => match t.token { - TokenKind::Identifier(id) => (id, t.info), - _ => { - return Err(ParseError::new( - format!( - "Expected parameter name, but found {}", - t.token.to_string().light_red() - ), - info, - ) - .with_label("This is not a valid identifier".to_string(), t.info)); - } - }, - Err(err) => { - return Err(ParseError::new( - "Failed to parse parameter name".to_string(), - LineInfo::eof(info.end, self.lexer.current_index()), - ) - .with_label(err.message().to_owned(), err.info().clone())); - } - }; - params.push(ParamAst { - name: param_name, - ty: Some(ty), - info: param_info, - }); + exprs.push(self.parse_expr(COMMA_PREC)?); if let Ok(nt) = self.lexer.peek_token(0) { - if nt.token == TokenKind::Comma { + if nt.token == TokenKind::Op(COMMA_SYM.to_string()) { self.lexer.next_token().unwrap(); continue; - } else if nt.token == TokenKind::RightParen { + } else if nt.token == TokenKind::RightBracket { break; } } @@ -505,41 +340,26 @@ impl Parser { format!( "Expected {} or {}, but found {}", ",".yellow(), - ")".yellow(), + "]".yellow(), nt.token.to_string().light_red() ), - info, + nt.info.clone(), ) .with_label( - "This should be either a comma or a right parenthesis".to_string(), + "This should be either a comma or a right bracket".to_string(), nt.info, )); } else { return Err(ParseError::new( "Unexpected end of program".to_string(), - LineInfo::eof(info.end, self.lexer.current_index()), + LineInfo::eof(start_info.end, self.lexer.current_index()), )); } } - self.parse_expected(TokenKind::RightParen, ")")?; - self.parse_expected(TokenKind::Op("=".into()), "=")?; - let body = self.parse_top_expr()?; - log::trace!( - "Parsed function definition: {}({:?}) -> {:?}", - name, - params, - body - ); - let function = syntax_sugar::roll_function_definition(params, body); - let assign_info = info.join(function.info()); - Ok(Ast::Assignment { - annotation: ret_type, - target: Box::new(Ast::Identifier { - name, - info: info.clone(), - }), - expr: Box::new(function), - info: assign_info, + let last = self.parse_expected_eq(TokenKind::RightBracket, "]")?; + Ok(Ast::List { + exprs, + info: start_info.join(&last.info), }) } @@ -559,7 +379,7 @@ impl Parser { /// the expected tokens. /// /// # Examples - /// ``` + /// ```d /// let mut parser = Parser::new(lexer); /// if let Some(result) = parser.parse_record_fields() { /// match result { @@ -571,7 +391,8 @@ impl Parser { /// } /// ``` #[allow(clippy::type_complexity)] - fn parse_record_fields(&mut self, first_info: &LineInfo) -> Option> { + fn parse_record_fields(&mut self, start_info: &LineInfo) -> Option> { + let mut last_info = LineInfo::default(); let mut fields = Vec::new(); // Initial soft parse to check if the record is empty // Or if it is a block @@ -585,24 +406,31 @@ impl Parser { })); // Empty record } TokenKind::Identifier(id) => RecordKey::String(id), - TokenKind::Number(n) => RecordKey::Number(n), + // TokenKind::Number(n) => RecordKey::Number(n), TokenKind::String(s) => RecordKey::String(s), TokenKind::Char(c) => RecordKey::String(c.to_string()), _ => return None, // Not a record }; - if let Ok(t) = self.lexer.peek_token(1) { - if t.token != TokenKind::Colon { - return None; // Not a record - } + let Ok(t) = self.lexer.peek_token(1) else { + return None; + }; + if t.token != TokenKind::Colon { + return None; // Not a record } + // If we found both a valid key and a colon, we found a record! - self.lexer.next_token().unwrap(); - self.parse_expected(TokenKind::Colon, ":").ok()?; - let value = self.parse_top_expr().ok()?; + self.lexer.next_token().unwrap(); // Consume the key + self.lexer.next_token().unwrap(); // Consume the colon + let value = match self.parse_expr(COMMA_PREC) { + Ok(value) => value, + Err(err) => return Some(Err(err)), + }; fields.push((key, value)); if let Ok(t) = self.lexer.next_token() { match t.token { - TokenKind::Comma => (), // Continue parsing + TokenKind::Op(op) if op == COMMA_SYM => { + last_info = t.info; + } TokenKind::RightBrace => { return Some(Ok(Ast::Record { fields, @@ -617,7 +445,7 @@ impl Parser { "}".yellow(), t.token.to_string().light_red() ), - first_info.join(&t.info), + start_info.join(&t.info), ) .with_label( "This should be either a comma or a right brace".to_string(), @@ -627,7 +455,6 @@ impl Parser { } } } - let mut last_info = LineInfo::default(); // Parse the rest of the fields more strictly while let Ok(t) = self.lexer.next_token() { if t.token == TokenKind::RightBrace { @@ -636,7 +463,7 @@ impl Parser { } let key = match t.token { TokenKind::Identifier(id) => RecordKey::String(id), - TokenKind::Number(n) => RecordKey::Number(n), + // TokenKind::Number(n) => RecordKey::Number(n), TokenKind::String(s) => RecordKey::String(s), TokenKind::Char(c) => RecordKey::String(c.to_string()), _ => { @@ -645,22 +472,22 @@ impl Parser { "Expected record key, but found {}", t.token.to_string().light_red() ), - first_info.join(&t.info), + start_info.join(&t.info), ) .with_label("This is not a valid record key".to_string(), t.info))); } }; - if let Err(err) = self.parse_expected(TokenKind::Colon, ":") { + if let Err(err) = self.parse_expected_eq(TokenKind::Colon, ":") { return Some(Err(err)); } - let value = match self.parse_top_expr() { + let value = match self.parse_expr(COMMA_PREC) { Ok(value) => value, Err(err) => return Some(Err(err)), }; fields.push((key, value)); if let Ok(t) = self.lexer.next_token() { match t.token { - TokenKind::Comma => continue, + TokenKind::Op(op) if op == COMMA_SYM => continue, TokenKind::RightBrace => { last_info = t.info; break; @@ -673,7 +500,7 @@ impl Parser { "}".yellow(), t.token.to_string().light_red() ), - first_info.join(&t.info), + start_info.join(&t.info), ) .with_label( "This should be either a comma or a right brace".to_string(), @@ -685,106 +512,14 @@ impl Parser { } Some(Ok(Ast::Record { fields, - info: first_info.join(&last_info), + info: start_info.join(&last_info), })) } - fn try_parse_type(&mut self, t: &TokenInfo) -> Option { - match &t.token { - TokenKind::Identifier(id) => { - if self.is_type(id) { - return Some(TypeAst::Identifier { - name: id.clone(), - info: t.info.clone(), - }); - } - } - _ => (), - } - - None - } - // fn parse_assignment( - // &mut self, - // annotation: Option, - // target: Ast, - // info: LineInfo, - // ) -> ParseResult { - // log::trace!("Parsing assignment: {:?}", target); - // self.parse_expected(TokenKind::Op("=".into()), "=")?; - // let expr = self.parse_top_expr()?; - // log::trace!("Parsed assignment: {:?} = {:?}", target, expr); - // Ok(Ast::Assignment { - // annotation, - // target: Box::new(target), - // expr: Box::new(expr), - // info, - // }) - // } - - /// Parses a typed definition, such as a variable assignment with a type annotation. - /// - /// # Parameters - /// - `annotation`: The type annotation for the definition. - /// - `nt`: The current token being processed. - /// - `skip`: The number of tokens to skip when peeking into the lexer to check for an assignment operator. - fn parse_typed_def( - &mut self, - annotation: TypeAst, - nt: &TokenInfo, - mut skip: usize, - ) -> Option { - match &nt.token { - TokenKind::Identifier(name) => { - log::trace!( - "Looking for variable assignment: {} {} = ...", - annotation.print_sexpr(), - name, - ); - // Check if the token after the identifier is an assignment operator - skip += 1; - let Ok(nt) = self.lexer.peek_token_not(pred::ignored, skip) else { - return None; // Not a valid assignment - }; - if nt.token != TokenKind::Op("=".into()) { - log::trace!("Not an assignment: {:?}", nt.token); - return None; // Not a valid assignment - } - // Consume all skipped tokens - for _ in 0..skip { - self.lexer.next_token().unwrap(); - } - - //? From now we are sure that we are parsing an assignment! - //? No more peeking! Only consume and assert tokens! - - // Parse the assignment expression - let expr = match self.parse_top_expr() { - Ok(expr) => expr, - Err(err) => return Some(Err(err)), - }; - - log::trace!("Parsed assignment: {:?} = {:?}", &name, &expr); - - Some(Ok(Ast::Assignment { - info: nt.info.join(expr.info()), - annotation: Some(annotation), - target: Box::new(Ast::Identifier { - name: name.clone(), - info: nt.info.clone(), - }), - expr: Box::new(expr), - })) - } - // TODO: Support for other types of typed definitions (e.g. tuple, lists, records, etc.) - _ => None, // Not a valid assignment - } - } - fn parse_primary(&mut self) -> ParseResult { let t = self .lexer - .expect_next_token_not(pred::ignored) + .expect_next_token_not(pred::ignore) .map_err(|err| { ParseError::new( "Expected primary expression".to_string(), @@ -792,258 +527,73 @@ impl Parser { ) .with_label(err.message().to_owned(), err.info().clone()) })?; - - if let Some(ty) = self.try_parse_type(&t) { - log::trace!("Parsed type: {:?}", ty); - let mut skip = 0; - let nt = self - .lexer - .peek_token_not(pred::ignored, skip) - .unwrap_or(TokenInfo { - token: TokenKind::EndOfFile, - info: LineInfo::default(), - }); - if nt.token.is_terminator() { - return Ok(Ast::LiteralType { expr: ty }); - } - skip += 1; // Skip the next token when parsing next steps... - - // Try parse type construction (`List int` or `List`) - if let Some(app) = self.try_parse_type(&nt) { - // TODO: Support multiple arguments - return Ok(Ast::LiteralType { - expr: TypeAst::Constructor { - expr: Box::new(ty), - arg: Box::new(app), - info: t.info.join(&nt.info), - }, - }); - } - // Try parse typed definition - else if let Some(res) = self.parse_typed_def(ty, &nt, skip) { - return res; - } - // No other expression should include a type identifier - return Err(ParseError::new( - format!( - "Expected variable assignment, function definition or type construction after {}, but found {}", - t.token.to_string().yellow(), - nt.token.to_string().light_red() - ), - nt.info.clone(), - ) - .with_label("This is not a valid type".to_string(), nt.info)); - } - + log::trace!("Parsing primary: {:?}", t.token); match t.token { lit if lit.is_literal() => self.parse_literal(&lit, t.info), - TokenKind::Identifier(id) => { - // Check if function call - if let Ok(t) = self.lexer.peek_token(0) { - if t.token - == (TokenKind::LeftParen { - is_function_call: true, - }) - { - self.lexer.next_token().unwrap(); // Consume the left paren - return self.parse_paren_call(id, t.info); - } - } - // Check if function definition - if let Ok(nt) = self.lexer.peek_token(0) { - if let TokenKind::Identifier(name) = nt.token { - if let Ok(nnt) = self.lexer.peek_token(1) { - if nnt.token - == (TokenKind::LeftParen { - is_function_call: true, - }) - { - self.lexer.next_token().unwrap(); // Consume the name identifier - self.lexer.next_token().unwrap(); // Consume the left paren - // Start parsing the params and body of the function - return self.parse_func_def( - Some(TypeAst::Identifier { - name: id.clone(), - info: t.info.clone(), - }), - name, - nnt.info, - vec![], - ); - } - } - } else if let TokenKind::LeftParen { .. } = nt.token { - self.lexer.next_token().unwrap(); // Consume the left paren - return self.parse_paren_call(id, nt.info); - } - } - Ok(Ast::Identifier { - name: id, - info: t.info, - }) - } + TokenKind::Identifier(id) => Ok(Ast::Identifier { + name: id, + info: t.info, + }), TokenKind::Op(op) => { - // TODO: Don't lookup operators in the parser, do this in the type checker! - if let Some(op) = self.find_operator_pos(&op, OperatorPosition::Prefix) { - let op_info = op.info.clone(); - let rhs = self.parse_primary()?; + if let Some(op) = self.find_operator_pos(&op, OpPos::Prefix) { + log::trace!("Parsing prefix operator: {:?}", op); Ok(Ast::Unary { - op_info, - expr: Box::new(rhs), + op_info: op.clone(), + expr: Box::new(self.parse_term()?), info: t.info, }) } else { - return Err(ParseError::new( + Err(ParseError::new( format!("Expected prefix operator, but found {}", op.light_red()), t.info.clone(), ) - .with_label("This is not a valid prefix operator".to_string(), t.info)); + .with_label("This is not a valid prefix operator".to_string(), t.info)) } } + start if start.is_grouping_start() => { match start { - // Tuples, Units and Parentheses: () TokenKind::LeftParen { is_function_call: false, - } => { - // Tuples are defined by a comma-separated list of expressions - let mut explicit_single = false; - let mut exprs = Vec::new(); - while let Ok(end) = self.lexer.peek_token(0) { - if end.token == TokenKind::RightParen { - break; - } - exprs.push(self.parse_top_expr()?); - if let Ok(nt) = self.lexer.peek_token(0) { - if nt.token == TokenKind::Comma { - self.lexer.next_token().unwrap(); - if self.lexer.peek_token(0).unwrap().token - == TokenKind::RightParen - { - explicit_single = true; - // Break in the next iteration - } - continue; - } else if nt.token == TokenKind::RightParen { - break; - } - } - if let Ok(nt) = self.lexer.peek_token(0) { - return Err(ParseError::new( - format!( - "Expected {} or {}, but found {}", - ",".yellow(), - ")".yellow(), - nt.token.to_string().light_red() - ), - nt.info.clone(), - ) - .with_label( - "This should be either a comma or a right parenthesis" - .to_string(), - nt.info, - )); - } else { - return Err(ParseError::new( - "Unexpected end of program".to_string(), - LineInfo::eof(t.info.end, self.lexer.current_index()), - )); - } - } - let end = self.parse_expected(TokenKind::RightParen, ")")?; - if exprs.len() == 1 && !explicit_single { - exprs.pop().ok_or(ParseError::new( - "Expected a single expression".to_string(), - t.info.clone(), - )) - } else { - Ok(Ast::Tuple { - exprs, - info: t.info.join(&end.info), - }) - } - } - // Records and Blocks: {} - TokenKind::LeftBrace => { - // Try to parse as record - if let Some(res) = self.parse_record_fields(&t.info) { - res - } else { - // Parse as block - let mut exprs = Vec::new(); - while let Ok(end) = self.lexer.peek_token(0) { - if end.token == TokenKind::RightBrace { - break; - } - exprs.push(self.parse_top_expr()?); - } - let last = self.parse_expected(TokenKind::RightBrace, "}")?; - Ok(Ast::Block { - exprs, - info: t.info.join(&last.info), - }) - } - } - // Lists: [] - TokenKind::LeftBracket => { - let mut exprs = Vec::new(); - while let Ok(end) = self.lexer.peek_token(0) { - if end.token == TokenKind::RightBracket { - break; - } - exprs.push(self.parse_top_expr()?); - if let Ok(nt) = self.lexer.peek_token(0) { - if nt.token == TokenKind::Comma { - self.lexer.next_token().unwrap(); - continue; - } else if nt.token == TokenKind::RightBracket { - break; - } - } - if let Ok(nt) = self.lexer.peek_token(0) { - return Err(ParseError::new( - format!( - "Expected {} or {}, but found {}", - ",".yellow(), - "]".yellow(), - nt.token.to_string().light_red() - ), - nt.info.clone(), - ) - .with_label( - "This should be either a comma or a right bracket".to_string(), - nt.info, - )); - } else { - return Err(ParseError::new( - "Unexpected end of program".to_string(), - LineInfo::eof(t.info.end, self.lexer.current_index()), - )); - } - } - let last = self.parse_expected(TokenKind::RightBracket, "]")?; - Ok(Ast::List { - exprs, - info: t.info.join(&last.info), - }) - } + } => self.parse_tuple(), // Tuples, Units and Parentheses: () + TokenKind::LeftBrace => self.parse_record_or_block(t.info), // Records and Blocks: {} + TokenKind::LeftBracket => self.parse_list(t.info), // Lists: [] _ => unreachable!(), } } - _ => { - return Err(ParseError::new( - format!( - "Expected primary expression, but found {}", - t.token.to_string().light_red() - ), - t.info.clone(), - ) - .with_label( - format!("This {} is invalid here", t.token.to_string().yellow()), - t.info, - )); + _ => Err(ParseError::new( + format!( + "Expected primary expression, but found {}", + t.token.to_string().light_red() + ), + t.info.clone(), + ) + .with_label( + format!("The {} is invalid here", t.token.to_string().yellow()), + t.info, + )), + } + } + + fn parse_term(&mut self) -> ParseResult { + let primary = self.parse_primary()?; + // Check if function call with parentheses like `f(5, 6, 7)`, **NOT** `f (5, 6, 7)` + if let Ok(nt) = self.lexer.peek_token(0) { + if matches!( + &nt.token, + TokenKind::LeftParen { + is_function_call: true + } + ) { + self.lexer.next_token().unwrap(); + let args = match self.parse_tuple()? { + Ast::Tuple { exprs, .. } => exprs, + single_expr => vec![single_expr], + }; + return Ok(specialize::roll_function_call(primary, args)); } } + Ok(primary) } /// Check if to continue parsing the next expression in the sequence @@ -1058,35 +608,28 @@ impl Parser { /// - **not an infix operator** /// - its **precedence is lower than** `min_prec` /// - it is a **terminator** - fn check_op( - &self, - nt: &LexResult, - min_prec: OperatorPrecedence, - allow_eq: bool, - ) -> Option { - let t = nt.as_ref().ok()?; - let op = if let TokenKind::Op(op) = &t.token { - op - } else { - return None; - }; - let op = self.find_operator(op, |op| { - op.position == OperatorPosition::Infix - || op.position == OperatorPosition::InfixAccumulate - })?; - let is_infix = op.info.position.is_infix(); - let is_greater = op.info.precedence > min_prec; - let is_right_assoc = op.info.associativity == OperatorAssociativity::Right; - let is_equal = op.info.precedence == min_prec; - if is_infix && (is_greater || ((is_right_assoc || allow_eq) && is_equal)) { + fn check_binary_op(&self, min_prec: OpPrec, op: &str) -> Option { + let op = self.find_operator(op, |op| op.position == OpPos::Infix)?; + let is_greater = op.precedence > min_prec; + let is_right_assoc = op.associativity == OpAssoc::Right; + let is_equal = op.precedence == min_prec; + if is_greater || (is_right_assoc && is_equal) { Some(op.clone()) } else { None } } - fn next_prec(curr_op: &OperatorInfo, next_op: &OperatorInfo) -> OperatorPrecedence { - curr_op.precedence + (next_op.precedence > curr_op.precedence) as OperatorPrecedence + fn check_postfix_op(&self, min_prec: OpPrec, op: &str) -> Option { + let op = self.find_operator(op, |op| op.position == OpPos::Postfix)?; + let is_greater = op.precedence > min_prec; + let is_right_assoc = op.associativity == OpAssoc::Right; + let is_equal = op.precedence == min_prec; + if is_greater || (is_right_assoc && is_equal) { + Some(op.clone()) + } else { + None + } } /// Parse an expression with a given left-hand side and minimum precedence level @@ -1100,258 +643,144 @@ impl Parser { /// The parsed expression or a parse error if the expression could not be parsed /// /// ## Algorithm - /// See: - /// - https://en.wikipedia.org/wiki/Operator-precedence_parser - /// - https://matklad.github.io/2020/04/13/simple-but-powerful-pratt-parsing.html - /// - https://eli.thegreenplace.net/2012/08/02/parsing-expressions-by-precedence-climbing - /// - https://www.engr.mun.ca/~theo/Misc/exp_parsing.htm - /// - https://crockford.com/javascript/tdop/tdop.html - fn parse_expr(&mut self, lhs: Ast, min_prec: OperatorPrecedence) -> ParseResult { - let mut expr = lhs; - while let Some(curr_op) = { - let nt = self.lexer.peek_token(0); - self.check_op(&nt, min_prec, false) - } { - let curr_info = self.lexer.next_token().unwrap().info; // Consume the operator token - if curr_op.info.position.is_accumulate() { - expr = self.parse_expr_accum(&curr_op, expr, curr_info)?; - continue; - } - let mut rhs = self.parse_primary()?; - while let Some(next_op) = { - let nt = self.lexer.peek_token(0); - self.check_op(&nt, curr_op.info.precedence, false) - } { - rhs = self.parse_expr(rhs, Self::next_prec(&curr_op.info, &next_op.info))?; - } - if let Some(desugar) = syntax_sugar::try_binary(&expr, &curr_op.info, &rhs) { - expr = desugar; - } else { - let info = expr.info().join(rhs.info()); - expr = if let OperatorHandler::Parse(handler) = curr_op.handler { - handler(StaticOperatorAst::Infix(expr, rhs)) - } else { - Ast::Binary { - lhs: Box::new(expr), - op_info: curr_op.info.clone(), - rhs: Box::new(rhs), - info, - } - }; - } - } - Ok(expr) - } - - /// Expect the parser state to be at the end of [expr, op] sequence. - /// Next token should be a new expression or a terminator. - fn parse_expr_accum(&mut self, op: &Operator, first: Ast, info: LineInfo) -> ParseResult { - let mut exprs = vec![first]; - while let Ok(t) = self.lexer.peek_token_not(pred::ignored, 0) { - if op.info.allow_trailing && t.token.is_terminator() { - break; + /// See: https://matklad.github.io/2020/04/13/simple-but-powerful-pratt-parsing.html + fn parse_expr(&mut self, min_prec: OpPrec) -> ParseResult { + let mut expr = self.parse_term()?; + // println!("Parsed term: {:?}", expr); + while let Ok(nt) = self.lexer.peek_token(0) { + if nt.token.is_terminator() { + break; // Stop parsing on expression terminators } + if let TokenKind::Op(op) = &nt.token { + if let Some(op) = self.check_postfix_op(min_prec, op) { + log::trace!("Parsing postfix operator: {:?}", op); + self.lexer.next_token().unwrap(); + expr = Ast::Unary { + info: expr.info().join(&nt.info), + op_info: op.clone(), + expr: Box::new(expr), + }; + continue; + } else if let Some(op) = self.check_binary_op(min_prec, op) { + log::trace!("Parsing infix operator: {:?}", op); + self.lexer.next_token().unwrap(); + let rhs = self.parse_expr(op.precedence)?; - // Parse the next nested expression in the sequence - let lhs = self.parse_primary()?; - exprs.push(self.parse_expr(lhs, op.info.precedence)?); - - // Expect the next token to be the same operator or another expression - let nt = self.lexer.peek_token_not(pred::ignored, 0); - if let Some(next_op) = self.check_op(&nt, op.info.precedence, true) { - if !next_op.info.eq(&op.info) { + let info = expr.info().join(rhs.info()); + expr = match op.symbol.as_str() { + ASSIGNMENT_SYM => { + // Allow all definitions in the parser, even if they are not valid in the current context + specialize::assignment(expr, rhs, info, None)? + } + MEMBER_ACCESS_SYM => specialize::member_access(expr, rhs, info)?, + _ => Ast::Binary { + lhs: Box::new(expr), + op_info: op.clone(), + rhs: Box::new(rhs), + info, + }, + }; + continue; + } else { break; } - self.lexer.read_next_token_not(pred::ignored).unwrap(); // Consume the operator token + } + if FUNCTION_APP_PREC > min_prec { + let call_info = expr.info().join(&nt.info); + // Allow all definitions in the parser, even if they are not valid in the current context + expr = specialize::call(expr, self.parse_term()?, call_info, &self.types, None)?; + continue; + } + if nt.token.is_terminator() { + break; // Stop parsing on expression terminators } else { - break; + return Err(ParseError::new( + format!( + "Expected operator or funciton application, but found {}", + nt.token.to_string().light_red() + ), + nt.info.clone(), + ) + .with_label("Not valid in this context".to_string(), nt.info)); } } - let expr = if let OperatorHandler::Parse(handler) = &op.handler { - handler(StaticOperatorAst::Accumulate(exprs)) - } else { - let info = info.join(exprs.last().unwrap().info()); - Ast::Accumulate { - op_info: op.info.clone(), - exprs, - info, - } - }; - Ok(expr) } /// Parse a top-level expression. fn parse_top_expr(&mut self) -> ParseResult { - let lhs = self.parse_primary()?; - let expr = self.parse_expr(lhs, 0); - let nt = self.lexer.peek_token(0); - if let Ok(t) = nt { - if t.token.is_top_level_terminal(false) { - self.lexer.next_token().unwrap(); + match self.parse_expr(0) { + Ok(expr) => { + self.skip_terminal_and_ignored(); + // Allow all definitions in the parser, even if they are not valid in the current context + specialize::top(expr, &self.types, None) } + Err(err) => Err(err), } - expr } - fn parse_expected( - &mut self, - expected_token: TokenKind, - symbol: &'static str, - ) -> Result { - match self.lexer.expect_next_token_not(pred::ignored) { - Ok(t) if t.token == expected_token => Ok(t), - Ok(t) => Err(ParseError::new( - format!( - "Expected {} but found {}", - symbol.yellow(), - t.token.to_string().light_red() - ), - t.info.clone(), - ) - .with_label(format!("This should be a {}", symbol), t.info)), - Err(err) => Err(ParseError::new( - format!("Expected {}", symbol.yellow(),), - err.info().clone(), - ) - .with_label(err.message().to_owned(), err.info().clone())), - } - } -} - -mod syntax_sugar { - use malachite::{num::basic::traits::Zero, Integer, Rational}; - - use crate::interpreter::number::{FloatingPoint, Number, NumberCasting}; - - use super::*; - - pub fn try_literal_fraction(lhs: &Number, rhs: &Number, info: LineInfo) -> Option { - let lhs = match lhs { - Number::UnsignedInteger(lhs) => lhs.to_signed(), - Number::SignedInteger(lhs) => lhs.clone(), - _ => return None, - } - .to_bigint(); - let rhs = match rhs { - Number::UnsignedInteger(rhs) => rhs.to_signed(), - Number::SignedInteger(rhs) => rhs.clone(), - _ => return None, + /// Skip all ignored tokens and the next top-level terminal token (`;`, `\n`, `EOF`). + fn skip_terminal_and_ignored(&mut self) { + // Remove all ignored tokens after the expression + while let Ok(t) = self.lexer.peek_token(0) { + if pred::ignore(&t.token) { + self.lexer.next_token().unwrap(); + } else { + break; + } } - .to_bigint(); - if rhs.cmp(&Integer::ZERO) == std::cmp::Ordering::Equal { - return None; + // If the next token is a top-level terminal, consume it + if let Ok(t) = self.lexer.peek_token(0) { + if t.token.is_top_level_terminal(false) { + self.lexer.next_token().unwrap(); + } } - Some(Ast::Literal { - value: Value::Number(Number::FloatingPoint( - FloatingPoint::FloatBig(Rational::from_integers(lhs, rhs)).optimize(), - )), - info, - }) - } - - pub fn try_binary(lhs: &Ast, op: &OperatorInfo, rhs: &Ast) -> Option { - match (lhs, op, rhs) { - ( - Ast::Literal { - value: Value::Number(lhs), - info: left_info, - }, - OperatorInfo { name, symbol, .. }, - Ast::Literal { - value: Value::Number(rhs), - info: right_info, - }, - ) if name == "div" && symbol == "/" => { - try_literal_fraction(lhs, rhs, left_info.join(right_info)) + // Continue to ignore any remaining ignored tokens + while let Ok(t) = self.lexer.peek_token(0) { + if pred::ignore(&t.token) { + self.lexer.next_token().unwrap(); + } else { + break; } - _ => None, } } - /// Takes a function name, a list of parameters and a body and rolls them into a single assignment expression. - /// Parameters are rolled into a nested function definition. - /// All parameters are sorted like: - /// ```lento - /// func(a, b, c) = expr - /// ``` - /// becomes: - /// ```lento - /// func = a -> b -> c -> expr - /// ``` + /// Parse **a single** expression from the stream of tokens. + /// Returns an AST node or an error. + /// If the first token is an EOF, then the parser will return an empty unit expression. /// - /// # Arguments - /// - `func_name` The name of the function - /// - `params` A list of parameters in left-to-right order: `a, b, c` - /// - `body` The body of the function - pub fn roll_function_definition(params: Vec, body: Ast) -> Ast { - assert!(!params.is_empty(), "Expected at least one parameter"); - let info = body.info().join(params.last().map(|p| &p.info).unwrap()); - let mut params = params.iter().rev(); - let mut function = Ast::FunctionDef { - param: params.next().unwrap().clone(), - body: Box::new(body), - return_type: None, - info, - }; - for param in params { - function = Ast::FunctionDef { - info: function.info().join(¶m.info), - param: param.clone(), - body: Box::new(function), - return_type: None, - }; + /// # Note + /// The parser will not necessarily consume all tokens from the stream. + /// It will **ONLY** consume a whole complete expression. + /// There may be remaining tokens in the stream after the expression is parsed. + pub fn parse_one(&mut self) -> ParseResult { + // Check if the next token is an EOF, then return an empty unit top-level expression + if let Ok(t) = self.lexer.peek_token_not(pred::ignore, 0) { + if pred::eof(&t.token) { + return Ok(Ast::Literal { + value: Value::Unit, + info: t.info, + }); + } } - function + self.parse_top_expr() } - pub fn roll_function_call(name: String, args: Vec, start_info: LineInfo) -> Ast { - let last_info = args - .last() - .map(|a| a.info().clone()) - .unwrap_or(start_info.clone()); - let call_info = start_info.join(&last_info); - let mut args = args.into_iter(); - let mut call = Ast::FunctionCall { - expr: Box::new(Ast::Identifier { - name, - info: call_info.clone(), - }), - arg: Box::new(args.next().unwrap()), - info: call_info.clone(), - }; - for arg in args { - let arg_info = call_info.join(arg.info()); - call = Ast::FunctionCall { - expr: Box::new(call), - arg: Box::new(arg), - info: arg_info, - }; + /// Parse a global AST from the stream of tokens. + /// A global AST is a list of **all** top-level AST nodes (expressions). + pub fn parse_all(&mut self) -> ParseResults { + let mut asts = Vec::new(); + loop { + if let Ok(t) = self.lexer.peek_token_not(pred::ignore, 0) { + if pred::eof(&t.token) { + break; + } + } + match self.parse_top_expr() { + Ok(expr) => asts.push(expr), + Err(e) => return Err(e), + } } - call + Ok(asts) } } - -//--------------------------------------------------------------------------------------// -// Parser Factory Functions // -//--------------------------------------------------------------------------------------// - -pub fn from_file(file: File) -> Parser> { - Parser::new(lexer::from_file(file)) -} - -pub fn from_string(source: String) -> Parser> { - Parser::new(lexer::from_string(source)) -} - -pub fn from_str(source: &str) -> Parser> { - Parser::new(lexer::from_str(source)) -} - -pub fn from_stdin() -> Parser { - Parser::new(lexer::from_stdin()) -} - -pub fn from_stream(reader: R) -> Parser { - Parser::new(lexer::from_stream(reader)) -} diff --git a/core/src/parser/pattern.rs b/core/src/parser/pattern.rs new file mode 100644 index 0000000..2e1da98 --- /dev/null +++ b/core/src/parser/pattern.rs @@ -0,0 +1,254 @@ +use std::hash::Hash; + +use crate::{ + interpreter::{ + number::{Number, SignedInteger, UnsignedInteger}, + value::{RecordKey, Value}, + }, + type_checker::types::TypeJudgements, + util::error::LineInfo, +}; + +/// A pattern used for binding variables in: +/// - Variable assignments +/// - Function definitions +/// - Function arguments +/// - Destructuring assignments +#[derive(Debug, Clone, Eq)] +pub enum BindPattern { + /// A variable binding pattern. + Variable { + /// The name of the variable. + name: String, + /// The type annotation for the variable. + info: LineInfo, + }, + /// A tuple binding pattern. + Tuple { + /// The elements of the tuple. + elements: Vec, + info: LineInfo, + }, + /// A record binding pattern. + Record { + /// The fields of the record. + fields: Vec<(RecordKey, BindPattern)>, + info: LineInfo, + }, + /// A list binding pattern. + List { + /// The elements of the list. + elements: Vec, + info: LineInfo, + }, + /// A wildcard pattern that matches any value. + Wildcard, + /// A literal pattern that matches a specific value. + Literal { + /// The value to match. + value: LiteralPattern, + info: LineInfo, + }, + /// A rest of a collection pattern that matches the rest of a list. + Rest { + /// The name of the variable to bind the rest of the list. + name: String, + /// The type annotation for the variable. + info: LineInfo, + }, +} + +impl BindPattern { + pub fn info(&self) -> &LineInfo { + match self { + BindPattern::Variable { info, .. } => info, + BindPattern::Tuple { info, .. } => info, + BindPattern::Record { info, .. } => info, + BindPattern::List { info, .. } => info, + BindPattern::Wildcard => panic!("Wildcard pattern has no line info"), + BindPattern::Literal { info, .. } => info, + BindPattern::Rest { info, .. } => info, + } + } + + pub fn specialize(&mut self, _judgements: &TypeJudgements, _changed: &mut bool) { + match self { + BindPattern::Variable { .. } => (), + BindPattern::Tuple { elements, .. } => { + for element in elements { + element.specialize(_judgements, _changed); + } + } + BindPattern::Record { fields, .. } => { + for (_, element) in fields { + element.specialize(_judgements, _changed); + } + } + BindPattern::List { elements, .. } => { + for element in elements { + element.specialize(_judgements, _changed); + } + } + BindPattern::Wildcard => (), + BindPattern::Literal { .. } => (), + BindPattern::Rest { .. } => (), + } + } + + pub fn print_expr(&self) -> String { + match self { + BindPattern::Variable { name, .. } => name.clone(), + + BindPattern::Tuple { elements, .. } => format!( + "({})", + elements + .iter() + .map(|e| e.print_expr()) + .collect::>() + .join(", ") + ), + BindPattern::Record { fields, .. } => format!( + "{{ {} }}", + fields + .iter() + .map(|(k, v)| format!("{}: {}", k, v.print_expr())) + .collect::>() + .join(", ") + ), + BindPattern::List { elements, .. } => format!( + "[{}]", + elements + .iter() + .map(|e| e.print_expr()) + .collect::>() + .join(", ") + ), + BindPattern::Wildcard => "_".to_string(), + BindPattern::Literal { value, .. } => value.as_value().pretty_print(), + BindPattern::Rest { name, .. } => format!("...{}", name), + } + } + + pub fn pretty_print(&self) -> String { + match self { + BindPattern::Variable { name, .. } => name.clone(), + BindPattern::Tuple { elements, .. } => { + let mut result = "(".to_string(); + for (i, v) in elements.iter().enumerate() { + result.push_str(&v.pretty_print()); + if i < elements.len() - 1 { + result.push_str(", "); + } + } + result.push(')'); + result + } + BindPattern::Record { fields, .. } => { + let mut result = "{ ".to_string(); + for (i, (k, v)) in fields.iter().enumerate() { + result.push_str(&format!("{}: {}", k, v.pretty_print())); + if i < fields.len() - 1 { + result.push_str(", "); + } + } + result.push_str(" }"); + result + } + BindPattern::List { elements, .. } => { + let mut result = "[".to_string(); + for (i, v) in elements.iter().enumerate() { + result.push_str(&v.pretty_print()); + if i < elements.len() - 1 { + result.push_str(", "); + } + } + result.push(']'); + result + } + BindPattern::Wildcard => "_".to_string(), + BindPattern::Literal { value, .. } => value.as_value().pretty_print(), + BindPattern::Rest { name, .. } => format!("...{}", name), + } + } + + pub fn is_wildcard(&self) -> bool { + matches!(self, BindPattern::Wildcard) + } +} + +impl PartialEq for BindPattern { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Variable { name: l0, .. }, Self::Variable { name: r0, .. }) => l0 == r0, + (Self::Tuple { elements: l0, .. }, Self::Tuple { elements: r0, .. }) => l0 == r0, + (Self::Record { fields: l0, .. }, Self::Record { fields: r0, .. }) => l0 == r0, + (Self::List { elements: l0, .. }, Self::List { elements: r0, .. }) => l0 == r0, + (Self::Wildcard, Self::Wildcard) => true, + (Self::Literal { value: l0, .. }, Self::Literal { value: r0, .. }) => l0 == r0, + (Self::Rest { name: l0, .. }, Self::Rest { name: r0, .. }) => l0 == r0, + _ => false, + } + } +} + +impl Hash for BindPattern { + fn hash(&self, state: &mut H) { + match self { + BindPattern::Variable { name, .. } => name.hash(state), + BindPattern::Tuple { elements, .. } => { + for element in elements { + element.hash(state); + } + } + BindPattern::Record { fields, .. } => { + for (key, value) in fields { + key.hash(state); + value.hash(state); + } + } + BindPattern::List { elements, .. } => { + for element in elements { + element.hash(state); + } + } + BindPattern::Wildcard => "Wildcard".hash(state), + BindPattern::Literal { value, .. } => value.hash(state), + BindPattern::Rest { name, .. } => name.hash(state), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum LiteralPattern { + UnsignedInteger(UnsignedInteger), + SignedInteger(SignedInteger), + String(String), + Char(char), + Boolean(bool), +} + +impl LiteralPattern { + pub fn into_value(self) -> Value { + match self { + LiteralPattern::UnsignedInteger(value) => Value::Number(Number::UnsignedInteger(value)), + LiteralPattern::SignedInteger(value) => Value::Number(Number::SignedInteger(value)), + LiteralPattern::String(value) => Value::String(value), + LiteralPattern::Char(value) => Value::Char(value), + LiteralPattern::Boolean(value) => Value::Boolean(value), + } + } + + pub fn as_value(&self) -> Value { + match self { + LiteralPattern::UnsignedInteger(value) => { + Value::Number(Number::UnsignedInteger(value.clone())) + } + LiteralPattern::SignedInteger(value) => { + Value::Number(Number::SignedInteger(value.clone())) + } + LiteralPattern::String(value) => Value::String(value.clone()), + LiteralPattern::Char(value) => Value::Char(*value), + LiteralPattern::Boolean(value) => Value::Boolean(*value), + } + } +} diff --git a/core/src/parser/tests.rs b/core/src/parser/tests.rs index 9405217..39e4573 100644 --- a/core/src/parser/tests.rs +++ b/core/src/parser/tests.rs @@ -9,6 +9,7 @@ mod tests { parser::{ ast::{Ast, TypeAst}, parser::from_string, + pattern::BindPattern, }, stdlib::init::{stdlib, Initializer}, util::error::LineInfo, @@ -51,10 +52,20 @@ mod tests { parser.parse_all() } + #[test] + fn unit() { + let result = parse_str_one("()", None); + let result = result.unwrap(); + + assert!(matches!(result, Ast::Tuple { .. })); + if let Ast::Tuple { exprs, info: _ } = &result { + assert_eq!(exprs.len(), 0); + } + } + #[test] fn number() { let result = parse_str_one("1", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(result == lit(make_u1(1))); @@ -62,8 +73,7 @@ mod tests { #[test] fn number_many() { - let result = parse_str_all("1 2 3 4 5", None); - assert!(result.is_ok()); + let result = parse_str_all("1 \n 2 \n 3 \n 4 \n 5", None); let result = result.unwrap(); assert!(result.len() == 5); assert!(result[0] == lit(make_u1(1))); @@ -76,7 +86,6 @@ mod tests { #[test] fn number_many_semicolon() { let result = parse_str_all("1; 2; 3;", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(result.len() == 3); assert!(result[0] == lit(make_u1(1))); @@ -87,29 +96,14 @@ mod tests { #[test] fn number_par() { let result = parse_str_one("(1)", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(result == lit(make_u1(1))); } - #[test] - fn tuple_1_trailing() { - let result = parse_str_one("(1,)", None); - assert!(result.is_ok()); - let result = result.unwrap(); - - assert!(matches!(result, Ast::Tuple { .. })); - if let Ast::Tuple { exprs, .. } = &result { - assert_eq!(exprs.len(), 1); - assert_eq!(exprs[0], lit(make_u1(1))); - } - } - #[test] fn tuple_2() { let result = parse_str_one("(1, 2)", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Tuple { .. })); @@ -123,22 +117,6 @@ mod tests { #[test] fn tuple_3() { let result = parse_str_one("(1, 2, 3)", None); - assert!(result.is_ok()); - let result = result.unwrap(); - - assert!(matches!(result, Ast::Tuple { .. })); - if let Ast::Tuple { exprs, .. } = &result { - assert_eq!(exprs.len(), 3); - assert_eq!(exprs[0], lit(make_u1(1))); - assert_eq!(exprs[1], lit(make_u8(2))); - assert_eq!(exprs[2], lit(make_u8(3))); - } - } - - #[test] - fn tuple_3_trailing() { - let result = parse_str_one("(1, 2, 3,)", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Tuple { .. })); @@ -153,7 +131,6 @@ mod tests { #[test] fn tuple_addition() { let result = parse_str_one("(1, 2) + (3, 4)", Some(&stdlib())); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Binary { .. })); @@ -161,7 +138,7 @@ mod tests { lhs, op_info, rhs, .. } = &result { - assert_eq!(op_info.name, "add".to_string()); + assert_eq!(&op_info.symbol, "+"); assert!(matches!(*lhs.to_owned(), Ast::Tuple { .. })); assert!(matches!(*rhs.to_owned(), Ast::Tuple { .. })); } @@ -170,7 +147,6 @@ mod tests { #[test] fn list_3() { let result = parse_str_one("[1, 2, 3]", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::List { .. })); @@ -197,7 +173,6 @@ mod tests { arg: Box::new(lit(Value::String("Hello, World!".to_string()))), info: LineInfo::default(), }; - assert!(result.is_ok()); let result = result.unwrap(); assert!(result == expected); @@ -214,7 +189,6 @@ mod tests { arg: Box::new(lit(Value::String("Hello, World!".to_string()))), info: LineInfo::default(), }; - assert!(result.is_ok()); let result = result.unwrap(); assert!(result == expected); @@ -231,7 +205,6 @@ mod tests { arg: Box::new(lit(Value::String("Hello, World!".to_string()))), info: LineInfo::default(), }; - assert!(result.is_ok()); let result = result.unwrap(); assert!(result == expected); @@ -251,7 +224,6 @@ mod tests { arg: Box::new(lit(Value::String("Hello, World!".to_string()))), info: LineInfo::default(), }; - assert!(result.is_ok()); let result = result.unwrap(); assert!(result.len() == 3); // All three should be the same @@ -263,13 +235,12 @@ mod tests { #[test] fn arithmetic() { let result = parse_str_one("1 + 2", Some(&stdlib())); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Binary { .. })); // Assert "add" if let Ast::Binary { op_info, .. } = &result { - assert_eq!(op_info.name, "add".to_string()); + assert_eq!(&op_info.symbol, "+"); } if let Ast::Binary { lhs, rhs, .. } = &result { // Always true @@ -293,7 +264,6 @@ mod tests { #[test] fn arithmetic_tree() { let result = parse_str_one("1 + 2 + 3", Some(&stdlib())); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Binary { .. })); @@ -307,7 +277,6 @@ mod tests { #[test] fn literal_type_identifier() { let result = parse_str_one("int", Some(&stdlib())); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::LiteralType { .. })); if let Ast::LiteralType { expr, .. } = &result { @@ -322,9 +291,7 @@ mod tests { #[test] fn typed_assignment() { let result = parse_str_one("int x = 1", Some(&stdlib())); - assert!(result.is_ok()); let result = result.unwrap(); - println!("result: {:?}", result); assert!(matches!(result, Ast::Assignment { .. })); if let Ast::Assignment { @@ -334,7 +301,7 @@ mod tests { .. } = &result { - assert!(matches!(*target.to_owned(), Ast::Identifier { .. })); + assert!(matches!(target, BindPattern::Variable { .. })); assert!(matches!(*expr.to_owned(), Ast::Literal { .. })); assert!(annotation.to_owned().is_some()); assert!(matches!( @@ -354,9 +321,7 @@ mod tests { #[test] fn untyped_assignment() { let result = parse_str_one("x = 1", Some(&stdlib())); - assert!(result.is_ok()); let result = result.unwrap(); - println!("result: {:?}", result); assert!(matches!(result, Ast::Assignment { .. })); } @@ -364,13 +329,11 @@ mod tests { #[test] fn assign_add() { let result = parse_str_one("x = 1 + 2", Some(&stdlib())); - assert!(result.is_ok()); let result = result.unwrap(); - println!("result: {:?}", result); assert!(matches!(result, Ast::Assignment { .. })); if let Ast::Assignment { target, expr, .. } = &result { - assert!(matches!(*target.to_owned(), Ast::Identifier { .. })); + assert!(matches!(target, BindPattern::Variable { .. })); assert!(matches!(*expr.to_owned(), Ast::Binary { .. })); } } @@ -378,7 +341,6 @@ mod tests { #[test] fn comment() { let result = parse_str_all("1; // This is a comment", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(result.len() == 1); assert!(matches!(result[0], Ast::Literal { .. })); @@ -396,7 +358,6 @@ mod tests { "#, None, ); - assert!(result.is_ok()); let result = result.unwrap(); assert!(result.len() == 2); assert!(result[0] == lit(make_u1(1))); @@ -406,7 +367,6 @@ mod tests { #[test] fn arithmetic_complex() { let result = parse_str_one("5 * (10 - 2) / 2 + 1", Some(&stdlib())); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Binary { .. })); @@ -426,7 +386,6 @@ mod tests { #[test] fn record_literal_empty() { let result = parse_str_one("{}", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Record { .. })); @@ -438,7 +397,6 @@ mod tests { #[test] fn record_literal_one() { let result = parse_str_one("{ x: 1 }", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Record { .. })); @@ -446,9 +404,8 @@ mod tests { assert_eq!(fields.len(), 1); let fields = fields.iter().collect::>(); assert!(matches!(fields[0].0, RecordKey::String(_))); - if let RecordKey::String(key) = &fields[0].0 { - assert_eq!(key, "x"); - } + let RecordKey::String(key) = &fields[0].0; + assert_eq!(key, "x"); assert!(matches!(fields[0].1, Ast::Literal { .. })); assert_eq!(fields[0].1, lit(make_u1(1))); } @@ -457,7 +414,6 @@ mod tests { #[test] fn record_literal_two() { let result = parse_str_one("{ x: 1, y: 2 }", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Record { .. })); @@ -466,12 +422,10 @@ mod tests { let fields = fields.iter().collect::>(); assert!(matches!(fields[0].0, RecordKey::String(_))); assert!(matches!(fields[1].0, RecordKey::String(_))); - if let RecordKey::String(key) = &fields[0].0 { - assert_eq!(key, "x"); - } - if let RecordKey::String(key) = &fields[1].0 { - assert_eq!(key, "y"); - } + let RecordKey::String(key) = &fields[0].0; + assert_eq!(key, "x"); + let RecordKey::String(key) = &fields[1].0; + assert_eq!(key, "y"); assert!(matches!(fields[0].1, Ast::Literal { .. })); assert!(matches!(fields[1].1, Ast::Literal { .. })); assert_eq!(fields[0].1, lit(make_u1(1))); @@ -482,7 +436,6 @@ mod tests { #[test] fn record_literal_nested() { let result = parse_str_one("{ x: { y: 1 } }", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Record { .. })); @@ -490,9 +443,8 @@ mod tests { assert_eq!(fields.len(), 1); let fields = fields.iter().collect::>(); assert!(matches!(fields[0].0, RecordKey::String(_))); - if let RecordKey::String(key) = &fields[0].0 { - assert_eq!(key, "x"); - } + let RecordKey::String(key) = &fields[0].0; + assert_eq!(key, "x"); assert!(matches!(fields[0].1, Ast::Record { .. })); if let Ast::Record { fields: inner_fields, @@ -502,9 +454,8 @@ mod tests { assert_eq!(inner_fields.len(), 1); let inner_fields = inner_fields.iter().collect::>(); assert!(matches!(inner_fields[0].0, RecordKey::String(_))); - if let RecordKey::String(key) = &inner_fields[0].0 { - assert_eq!(key, "y"); - } + let RecordKey::String(key) = &inner_fields[0].0; + assert_eq!(key, "y"); assert!(matches!(inner_fields[0].1, Ast::Literal { .. })); assert_eq!(inner_fields[0].1, lit(make_u1(1))); } @@ -514,7 +465,6 @@ mod tests { #[test] fn record_nested_block() { let result = parse_str_one("{ x: { 1 + 2 } }", Some(&stdlib())); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Record { .. })); @@ -522,9 +472,8 @@ mod tests { assert_eq!(fields.len(), 1); let fields = fields.iter().collect::>(); assert!(matches!(fields[0].0, RecordKey::String(_))); - if let RecordKey::String(key) = &fields[0].0 { - assert_eq!(key, "x"); - } + let RecordKey::String(key) = &fields[0].0; + assert_eq!(key, "x"); assert!(matches!(fields[0].1, Ast::Block { .. })); if let Ast::Block { exprs: inner, .. } = &fields[0].1 { assert_eq!(inner.len(), 1); @@ -536,7 +485,6 @@ mod tests { #[test] fn block_one() { let result = parse_str_one("{ 1 }", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Block { .. })); @@ -549,7 +497,6 @@ mod tests { #[test] fn block_two() { let result = parse_str_one("{ 1; 2 }", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Block { .. })); @@ -563,7 +510,6 @@ mod tests { #[test] fn block_three() { let result = parse_str_one("{ 1; 2; 3 }", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Block { .. })); @@ -577,8 +523,7 @@ mod tests { #[test] fn block_three_no_semicolon() { - let result = parse_str_one("{ 1 2 3 }", None); - assert!(result.is_ok()); + let result = parse_str_one("{ 1 \n 2 \n 3 }", None); let result = result.unwrap(); assert!(matches!(result, Ast::Block { .. })); @@ -593,7 +538,6 @@ mod tests { #[test] fn block_nested() { let result = parse_str_one("{ { 1 } }", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Block { .. })); @@ -606,7 +550,6 @@ mod tests { #[test] fn block_nested_two() { let result = parse_str_one("{ { 1; 2 } }", None); - assert!(result.is_ok()); let result = result.unwrap(); assert!(matches!(result, Ast::Block { .. })); @@ -623,4 +566,679 @@ mod tests { } } } + + #[test] + fn function_def_paren_explicit_args_and_ret() { + parse_str_one("u8 add(u8 x, u8 y, u8 z) = { x + y + z }", Some(&stdlib())).unwrap(); + } + + #[test] + fn function_def_no_paren_explicit_args_and_ret() { + parse_str_one("u8 add u8 x, u8 y, u8 z = { x + y + z }", Some(&stdlib())).unwrap(); + } + + #[test] + fn function_def_no_paren_explicit_args() { + parse_str_one("add u8 x, u8 y, u8 z = { x + y + z }", Some(&stdlib())).unwrap(); + } + + #[test] + fn function_def_paren_implicit_args_and_ret() { + parse_str_one("add(x, y, z) = { x + y + z }", Some(&stdlib())).unwrap(); + } + + #[test] + fn function_def_no_paren_implicit_args_and_ret() { + parse_str_one("add x, y, z = { x + y + z }", Some(&stdlib())).unwrap(); + } + + #[test] + fn function_def_mixed_parens() { + parse_str_one( + "u8 add x, y, (z), a, (b), (c) = { x + y + z + a + b + c }", + Some(&stdlib()), + ) + .unwrap(); + } + + #[test] + fn function_def_paren_explicit_oneline() { + parse_str_one("u8 add(u8 x, u8 y, u8 z) = x + y + z;", Some(&stdlib())).unwrap(); + } + + #[test] + fn function_def_no_paren_explicit_oneline() { + parse_str_one("u8 add u8 x, u8 y, u8 z = x + y + z;", Some(&stdlib())).unwrap(); + } + + #[test] + fn function_def_paren_implicit_oneline() { + parse_str_one("add(x, y, z) = x + y + z;", Some(&stdlib())).unwrap(); + } + + #[test] + fn function_def_no_paren_implicit_oneline() { + parse_str_one("add x, y, z = x + y + z;", Some(&stdlib())).unwrap(); + } + + #[test] + fn function_def_with_return_type() { + parse_str_one("int add(int x, int y) = x + y;", Some(&stdlib())).unwrap(); + } + + #[test] + fn function_def_with_return_type_no_parens() { + parse_str_one("int add int x, int y = x + y;", Some(&stdlib())).unwrap(); + } + + #[test] + fn function_def_with_return_type_block() { + parse_str_one("int add(int x, int y) = { x + y }", Some(&stdlib())).unwrap(); + } + + #[test] + fn function_def_multiple_statements() { + parse_str_one( + "int add(int x, int y) = { + let z = x + y; + z + }", + Some(&stdlib()), + ) + .unwrap(); + } + + #[test] + fn function_def_nested() { + parse_str_one( + "int outer(int x) = { + int inner(int y) = x + y; + inner(x) + }", + Some(&stdlib()), + ) + .unwrap(); + } + + #[test] + fn assignment_with_type() { + let result = parse_str_one("int x = 123", Some(&stdlib())); + if let Ast::Assignment { + target, + expr, + annotation, + .. + } = result.unwrap() + { + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "x"); + } + assert!(annotation.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = &annotation { + assert_eq!(name, "int"); + } + assert!(matches!(*expr, Ast::Literal { .. })); + } else { + panic!("Expected assignment"); + } + } + + #[test] + fn function_def_with_type_and_paren_arg() { + let result = parse_str_one("int f(int x) = x + 5", Some(&stdlib())); + if let Ast::Assignment { + target, + expr, + annotation, + .. + } = result.unwrap() + { + assert!(annotation.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = annotation { + assert_eq!(name, "int"); + } + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "f"); + } + assert!(matches!(*expr, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *expr { + assert!(param.ty.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = param.ty { + assert_eq!(name, "int"); + } + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "x"); + } + assert!(matches!(*body, Ast::Binary { .. })); + if let Ast::Binary { lhs, rhs, .. } = *body { + assert!(matches!(*lhs, Ast::Identifier { .. })); + if let Ast::Identifier { name, .. } = *lhs { + assert_eq!(name, "x"); + } + assert!(matches!(*rhs, Ast::Literal { .. })); + if let Ast::Literal { value, .. } = *rhs { + assert_eq!( + value, + Value::Number(Number::UnsignedInteger(UnsignedInteger::UInt8(5))) + ); + } + } + } + } else { + panic!("Expected function definition"); + } + } + + #[test] + fn function_def_with_paren_arg() { + let result = parse_str_one("f(int x) = x + 5", Some(&stdlib())); + if let Ast::Assignment { + target, + expr, + annotation, + .. + } = result.unwrap() + { + assert!(annotation.is_none()); + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "f"); + } + assert!(matches!(*expr, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *expr { + assert!(param.ty.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = param.ty { + assert_eq!(name, "int"); + } + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "x"); + } + assert!(matches!(*body, Ast::Binary { .. })); + if let Ast::Binary { lhs, rhs, .. } = *body { + assert!(matches!(*lhs, Ast::Identifier { .. })); + if let Ast::Identifier { name, .. } = *lhs { + assert_eq!(name, "x"); + } + assert!(matches!(*rhs, Ast::Literal { .. })); + if let Ast::Literal { value, .. } = *rhs { + assert_eq!( + value, + Value::Number(Number::UnsignedInteger(UnsignedInteger::UInt8(5))) + ); + } + } + } + } else { + panic!("Expected function definition"); + } + } + + #[test] + fn function_def_with_type_and_parenless_arg() { + let result = parse_str_one("int f(x) = x + 5", Some(&stdlib())); + if let Ast::Assignment { + target, + expr, + annotation, + .. + } = result.unwrap() + { + assert!(annotation.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = annotation { + assert_eq!(name, "int"); + } + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "f"); + } + assert!(matches!(*expr, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *expr { + assert!(param.ty.is_none()); + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "x"); + } + assert!(matches!(*body, Ast::Binary { .. })); + if let Ast::Binary { lhs, rhs, .. } = *body { + assert!(matches!(*lhs, Ast::Identifier { .. })); + if let Ast::Identifier { name, .. } = *lhs { + assert_eq!(name, "x"); + } + assert!(matches!(*rhs, Ast::Literal { .. })); + if let Ast::Literal { value, .. } = *rhs { + assert_eq!( + value, + Value::Number(Number::UnsignedInteger(UnsignedInteger::UInt8(5))) + ); + } + } + } + } else { + panic!("Expected function definition"); + } + } + + #[test] + fn function_def_with_parenless_arg() { + let result = parse_str_one("f(x) = x + 5", Some(&stdlib())); + if let Ast::Assignment { + target, + expr, + annotation, + .. + } = result.unwrap() + { + assert!(annotation.is_none()); + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "f"); + } + assert!(matches!(*expr, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *expr { + assert!(param.ty.is_none()); + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "x"); + } + assert!(matches!(*body, Ast::Binary { .. })); + if let Ast::Binary { lhs, rhs, .. } = *body { + assert!(matches!(*lhs, Ast::Identifier { .. })); + if let Ast::Identifier { name, .. } = *lhs { + assert_eq!(name, "x"); + } + assert!(matches!(*rhs, Ast::Literal { .. })); + if let Ast::Literal { value, .. } = *rhs { + assert_eq!( + value, + Value::Number(Number::UnsignedInteger(UnsignedInteger::UInt8(5))) + ); + } + } + } + } else { + panic!("Expected function definition"); + } + } + + #[test] + fn function_def_with_type_and_explicit_arg() { + let result = parse_str_one("int f int x = x + 5", Some(&stdlib())); + if let Ast::Assignment { + target, + expr, + annotation, + .. + } = result.unwrap() + { + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "f"); + } + assert!(annotation.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = annotation { + assert_eq!(name, "int"); + } + assert!(matches!(*expr, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *expr { + assert!(param.ty.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = param.ty { + assert_eq!(name, "int"); + } + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "x"); + } + assert!(matches!(*body, Ast::Binary { .. })); + if let Ast::Binary { lhs, rhs, .. } = *body { + assert!(matches!(*lhs, Ast::Identifier { .. })); + if let Ast::Identifier { name, .. } = *lhs { + assert_eq!(name, "x"); + } + assert!(matches!(*rhs, Ast::Literal { .. })); + if let Ast::Literal { value, .. } = *rhs { + assert_eq!( + value, + Value::Number(Number::UnsignedInteger(UnsignedInteger::UInt8(5))) + ); + } + } + } + } else { + panic!("Expected function definition"); + } + } + + #[test] + fn function_def_with_explicit_arg() { + let result = parse_str_one("f x = x + 5", Some(&stdlib())); + if let Ast::Assignment { + target, + expr, + annotation, + .. + } = result.unwrap() + { + assert!(annotation.is_none()); + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "f"); + } + assert!(annotation.is_none()); + assert!(matches!(*expr, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *expr { + assert!(param.ty.is_none()); + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "x"); + } + assert!(matches!(*body, Ast::Binary { .. })); + if let Ast::Binary { lhs, rhs, .. } = *body { + assert!(matches!(*lhs, Ast::Identifier { .. })); + if let Ast::Identifier { name, .. } = *lhs { + assert_eq!(name, "x"); + } + assert!(matches!(*rhs, Ast::Literal { .. })); + if let Ast::Literal { value, .. } = *rhs { + assert_eq!( + value, + Value::Number(Number::UnsignedInteger(UnsignedInteger::UInt8(5))) + ); + } + } + } + } else { + panic!("Expected function definition"); + } + } + + #[test] + fn function_def_with_multiple_explicit_args() { + let result = parse_str_one("f int x, int y = x + y", Some(&stdlib())); + if let Ast::Assignment { + target, + expr, + annotation, + .. + } = result.unwrap() + { + assert!(annotation.is_none()); + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "f"); + } + assert!(matches!(*expr, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *expr { + assert!(param.ty.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = param.ty { + assert_eq!(name, "int"); + } + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "x"); + } + assert!(matches!(*body, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *body { + assert!(param.ty.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = param.ty { + assert_eq!(name, "int"); + } + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "y"); + } + assert!(matches!(*body, Ast::Binary { .. })); + if let Ast::Binary { lhs, rhs, .. } = *body { + assert!(matches!(*lhs, Ast::Identifier { .. })); + if let Ast::Identifier { name, .. } = *lhs { + assert_eq!(name, "x"); + } + assert!(matches!(*rhs, Ast::Identifier { .. })); + if let Ast::Identifier { name, .. } = *rhs { + assert_eq!(name, "y"); + } + } + } + } + } else { + panic!("Expected function definition"); + } + } + + #[test] + fn function_def_with_type_and_paren_args_block() { + let result = parse_str_one( + "int f(int x, int y) = { + x + y + }", + Some(&stdlib()), + ); + if let Ast::Assignment { + target, + expr, + annotation, + .. + } = result.unwrap() + { + assert!(annotation.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = annotation { + assert_eq!(name, "int"); + } + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "f"); + } + assert!(matches!(*expr, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *expr { + assert!(param.ty.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = param.ty { + assert_eq!(name, "int"); + } + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "x"); + } + assert!(matches!(*body, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *body { + assert!(param.ty.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = param.ty { + assert_eq!(name, "int"); + } + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "y"); + } + assert!(matches!(*body, Ast::Block { .. })); + if let Ast::Block { exprs, .. } = *body { + assert_eq!(exprs.len(), 1); + assert!(matches!(exprs[0], Ast::Binary { .. })); + if let Ast::Binary { lhs, rhs, .. } = &exprs[0] { + assert!(matches!(**lhs, Ast::Identifier { .. })); + if let Ast::Identifier { ref name, .. } = **lhs { + assert_eq!(name, "x"); + } + assert!(matches!(**rhs, Ast::Identifier { .. })); + if let Ast::Identifier { ref name, .. } = **rhs { + assert_eq!(name, "y"); + } + } + } + } + } + } else { + panic!("Expected function definition"); + } + } + + #[test] + fn function_def_with_type_and_explicit_args_block() { + let result = parse_str_one( + "int f int x, int y = { + x + y + }", + Some(&stdlib()), + ); + if let Ast::Assignment { + target, + expr, + annotation, + .. + } = result.unwrap() + { + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "f"); + } + assert!(annotation.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = annotation { + assert_eq!(name, "int"); + } + assert!(matches!(*expr, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *expr { + assert!(param.ty.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = param.ty { + assert_eq!(name, "int"); + } + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "x"); + } + assert!(matches!(*body, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *body { + assert!(param.ty.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = param.ty { + assert_eq!(name, "int"); + } + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "y"); + } + assert!(matches!(*body, Ast::Block { .. })); + if let Ast::Block { exprs, .. } = *body { + assert_eq!(exprs.len(), 1); + assert!(matches!(exprs[0], Ast::Binary { .. })); + if let Ast::Binary { lhs, rhs, .. } = &exprs[0] { + assert!(matches!(**lhs, Ast::Identifier { .. })); + if let Ast::Identifier { ref name, .. } = **lhs { + assert_eq!(name, "x"); + } + assert!(matches!(**rhs, Ast::Identifier { .. })); + if let Ast::Identifier { ref name, .. } = **rhs { + assert_eq!(name, "y"); + } + } + } + } + } + } else { + panic!("Expected function definition"); + } + } + + #[test] + fn function_def_with_type_and_paren_args_oneline() { + let result = parse_str_one("int f(int x, int y) = x + y;", Some(&stdlib())); + if let Ast::Assignment { + target, + expr, + annotation, + .. + } = result.unwrap() + { + assert!(annotation.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = annotation { + assert_eq!(name, "int"); + } + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "f"); + } + assert!(matches!(*expr, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *expr { + assert!(param.ty.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = param.ty { + assert_eq!(name, "int"); + } + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "x"); + } + assert!(matches!(*body, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *body { + assert!(param.ty.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = param.ty { + assert_eq!(name, "int"); + } + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "y"); + } + assert!(matches!(*body, Ast::Binary { .. })); + if let Ast::Binary { lhs, rhs, .. } = *body { + assert!(matches!(*lhs, Ast::Identifier { .. })); + if let Ast::Identifier { name, .. } = *lhs { + assert_eq!(name, "x"); + } + assert!(matches!(*rhs, Ast::Identifier { .. })); + if let Ast::Identifier { name, .. } = *rhs { + assert_eq!(name, "y"); + } + } + } + } + } else { + panic!("Expected function definition"); + } + } + + #[test] + fn function_def_with_type_and_explicit_args_multiline() { + let result = parse_str_one( + "int f + int x, + int y + = x + y;", + Some(&stdlib()), + ); + let result = result.unwrap(); + if let Ast::Assignment { + target, + expr, + annotation, + .. + } = result + { + assert!(annotation.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = annotation { + assert_eq!(name, "int"); + } + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "f"); + } + assert!(matches!(*expr, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *expr { + assert!(param.ty.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = param.ty { + assert_eq!(name, "int"); + } + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "x"); + } + assert!(matches!(*body, Ast::Lambda { .. })); + if let Ast::Lambda { param, body, .. } = *body { + assert!(param.ty.is_some()); + if let Some(TypeAst::Identifier { name, .. }) = param.ty { + assert_eq!(name, "int"); + } + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "y"); + } + assert!(matches!(*body, Ast::Binary { .. })); + if let Ast::Binary { lhs, rhs, .. } = *body { + assert!(matches!(*lhs, Ast::Identifier { .. })); + if let Ast::Identifier { name, .. } = *lhs { + assert_eq!(name, "x"); + } + assert!(matches!(*rhs, Ast::Identifier { .. })); + if let Ast::Identifier { name, .. } = *rhs { + assert_eq!(name, "y"); + } + } + } + } + } else { + dbg!(result); + panic!("Expected function definition"); + } + } } diff --git a/core/src/stdlib/arithmetic.rs b/core/src/stdlib/arithmetic.rs index 76e6f86..d276e7c 100644 --- a/core/src/stdlib/arithmetic.rs +++ b/core/src/stdlib/arithmetic.rs @@ -49,14 +49,8 @@ pub fn add() -> Function { } }, vec![ - CheckedParam { - name: "lhs".to_string(), - ty: std_types::NUM(), - }, - CheckedParam { - name: "rhs".to_string(), - ty: std_types::NUM(), - }, + CheckedParam::from_str("lhs", std_types::NUM()), + CheckedParam::from_str("rhs", std_types::NUM()), ], std_types::NUM(), ) @@ -96,14 +90,8 @@ pub fn sub() -> Function { } }, vec![ - CheckedParam { - name: "lhs".to_string(), - ty: std_types::NUM(), - }, - CheckedParam { - name: "rhs".to_string(), - ty: std_types::NUM(), - }, + CheckedParam::from_str("lhs", std_types::NUM()), + CheckedParam::from_str("rhs", std_types::NUM()), ], std_types::NUM(), ) @@ -143,14 +131,8 @@ pub fn mul() -> Function { } }, vec![ - CheckedParam { - name: "lhs".to_string(), - ty: std_types::NUM(), - }, - CheckedParam { - name: "rhs".to_string(), - ty: std_types::NUM(), - }, + CheckedParam::from_str("lhs", std_types::NUM()), + CheckedParam::from_str("rhs", std_types::NUM()), ], std_types::NUM(), ) @@ -192,14 +174,8 @@ pub fn div() -> Function { } }, vec![ - CheckedParam { - name: "lhs".to_string(), - ty: std_types::NUM(), - }, - CheckedParam { - name: "rhs".to_string(), - ty: std_types::NUM(), - }, + CheckedParam::from_str("lhs", std_types::NUM()), + CheckedParam::from_str("rhs", std_types::NUM()), ], std_types::NUM(), ) diff --git a/core/src/stdlib/init.rs b/core/src/stdlib/init.rs index 0e3f48e..114de19 100644 --- a/core/src/stdlib/init.rs +++ b/core/src/stdlib/init.rs @@ -4,26 +4,19 @@ use colorful::Colorful; use crate::{ interpreter::{ - self, env::Environment, number::{FloatingPoint, Number}, - value::{Function, RecordKey, Value}, + value::{Function, Value}, }, lexer::lexer::Lexer, parser::{ - ast::Ast, - error::ParseError, - op::{ - default_operator_precedence, Operator, OperatorAssociativity, OperatorHandler, - OperatorPosition, OperatorSignature, RuntimeOperatorHandler, StaticOperatorAst, - }, + op::{prec, OpAssoc, OpHandler, OpPos, OpSignature, Operator, RuntimeOpHandler}, parser::Parser, }, stdlib::arithmetic, type_checker::{ - checked_ast::CheckedParam, checker::TypeChecker, - types::{std_types, Type, TypeTrait}, + types::{std_types, GetType, Type, TypeTrait}, }, util::{ error::{BaseErrorExt, LineInfo}, @@ -36,7 +29,7 @@ use super::{logical, system}; pub struct Initializer { operators: Vec, types: HashMap, - values: Vec<(Str, Value)>, + constants: Vec<(Str, Value)>, functions: Vec<(&'static str, Function)>, } @@ -44,10 +37,7 @@ impl Initializer { pub fn init_lexer(&self, lexer: &mut Lexer) { log::trace!("Initializing lexer with {} operators", self.operators.len()); for op in &self.operators { - // TODO: Why does this only add static operators to the lexer? - if let OperatorHandler::Static(_) = &op.handler { - lexer.operators.insert(op.info.symbol.clone()); - } + lexer.add_operator(op.info.symbol.clone()); } } @@ -58,7 +48,7 @@ impl Initializer { self.types.len() ); for op in &self.operators { - if let Err(e) = parser.define_op(op.clone()) { + if let Err(e) = parser.define_op(op.info.clone()) { panic!( "Parser initialization failed when adding operator '{:?}': {:?}", op, e @@ -68,6 +58,7 @@ impl Initializer { for ty in self.types.values() { match ty { Type::Literal(ref name) => parser.add_type(name.to_string()), + Type::Constructor(ref name, _, _) => parser.add_type(name.to_string()), Type::Alias(ref name, _) => parser.add_type(name.to_string()), _ => panic!("Expected literal or alias type but got {:?}", ty), } @@ -81,6 +72,9 @@ impl Initializer { self.functions.len(), self.operators.len() ); + for (name, val) in &self.constants { + type_checker.add_variable(name.to_string(), val.get_type().clone()); + } for op in &self.operators { type_checker.add_operator(op.clone()); } @@ -96,12 +90,12 @@ impl Initializer { pub fn init_environment(&self, env: &mut Environment) { log::trace!( "Initializing environment with {} values, {} functions, {} operators and {} types", - self.values.len(), + self.constants.len(), self.functions.len(), self.operators.len(), self.types.len() ); - for (name, val) in &self.values { + for (name, val) in &self.constants { if let Err(e) = env.add_value(name.clone(), val.clone(), &LineInfo::default()) { panic!( "Environment initialization failed when adding value {}: {}", @@ -125,7 +119,7 @@ impl Initializer { } for op in &self.operators { match &op.handler { - OperatorHandler::Runtime(RuntimeOperatorHandler { + OpHandler::Runtime(RuntimeOpHandler { function_name, signature, }) => { @@ -135,7 +129,7 @@ impl Initializer { if !signature.function_type().equals(func.get_fn_type()).success { panic!( "Function type mismatch for operator {}: expected {}, but got {}", - op.info.name.clone().yellow(), + op.info.symbol.clone().yellow(), signature.function_type().pretty_print(), func.get_type().pretty_print() ); @@ -143,8 +137,7 @@ impl Initializer { } } // Skip static operators - OperatorHandler::Parse(_) => {} - OperatorHandler::Static(_) => {} + OpHandler::Static(_) => {} } } for (name, ty) in &self.types { @@ -157,148 +150,76 @@ impl Initializer { } } } + + pub fn extend(&mut self, other: Initializer) { + log::trace!( + "Extending initializer with additional {} operators, {} types, {} constants, and {} functions", + other.operators.len(), + other.types.len(), + other.constants.len(), + other.functions.len() + ); + self.operators.extend(other.operators); + self.types.extend(other.types); + self.constants.extend(other.constants); + self.functions.extend(other.functions); + } } +//--------------------------------------------------------------------------------------// +// Standard Library // +//--------------------------------------------------------------------------------------// + pub fn stdlib() -> Initializer { Initializer { - //--------------------------------------------------------------------------------------// - // Operators // - //--------------------------------------------------------------------------------------// operators: vec![ - // Assignment operator, native to the language - // TODO: Implement this operator statically in the parser instead of using an operator handler - Operator::new_parse( - "assign".into(), - "=".into(), - OperatorPosition::Infix, - default_operator_precedence::ASSIGNMENT, - OperatorAssociativity::Right, - false, - |op| { - if let StaticOperatorAst::Infix(lhs, rhs) = op { - let info = lhs.info().join(rhs.info()); - Ast::Assignment { - annotation: None, - target: Box::new(lhs), - expr: Box::new(rhs), - info, - } - } else { - panic!("assign expects an infix operator"); - } - }, - ), - // Field access operator - Operator::new_static( - "field_access".into(), - ".".into(), - OperatorPosition::Infix, - default_operator_precedence::MEMBER_ACCESS, - OperatorAssociativity::Left, - false, - OperatorSignature::new( - vec![ - CheckedParam::from_str("record", std_types::ANY), - CheckedParam::from_str("field", std_types::ANY), - ], - std_types::ANY.clone(), - ), - |op| { - if let StaticOperatorAst::Infix(lhs, rhs) = op { - let key = if let Ast::Identifier { name, .. } = &rhs { - RecordKey::String(name.to_string()) - } else if let Ast::Literal { - value: Value::Number(interpreter::number::Number::UnsignedInteger(n)), - .. - } = &rhs - { - RecordKey::Number(Number::UnsignedInteger(n.clone())) - } else { - return Err(ParseError::new( - format!( - "Field access via {} requires a identifier or {} literal", - ".".yellow(), - std_types::UINT().pretty_print_color() - ), - rhs.info().clone(), - ) - .with_label( - format!( - "This is not an identifier or {}", - std_types::UINT().pretty_print_color() - ), - rhs.info().clone(), - ) - .with_hint(format!( - "Did you mean to use indexing via {} instead?", - "[]".yellow() - ))); - }; - let info = lhs.info().join(rhs.info()); - Ok(Ast::FieldAccess { - expr: Box::new(lhs), - field: key, - info, - }) - } else { - panic!("field_access expects an infix operator"); - } - }, - ), - // Addition operator Operator::new_runtime( "add".into(), "+".into(), - OperatorPosition::Infix, - default_operator_precedence::ADDITIVE, - OperatorAssociativity::Left, + OpPos::Infix, + prec::ADDITIVE_PREC, + OpAssoc::Left, false, - OperatorSignature::from_function(arithmetic::add().get_fn_type()), + OpSignature::from_function(arithmetic::add().get_fn_type()), ), Operator::new_runtime( "sub".into(), "-".into(), - OperatorPosition::Infix, - default_operator_precedence::ADDITIVE, - OperatorAssociativity::Left, + OpPos::Infix, + prec::ADDITIVE_PREC, + OpAssoc::Left, false, - OperatorSignature::from_function(arithmetic::sub().get_fn_type()), + OpSignature::from_function(arithmetic::sub().get_fn_type()), ), Operator::new_runtime( "mul".into(), "*".into(), - OperatorPosition::Infix, - default_operator_precedence::MULTIPLICATIVE, - OperatorAssociativity::Left, + OpPos::Infix, + prec::MULTIPLICATIVE_PREC, + OpAssoc::Left, false, - OperatorSignature::from_function(arithmetic::mul().get_fn_type()), + OpSignature::from_function(arithmetic::mul().get_fn_type()), ), Operator::new_runtime( "div".into(), "/".into(), - OperatorPosition::Infix, - default_operator_precedence::MULTIPLICATIVE, - OperatorAssociativity::Left, + OpPos::Infix, + prec::MULTIPLICATIVE_PREC, + OpAssoc::Left, false, - OperatorSignature::from_function(arithmetic::div().get_fn_type()), + OpSignature::from_function(arithmetic::div().get_fn_type()), ), - // Comparison operators Operator::new_runtime( "eq".into(), "==".into(), - OperatorPosition::Infix, - default_operator_precedence::EQUALITY, - OperatorAssociativity::Left, + OpPos::Infix, + prec::EQUALITY_PREC, + OpAssoc::Left, false, - OperatorSignature::from_function(logical::eq().get_fn_type()), + OpSignature::from_function(logical::eq().get_fn_type()), ), ], - - //--------------------------------------------------------------------------------------// - // Built-in Types // - //--------------------------------------------------------------------------------------// types: vec![ - std_types::ANY, std_types::TYPE, std_types::UNIT, std_types::STRING, @@ -334,13 +255,11 @@ pub fn stdlib() -> Initializer { ("int".into(), std_types::INT()), ("float".into(), std_types::FLOAT()), ("num".into(), std_types::NUM()), + ("List".into(), std_types::LIST()), + ("Map".into(), std_types::MAP()), ]) .collect(), - - //--------------------------------------------------------------------------------------// - // Constants // - //--------------------------------------------------------------------------------------// - values: vec![ + constants: vec![ ( Str::String("pi".to_string()), Value::Number(Number::FloatingPoint(FloatingPoint::Float64( @@ -392,10 +311,6 @@ pub fn stdlib() -> Initializer { Value::Number(Number::FloatingPoint(FloatingPoint::Float64(f64::NAN))), ), ], - - //--------------------------------------------------------------------------------------// - // Functions // - //--------------------------------------------------------------------------------------// functions: vec![ ("add", arithmetic::add()), ("sub", arithmetic::sub()), diff --git a/core/src/stdlib/logical.rs b/core/src/stdlib/logical.rs index 72a8b40..b99f584 100644 --- a/core/src/stdlib/logical.rs +++ b/core/src/stdlib/logical.rs @@ -20,14 +20,8 @@ pub fn eq() -> Function { Ok(Value::Boolean(values[0] == values[1])) }, vec![ - CheckedParam { - name: "lhs".to_string(), - ty: std_types::ANY, - }, - CheckedParam { - name: "rhs".to_string(), - ty: std_types::ANY, - }, + CheckedParam::from_str("lhs", std_types::ANY), + CheckedParam::from_str("rhs", std_types::ANY), ], std_types::BOOL, ) diff --git a/core/src/stdlib/system.rs b/core/src/stdlib/system.rs index d213aa6..ab6e4c4 100644 --- a/core/src/stdlib/system.rs +++ b/core/src/stdlib/system.rs @@ -23,10 +23,7 @@ pub fn print() -> Function { } Ok(Value::Unit) }, - vec![CheckedParam { - name: "values".to_string(), - ty: Type::Variable("T".into()), - }], + vec![CheckedParam::from_str("values", Type::Variable("T".into()))], std_types::UNIT, ) } @@ -46,10 +43,7 @@ pub fn dbg() -> Function { ); Ok(value) }, - vec![CheckedParam { - name: "values".to_string(), - ty: Type::Variable("T".into()), - }], + vec![CheckedParam::from_str("values", Type::Variable("T".into()))], Type::Variable("T".into()), ) } @@ -64,10 +58,7 @@ pub fn type_of() -> Function { } Ok(Value::Type(values[0].get_type().clone())) }, - vec![CheckedParam { - name: "value".to_string(), - ty: std_types::ANY, - }], + vec![CheckedParam::from_str("value", std_types::ANY)], std_types::TYPE, ) } @@ -99,10 +90,7 @@ pub fn exit() -> Function { ), } }, - vec![CheckedParam { - name: "code".to_string(), - ty: std_types::INT32, - }], + vec![CheckedParam::from_str("code", std_types::INT32)], std_types::UNIT, ) } diff --git a/core/src/stdlib/tests.rs b/core/src/stdlib/tests.rs index 02be010..4a9a8d3 100644 --- a/core/src/stdlib/tests.rs +++ b/core/src/stdlib/tests.rs @@ -35,7 +35,6 @@ mod tests { assert!(env.lookup_function("div").is_some()); assert!(env.lookup_function("eq").is_some()); // Types - assert!(env.get_type("any").is_some()); assert!(env.get_type("unit").is_some()); assert!(env.get_type("str").is_some()); assert!(env.get_type("char").is_some()); diff --git a/core/src/type_checker/checked_ast.rs b/core/src/type_checker/checked_ast.rs index 2105d14..2a98f26 100644 --- a/core/src/type_checker/checked_ast.rs +++ b/core/src/type_checker/checked_ast.rs @@ -1,5 +1,6 @@ use crate::{ interpreter::value::{Function, RecordKey, Value}, + parser::pattern::BindPattern, type_checker::types::Type, util::error::LineInfo, }; @@ -13,19 +14,25 @@ pub struct CheckedOperator { pub handler: Function, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct CheckedParam { - pub name: String, pub ty: Type, + pub pattern: BindPattern, } impl CheckedParam { - pub fn new(name: String, ty: Type) -> CheckedParam { - CheckedParam { name, ty } + pub fn new(pattern: BindPattern, ty: Type) -> CheckedParam { + CheckedParam { pattern, ty } } pub fn from_str>(name: S, ty: Type) -> CheckedParam { - CheckedParam::new(name.into(), ty) + CheckedParam::new( + BindPattern::Variable { + name: name.into(), + info: LineInfo::default(), + }, + ty, + ) } } @@ -43,7 +50,7 @@ pub enum CheckedAst { exprs: Vec, /// The type of the tuple, made up of the types of the elements. /// Each element's type is listed in the same order as the elements. - expr_types: Type, + ty: Type, info: LineInfo, }, /// A dynamic list of elements. @@ -85,20 +92,24 @@ pub enum CheckedAst { /// The argument to the function call arg: Box, /// The return type of the function call - return_type: Type, + ret_ty: Type, info: LineInfo, }, - /// A function definition is a named function with a list of parameters and a body expression - FunctionDef { + /// A lambda expression is an anonymous function that can be passed as a value. + Lambda { + /// The parameter of the lambda function param: CheckedParam, + /// The body of the lambda function body: Box, + /// The return type of the lambda function return_type: Type, + /// The type of the lambda function, which is a function type ty: Type, info: LineInfo, }, /// An assignment expression assigns a value to a variable via a matching pattern (identifier, destructuring of a tuple, record, etc.). Assignment { - target: CheckedBindPattern, + target: BindPattern, expr: Box, info: LineInfo, }, @@ -117,13 +128,13 @@ impl GetType for CheckedAst { match self { CheckedAst::Literal { value: v, info: _ } => v.get_type(), CheckedAst::LiteralType { .. } => &std_types::TYPE, - CheckedAst::Tuple { expr_types: ty, .. } => ty, + CheckedAst::Tuple { ty, .. } => ty, CheckedAst::List { ty, .. } => ty, CheckedAst::Record { ty, .. } => ty, CheckedAst::FieldAccess { ty, .. } => ty, CheckedAst::Identifier { ty, .. } => ty, - CheckedAst::FunctionCall { return_type, .. } => return_type, - CheckedAst::FunctionDef { ty, .. } => ty, + CheckedAst::FunctionCall { ret_ty, .. } => ret_ty, + CheckedAst::Lambda { ty, .. } => ty, CheckedAst::Assignment { .. } => &std_types::UNIT, CheckedAst::Block { exprs: _, ty, .. } => ty, } @@ -131,13 +142,21 @@ impl GetType for CheckedAst { } impl CheckedAst { - pub fn function_def( + pub fn unit(info: LineInfo) -> CheckedAst { + CheckedAst::Tuple { + exprs: vec![], + ty: std_types::UNIT, + info, + } + } + + pub fn lambda( param: CheckedParam, body: CheckedAst, return_type: Type, info: LineInfo, ) -> CheckedAst { - CheckedAst::FunctionDef { + CheckedAst::Lambda { ty: Type::Function(Box::new(FunctionType::new( param.clone(), return_type.clone(), @@ -159,7 +178,7 @@ impl CheckedAst { CheckedAst::FieldAccess { info, .. } => info, CheckedAst::Identifier { info, .. } => info, CheckedAst::FunctionCall { info, .. } => info, - CheckedAst::FunctionDef { info, .. } => info, + CheckedAst::Lambda { info, .. } => info, CheckedAst::Assignment { info, .. } => info, CheckedAst::Block { info, .. } => info, } @@ -171,7 +190,7 @@ impl CheckedAst { CheckedAst::LiteralType { .. } => (), CheckedAst::Tuple { exprs: elements, - expr_types: ty, + ty, .. } => { for element in elements { @@ -214,14 +233,14 @@ impl CheckedAst { CheckedAst::FunctionCall { expr: function, arg, - return_type, + ret_ty: return_type, .. } => { function.specialize(judgements, changed); arg.specialize(judgements, changed); *return_type = return_type.specialize(judgements, changed); } - CheckedAst::FunctionDef { + CheckedAst::Lambda { param, body, return_type, @@ -254,7 +273,7 @@ impl CheckedAst { } } - pub fn print_sexpr(&self) -> String { + pub fn print_expr(&self) -> String { match self { CheckedAst::Literal { value, info: _ } => value.pretty_print(), CheckedAst::LiteralType { value, info: _ } => value.pretty_print(), @@ -264,7 +283,7 @@ impl CheckedAst { "({})", elements .iter() - .map(|e| e.print_sexpr()) + .map(|e| e.print_expr()) .collect::>() .join(", ") ), @@ -274,7 +293,7 @@ impl CheckedAst { "[{}]", elements .iter() - .map(|e| e.print_sexpr()) + .map(|e| e.print_expr()) .collect::>() .join(", ") ), @@ -282,7 +301,7 @@ impl CheckedAst { "{{ {} }}", fields .iter() - .map(|(k, v)| format!("{}: {}", k, v.print_sexpr())) + .map(|(k, v)| format!("{}: {}", k, v.print_expr())) .collect::>() .join(", ") ), @@ -290,7 +309,7 @@ impl CheckedAst { expr: record, field, .. - } => format!("({}.{})", record.print_sexpr(), field), + } => format!("({}.{})", record.print_expr(), field), CheckedAst::Identifier { name, .. } => name.clone(), CheckedAst::FunctionCall { expr: function, @@ -312,22 +331,27 @@ impl CheckedAst { } format!( "{}({})", - function.print_sexpr(), + function.print_expr(), args.iter() - .map(|a| a.print_sexpr()) + .map(|a| a.print_expr()) .collect::>() .join(", ") ) } - CheckedAst::FunctionDef { param, body, .. } => { - format!("({} {} -> {})", param.ty, param.name, body.print_sexpr()) + CheckedAst::Lambda { param, body, .. } => { + format!( + "({} {} -> {})", + param.ty, + param.pattern.print_expr(), + body.print_expr() + ) } CheckedAst::Assignment { target: lhs, expr: rhs, .. } => { - format!("({} = {})", lhs.print_sexpr(), rhs.print_sexpr()) + format!("({} = {})", lhs.print_expr(), rhs.print_expr()) } CheckedAst::Block { exprs: expressions, .. @@ -335,7 +359,7 @@ impl CheckedAst { "{{ {} }}", expressions .iter() - .map(|e| e.print_sexpr()) + .map(|e| e.print_expr()) .collect::>() .join("; ") ), @@ -392,8 +416,12 @@ impl CheckedAst { } => { format!("{}({})", function.pretty_print(), arg.pretty_print()) } - Self::FunctionDef { param, body, .. } => { - format!("{} -> {}", param.name, body.pretty_print()) + Self::Lambda { param, body, .. } => { + format!( + "{} -> {}", + param.pattern.pretty_print(), + body.pretty_print() + ) } Self::Assignment { target: lhs, @@ -420,179 +448,3 @@ impl CheckedAst { } } } - -#[derive(Debug, Clone)] -pub enum CheckedBindPattern { - /// A variable binding pattern. - Variable { - /// The name of the variable. - name: String, - /// The type annotation for the variable. - info: LineInfo, - }, - /// A tuple binding pattern. - Tuple { - /// The elements of the tuple. - elements: Vec, - info: LineInfo, - }, - /// A record binding pattern. - Record { - /// The fields of the record. - fields: Vec<(RecordKey, CheckedBindPattern)>, - info: LineInfo, - }, - /// A list binding pattern. - List { - /// The elements of the list. - elements: Vec, - info: LineInfo, - }, - /// A wildcard pattern that matches any value. - Wildcard, - /// A literal pattern that matches a specific value. - Literal { - /// The value to match. - value: Value, - info: LineInfo, - }, - /// A rest of a collection pattern that matches the rest of a list. - Rest { - /// The name of the variable to bind the rest of the list. - name: String, - /// The type annotation for the variable. - info: LineInfo, - }, -} - -impl CheckedBindPattern { - pub fn info(&self) -> &LineInfo { - match self { - CheckedBindPattern::Variable { info, .. } => info, - CheckedBindPattern::Tuple { info, .. } => info, - CheckedBindPattern::Record { info, .. } => info, - CheckedBindPattern::List { info, .. } => info, - CheckedBindPattern::Wildcard => panic!("Wildcard pattern has no line info"), - CheckedBindPattern::Literal { info, .. } => info, - CheckedBindPattern::Rest { info, .. } => info, - } - } - - pub fn specialize(&mut self, _judgements: &TypeJudgements, _changed: &mut bool) { - match self { - CheckedBindPattern::Variable { name: _, info: _ } => (), - CheckedBindPattern::Tuple { elements, info: _ } => { - for element in elements { - element.specialize(_judgements, _changed); - } - } - CheckedBindPattern::Record { fields, info: _ } => { - for (_, element) in fields { - element.specialize(_judgements, _changed); - } - } - CheckedBindPattern::List { elements, info: _ } => { - for element in elements { - element.specialize(_judgements, _changed); - } - } - CheckedBindPattern::Wildcard => (), - CheckedBindPattern::Literal { value: _, info: _ } => (), - CheckedBindPattern::Rest { name: _, info: _ } => (), - } - } - - pub fn print_sexpr(&self) -> String { - match self { - CheckedBindPattern::Variable { name, .. } => name.clone(), - CheckedBindPattern::Tuple { elements, .. } => format!( - "({})", - elements - .iter() - .map(|e| e.print_sexpr()) - .collect::>() - .join(", ") - ), - CheckedBindPattern::Record { fields, .. } => format!( - "{{ {} }}", - fields - .iter() - .map(|(k, v)| format!("{}: {}", k, v.print_sexpr())) - .collect::>() - .join(", ") - ), - CheckedBindPattern::List { elements, .. } => format!( - "[{}]", - elements - .iter() - .map(|e| e.print_sexpr()) - .collect::>() - .join(", ") - ), - CheckedBindPattern::Wildcard => "_".to_string(), - CheckedBindPattern::Literal { value, .. } => value.pretty_print(), - CheckedBindPattern::Rest { name, .. } => format!("...{}", name), - } - } - - pub fn pretty_print(&self) -> String { - match self { - CheckedBindPattern::Variable { name, .. } => name.clone(), - CheckedBindPattern::Tuple { elements, .. } => { - let mut result = "(".to_string(); - for (i, v) in elements.iter().enumerate() { - result.push_str(&v.pretty_print()); - if i < elements.len() - 1 { - result.push_str(", "); - } - } - result.push(')'); - result - } - CheckedBindPattern::Record { fields, .. } => { - let mut result = "{ ".to_string(); - for (i, (k, v)) in fields.iter().enumerate() { - result.push_str(&format!("{}: {}", k, v.pretty_print())); - if i < fields.len() - 1 { - result.push_str(", "); - } - } - result.push_str(" }"); - result - } - CheckedBindPattern::List { elements, .. } => { - let mut result = "[".to_string(); - for (i, v) in elements.iter().enumerate() { - result.push_str(&v.pretty_print()); - if i < elements.len() - 1 { - result.push_str(", "); - } - } - result.push(']'); - result - } - CheckedBindPattern::Wildcard => "_".to_string(), - CheckedBindPattern::Literal { value, .. } => value.pretty_print(), - CheckedBindPattern::Rest { name, .. } => format!("...{}", name), - } - } - - pub fn is_wildcard(&self) -> bool { - matches!(self, CheckedBindPattern::Wildcard) - } -} - -impl PartialEq for CheckedBindPattern { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Variable { name: l0, .. }, Self::Variable { name: r0, .. }) => l0 == r0, - (Self::Tuple { elements: l0, .. }, Self::Tuple { elements: r0, .. }) => l0 == r0, - (Self::Record { fields: l0, .. }, Self::Record { fields: r0, .. }) => l0 == r0, - (Self::List { elements: l0, .. }, Self::List { elements: r0, .. }) => l0 == r0, - (Self::Wildcard, Self::Wildcard) => true, - (Self::Literal { value: l0, .. }, Self::Literal { value: r0, .. }) => l0 == r0, - (Self::Rest { name: l0, .. }, Self::Rest { name: r0, .. }) => l0 == r0, - _ => false, - } - } -} diff --git a/core/src/type_checker/checker.rs b/core/src/type_checker/checker.rs index 0442e4f..c78a15c 100644 --- a/core/src/type_checker/checker.rs +++ b/core/src/type_checker/checker.rs @@ -1,22 +1,24 @@ -use std::{borrow::Borrow, collections::HashMap}; +use std::{ + borrow::Borrow, + collections::{HashMap, HashSet}, +}; use colorful::Colorful; use crate::{ interpreter::value::{RecordKey, Value}, parser::{ + self, ast::{Ast, ParamAst, TypeAst}, error::ParseError, - op::{ - Operator, OperatorHandler, OperatorInfo, RuntimeOperatorHandler, StaticOperatorAst, - StaticOperatorHandler, - }, + op::{OpHandler, OpInfo, Operator, RuntimeOpHandler, StaticOpAst, StaticOpHandler}, + pattern::BindPattern, }, util::error::{BaseError, BaseErrorExt, LineInfo}, }; use super::{ - checked_ast::{CheckedAst, CheckedBindPattern, CheckedParam}, + checked_ast::{CheckedAst, CheckedParam}, types::{std_types, FunctionType, GetType, Type, TypeTrait}, }; @@ -78,7 +80,7 @@ impl From for TypeErrorVariant { } // The result of the type checker stage -pub type TypeResult = Result; +pub type TypeCheckerResult = Result; /// The type environment contains all the types and functions in the program. /// It is used to check the types of expressions and functions. @@ -121,8 +123,8 @@ impl TypeEnv { } // Add a variable to the type environment - pub fn add_variable(&mut self, name: &str, ty: Type) { - self.variables.insert(name.to_string(), ty); + pub fn add_variable(&mut self, name: String, ty: Type) { + self.variables.insert(name, ty); } // Add an operator to the type environment @@ -164,6 +166,10 @@ impl TypeChecker<'_> { self.env.add_function(name.to_string(), variation); } + pub fn add_variable(&mut self, name: String, ty: Type) { + self.env.add_variable(name, ty); + } + fn new_scope(&self) -> TypeChecker { TypeChecker { env: TypeEnv::default(), @@ -224,10 +230,10 @@ impl TypeChecker<'_> { operators } - fn lookup_static_operator(&self, symbol: &str) -> Option<&StaticOperatorHandler> { - let operator: Option<&StaticOperatorHandler> = self.env.operators.iter().find_map(|o| { + fn lookup_static_operator(&self, symbol: &str) -> Option<&StaticOpHandler> { + let operator: Option<&StaticOpHandler> = self.env.operators.iter().find_map(|o| { if o.info.symbol == symbol { - if let OperatorHandler::Static(op) = &o.handler { + if let OpHandler::Static(op) = &o.handler { return Some(op); } } @@ -240,14 +246,14 @@ impl TypeChecker<'_> { // ================== Scanning functions ================== - fn scan_forward(&mut self, expr: &[Ast]) -> TypeResult<()> { + fn scan_forward(&mut self, expr: &[Ast]) -> TypeCheckerResult<()> { for e in expr { if let Ast::Assignment { target, expr, .. } = e { - let Ast::Identifier { name, .. } = target.borrow() else { + let BindPattern::Variable { name, .. } = target else { continue; }; match expr.borrow() { - Ast::FunctionDef { + Ast::Lambda { param, body, return_type, @@ -255,7 +261,7 @@ impl TypeChecker<'_> { } => { let checked_param = self.check_param(param)?; let checked = - self.check_function(checked_param.clone(), body, return_type, info)?; + self.check_lambda(checked_param.clone(), body, return_type, info)?; let variation = FunctionType { param: checked_param, return_type: checked.get_type().clone(), @@ -269,7 +275,8 @@ impl TypeChecker<'_> { } _ => { let checked = self.check_expr(expr)?; - self.env.add_variable(name, checked.get_type().clone()); + self.env + .add_variable(name.clone(), checked.get_type().clone()); } } } @@ -279,7 +286,7 @@ impl TypeChecker<'_> { // ================== Type checking functions ================== - pub fn check_top_exprs(&mut self, exprs: &[Ast]) -> TypeResult> { + pub fn check_top_exprs(&mut self, exprs: &[Ast]) -> TypeCheckerResult> { self.scan_forward(exprs)?; let mut res = vec![]; for e in exprs { @@ -289,14 +296,14 @@ impl TypeChecker<'_> { } /// Check the type of an expression - pub fn check_expr(&mut self, expr: &Ast) -> TypeResult { + pub fn check_expr(&mut self, expr: &Ast) -> TypeCheckerResult { Ok(match expr { - Ast::FunctionDef { + Ast::Lambda { param, body, return_type, info, - } => self.check_function(self.check_param(param)?, body, return_type, info)?, + } => self.check_lambda(self.check_param(param)?, body, return_type, info)?, Ast::Literal { value, info } => CheckedAst::Literal { value: value.clone(), info: info.clone(), @@ -305,22 +312,24 @@ impl TypeChecker<'_> { Ast::Tuple { exprs, info } => self.check_tuple(exprs, info)?, Ast::List { exprs: elems, info } => self.check_list(elems, info)?, Ast::Record { fields, info } => self.check_record(fields, info)?, - Ast::FieldAccess { + Ast::MemderAccess { expr: record, field, info, } => self.check_field_access(record, field, info)?, Ast::Identifier { name, info } => self.check_identifier(name, info)?, - Ast::FunctionCall { - expr, - arg: args, - info, - } => self.check_call(expr, args, info)?, - Ast::Accumulate { - op_info: op, - exprs: operands, - info, - } => self.check_accumulate(op, operands, info)?, + Ast::FunctionCall { expr, arg, info } => { + // Try to specialize it + let types = self.env.types.keys().cloned().collect::>(); + let variables = self.env.variables.keys().cloned().collect::>(); + if let Some(res) = + parser::specialize::block_def_call(expr, arg, info, &types, Some(&variables)) + { + self.check_expr(&res?)? + } else { + self.check_call(expr, arg, info)? + } + } Ast::Binary { lhs, op_info, @@ -342,7 +351,7 @@ impl TypeChecker<'_> { }) } - fn check_type_expr(&self, expr: &TypeAst) -> TypeResult { + fn check_type_expr(&self, expr: &TypeAst) -> TypeCheckerResult { Ok(match expr { TypeAst::Identifier { name, info } => { self.lookup_type(name).cloned().ok_or_else(|| { @@ -350,29 +359,39 @@ impl TypeChecker<'_> { .with_label("This type is not defined".to_string(), info.clone()) })? } - TypeAst::Constructor { expr, arg, info } => { + TypeAst::Constructor { expr, params, info } => { let expr_info = expr.info(); let Type::Alias(base_name, base_type) = self.check_type_expr(expr)? else { return Err(TypeError::new( format!( "Cannot use constructor on non-constructor type {}", - expr.print_sexpr() + expr.print_expr() ), info.clone(), ) .with_label( - format!("This is not a constructable type {}", expr.print_sexpr()), + format!("This is not a constructable type {}", expr.print_expr()), expr_info.clone(), ) .into()); }; - let arg = self.check_type_expr(arg)?; - Type::Constructor(base_name, vec![arg], base_type) + let args = params + .iter() + .map(|a| self.check_type_expr(a)) + .collect::>>()?; + Type::Constructor(base_name, args, base_type) + } + TypeAst::Record { fields, .. } => { + let fields = fields + .iter() + .map(|(k, v)| Ok((k.clone(), self.check_type_expr(v)?))) + .collect::>>()?; + Type::Record(fields) } }) } - fn check_literal_type(&self, expr: &TypeAst) -> TypeResult { + fn check_literal_type(&self, expr: &TypeAst) -> TypeCheckerResult { let info = expr.info(); let ty = self.check_type_expr(expr)?; Ok(CheckedAst::LiteralType { @@ -381,35 +400,32 @@ impl TypeChecker<'_> { }) } - fn check_param(&self, param: &ParamAst) -> TypeResult { + fn check_param(&self, param: &ParamAst) -> TypeCheckerResult { let param_ty = if let Some(ty) = ¶m.ty { self.check_type_expr(ty)? } else { - return Err(TypeError::new( - format!("Missing parameter type for {}", param.name.clone().yellow()), - param.info.clone(), - ) - .with_label( - "Add a type to this parameter".to_string(), - param.info.clone(), - ) - .into()); + std_types::ANY // TODO: Infer a more specific type }; let param = CheckedParam { - name: param.name.clone(), + pattern: param.pattern.clone(), ty: param_ty, }; Ok(param) } - fn check_function( + fn check_lambda( &mut self, param: CheckedParam, body: &Ast, return_type: &Option, info: &LineInfo, - ) -> TypeResult { - let checked_body = self.new_scope().check_expr(body)?; + ) -> TypeCheckerResult { + let mut body_scope = self.new_scope(); + let pattern_names = binding_typed_names(¶m.pattern, ¶m.ty); + for (name, ty) in pattern_names.into_iter() { + body_scope.add_variable(name, ty); + } + let checked_body = body_scope.check_expr(body)?; let body_type = checked_body.get_type().clone(); let return_type = if let Some(ty) = &return_type { let ty = self.check_type_expr(ty)?; @@ -434,7 +450,7 @@ impl TypeChecker<'_> { body_type }; - Ok(CheckedAst::function_def( + Ok(CheckedAst::lambda( param, checked_body, return_type, @@ -442,13 +458,9 @@ impl TypeChecker<'_> { )) } - fn check_tuple(&mut self, elems: &[Ast], info: &LineInfo) -> TypeResult { + fn check_tuple(&mut self, elems: &[Ast], info: &LineInfo) -> TypeCheckerResult { if elems.is_empty() { - return Ok(CheckedAst::Tuple { - exprs: vec![], - expr_types: std_types::UNIT, - info: info.clone(), - }); + return Ok(CheckedAst::unit(info.clone())); } let checked_elems = self.check_top_exprs(elems)?; let elem_types = checked_elems @@ -458,12 +470,12 @@ impl TypeChecker<'_> { .collect::>(); Ok(CheckedAst::Tuple { exprs: checked_elems, - expr_types: Type::Tuple(elem_types), + ty: Type::Tuple(elem_types), info: info.clone(), }) } - fn check_list(&mut self, elems: &[Ast], info: &LineInfo) -> TypeResult { + fn check_list(&mut self, elems: &[Ast], info: &LineInfo) -> TypeCheckerResult { let checked_elems = self.check_top_exprs(elems)?; let elem_types = checked_elems .iter() @@ -494,11 +506,11 @@ impl TypeChecker<'_> { &mut self, pairs: &[(RecordKey, Ast)], info: &LineInfo, - ) -> TypeResult { + ) -> TypeCheckerResult { let pairs = pairs .iter() .map(|(k, v)| Ok((k.clone(), self.check_expr(v)?))) - .collect::>>()?; + .collect::>>()?; let record_type = Type::Record( pairs .iter() @@ -517,7 +529,7 @@ impl TypeChecker<'_> { record: &Ast, field: &RecordKey, info: &LineInfo, - ) -> TypeResult { + ) -> TypeCheckerResult { let record = self.check_expr(record)?; let record_ty = record.get_type(); if let Type::Record(fields) = record_ty { @@ -567,7 +579,7 @@ impl TypeChecker<'_> { } } - fn check_identifier(&self, name: &str, info: &LineInfo) -> TypeResult { + fn check_identifier(&self, name: &str, info: &LineInfo) -> TypeCheckerResult { Ok(match self.lookup_identifier(name) { Some(IdentifierType::Variable(ty)) => CheckedAst::Identifier { name: name.to_string(), @@ -611,79 +623,65 @@ impl TypeChecker<'_> { fn check_assignment( &mut self, annotation: &Option, - target: &Ast, + target: &BindPattern, expr: &Ast, info: &LineInfo, - ) -> TypeResult { - let target_name = match target { - Ast::Identifier { name, .. } => name, - _ => { - return Err(TypeError::new( - "Assignment expects an identifier".to_string(), - info.clone(), - ) - .with_label( - "This is not an identifier".to_string(), - target.info().clone(), - ) - .with_hint("Did you mean to assign to an identifier?".to_string()) - .into()); - } - }; - if let Some(existing) = self.lookup_local_identifier(target_name) { - let ty_name = match existing { - IdentifierType::Variable(_) => "Variable", - IdentifierType::Type(_) => "Type", - IdentifierType::Function(_) => "Function", - }; - return Err(TypeError::new( - format!( - "{} {} already exists", - ty_name, - target_name.clone().yellow() - ), - info.clone(), - ) - .with_hint("Use a different name for the variable".to_string()) - .into()); - } - let expr = self.check_expr(expr)?; - let body_ty = expr.get_type().clone(); - if let Some(ann) = annotation { - let ann_ty = self.check_type_expr(ann)?; - if !body_ty.subtype(&ann_ty).success { - return Err(TypeError::new( - format!( - "{} is not a valid subtype of {}", - expr.pretty_print(), - ann_ty.pretty_print_color(), - ), - info.clone(), - ) - .with_label( - format!( - "This should be of type {} but is {}", - ann_ty.pretty_print_color(), - body_ty.pretty_print_color() - ), - expr.info().clone(), - ) - .into()); + ) -> TypeCheckerResult { + match target { + BindPattern::Variable { name, .. } => { + if self.lookup_local_identifier(name).is_some() { + return Err(TypeError::new( + format!("{} is already defined", name.clone().yellow()), + info.clone(), + ) + .with_label( + "This already exists in the current scope".to_string(), + target.info().clone(), + ) + .into()); + } + let expr = self.check_expr(expr)?; + let ty = expr.get_type().clone(); + if let Some(ty_ast) = annotation { + let expected_ty = self.check_type_expr(ty_ast)?; + if !ty.subtype(&expected_ty).success { + return Err(TypeError::new( + format!( + "Cannot assign {} to {}", + ty.pretty_print_color(), + expected_ty.pretty_print_color() + ), + info.clone(), + ) + .with_label( + format!("This is of type {}", ty.pretty_print_color()), + expr.info().clone(), + ) + .with_label( + format!("This expected type {}", expected_ty.pretty_print_color()), + info.clone(), + ) + .into()); + } + } + self.add_variable(name.clone(), ty.clone()); + Ok(CheckedAst::Assignment { + target: BindPattern::Variable { + name: name.clone(), + info: info.clone(), + }, + expr: Box::new(expr), + info: info.clone(), + }) } + _ => Err(TypeErrorVariant::ParseError(ParseError::new( + format!("Invalid assignment target: {}", target.print_expr()), + target.info().clone(), + ))), } - let assign_info = info.join(expr.info()); - self.env.add_variable(target_name, body_ty.clone()); - Ok(CheckedAst::Assignment { - target: CheckedBindPattern::Variable { - name: target_name.to_string(), - info: target.info().clone(), - }, - expr: Box::new(expr), - info: assign_info, - }) } - fn check_block(&mut self, exprs: &[Ast], info: &LineInfo) -> TypeResult { + fn check_block(&mut self, exprs: &[Ast], info: &LineInfo) -> TypeCheckerResult { let mut scope = self.new_scope(); let exprs = scope.check_top_exprs(exprs)?; let ty = if let Some(expr) = exprs.last() { @@ -698,7 +696,12 @@ impl TypeChecker<'_> { }) } - fn check_call(&mut self, expr: &Ast, arg: &Ast, info: &LineInfo) -> TypeResult { + fn check_call( + &mut self, + expr: &Ast, + arg: &Ast, + info: &LineInfo, + ) -> TypeCheckerResult { // TODO: Add support for multiple function variants // TODO: This job should be done in the type checker // TODO: so that the interpreter can just call the function @@ -716,7 +719,7 @@ impl TypeChecker<'_> { if tr.success { log::info!( "Function call: {} : {} -> {} with argument {} : {}", - expr.print_sexpr(), + expr.print_expr(), param.ty.pretty_print_color(), ret.pretty_print_color(), arg.pretty_print(), @@ -730,7 +733,7 @@ impl TypeChecker<'_> { if changed { log::trace!( "Specialized call: {} : {} -> {}", - expr.print_sexpr(), + expr.print_expr(), param_ty.pretty_print_color(), ret_ty.pretty_print_color() ); @@ -738,7 +741,7 @@ impl TypeChecker<'_> { } Ok(CheckedAst::FunctionCall { - return_type: ret.clone(), + ret_ty: ret.clone(), expr: Box::new(expr), arg: Box::new(arg), info: info.clone(), @@ -772,84 +775,6 @@ impl TypeChecker<'_> { } } - /// Check the type of an accumulate expression. - /// An accumulate expression results in a function variation-specific call. - fn check_accumulate( - &mut self, - op_info: &OperatorInfo, - operands: &[Ast], - info: &LineInfo, - ) -> TypeResult { - let checked_operands = operands - .iter() - .map(|a| self.check_expr(a)) - .collect::>>()?; - let alternatives = self.lookup_operator(&op_info.symbol); - if let Some(op) = alternatives.iter().find(|op| { - checked_operands - .iter() - .zip(op.signature().params.iter()) - .all(|(operand, param)| operand.get_type().subtype(¶m.ty).success) - }) { - match &op.handler { - OperatorHandler::Runtime(RuntimeOperatorHandler { function_name, .. }) => { - // Construct a function call expression - let mut operands = operands.iter(); - let mut call = Ast::FunctionCall { - expr: Box::new(Ast::Identifier { - name: function_name.clone(), - info: info.clone(), - }), - arg: Box::new(operands.next().unwrap().clone()), - info: info.clone(), - }; - for arg in operands { - call = Ast::FunctionCall { - expr: Box::new(call), - arg: Box::new(arg.clone()), - info: info.clone(), - }; - } - self.check_expr(&call) - } - OperatorHandler::Parse(_) => { - unreachable!("Parse operators are not supported in accumulate expressions"); - } - OperatorHandler::Static(StaticOperatorHandler { handler, .. }) => { - // Evaluate the handler at compile-time - self.check_expr(&handler(StaticOperatorAst::Accumulate(operands.to_vec()))?) - } - } - } else { - let mut err = TypeError::new( - format!( - "Unknown accumulate operator {}", - op_info.symbol.clone().yellow() - ), - info.clone(), - ); - - if let Some(op) = alternatives.first() { - let params = op - .signature() - .params - .iter() - .map(|p| p.ty.pretty_print_color()) - .collect::>() - .join(", "); - let ret = op.signature().ret.pretty_print_color(); - err = err.with_hint(format!( - "Did you mean {} {} {}?", - params, - op_info.symbol.clone().yellow(), - ret - )); - } - - Err(err.into()) - } - } - /// Check the type of a binary expression. /// A binary expression results in a function variation-specific call on the form: /// @@ -866,16 +791,13 @@ impl TypeChecker<'_> { fn check_binary( &mut self, lhs: &Ast, - op_info: &OperatorInfo, + op_info: &OpInfo, rhs: &Ast, info: &LineInfo, - ) -> TypeResult { + ) -> TypeCheckerResult { if let Some(op) = self.lookup_static_operator(&op_info.symbol) { log::trace!("Found static operator: {}", op_info.symbol); - return self.check_expr(&(op.handler)(StaticOperatorAst::Infix( - lhs.clone(), - rhs.clone(), - ))?); + return self.check_expr(&(op.handler)(StaticOpAst::Infix(lhs.clone(), rhs.clone()))?); } let checked_lhs = self.check_expr(lhs)?; let checked_rhs = self.check_expr(rhs)?; @@ -915,7 +837,7 @@ impl TypeChecker<'_> { op.signature().ret.pretty_print_color() ); match &op.handler { - OperatorHandler::Runtime(RuntimeOperatorHandler { function_name, .. }) => { + OpHandler::Runtime(RuntimeOpHandler { function_name, .. }) => { let function_ty = self .lookup_function(function_name) .ok_or_else(|| { @@ -969,32 +891,27 @@ impl TypeChecker<'_> { info: info.clone(), }), arg: Box::new(checked_lhs), - return_type: inner_ret.clone(), + ret_ty: inner_ret.clone(), info: info.clone(), }), arg: Box::new(checked_rhs), - return_type: outer_ret.clone(), + ret_ty: outer_ret.clone(), info: info.clone(), }; log::trace!( "Binary: {} : {}", - result.print_sexpr(), + result.print_expr(), result.get_type().pretty_print_color() ); return Ok(result); } - OperatorHandler::Parse(_) => { - unreachable!("Parse operators are not supported in binary expressions"); - } - OperatorHandler::Static(StaticOperatorHandler { handler, .. }) => { + OpHandler::Static(StaticOpHandler { handler, .. }) => { log::trace!("Static operator: {}", op_info.symbol); // Evaluate the handler at compile-time - return self.check_expr(&handler(StaticOperatorAst::Infix( - lhs.clone(), - rhs.clone(), - ))?); + return self + .check_expr(&handler(StaticOpAst::Infix(lhs.clone(), rhs.clone()))?); } } } @@ -1035,10 +952,10 @@ impl TypeChecker<'_> { fn check_unary( &mut self, - op_info: &OperatorInfo, + op_info: &OpInfo, operand: &Ast, info: &LineInfo, - ) -> TypeResult { + ) -> TypeCheckerResult { let checked_operand = self.check_expr(operand)?; let operand_type = checked_operand.get_type(); let mut closest_match = None; @@ -1048,7 +965,7 @@ impl TypeChecker<'_> { continue; } match &op.handler { - OperatorHandler::Runtime(RuntimeOperatorHandler { function_name, .. }) => { + OpHandler::Runtime(RuntimeOpHandler { function_name, .. }) => { let call = Ast::FunctionCall { expr: Box::new(Ast::Identifier { name: function_name.clone(), @@ -1059,12 +976,9 @@ impl TypeChecker<'_> { }; return self.check_expr(&call); } - OperatorHandler::Parse(_) => { - unreachable!("Parse operators are not supported in unary expressions"); - } - OperatorHandler::Static(StaticOperatorHandler { handler, .. }) => { + OpHandler::Static(StaticOpHandler { handler, .. }) => { // Evaluate the handler at compile-time - return self.check_expr(&handler(StaticOperatorAst::Prefix(operand.clone()))?); + return self.check_expr(&handler(StaticOpAst::Prefix(operand.clone()))?); } } } @@ -1083,5 +997,173 @@ impl TypeChecker<'_> { Err(err.into()) } + /// TODO: Check the binding pattern against the expression type. + /// TODO: Create a new CheckedBindPattern that contains the type information for each variable in the pattern. + fn _check_binding_pattern( + &mut self, + pattern: &BindPattern, + expr_ty: &Type, + info: &LineInfo, + ) -> TypeCheckerResult<()> { + match pattern { + BindPattern::Variable { name, .. } => { + self.add_variable(name.clone(), expr_ty.clone()); + Ok(()) + } + BindPattern::Tuple { elements, .. } => { + if let Type::Tuple(types) = expr_ty { + if elements.len() != types.len() { + return Err(TypeError::new( + format!( + "Tuple pattern has {} elements, but the type has {} elements", + elements.len(), + types.len() + ), + info.clone(), + ) + .with_label( + format!("This pattern has {} elements", elements.len()), + pattern.info().clone(), + ) + .with_label( + format!("This type has {} elements", types.len()), + info.clone(), + ) + .into()); + } + for (element, ty) in elements.iter().zip(types) { + self._check_binding_pattern(element, ty, info)?; + } + Ok(()) + } else { + Err(TypeError::new( + format!( + "Cannot match tuple pattern with non-tuple type {}", + expr_ty.pretty_print_color() + ), + info.clone(), + ) + .with_label("This is not a tuple type".to_string(), info.clone()) + .into()) + } + } + BindPattern::Record { fields, .. } => { + if let Type::Record(types) = expr_ty { + for (key, pattern) in fields { + if let Some((_, ty)) = types.iter().find(|(k, _)| k == key) { + self._check_binding_pattern(pattern, ty, info)?; + } else { + return Err(TypeError::new( + format!( + "Field {} not found in record type {}", + key.to_string().yellow(), + expr_ty.pretty_print_color() + ), + info.clone(), + ) + .with_label( + format!( + "This record does not have the field {}", + key.to_string().yellow() + ), + pattern.info().clone(), + ) + .into()); + } + } + Ok(()) + } else { + Err(TypeError::new( + format!( + "Cannot match record pattern with non-record type {}", + expr_ty.pretty_print_color() + ), + info.clone(), + ) + .with_label("This is not a record type".to_string(), info.clone()) + .into()) + } + } + BindPattern::List { elements, .. } => { + if let Type::List(element_type) = expr_ty { + for element in elements { + self._check_binding_pattern(element, element_type, info)?; + } + Ok(()) + } else { + Err(TypeError::new( + format!( + "Cannot match list pattern with non-list type {}", + expr_ty.pretty_print_color() + ), + info.clone(), + ) + .with_label("This is not a list type".to_string(), info.clone()) + .into()) + } + } + BindPattern::Wildcard => Ok(()), + BindPattern::Literal { value, .. } => { + let value = value.as_value(); + let value_type = value.get_type(); + if !value_type.subtype(expr_ty).success { + return Err(TypeError::new( + format!( + "Literal pattern of type {} does not match type {}", + value_type.pretty_print_color(), + expr_ty.pretty_print_color() + ), + info.clone(), + ) + .with_label( + format!("This is of type {}", value_type.pretty_print_color()), + pattern.info().clone(), + ) + .into()); + } + Ok(()) + } + BindPattern::Rest { .. } => Ok(()), + } + } + // ================== Type inference functions ================== } + +/// Extract all names and their types from a binding pattern +fn binding_typed_names(pattern: &BindPattern, ty: &Type) -> HashSet<(String, Type)> { + fn visit(pattern: &BindPattern, ty: &Type, names: &mut HashSet<(String, Type)>) { + match pattern { + BindPattern::Variable { name, .. } => { + names.insert((name.clone(), ty.clone())); + } + BindPattern::Tuple { elements, .. } => { + if let Type::Tuple(types) = ty { + for (element, t) in elements.iter().zip(types) { + visit(element, t, names); + } + } + } + BindPattern::Record { fields, .. } => { + if let Type::Record(types) = ty { + for (key, pattern) in fields { + if let Some((_, t)) = types.iter().find(|(k, _)| k == key) { + visit(pattern, t, names); + } + } + } + } + BindPattern::List { elements, .. } => { + if let Type::List(element_type) = ty { + for element in elements { + visit(element, &element_type, names); + } + } + } + BindPattern::Wildcard | BindPattern::Rest { .. } | BindPattern::Literal { .. } => {} + } + } + let mut names = HashSet::new(); + visit(pattern, ty, &mut names); + names +} diff --git a/core/src/type_checker/mod.rs b/core/src/type_checker/mod.rs index ad6dc68..a01326a 100644 --- a/core/src/type_checker/mod.rs +++ b/core/src/type_checker/mod.rs @@ -1,4 +1,5 @@ pub mod checked_ast; pub mod checker; +mod specialize; mod tests; pub mod types; diff --git a/core/src/type_checker/specialize.rs b/core/src/type_checker/specialize.rs new file mode 100644 index 0000000..5e33221 --- /dev/null +++ b/core/src/type_checker/specialize.rs @@ -0,0 +1,950 @@ +use crate::parser::{ + ast::{Ast, ParamAst, TypeAst}, + error::ParseError, + op::{ASSIGNMENT_SYM, COMMA_SYM, MEMBER_ACCESS_SYM}, + parser::ParseResult, + pattern::LiteralPattern, +}; +use crate::{ + interpreter::{ + number::Number, + value::{RecordKey, Value}, + }, + parser::{op::OpInfo, pattern::BindPattern}, + type_checker::types::std_types, + util::error::{BaseErrorExt, LineInfo}, +}; +use colorful::Colorful; +use std::collections::HashSet; + +//--------------------------------------------------------------------------------------// +// Syntax Sugar // +//--------------------------------------------------------------------------------------// + +/// Specialize any top-level expressions by transforming them into +/// their concrete `AST` representation from loose `AST` representation. +/// This is done by transforming: +/// - `a; b; c` into `Block { exprs: [a, b, c] }` +/// - `x = y` into `Assignment { target: BindPatter::Variable { name: x }, expr: y }` +/// - `x.y` into `MemberAccess { expr: x, field: y }` +/// - `f(x, y) = b` into `Assignment { target: BindPattern::Function { name: f, params: [x, y] }, expr: b }` +/// - `f x, y = b` into `Assignment { target: BindPattern::Function { name: f, params: [x, y] }, expr: b }` +/// - `int f str x, bool y = b` into `Assignment { annotation: Some(int), target: BindPattern::Function { name: f, params: [str x, bool y] }, expr: b }` +/// - `List int` into `LiteralType { expr: TypeAst::Constructor { expr: List }, args: [int] }` +pub fn top(expr: Ast, types: &HashSet, variables: Option<&HashSet>) -> ParseResult { + match expr { + // Specialize type constructors to literal types + Ast::FunctionCall { expr, arg, info } => call(*expr, *arg, info, types, variables), + // Specialize assignments to binding patterns with optional type annotations + Ast::Binary { + lhs, + op_info, + rhs, + info, + } if op_info.symbol == ASSIGNMENT_SYM => assignment(*lhs, *rhs, info, types, variables), + Ast::Binary { + lhs, + op_info, + rhs, + info, + } if op_info.symbol == MEMBER_ACCESS_SYM => member_access(*lhs, *rhs, info), + Ast::Binary { + lhs, + op_info, + rhs, + info, + .. + } if op_info.symbol == COMMA_SYM => { + comma_sequence(*lhs, op_info, *rhs, info, types, variables) + } + // No specialization available + _ => Ok(expr), + } +} + +pub fn member_access(expr: Ast, rhs: Ast, info: LineInfo) -> ParseResult { + log::trace!( + "Specializing member access: {}.{}", + expr.print_expr().light_blue(), + rhs.print_expr().light_blue() + ); + Ok(Ast::MemderAccess { + expr: Box::new(expr), + field: record_key(rhs)?, + info, + }) +} + +/// Specialize a comma-separated sequence of expressions into a more specific expression. +/// E.g: +/// - `f x, y, z {...}` becomes `f x, y, z = {...}` if `f` is a function definition. +pub fn comma_sequence( + expr: Ast, + op_info: OpInfo, + rhs: Ast, + info: LineInfo, + types: &HashSet, + variables: Option<&HashSet>, +) -> ParseResult { + log::trace!("Specializing comma sequence: {:?}, {:?}", expr, rhs); + if let Some(res) = block_def_comma_sequence(&expr, op_info, &rhs, &info, types, variables) { + return res; + } + let mut exprs = flatten_sequence(expr); + exprs.push(rhs); + Ok(Ast::Tuple { + info: info.join(exprs.last().unwrap().info()), + exprs, + }) +} + +// Specialize function definitions with blocks +/// This is used to handle function definitions like: +/// ```ignore +/// f x, y, z {...} +/// List int f str x, int y {...} +/// ``` +/// Complex types can also be used here as seen above. +pub fn block_def_comma_sequence( + expr: &Ast, + op_info: OpInfo, + rhs: &Ast, + info: &LineInfo, + types: &HashSet, + variables: Option<&HashSet>, +) -> Option { + if variables.is_some() && matches!(rhs, Ast::FunctionCall { .. }) { + let Ast::FunctionCall { + expr: last_expr, + arg: last_arg, + .. + } = rhs + else { + unreachable!(); + }; + if matches!(**last_arg, Ast::Block { .. }) { + // If the outer arg is a block, we can potentially specialize it as a function definition with a block. + log::trace!("Trying to specialize potential function definition with block"); + let res = assignment( + Ast::Binary { + lhs: Box::new(expr.clone()), + op_info, + rhs: last_expr.clone(), + info: info.clone(), + }, + *last_arg.clone(), + info.clone(), + types, + variables, + ); + if res.is_ok() { + log::trace!("Specialized function definition successfully!"); + return Some(res); + } else { + log::trace!("Failed to specialize function definition: {:?}", res); + } + } + } + None +} + +// Specialize type constructors to literal types +pub fn call( + expr: Ast, + arg: Ast, + info: LineInfo, + types: &HashSet, + variables: Option<&HashSet>, +) -> ParseResult { + if let Some(res) = block_def_call(&expr, &arg, &info, types, variables) { + return res; + } + Ok(match expr { + Ast::Identifier { + name: constructor_name, + info: constructor_info, + } if types.contains(&constructor_name) && is_type_expr(&arg, types) => { + // If the function is a type constructor, we can specialize it as a literal type. + log::trace!( + "Specializing new type constructor: {}({})", + constructor_name.clone().light_blue(), + arg.print_expr().light_blue() + ); + let args = vec![into_type_ast(arg)?]; + Ast::LiteralType { + expr: TypeAst::Constructor { + expr: Box::new(TypeAst::Identifier { + name: constructor_name, + info: constructor_info, + }), + params: args, + info: info.clone(), + }, + } + } + Ast::LiteralType { + expr: + TypeAst::Constructor { + expr, + mut params, + info, + }, + } if is_type_expr(&arg, types) => { + // If the expression is a literal type constructor, we can add the argument as a type parameter. + log::trace!( + "Specializing existing type constructor: {}({}, {})", + expr.print_expr().light_blue(), + params + .iter() + .map(|p| p.pretty_print()) + .collect::>() + .join(", ") + .light_blue(), + arg.print_expr().light_blue() + ); + params.push(into_type_ast(arg)?); + Ast::LiteralType { + expr: TypeAst::Constructor { expr, params, info }, + } + } + Ast::FunctionCall { + expr: inner, + arg: arg_inner, + info: inner_info, + } => { + match *arg_inner { + Ast::Identifier { + name: constructor_name, + info: constructor_info, + } if types.contains(&constructor_name) && is_type_expr(&arg, types) => { + // Apply the new type argument to the type constructor + log::trace!( + "Specializing new type constructor call: {}({})", + constructor_name.clone().light_blue(), + arg.print_expr().light_blue() + ); + let args = vec![into_type_ast(arg)?]; + Ast::FunctionCall { + expr: inner, + arg: Box::new(Ast::LiteralType { + expr: TypeAst::Constructor { + expr: Box::new(TypeAst::Identifier { + name: constructor_name, + info: constructor_info, + }), + params: args, + info: inner_info.join(&info), + }, + }), + info: inner_info, + } + } + // If the expression is not a type constructor, create a function call as is. + _ => Ast::FunctionCall { + expr: Box::new(Ast::FunctionCall { + expr: inner, + arg: arg_inner, + info: inner_info, + }), + arg: Box::new(arg), + info, + }, + } + } + // If the expression is not a type constructor, create a function call as is. + _ => Ast::FunctionCall { + expr: Box::new(expr), + arg: Box::new(arg), + info, + }, + }) +} + +// Specialize function definitions with blocks +/// This is used to handle function definitions like: +/// ```ignore +/// f() {...} +/// f(x) {...} +/// bool f() {...} +/// int f str x {...} +/// Map int str f int x {...} +/// ``` +/// Complex types can also be used here as seen above. +pub fn block_def_call( + expr: &Ast, + arg: &Ast, + info: &LineInfo, + types: &HashSet, + variables: Option<&HashSet>, +) -> Option { + if variables.is_some() && matches!(arg, Ast::Block { .. }) { + // Check if function definition with block: + // `f() {...}`, `bool f() {...}`, `f(x) {...}` or `int f str x {...}` etc. + log::trace!("Trying to specializing potential function definition with block"); + let res = assignment(expr.clone(), arg.clone(), info.clone(), types, variables); + if res.is_ok() { + log::trace!("Specialized function definition successfully!"); + return Some(res); + } else { + log::trace!("Failed to specialize function definition: {:?}", res); + } + } + None +} + +pub fn record_key(expr: Ast) -> Result { + match expr { + Ast::Identifier { name, .. } => Ok(RecordKey::String(name.to_string())), + // Ast::Literal { + // value: Value::Number(Number::UnsignedInteger(n)), + // .. + // } => Ok(RecordKey::Number(Number::UnsignedInteger(n.clone()))), + _ => Err(ParseError::new( + format!( + "Field access via {} requires a identifier or {} literal", + ".".yellow(), + std_types::UINT().pretty_print_color() + ), + expr.info().clone(), + ) + .with_label( + format!( + "This is not an identifier or {}", + std_types::UINT().pretty_print_color() + ), + expr.info().clone(), + ) + .with_hint(format!( + "Did you mean to use indexing via {} instead?", + "[]".yellow() + ))), + } +} + +/// Convert a loose AST expression into a binding pattern. +pub fn assignment( + target: Ast, + body: Ast, + assignment_info: LineInfo, + types: &HashSet, + variables: Option<&HashSet>, +) -> ParseResult { + log::debug!( + "Specializing assignment: {} = {}", + target.print_expr(), + body.print_expr() + ); + log::trace!("Specializing assignment: {:?} = {:?}", target, body); + match &target { + // `x, y = ...`, `f x, y = ...`, `f int x, bool y = ...` or `int f int x, bool y = ...` etc. + // Potentially a function definition. + Ast::Binary { op_info, .. } if op_info.symbol == COMMA_SYM => definition( + flatten_sequence(target) + .into_iter() + .flat_map(flatten_calls) + .collect::>(), + body, + types, + variables, + ), + // `int x = ...`, `f x = ...`, `f int x = ...`, `int f x = ...` or `int f int x = ...`. + // Complex types can also be used here, like `List int list1 = ...` or `Map int str map1 = ...`. + // Potentially a function definition. + Ast::FunctionCall { .. } => definition(flatten_calls(target), body, types, variables), + // Try parse other generic binding patterns (non-typed) for assignments like: + // `_ = ...`, `x = ...`, `[x, y] = ...`, `{ a: x, b: y } = ...`, etc. + _ => Ok(Ast::Assignment { + annotation: None, + target: binding_pattern(target)?, + expr: Box::new(body), + info: assignment_info, + }), + } +} + +/// Parse a variable or function definition unless the variable is already defined (known per scope). +/// This is used to handle function definitions like: +/// ```ignore +/// x ... +/// int x ... +/// f x ... +/// f int x ... +/// int f x ... +/// int f int x ... +/// List int f int x ... +/// Map int str f int x ... +/// f x, y ... +/// f x, y, z ... +/// int f x, y ... +/// int f int x, y ... +/// Map int str f int x, List int y ... +/// ``` +/// Complex types can also be used here as seen above. +pub fn definition( + mut exprs: Vec, + body: Ast, + types: &HashSet, + variables: Option<&HashSet>, +) -> ParseResult { + log::trace!( + "Parsing function definition start from expressions: {:?}", + exprs + .iter() + .map(|e| e.print_expr()) + .collect::>() + ); + let func = next_typed_binding_pattern(&mut exprs, types)?; + let BindPattern::Variable { + name, + info: name_info, + } = func.pattern + else { + return Err(ParseError::new( + format!( + "Invalid function binding pattern: {}", + func.pattern.print_expr() + ), + func.pattern.info().clone(), + ) + .with_label( + format!( + "Expected a function name, found: {}", + func.pattern.print_expr() + ), + func.pattern.info().clone(), + )); + }; + log::trace!("Found variable or function name: {}", name.clone().yellow()); + if let Some(variables) = variables { + // Check if the variable is already defined in the current scope + // Otherwise create a new variable binding + log::trace!( + "Checking if variable `{}` is already defined in the current scope", + name + ); + if variables.contains(&name) { + return Err(ParseError::new( + format!("Variable `{}` is already defined", name), + name_info.clone(), + ) + .with_label("This already exist".to_string(), name_info.clone())); + } + } + if exprs.is_empty() { + // `int x = ...` + // If no parameters are found, it's a typed variable binding + log::trace!("Found a typed variable binding: {} = ...", name); + Ok(Ast::Assignment { + info: name_info.clone(), + annotation: func.ty, + target: BindPattern::Variable { + name, + info: name_info, + }, + expr: Box::new(body), + }) + } else { + // `f x = ...`, `int f x, y = ...`, or `int f int x, bool y, str z = ...`, etc. + let mut params = Vec::new(); + while !exprs.is_empty() { + // Parse the next (possibly typed) parameter from the remaining expressions + let param = next_typed_binding_pattern(&mut exprs, types)?; + log::trace!( + "Found function parameter: {}{}", + if let Some(annotation) = ¶m.ty { + format!("{} ", annotation.pretty_print().light_blue()) + } else { + "".to_string() + }, + param.pattern.pretty_print() + ); + params.push(param); + } + // Create a function definition via nested lambda expressions assigned to a variable + log::trace!( + "Creating function definition for {} with parameters: {:?}", + name, + params + .iter() + .map(|p| p.pattern.pretty_print()) + .collect::>() + ); + let info = name_info.join(body.info()); + Ok(create_function_assignment( + func.ty, name, name_info, params, body, info, + )) + } +} + +/// Create a function definition via nested lambda expressions assigned to a variable. +pub fn create_function_assignment( + annotation: Option, + name: String, + name_info: LineInfo, + params: Vec, + body: Ast, + info: LineInfo, +) -> Ast { + log::trace!( + "Creating function {}{}({}) = {}", + if let Some(annotation) = &annotation { + format!("{} ", annotation.pretty_print().light_blue()) + } else { + "".to_string() + }, + name, + params + .iter() + .map(|p| format!( + "{}{}", + if let Some(annotation) = &p.ty { + format!("{} ", annotation.pretty_print().light_blue()) + } else { + "".to_string() + }, + p.pattern.pretty_print() + )) + .collect::>() + .join(", "), + body.print_expr() + ); + let mut params = params.into_iter().rev(); + let first = params.next().unwrap(); + let mut lambda = Ast::Lambda { + info: first.pattern.info().join(body.info()), + param: first, + body: Box::new(body), + return_type: None, + }; + for param in params { + lambda = Ast::Lambda { + info: lambda.info().join(param.pattern.info()), + param, + body: Box::new(lambda), + return_type: None, + }; + } + Ast::Assignment { + info, + annotation, + target: BindPattern::Variable { + name, + info: name_info, + }, + expr: Box::new(lambda), + } +} + +//--------------------------------------------------------------------------------------// +// Utilities // +//--------------------------------------------------------------------------------------// + +pub fn flatten_calls(expr: Ast) -> Vec { + let mut exprs = Vec::new(); + let mut queue = vec![expr]; + while let Some(current) = queue.pop() { + match current { + // Flatten function calls into a list of expressions + Ast::FunctionCall { expr, arg, .. } => { + queue.push(*arg); + queue.push(*expr); + } + _ => exprs.push(current), + } + } + exprs +} + +pub fn flatten_sequence(expr: Ast) -> Vec { + let mut exprs = Vec::new(); + let mut queue = vec![expr]; + while let Some(current) = queue.pop() { + match current { + // Flatten sequences of expressions + Ast::Binary { + lhs, op_info, rhs, .. + } if op_info.symbol == COMMA_SYM => { + queue.push(*rhs); + queue.push(*lhs); + } + _ => exprs.push(current), + } + } + exprs +} + +/// Parse a typed binding pattern from a list of expressions, like: +/// ```ignore +/// int f str x, bool y +/// ``` +/// Yelds the first parameter with its type annotation, like: +/// ```ignore +/// ParamAst { +/// ty: Some(TypeAst::Identifier { name: "int".to_string(), info: ... }), +/// pattern: BindPattern::Variable { +/// name: "f".to_string(), +/// info: ..., +/// } +/// } +/// ``` +/// And modifies the `exprs` vector to remove the first parameter and its type annotation. +/// So the remaining expressions will be the rest of the parameters, like: +/// ```ignore +/// str x, bool y +/// ``` +pub fn next_typed_binding_pattern( + exprs: &mut Vec, + types: &HashSet, +) -> Result { + if exprs.is_empty() { + unreachable!("Expected at least one expression for a binding pattern"); + } + let annotation = match try_type_annotation(exprs, types) { + Some(Ok(annotation)) => Some(annotation), + Some(Err(err)) => return Err(err), + None => None, + }; + if exprs.is_empty() { + return Err(ParseError::new( + "Expected a binding pattern".to_string(), + annotation.unwrap().info().clone(), + )); + } + Ok(ParamAst { + ty: annotation, + pattern: binding_pattern(exprs.remove(0))?, + }) +} + +// Convert a loose AST expression into a binding pattern. +pub fn binding_pattern(expr: Ast) -> Result { + match expr { + Ast::Identifier { name, .. } if name.starts_with("_") => Ok(BindPattern::Wildcard), + Ast::Identifier { name, info } => Ok(BindPattern::Variable { name, info }), + Ast::Tuple { exprs, info } => { + let elements = exprs + .into_iter() + .map(binding_pattern) + .collect::, _>>()?; + Ok(BindPattern::Tuple { elements, info }) + } + Ast::Record { fields, info } => { + let fields = fields + .into_iter() + .map(|(k, v)| Ok((k, binding_pattern(v)?))) + .collect::, _>>()?; + Ok(BindPattern::Record { fields, info }) + } + Ast::List { exprs, info } => { + let elements = exprs + .into_iter() + .map(binding_pattern) + .collect::, _>>()?; + Ok(BindPattern::List { elements, info }) + } + Ast::Literal { value, info } => Ok(BindPattern::Literal { + value: literal_pattern(value).ok_or( + ParseError::new("Expected a literal value pattern".to_string(), info.clone()) + .with_label("This is not valid".to_string(), info.clone()), + )?, + info, + }), + _ => Err(ParseError::new( + format!("Invalid binding pattern: {}", expr.print_expr()), + expr.info().clone(), + )), + } +} + +/// A helper function to convert a `Value` into a `LiteralPattern`. +pub fn literal_pattern(value: Value) -> Option { + Some(match value { + Value::Number(n) => match n { + Number::UnsignedInteger(u) => LiteralPattern::UnsignedInteger(u), + Number::SignedInteger(i) => LiteralPattern::SignedInteger(i), + _ => return None, // Only unsigned and signed integers are supported as literals + }, + Value::String(s) => LiteralPattern::String(s), + Value::Char(c) => LiteralPattern::Char(c), + Value::Boolean(b) => LiteralPattern::Boolean(b), + _ => return None, + }) +} + +/// Try to parse a type annotation from a list of expressions given from a function definition assignment. +/// +/// ## Example +/// ```ignore +/// int f str x, bool y +/// ``` +/// Becomes: (`int`, `f(str(x)), bool(y)`) +/// +/// ## Example +/// ```ignore +/// List int f int x +/// ``` +/// Becomes: (`List(int)`, `f(int(x))`) +pub fn try_type_annotation( + exprs: &mut Vec, + types: &HashSet, +) -> Option> { + let mut annotations: Vec<&Ast> = Vec::new(); + let mut is_type_constructor = false; + for expr in exprs.iter() { + // Check if the expression is a type expression + if !is_type_expr(expr, types) { + break; + } + // If the expression is a type expression, we add it to the annotations + annotations.push(expr); + if annotations.is_empty() { + // `(((HashMap int) str) f) int x` + if let Ast::Identifier { name, .. } = expr { + if types.contains(name) { + is_type_constructor = true; + } + } + } + } + // Make immutable for the rest of the function + let annotations = annotations; + is_type_constructor = is_type_constructor && annotations.len() > 1; // A constructor must have at least one type argument + + // Create a type from the annotation expressions + if annotations.is_empty() { + return None; + } + // We have type annotation expressions! 🎉 + if is_type_constructor { + // We have a type constructor, like `List int` or `Map str int` + // The first expression is the type constructor name, the rest are type arguments + let Ast::Identifier { + name: constructor_name, + info: constructor_info, + } = annotations.first().unwrap() + else { + unreachable!(); + }; + let mut params = Vec::new(); + for param in annotations.iter().skip(1) { + match into_type_ast((*param).clone()) { + Ok(type_ast) => { + params.push(type_ast); + } + Err(err) => { + return Some(Err(err.with_label( + format!("Invalid type argument: {}", param.print_expr()), + param.info().clone(), + ))); + } + } + } + let last_info = params + .last() + .map_or(constructor_info.clone(), |p| p.info().clone()); + let type_expr = TypeAst::Constructor { + expr: Box::new(TypeAst::Identifier { + name: constructor_name.clone(), + info: constructor_info.clone(), + }), + params, + info: constructor_info.join(&last_info), + }; + // Remove the type annotation expressions from the list + exprs.drain(0..annotations.len()); + Some(Ok(type_expr)) + } else { + // No type constructor found, make sure there is only ONE type annotation expression + if annotations.len() != 1 { + let mut err = ParseError::new( + "Expected a single type annotation expression".to_string(), + annotations[0].info().clone(), + ); + if let Some(invalid) = annotations.get(1..) { + let first_info = annotations.first().unwrap().info(); + let last_info = invalid.last().unwrap().info(); + err = err.with_label("These are invalid".to_string(), first_info.join(last_info)); + } + err = err.with_hint("Did you mean to define a variable or function?".to_string()); + + return Some(Err(err)); + } + match into_type_ast(annotations[0].clone()) { + Ok(type_ast) => { + // Remove the type annotation expressions from the list + exprs.drain(0..annotations.len()); + Some(Ok(type_ast)) + } + Err(err) => Some(Err(err.with_label( + "Invalid type annotation".to_string(), + annotations[0].info().clone(), + ))), + } + } +} + +/// Quick check if a single expression can be converted into a `TypeAst`. +pub fn is_type_expr(expr: &Ast, types: &HashSet) -> bool { + match expr { + Ast::Identifier { name, .. } => types.contains(name), + // Sum types like `int | str` + Ast::Binary { lhs, rhs, .. } => is_type_expr(lhs, types) && is_type_expr(rhs, types), + Ast::List { exprs, .. } if exprs.len() == 1 => is_type_expr(&exprs[0], types), + Ast::LiteralType { .. } => true, + Ast::Literal { .. } => false, + // Product type like `(int, str)` + Ast::Tuple { exprs, .. } => exprs.iter().all(|e| is_type_expr(e, types)), + // Product type like `{ a: int, b: str }` + Ast::Record { fields, .. } => fields.iter().all(|(_, v)| is_type_expr(v, types)), + _ => false, + } +} + +pub fn into_type_ast(expr: Ast) -> Result { + log::trace!( + "Converting expression into TypeAst: {}", + expr.print_expr().light_blue() + ); + match expr { + Ast::Identifier { name, info } => Ok(TypeAst::Identifier { + name: name.clone(), + info: info.clone(), + }), + Ast::List { mut exprs, info } if exprs.len() == 1 => { + let inner = into_type_ast(exprs.remove(0))?; + Ok(TypeAst::Constructor { + expr: Box::new(TypeAst::Identifier { + name: "List".to_string(), + info: info.clone(), + }), + params: vec![inner], + info: info.clone(), + }) + } + Ast::Record { fields, info } => { + let fields = fields + .into_iter() + .map(|(key, value)| Ok((key, into_type_ast(value)?))) + .collect::, ParseError>>()?; + Ok(TypeAst::Record { fields, info }) + } + Ast::LiteralType { expr } => Ok(expr.clone()), + _ => Err(ParseError::new( + format!("Expected a type expression, found: {}", expr.print_expr()), + expr.info().clone(), + )), + } +} + +/// Takes a function name, a list of parameters and a body and rolls them into a single assignment expression. +/// Parameters are rolled into a nested function definition. +/// All parameters are sorted like: +/// ```lento +/// func(a, b, c) = expr +/// ``` +/// becomes: +/// ```lento +/// func = a -> b -> c -> expr +/// ``` +/// +/// # Arguments +/// - `func_name` The name of the function +/// - `params` A list of parameters in left-to-right order: `a, b, c` +/// - `body` The body of the function +pub fn _roll_function_definition(params: Vec, body: Ast) -> Ast { + assert!(!params.is_empty(), "Expected at least one parameter"); + let info = body + .info() + .join(params.last().map(|p| p.pattern.info()).unwrap()); + let mut params = params.iter().rev(); + let mut function = Ast::Lambda { + param: params.next().unwrap().clone(), + body: Box::new(body), + return_type: None, + info, + }; + for param in params { + function = Ast::Lambda { + info: function.info().join(param.pattern.info()), + param: param.clone(), + body: Box::new(function), + return_type: None, + }; + } + function +} + +/// Takes a function name, a list of arguments and rolls them into a single function call expression. +/// Arguments are rolled into a nested function call. +/// All arguments are sorted like: +/// ```lento +/// func(a, b, c) +/// ``` +/// becomes: +/// ```lento +/// func(a)(b)(c) +/// ``` +pub fn roll_function_call(expr: Ast, args: Vec, types: &HashSet) -> Ast { + let last_info = args + .last() + .map(|a| a.info().clone()) + .unwrap_or(expr.info().clone()); + let call_info = expr.info().join(&last_info); + match expr { + Ast::Identifier { + name: constructor_name, + info: constructor_info, + } if types.contains(&constructor_name) && args.iter().all(|a| is_type_expr(a, types)) => { + // If the expression is a type constructor, we can specialize it as a literal type. + log::trace!( + "Specializing new type constructor: {}({})", + constructor_name.clone().light_blue(), + args.iter() + .map(|a| a.print_expr().light_blue().to_string()) + .collect::>() + .join(", ") + ); + let args = args + .into_iter() + .map(into_type_ast) + .collect::, _>>() + .unwrap(); + Ast::LiteralType { + expr: TypeAst::Constructor { + expr: Box::new(TypeAst::Identifier { + name: constructor_name, + info: constructor_info, + }), + params: args, + info: call_info.clone(), + }, + } + } + expr => { + // If the expression is not a type constructor, we can create a function call as is. + log::trace!( + "Creating function call: {}({})", + expr.print_expr().light_blue(), + args.iter() + .map(|a| a.print_expr().light_blue().to_string()) + .collect::>() + .join(", ") + ); + let mut args = args.into_iter(); + let mut call = Ast::FunctionCall { + expr: Box::new(expr), + arg: Box::new(args.next().unwrap()), + info: call_info.clone(), + }; + for arg in args { + let arg_info = call_info.join(arg.info()); + call = Ast::FunctionCall { + expr: Box::new(call), + arg: Box::new(arg), + info: arg_info, + }; + } + call + } + } +} diff --git a/core/src/type_checker/tests.rs b/core/src/type_checker/tests.rs index 22aaf94..1cd5841 100644 --- a/core/src/type_checker/tests.rs +++ b/core/src/type_checker/tests.rs @@ -4,15 +4,28 @@ mod tests { use crate::{ interpreter::value::Value, - parser::parser::from_string, - stdlib::init::stdlib, + parser::{parser::from_string, pattern::BindPattern}, + stdlib::init::{stdlib, Initializer}, type_checker::{ checked_ast::CheckedAst, - checker::TypeChecker, + checker::{TypeChecker, TypeCheckerResult, TypeErrorVariant}, types::{std_types, Type, TypeTrait}, }, }; + fn check_str_one(input: &str, init: Option<&Initializer>) -> TypeCheckerResult { + let mut parser = from_string(input.to_string()); + let mut checker = TypeChecker::default(); + if let Some(init) = init { + init.init_parser(&mut parser); + init.init_type_checker(&mut checker); + } + match parser.parse_one() { + Ok(ast) => checker.check_expr(&ast), + Err(err) => Err(TypeErrorVariant::ParseError(err)), + } + } + #[test] fn types() { let types = [ @@ -53,4 +66,63 @@ mod tests { assert!(std_types::CHAR.subtype(&outer).success); assert!(std_types::BOOL.subtype(&outer).success); } + + #[test] + fn invalid_function() { + let result = check_str_one("() 1", Some(&stdlib())); + dbg!("{:?}", &result); + assert!(result.is_err()); + } + + #[test] + fn function_def_with_return_type_single_no_parens_block() { + let result = check_str_one("int f int x { x + 5 }", Some(&stdlib())).unwrap(); + if let CheckedAst::Assignment { target, expr, .. } = result { + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "f"); + } + assert!(matches!(*expr, CheckedAst::Lambda { .. })); + if let CheckedAst::Lambda { param, body, .. } = *expr { + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "x"); + } + assert!(matches!(*body, CheckedAst::Block { .. })); + } + } else { + panic!( + "Expected function definition with return type and no parens: {:?}", + result + ); + } + } + + #[test] + fn function_def_with_return_type_many_no_parens_block() { + let result = check_str_one("int add int x, int y { x + y }", Some(&stdlib())).unwrap(); + if let CheckedAst::Assignment { target, expr, .. } = result { + assert!(matches!(target, BindPattern::Variable { .. })); + if let BindPattern::Variable { name, .. } = target { + assert_eq!(name, "add"); + } + assert!(matches!(*expr, CheckedAst::Lambda { .. })); + if let CheckedAst::Lambda { param, body, .. } = *expr { + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "y"); + } + assert!(matches!(*body, CheckedAst::Lambda { .. })); + if let CheckedAst::Lambda { param, body, .. } = *body { + if let BindPattern::Variable { name, .. } = ¶m.pattern { + assert_eq!(name, "x"); + } + assert!(matches!(*body, CheckedAst::Block { .. })); + } + } + } else { + panic!( + "Expected function definition with return type and no parens: {:?}", + result + ); + } + } } diff --git a/core/src/type_checker/types.rs b/core/src/type_checker/types.rs index b45aa67..55876a3 100644 --- a/core/src/type_checker/types.rs +++ b/core/src/type_checker/types.rs @@ -21,14 +21,14 @@ pub trait GetType { pub type TypeJudgements = HashMap; -pub struct TypeResult { +pub struct TypeJudgeResult { pub success: bool, pub judgements: TypeJudgements, } -impl TypeResult { +impl TypeJudgeResult { pub fn success() -> Self { - TypeResult { + TypeJudgeResult { success: true, judgements: HashMap::new(), } @@ -37,14 +37,14 @@ impl TypeResult { pub fn judge(self, name: Str, ty: Type) -> Self { let mut judgements = self.judgements; judgements.insert(name, ty); - TypeResult { + TypeJudgeResult { success: self.success, judgements, } } pub fn fail() -> Self { - TypeResult { + TypeJudgeResult { success: false, judgements: HashMap::new(), } @@ -52,7 +52,7 @@ impl TypeResult { pub fn and(self, other: Self) -> Self { if self.success && other.success { - TypeResult { + TypeJudgeResult { success: true, judgements: self .judgements @@ -61,7 +61,7 @@ impl TypeResult { .collect(), } } else { - TypeResult::fail() + TypeJudgeResult::fail() } } @@ -91,12 +91,12 @@ impl TypeResult { } } -impl From for TypeResult { +impl From for TypeJudgeResult { fn from(success: bool) -> Self { if success { - TypeResult::success() + TypeJudgeResult::success() } else { - TypeResult::fail() + TypeJudgeResult::fail() } } } @@ -106,7 +106,7 @@ pub trait TypeTrait { /// Check if the type is equal to the other type. /// Two types are equal if they are the same type. /// The equality relation is reflexive, symmetric, and transitive. - fn equals(&self, other: &Self) -> TypeResult { + fn equals(&self, other: &Self) -> TypeJudgeResult { self.subtype(other).and(other.subtype(self)) } @@ -130,12 +130,12 @@ pub trait TypeTrait { /// /// **Covariant return type**: The return type of the subtype function (int) is a subtype of the return type of the supertype function (number). \ /// **Contravariant parameter type** : The parameter type of the supertype function (number) is a supertype of the parameter type of the subtype function (int). - fn subtype(&self, other: &Self) -> TypeResult; + fn subtype(&self, other: &Self) -> TypeJudgeResult; fn simplify(self) -> Self; fn specialize(&self, judgements: &TypeJudgements, changed: &mut bool) -> Self; } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct FunctionType { pub param: CheckedParam, pub return_type: Type, @@ -165,7 +165,7 @@ impl FunctionType { } impl TypeTrait for FunctionType { - fn subtype(&self, other: &Self) -> TypeResult { + fn subtype(&self, other: &Self) -> TypeJudgeResult { self.param .ty .subtype(&other.param.ty) @@ -175,7 +175,7 @@ impl TypeTrait for FunctionType { fn simplify(self) -> Self { FunctionType { param: CheckedParam { - name: self.param.name, + pattern: self.param.pattern, ty: self.param.ty.simplify(), }, return_type: self.return_type.simplify(), @@ -185,7 +185,7 @@ impl TypeTrait for FunctionType { fn specialize(&self, judgements: &TypeJudgements, changed: &mut bool) -> Self { FunctionType { param: CheckedParam { - name: self.param.name.clone(), + pattern: self.param.pattern.clone(), ty: self.param.ty.specialize(judgements, changed), }, return_type: self.return_type.specialize(judgements, changed), @@ -193,7 +193,7 @@ impl TypeTrait for FunctionType { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq, Hash)] pub enum Type { /// A literal type (name). /// Examples such as `int`, `float`, `string`, `char`, `bool`, `unit`, `any`. @@ -230,6 +230,12 @@ pub enum Type { /// A list type is a product type. List(Box), + /// A map type. + /// The first argument is the type of the keys in the map. + /// The second argument is the type of the values in the map. + /// A map type is a product type. + Map(Box, Box), + /// A record type. /// The first argument is the list of fields in the record. /// A record type is a product type. @@ -277,7 +283,7 @@ pub enum Type { } impl TypeTrait for Type { - fn subtype(&self, other: &Type) -> TypeResult { + fn subtype(&self, other: &Type) -> TypeJudgeResult { let subtype = match (self, other) { (Type::Literal(Str::Str("any")), _) => true.into(), // TODO: Find a way to use `std_types::ANY` here. (_, Type::Literal(Str::Str("any"))) => true.into(), @@ -285,18 +291,18 @@ impl TypeTrait for Type { (Type::Alias(_, ty1), _) => ty1.subtype(other), (_, Type::Alias(_, ty)) => self.subtype(ty), (Type::Variable(s1), Type::Variable(s2)) => (s1 == s2).into(), - (Type::Variable(v), t) => TypeResult::success().judge(v.clone(), t.clone()), - (t, Type::Variable(v)) => TypeResult::success().judge(v.clone(), t.clone()), + (Type::Variable(v), t) => TypeJudgeResult::success().judge(v.clone(), t.clone()), + (t, Type::Variable(v)) => TypeJudgeResult::success().judge(v.clone(), t.clone()), (Type::Constructor(s1, params1, _), Type::Constructor(s2, params2, _)) => { if s1 == s2 && params1.len() == params2.len() { params1 .iter() .zip(params2) - .fold(TypeResult::success(), |acc, (p1, p2)| { + .fold(TypeJudgeResult::success(), |acc, (p1, p2)| { acc.and(p1.subtype(p2)) }) } else { - TypeResult::fail() + TypeJudgeResult::fail() } } (Type::Function(ty1), Type::Function(ty2)) => { @@ -308,42 +314,42 @@ impl TypeTrait for Type { types1 .iter() .zip(types2) - .fold(TypeResult::success(), |acc, (t1, t2)| { + .fold(TypeJudgeResult::success(), |acc, (t1, t2)| { acc.and(t1.subtype(t2)) }) } else { - TypeResult::fail() + TypeJudgeResult::fail() } } (Type::List(t1), Type::List(t2)) => t1.subtype(t2), (Type::Record(fields1), Type::Record(fields2)) => { if fields1.len() == fields2.len() { fields1.iter().zip(fields2).fold( - TypeResult::success(), + TypeJudgeResult::success(), |acc, ((n1, t1), (n2, t2))| { if n1 == n2 { acc.and(t1.subtype(t2)) } else { - TypeResult::fail() + TypeJudgeResult::fail() } }, ) } else { - TypeResult::fail() + TypeJudgeResult::fail() } } (Type::Sum(types1), Type::Sum(types2)) => { - types1.iter().fold(TypeResult::success(), |acc, t1| { + types1.iter().fold(TypeJudgeResult::success(), |acc, t1| { acc.and( types2 .iter() - .fold(TypeResult::fail(), |acc, t2| acc.or(t1.subtype(t2))), + .fold(TypeJudgeResult::fail(), |acc, t2| acc.or(t1.subtype(t2))), ) }) } (_, Type::Sum(types)) => types .iter() - .fold(TypeResult::fail(), |acc, t| acc.or(self.subtype(t))), + .fold(TypeJudgeResult::fail(), |acc, t| acc.or(self.subtype(t))), (Type::Variant(parent1, name1, fields1), Type::Variant(parent2, name2, fields2)) => { parent1 .equals(parent2) @@ -353,7 +359,7 @@ impl TypeTrait for Type { fields1 .iter() .zip(fields2) - .fold(TypeResult::success(), |acc, (t1, t2)| { + .fold(TypeJudgeResult::success(), |acc, (t1, t2)| { acc.and(t1.subtype(t2)) }), ) @@ -379,6 +385,9 @@ impl TypeTrait for Type { } Type::Tuple(types) => Type::Tuple(types.into_iter().map(Type::simplify).collect()), Type::List(t) => Type::List(Box::new(t.simplify())), + Type::Map(key, value) => { + Type::Map(Box::new(key.simplify()), Box::new(value.simplify())) + } Type::Record(fields) => { Type::Record(fields.into_iter().map(|(n, t)| (n, t.simplify())).collect()) } @@ -429,6 +438,10 @@ impl TypeTrait for Type { .collect(), ), Type::List(t) => Type::List(Box::new(t.specialize(judgements, changed))), + Type::Map(key, value) => Type::Map( + Box::new(key.specialize(judgements, changed)), + Box::new(value.specialize(judgements, changed)), + ), Type::Record(fields) => Type::Record( fields .iter() @@ -474,6 +487,7 @@ impl Display for Type { write!(f, ")") } Type::List(t) => write!(f, "[{}]", t), + Type::Map(key, value) => write!(f, "Map({}, {})", key, value), Type::Record(fields) => { write!(f, "{{")?; for (i, (name, t)) in fields.iter().enumerate() { @@ -544,6 +558,9 @@ impl Type { } } Type::List(t) => format!("[{}]", t.pretty_print()), + Type::Map(key, value) => { + format!("Map({}, {})", key.pretty_print(), value.pretty_print()) + } Type::Record(fields) => { let mut result = String::new(); result.push('{'); @@ -626,6 +643,13 @@ impl Type { } } Type::List(t) => format!("[{}]", t.pretty_print_color()), + Type::Map(key, value) => { + format!( + "Map({}, {})", + key.pretty_print_color(), + value.pretty_print_color() + ) + } Type::Record(fields) => { let mut result = String::new(); result.push_str(&"{".dark_gray().to_string()); @@ -810,4 +834,33 @@ pub mod std_types { pub fn NUM() -> Type { Type::Alias(Str::Str("num"), Box::new(Type::Sum(vec![INT(), FLOAT()]))) } + + //---------------------------------------------------------------------------------------// + // Container types // + //---------------------------------------------------------------------------------------// + + /// A list type. + /// The type of a list of elements of type `T`. + #[allow(non_snake_case)] + pub fn LIST() -> Type { + Type::Constructor( + Str::Str("List"), + vec![Type::Variable(Str::Str("T"))], + Box::new(Type::List(Box::new(Type::Variable(Str::Str("T"))))), + ) + } + + /// A map type. + /// The type of a map with keys of type `K` and values of type `V`. + #[allow(non_snake_case)] + pub fn MAP() -> Type { + Type::Constructor( + Str::Str("Map"), + vec![Type::Variable(Str::Str("K")), Type::Variable(Str::Str("V"))], + Box::new(Type::Map( + Box::new(Type::Variable(Str::Str("K"))), + Box::new(Type::Variable(Str::Str("V"))), + )), + ) + } } diff --git a/core/src/util/error.rs b/core/src/util/error.rs index 2e513fd..5388749 100644 --- a/core/src/util/error.rs +++ b/core/src/util/error.rs @@ -1,4 +1,4 @@ -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord)] +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] pub struct LocationInfo { pub index: usize, pub line: usize, @@ -37,7 +37,7 @@ impl Default for LocationInfo { } } -#[derive(Debug, Clone, Default, PartialEq, PartialOrd, Eq, Ord)] +#[derive(Debug, Clone, Default, PartialEq, PartialOrd, Eq, Ord, Hash)] pub struct LineInfo { pub start: LocationInfo, pub end: LocationInfo,