diff --git a/crates/deno_task_shell/src/grammar.pest b/crates/deno_task_shell/src/grammar.pest index abc6947..6289e84 100644 --- a/crates/deno_task_shell/src/grammar.pest +++ b/crates/deno_task_shell/src/grammar.pest @@ -202,27 +202,33 @@ bitwise_or = { "|" } logical_and = { "&&" } logical_or = { "||" } +unary_plus = { "+" } +unary_minus = { "-" } +logical_not = { "!" } +bitwise_not = { "~" } +increment = { "++" } +decrement = { "--" } + unary_arithmetic_expr = !{ - (unary_arithmetic_op | post_arithmetic_op) ~ (parentheses_expr | VARIABLE | NUMBER) | - (parentheses_expr | VARIABLE | NUMBER) ~ post_arithmetic_op + unary_pre_arithmetic_expr | unary_post_arithmetic_expr } -unary_arithmetic_op = _{ - unary_plus | unary_minus | logical_not | bitwise_not +unary_pre_arithmetic_expr = !{ + pre_arithmetic_op ~ (parentheses_expr | VARIABLE | NUMBER) } -unary_plus = { "+" } -unary_minus = { "-" } -logical_not = { "!" } -bitwise_not = { "~" } +unary_post_arithmetic_expr = !{ + (parentheses_expr | VARIABLE | NUMBER) ~ post_arithmetic_op +} + +pre_arithmetic_op= !{ + increment | decrement | unary_plus | unary_minus | logical_not | bitwise_not +} post_arithmetic_op = !{ increment | decrement } -increment = { "++" } -decrement = { "--" } - assignment_operator = _{ assign | multiply_assign | divide_assign | modulo_assign | add_assign | subtract_assign | left_shift_assign | right_shift_assign | bitwise_and_assign | bitwise_xor_assign | bitwise_or_assign diff --git a/crates/deno_task_shell/src/parser.rs b/crates/deno_task_shell/src/parser.rs index 2745ca0..9a3cbf8 100644 --- a/crates/deno_task_shell/src/parser.rs +++ b/crates/deno_task_shell/src/parser.rs @@ -431,11 +431,6 @@ pub enum ArithmeticPart { operator: UnaryArithmeticOp, operand: Box, }, - #[error("Invalid post arithmetic expression")] - PostArithmeticExpr { - operand: Box, - operator: PostArithmeticOp, - }, #[error("Invalid variable")] Variable(String), #[error("Invalid number")] @@ -481,7 +476,9 @@ pub enum AssignmentOp { #[cfg_attr(feature = "serialization", derive(serde::Serialize))] #[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] #[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub enum UnaryArithmeticOp { +pub enum PreArithmeticOp { + Increment, // ++ + Decrement, // -- Plus, // + Minus, // - LogicalNot, // ! @@ -490,12 +487,20 @@ pub enum UnaryArithmeticOp { #[cfg_attr(feature = "serialization", derive(serde::Serialize))] #[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Copy)] pub enum PostArithmeticOp { Increment, // ++ Decrement, // -- } +#[cfg_attr(feature = "serialization", derive(serde::Serialize))] +#[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum UnaryArithmeticOp { + Pre(PreArithmeticOp), + Post(PostArithmeticOp), +} + #[cfg_attr(feature = "serialization", derive(serde::Serialize))] #[cfg_attr( feature = "serialization", @@ -1419,56 +1424,123 @@ fn parse_arithmetic_expr(pair: Pair) -> Result { fn parse_unary_arithmetic_expr(pair: Pair) -> Result { let mut inner = pair.into_inner(); - let first = inner.next().unwrap(); + let first = inner + .next() + .ok_or_else(|| miette!("Expected unary operator"))?; match first.as_rule() { - Rule::unary_arithmetic_op => { - let op = parse_unary_arithmetic_op(first)?; - let operand = parse_arithmetic_expr(inner.next().unwrap())?; + Rule::unary_pre_arithmetic_expr => unary_pre_arithmetic_expr(first), + Rule::unary_post_arithmetic_expr => unary_post_arithmetic_expr(first), + _ => Err(miette!( + "Unexpected rule in unary arithmetic expression: {:?}", + first.as_rule() + )), + } +} + +fn unary_pre_arithmetic_expr(pair: Pair) -> Result { + let mut inner = pair.into_inner(); + let first = inner + .next() + .ok_or_else(|| miette!("Expected unary pre operator"))?; + let second = inner.next().ok_or_else(|| miette!("Expected operand"))?; + let operand = match second.as_rule() { + Rule::parentheses_expr => { + let inner = second.into_inner().next().unwrap(); + let parts = parse_arithmetic_sequence(inner)?; + Ok(ArithmeticPart::ParenthesesExpr(Box::new(Arithmetic { + parts, + }))) + } + Rule::VARIABLE => Ok(ArithmeticPart::Variable(second.as_str().to_string())), + Rule::NUMBER => Ok(ArithmeticPart::Number(second.as_str().to_string())), + _ => Err(miette!( + "Unexpected rule in arithmetic expression: {:?}", + second.as_rule() + )), + }?; + + match first.as_rule() { + Rule::pre_arithmetic_op => { + let op = parse_pre_arithmetic_op(first)?; Ok(ArithmeticPart::UnaryArithmeticExpr { - operator: op, + operator: UnaryArithmeticOp::Pre(op), operand: Box::new(operand), }) } Rule::post_arithmetic_op => { - let operand = parse_arithmetic_expr(inner.next().unwrap())?; let op = parse_post_arithmetic_op(first)?; - Ok(ArithmeticPart::PostArithmeticExpr { - operand: Box::new(operand), - operator: op, - }) - } - _ => { - let operand = parse_arithmetic_expr(first)?; - let op = parse_post_arithmetic_op(inner.next().unwrap())?; - Ok(ArithmeticPart::PostArithmeticExpr { + Ok(ArithmeticPart::UnaryArithmeticExpr { + operator: UnaryArithmeticOp::Post(op), operand: Box::new(operand), - operator: op, }) } + _ => Err(miette!( + "Unexpected rule in unary arithmetic operator: {:?}", + first.as_rule() + )), } } -fn parse_unary_arithmetic_op(pair: Pair) -> Result { - match pair.as_str() { - "+" => Ok(UnaryArithmeticOp::Plus), - "-" => Ok(UnaryArithmeticOp::Minus), - "!" => Ok(UnaryArithmeticOp::LogicalNot), - "~" => Ok(UnaryArithmeticOp::BitwiseNot), +fn unary_post_arithmetic_expr(pair: Pair) -> Result { + let mut inner = pair.into_inner(); + let first = inner + .next() + .ok_or_else(|| miette!("Expected unary post operator"))?; + let second = inner.next().ok_or_else(|| miette!("Expected operand"))?; + + let operand = match first.as_rule() { + Rule::parentheses_expr => { + let inner = first.into_inner().next().unwrap(); + let parts = parse_arithmetic_sequence(inner)?; + Ok(ArithmeticPart::ParenthesesExpr(Box::new(Arithmetic { + parts, + }))) + } + Rule::VARIABLE => Ok(ArithmeticPart::Variable(first.as_str().to_string())), + Rule::NUMBER => Ok(ArithmeticPart::Number(first.as_str().to_string())), + _ => Err(miette!( + "Unexpected rule in arithmetic expression: {:?}", + first.as_rule() + )), + }?; + let op = parse_post_arithmetic_op(second)?; + Ok(ArithmeticPart::UnaryArithmeticExpr { + operator: UnaryArithmeticOp::Post(op), + operand: Box::new(operand), + }) +} + +fn parse_pre_arithmetic_op(pair: Pair) -> Result { + let first = pair + .into_inner() + .next() + .ok_or_else(|| miette!("Expected increment or decrement operator"))?; + match first.as_rule() { + Rule::increment => Ok(PreArithmeticOp::Increment), + Rule::decrement => Ok(PreArithmeticOp::Decrement), + Rule::unary_plus => Ok(PreArithmeticOp::Plus), + Rule::unary_minus => Ok(PreArithmeticOp::Minus), + Rule::logical_not => Ok(PreArithmeticOp::LogicalNot), + Rule::bitwise_not => Ok(PreArithmeticOp::BitwiseNot), _ => Err(miette!( - "Invalid unary arithmetic operator: {}", - pair.as_str() + "Unexpected rule in pre arithmetic operator: {:?}", + first.as_rule() )), } } fn parse_post_arithmetic_op(pair: Pair) -> Result { - match pair.as_str() { - "++" => Ok(PostArithmeticOp::Increment), - "--" => Ok(PostArithmeticOp::Decrement), + let first = pair + .into_inner() + .next() + .ok_or_else(|| miette!("Expected increment or decrement operator"))?; + match first.as_rule() { + Rule::increment => Ok(PostArithmeticOp::Increment), + Rule::decrement => Ok(PostArithmeticOp::Decrement), _ => Err(miette!( - "Invalid post arithmetic operator: {}", - pair.as_str() + "Unexpected rule in post arithmetic operator: {:?}", + first.as_rule() )), } } diff --git a/crates/deno_task_shell/src/shell/execute.rs b/crates/deno_task_shell/src/shell/execute.rs index c17c39e..52ce617 100644 --- a/crates/deno_task_shell/src/shell/execute.rs +++ b/crates/deno_task_shell/src/shell/execute.rs @@ -12,47 +12,23 @@ use thiserror::Error; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; -use crate::parser::AssignmentOp; -use crate::parser::BinaryOp; -use crate::parser::Condition; -use crate::parser::ConditionInner; -use crate::parser::ElsePart; -use crate::parser::IoFile; -use crate::parser::RedirectOpInput; -use crate::parser::RedirectOpOutput; -use crate::parser::UnaryOp; -use crate::shell::commands::ShellCommand; -use crate::shell::commands::ShellCommandContext; -use crate::shell::types::pipe; -use crate::shell::types::ArithmeticResult; -use crate::shell::types::ArithmeticValue; -use crate::shell::types::EnvChange; -use crate::shell::types::ExecuteResult; -use crate::shell::types::FutureExecuteResult; -use crate::shell::types::ShellPipeReader; -use crate::shell::types::ShellPipeWriter; -use crate::shell::types::ShellState; - -use crate::parser::Arithmetic; -use crate::parser::ArithmeticPart; -use crate::parser::BinaryArithmeticOp; -use crate::parser::Command; -use crate::parser::CommandInner; -use crate::parser::IfClause; -use crate::parser::PipeSequence; -use crate::parser::PipeSequenceOperator; -use crate::parser::Pipeline; -use crate::parser::PipelineInner; -use crate::parser::Redirect; -use crate::parser::RedirectFd; -use crate::parser::RedirectOp; -use crate::parser::Sequence; -use crate::parser::SequentialList; -use crate::parser::SimpleCommand; -use crate::parser::UnaryArithmeticOp; -use crate::parser::Word; -use crate::parser::WordPart; -use crate::shell::types::WordEvalResult; +use crate::parser::{ + AssignmentOp, BinaryOp, Condition, ConditionInner, ElsePart, IoFile, + RedirectOpInput, RedirectOpOutput, UnaryArithmeticOp, UnaryOp, +}; +use crate::shell::commands::{ShellCommand, ShellCommandContext}; +use crate::shell::types::{ + pipe, ArithmeticResult, ArithmeticValue, EnvChange, ExecuteResult, + FutureExecuteResult, ShellPipeReader, ShellPipeWriter, ShellState, + WordEvalResult, +}; + +use crate::parser::{ + Arithmetic, ArithmeticPart, BinaryArithmeticOp, Command, CommandInner, + IfClause, PipeSequence, PipeSequenceOperator, Pipeline, PipelineInner, + Redirect, RedirectFd, RedirectOp, Sequence, SequentialList, SimpleCommand, + Word, WordPart, +}; use super::command::execute_unresolved_command_name; use super::command::UnresolvedCommandName; @@ -598,7 +574,10 @@ async fn evaluate_arithmetic_part( }? } }; - state.apply_env_var(name, &applied_value.to_string()); + state.apply_change(&EnvChange::SetShellVar( + (&name).to_string(), + applied_value.value.to_string(), + )); Ok( applied_value .clone() @@ -640,11 +619,7 @@ async fn evaluate_arithmetic_part( } ArithmeticPart::UnaryArithmeticExpr { operator, operand } => { let val = Box::pin(evaluate_arithmetic_part(operand, state)).await?; - apply_unary_op(*operator, val) - } - ArithmeticPart::PostArithmeticExpr { operand, .. } => { - let val = Box::pin(evaluate_arithmetic_part(operand, state)).await?; - Ok(val) + apply_unary_op(state, *operator, val, operand) } ArithmeticPart::Variable(name) => state .get_var(name) @@ -728,19 +703,15 @@ fn apply_conditional_binary_op( } fn apply_unary_op( + state: &mut ShellState, op: UnaryArithmeticOp, val: ArithmeticResult, + operand: &ArithmeticPart, ) -> Result { - match op { - UnaryArithmeticOp::Plus => Ok(val), - UnaryArithmeticOp::Minus => val.checked_neg(), - UnaryArithmeticOp::LogicalNot => Ok(if val.is_zero() { - ArithmeticResult::new(ArithmeticValue::Integer(1)) - } else { - ArithmeticResult::new(ArithmeticValue::Integer(0)) - }), - UnaryArithmeticOp::BitwiseNot => val.checked_not(), - } + let result = val.unary_op(operand, op)?; + let result_clone = result.clone(); + state.apply_changes(&result_clone.changes); + Ok(result) } async fn execute_pipe_sequence( @@ -1349,6 +1320,7 @@ fn evaluate_word_parts( if !current_text.is_empty() { result.extend(evaluate_word_text(state, current_text, is_quoted)?); } + result.with_changes(changes); Ok(result) } .boxed_local() diff --git a/crates/deno_task_shell/src/shell/types.rs b/crates/deno_task_shell/src/shell/types.rs index 3c87caa..f8db925 100644 --- a/crates/deno_task_shell/src/shell/types.rs +++ b/crates/deno_task_shell/src/shell/types.rs @@ -13,12 +13,17 @@ use std::rc::Rc; use std::str::FromStr; use futures::future::LocalBoxFuture; +use miette::miette; use miette::Error; use miette::IntoDiagnostic; use miette::Result; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; +use crate::parser::ArithmeticPart; +use crate::parser::PostArithmeticOp; +use crate::parser::PreArithmeticOp; +use crate::parser::UnaryArithmeticOp; use crate::shell::fs_util; use super::commands::builtin_commands; @@ -579,6 +584,141 @@ impl ArithmeticResult { } } + pub fn unary_op( + &self, + operand: &ArithmeticPart, + op: UnaryArithmeticOp, + ) -> Result { + match op { + UnaryArithmeticOp::Post(op_type) => match &self.value { + ArithmeticValue::Integer(val) => match operand { + ArithmeticPart::Variable(name) => { + let mut new_changes = self.changes.clone(); + new_changes.push(EnvChange::SetShellVar( + name.to_string(), + match op_type { + PostArithmeticOp::Increment => (*val + 1).to_string(), + PostArithmeticOp::Decrement => (*val - 1).to_string(), + }, + )); + Ok(ArithmeticResult { + value: ArithmeticValue::Integer(*val), + changes: new_changes, + }) + } + _ => Err(miette!( + "Invalid arithmetic result type for post-increment: {}", + self + )), + }, + ArithmeticValue::Float(val) => match operand { + ArithmeticPart::Variable(name) => { + let mut new_changes = self.changes.clone(); + new_changes.push(EnvChange::SetShellVar( + name.to_string(), + match op_type { + PostArithmeticOp::Increment => (*val + 1.0).to_string(), + PostArithmeticOp::Decrement => (*val - 1.0).to_string(), + }, + )); + Ok(ArithmeticResult { + value: ArithmeticValue::Float(*val), + changes: new_changes, + }) + } + _ => Err(miette!( + "Invalid arithmetic result type for post-increment: {}", + self + )), + }, + }, + UnaryArithmeticOp::Pre(op_type) => match &self.value { + ArithmeticValue::Integer(val) => match operand { + ArithmeticPart::Variable(name) => { + let mut new_changes = self.changes.clone(); + if op_type == PreArithmeticOp::Increment + || op_type == PreArithmeticOp::Decrement + { + new_changes.push(EnvChange::SetShellVar( + name.to_string(), + match op_type { + PreArithmeticOp::Increment => (*val + 1).to_string(), + PreArithmeticOp::Decrement => (*val - 1).to_string(), + _ => Err(miette!("No change to ENV need for: {}", self))?, + }, + )); + } + + Ok(ArithmeticResult { + value: match op_type { + PreArithmeticOp::Increment => { + ArithmeticValue::Integer(*val + 1) + } + PreArithmeticOp::Decrement => { + ArithmeticValue::Integer(*val - 1) + } + PreArithmeticOp::Plus => ArithmeticValue::Integer((*val).abs()), + PreArithmeticOp::Minus => { + ArithmeticValue::Integer(-(*val).abs()) + } + PreArithmeticOp::BitwiseNot => ArithmeticValue::Integer(!*val), + PreArithmeticOp::LogicalNot => { + ArithmeticValue::Integer(if *val == 0 { 1 } else { 0 }) + } + }, + changes: new_changes, + }) + } + _ => Err(miette!( + "Invalid arithmetic result type for pre-increment: {}", + self + )), + }, + ArithmeticValue::Float(val) => match operand { + ArithmeticPart::Variable(name) => { + let mut new_changes = self.changes.clone(); + if op_type == PreArithmeticOp::Increment + || op_type == PreArithmeticOp::Decrement + { + new_changes.push(EnvChange::SetShellVar( + name.to_string(), + match op_type { + PreArithmeticOp::Increment => (*val + 1.0).to_string(), + PreArithmeticOp::Decrement => (*val - 1.0).to_string(), + _ => Err(miette!("No change to ENV need for: {}", self))?, + }, + )); + } + + Ok(ArithmeticResult { + value: match op_type { + PreArithmeticOp::Increment => { + ArithmeticValue::Float(*val + 1.0) + } + PreArithmeticOp::Decrement => { + ArithmeticValue::Float(*val - 1.0) + } + PreArithmeticOp::Plus => ArithmeticValue::Float((*val).abs()), + PreArithmeticOp::Minus => ArithmeticValue::Float(-(*val).abs()), + PreArithmeticOp::BitwiseNot => { + ArithmeticValue::Integer(!(*val as i64)) + } + PreArithmeticOp::LogicalNot => { + ArithmeticValue::Float(if *val == 0.0 { 1.0 } else { 0.0 }) + } + }, + changes: new_changes, + }) + } + _ => Err(miette!( + "Invalid arithmetic result type for pre-increment: {}", + self + )), + }, + }, + } + } + pub fn checked_add( &self, other: &ArithmeticResult, @@ -880,45 +1020,6 @@ impl ArithmeticResult { }) } - pub fn checked_neg(&self) -> Result { - let result = match &self.value { - ArithmeticValue::Integer(val) => val - .checked_neg() - .map(ArithmeticValue::Integer) - .ok_or_else(|| miette::miette!("Integer overflow: -{}", val))?, - ArithmeticValue::Float(val) => { - let result = -val; - if result.is_finite() { - ArithmeticValue::Float(result) - } else { - return Err(miette::miette!("Float overflow: -{}", val)); - } - } - }; - - Ok(ArithmeticResult { - value: result, - changes: self.changes.clone(), - }) - } - - pub fn checked_not(&self) -> Result { - let result = match &self.value { - ArithmeticValue::Integer(val) => ArithmeticValue::Integer(!val), - ArithmeticValue::Float(_) => { - return Err(miette::miette!( - "Invalid arithmetic result type for bitwise NOT: {}", - self - )) - } - }; - - Ok(ArithmeticResult { - value: result, - changes: self.changes.clone(), - }) - } - pub fn checked_shl( &self, other: &ArithmeticResult, @@ -1117,4 +1218,8 @@ impl WordEvalResult { pub fn join(&self, sep: &str) -> String { self.value.join(sep) } + + pub fn with_changes(&mut self, changes: Vec) { + self.changes.extend(changes); + } } diff --git a/crates/tests/src/lib.rs b/crates/tests/src/lib.rs index 45ab92e..6c73b7e 100644 --- a/crates/tests/src/lib.rs +++ b/crates/tests/src/lib.rs @@ -880,6 +880,54 @@ async fn arithmetic() { .assert_stdout("16\n") .run() .await; + + TestBuilder::new() + .command("echo $((a=1, ++a))") + .assert_stdout("2\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((a=1, a++))") + .assert_stdout("1\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((a=1, a--))") + .assert_stdout("1\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((a=1, --a))") + .assert_stdout("0\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((a=1, +a))") + .assert_stdout("1\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((a=1, -a))") + .assert_stdout("-1\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((a=3, ~a))") + .assert_stdout("-4\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((a=0, !a))") + .assert_stdout("1\n") + .run() + .await; } #[tokio::test] diff --git a/scripts/arithmetic.sh b/scripts/arithmetic.sh index 50c0752..1596791 100644 --- a/scripts/arithmetic.sh +++ b/scripts/arithmetic.sh @@ -1 +1,13 @@ -echo $((2 ** 3)) \ No newline at end of file +echo $((2 ** 3)) + +a=1 + +echo $a + +echo $((++a)) + +echo $a + +echo $((a--)) + +echo $a