From a0550b62416f9acac5a49e435a1544f5b3ffb911 Mon Sep 17 00:00:00 2001 From: Joe Neeman Date: Tue, 10 Sep 2024 14:49:07 +0700 Subject: [PATCH] Add a custom_transform method --- cli/src/cli.rs | 5 +- cli/src/doctest.rs | 127 ++++++++++++++---- ..._eval_stderr_array_at_empty_array.ncl.snap | 4 +- ...eval_stderr_array_at_out_of_bound.ncl.snap | 4 +- ...derr_array_range_reversed_indices.ncl.snap | 4 +- ...rr_array_range_step_negative_step.ncl.snap | 4 +- ..._stderr_caller_contract_violation.ncl.snap | 4 +- core/src/cache.rs | 49 +++++-- core/src/eval/mod.rs | 2 +- core/src/program.rs | 20 ++- doc/manual/typing.md | 4 +- 11 files changed, 169 insertions(+), 58 deletions(-) diff --git a/cli/src/cli.rs b/cli/src/cli.rs index dbc0e64849..37339f2a40 100644 --- a/cli/src/cli.rs +++ b/cli/src/cli.rs @@ -13,10 +13,7 @@ use nickel_lang_core::error::report::ErrorFormat; use crate::repl::ReplCommand; #[cfg(feature = "doc")] -use crate::doc::DocCommand; - -#[cfg(feature = "doc")] -use crate::doctest::TestCommand; +use crate::{doc::DocCommand, doctest::TestCommand}; #[cfg(feature = "format")] use crate::format::FormatCommand; diff --git a/cli/src/doctest.rs b/cli/src/doctest.rs index fe2599db14..aed075a3b8 100644 --- a/cli/src/doctest.rs +++ b/cli/src/doctest.rs @@ -1,14 +1,23 @@ -use std::{collections::HashMap, io::Write as _}; +//! The `nickel test` command. +//! +//! Extracts tests from docstrings and evaluates them, printing out any failures. + +use std::{collections::HashMap, io::Write as _, path::PathBuf, rc::Rc}; use comrak::{arena_tree::NodeEdge, nodes::AstNode, Arena, ComrakOptions}; use nickel_lang_core::{ - cache::{Cache, EntryState, SourcePath}, + cache::{Cache, ImportResolver, InputFormat, SourcePath}, error::{Error as CoreError, EvalError, IntoDiagnostics}, eval::{cache::CacheImpl, Closure, Environment}, identifier::{Ident, LocIdent}, + label::Label, match_sharedterm, mk_app, mk_fun, program::Program, - term::{record::RecordData, RichTerm, Term, Traverse as _, TraverseOrder}, + term::{ + make, record::RecordData, LabeledType, RichTerm, Term, Traverse as _, TraverseOrder, + TypeAnnotation, + }, + typ::{Type, TypeF}, }; use once_cell::sync::Lazy; use regex::Regex; @@ -26,10 +35,18 @@ pub struct TestCommand { pub input: InputOptions, } +/// The expected outcome of a test. #[derive(Debug)] enum Expected { + /// The test is expected to evaluate (without errors) to a specific value. + /// + /// The string here will be parsed into a nickel term, and then wrapped in a `std.contract.Equal` + /// contract to provide a nice error message. Value(String), + /// The test is expected to raise an error, and the error message is expected to contain + /// this string as a substring. Error(String), + /// The test is expected to evaluate without errors, but we don't care what it evaluates to. None, } @@ -114,29 +131,22 @@ impl DocTest { let expected = Expected::extract(&input); DocTest { input, expected } } - - fn code(&self) -> String { - match &self.expected { - Expected::Value(v) => format!("(({}\n)| std.contract.Equal ({v}))", self.input), - _ => self.input.clone(), - } - } } struct Error { + /// The record path to the field whose doctest triggered this error. path: Vec, + /// The field whose doctest triggered this error might have multiple tests in its + /// doc metadata. This is the index of the failing test. idx: usize, kind: ErrorKind, } enum ErrorKind { - UnexpectedFailure { - error: EvalError, - }, + /// A doctest was expected to succeed, but it failed. + UnexpectedFailure { error: EvalError }, /// A doctest was expected to fail, but instead it succeeded. - UnexpectedSuccess { - result: RichTerm, - }, + UnexpectedSuccess { result: RichTerm }, /// A doctest failed with an unexpected message. WrongTestFailure { messages: Vec, @@ -146,9 +156,9 @@ enum ErrorKind { // Go through the record spine, running tests one-by-one. // -// `spine` is the already-evaluated record spine, which we use only for finding the path to -// each test. We don't use `spine` for evaluating the test because `RecRecord`s have already -// been evaluated to `Record`s so it's challenging to reconstruct the environment. +// `spine` is the already-evaluated record spine. It was previously transformed +// with [`doctest_transform`], so all the tests are present in the record spine. +// They've already been closurized with the correct environment. fn run_tests( path: &mut Vec, prog: &mut Program, @@ -278,14 +288,15 @@ impl TestCommand { program: &mut Program, ) -> Result<(RichTerm, TestRegistry), CoreError> { let mut registry = TestRegistry::default(); - let term = program.parse()?; - let cache = &mut program.vm.import_resolver; - let transformed = doctest_transform(cache, &mut registry, term); - cache.set(program.main_id, transformed.clone()?, EntryState::Parsed); + program.typecheck()?; + program + .custom_transform(|cache, rt| doctest_transform(cache, &mut registry, rt)) + .map_err(|e| e.unwrap_error("transforming doctest"))?; Ok((program.eval_record_spine()?, registry)) } } +/// Extract all the nickel code blocks from a single doc comment. fn nickel_code_blocks<'a>(document: &'a AstNode<'a>) -> Vec { use comrak::arena_tree::Node; use comrak::nodes::{Ast, NodeCodeBlock, NodeValue}; @@ -351,6 +362,9 @@ fn nickel_code_blocks<'a>(document: &'a AstNode<'a>) -> Vec { // field that declares it. We wrap the test in a function so that it doesn't get // evaluated too soon. // +// The generated test field ids (i.e. `%0` in the example above) are collected +// in `registry` so that a later pass can go through and evaluate them. +// // One disadvantage with this traversal approach is that any parse errors in // the test will be encountered as soon as we explore the record spine. We might // prefer to delay parsing the tests until it's time to evaluate them. @@ -361,6 +375,38 @@ fn doctest_transform( registry: &mut TestRegistry, rt: RichTerm, ) -> Result { + // Get the path that of the current term, so we can pretend that test snippets + // came from the same path. This allows imports to work. + let path = rt + .pos + .as_opt_ref() + .and_then(|sp| cache.get_path(sp.src_id)) + .map(PathBuf::from); + + let source_path = match path { + Some(p) => SourcePath::Snippet(p), + None => SourcePath::Generated("test".to_owned()), + }; + + // Prepare a test snippet. Skips typechecking and transformations, because + // the returned term will get inserted into a bigger term that will be + // typechecked and transformed. + fn prepare( + cache: &mut Cache, + input: &str, + source_path: &SourcePath, + ) -> Result { + let src_id = cache.add_string(source_path.clone(), input.to_owned()); + cache.parse(src_id, InputFormat::Nickel)?; + // We could probably skip import resolution here also, but `Cache::get` insists + // that imports be resolved. + cache + .resolve_imports(src_id) + .map_err(|e| e.unwrap_error("test snippet"))?; + // unwrap: we just populated it + Ok(cache.get(src_id).unwrap()) + } + let mut record_with_doctests = |mut record_data: RecordData, dyn_fields, pos| -> Result<_, CoreError> { let mut doc_fields: Vec<(Ident, RichTerm)> = Vec::new(); @@ -374,11 +420,36 @@ fn doctest_transform( )); for (i, snippet) in snippets.iter().enumerate() { - let src_id = cache - .add_string(SourcePath::Generated("test".to_owned()), snippet.code()); - let (test_term, errors) = cache.parse_nocache(src_id)?; - if !errors.errors.is_empty() { - return Err(errors.into()); + let mut test_term = prepare(cache, &snippet.input, &source_path)?; + + if let Expected::Value(s) = &snippet.expected { + // Create the contract `std.contract.Equal ` and apply it to the + // test term. + let expected_term = prepare(cache, s, &source_path)?; + // unwrap: we just parsed it, so it will have a span + let expected_span = expected_term.pos.into_opt().unwrap(); + + let eq = make::static_access( + RichTerm::from(Term::Var("std".into())), + ["contract", "Equal"], + ); + let eq = mk_app!(eq, expected_term); + let eq_ty = Type::from(TypeF::Flat(eq)); + test_term = Term::Annotated( + TypeAnnotation { + typ: None, + contracts: vec![LabeledType { + typ: eq_ty.clone(), + label: Label { + typ: Rc::new(eq_ty), + span: expected_span, + ..Default::default() + }, + }], + }, + test_term, + ) + .into(); } // Make the test term lazy, so that the tests don't automatically get evaluated diff --git a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_empty_array.ncl.snap b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_empty_array.ncl.snap index ab6e249b8b..65b7d36f63 100644 --- a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_empty_array.ncl.snap +++ b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_empty_array.ncl.snap @@ -4,9 +4,9 @@ expression: err --- error: contract broken by the caller of `at` invalid array indexing - ┌─ :164:9 + ┌─ :165:9 │ -164 │ | std.contract.unstable.IndexedArrayFun 'Index +165 │ | std.contract.unstable.IndexedArrayFun 'Index │ -------------------------------------------- expected type │ ┌─ [INPUTS_PATH]/errors/array_at_empty_array.ncl:3:16 diff --git a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_out_of_bound.ncl.snap b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_out_of_bound.ncl.snap index 1f42c95a48..454d46b5cf 100644 --- a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_out_of_bound.ncl.snap +++ b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_out_of_bound.ncl.snap @@ -4,9 +4,9 @@ expression: err --- error: contract broken by the caller of `at` invalid array indexing - ┌─ :164:9 + ┌─ :165:9 │ -164 │ | std.contract.unstable.IndexedArrayFun 'Index +165 │ | std.contract.unstable.IndexedArrayFun 'Index │ -------------------------------------------- expected type │ ┌─ [INPUTS_PATH]/errors/array_at_out_of_bound.ncl:3:16 diff --git a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_reversed_indices.ncl.snap b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_reversed_indices.ncl.snap index a06c9f1260..7408ee1e51 100644 --- a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_reversed_indices.ncl.snap +++ b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_reversed_indices.ncl.snap @@ -4,9 +4,9 @@ expression: err --- error: contract broken by the caller of `range` invalid range - ┌─ :755:9 + ┌─ :769:9 │ -755 │ | std.contract.unstable.RangeFun Dyn +769 │ | std.contract.unstable.RangeFun Dyn │ ---------------------------------- expected type │ ┌─ [INPUTS_PATH]/errors/array_range_reversed_indices.ncl:3:19 diff --git a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_step_negative_step.ncl.snap b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_step_negative_step.ncl.snap index a0750a16f1..ad589a24e2 100644 --- a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_step_negative_step.ncl.snap +++ b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_step_negative_step.ncl.snap @@ -4,9 +4,9 @@ expression: err --- error: contract broken by the caller of `range_step` invalid range step - ┌─ :730:9 + ┌─ :744:9 │ -730 │ | std.contract.unstable.RangeFun (std.contract.unstable.RangeStep -> Dyn) +744 │ | std.contract.unstable.RangeFun (std.contract.unstable.RangeStep -> Dyn) │ ----------------------------------------------------------------------- expected type │ ┌─ [INPUTS_PATH]/errors/array_range_step_negative_step.ncl:3:27 diff --git a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_caller_contract_violation.ncl.snap b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_caller_contract_violation.ncl.snap index e16b46d0ff..8a26cdc5dd 100644 --- a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_caller_contract_violation.ncl.snap +++ b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_caller_contract_violation.ncl.snap @@ -4,9 +4,9 @@ expression: err --- error: contract broken by the caller of `map` expected an array - ┌─ :148:33 + ┌─ :149:33 │ -148 │ : forall a b. (a -> b) -> Array a -> Array b +149 │ : forall a b. (a -> b) -> Array a -> Array b │ ------- expected type of the argument provided by the caller │ ┌─ [INPUTS_PATH]/errors/caller_contract_violation.ncl:3:31 diff --git a/core/src/cache.rs b/core/src/cache.rs index bfe8e7792d..0f0046f00e 100644 --- a/core/src/cache.rs +++ b/core/src/cache.rs @@ -666,6 +666,43 @@ impl Cache { } } + /// Applies a custom transform to an input and its imports, leaving them + /// in the same state as before. Requires that the input has been parsed. + /// In order for the transform to apply to imports, they need to have been + /// resolved. + pub fn custom_transform( + &mut self, + file_id: FileId, + transform: &mut impl FnMut(&mut Cache, RichTerm) -> Result, + ) -> Result<(), CacheError> { + match self.entry_state(file_id) { + Some(state) if state >= EntryState::Parsed => { + if state < EntryState::Transforming { + let cached_term = self.terms.remove(&file_id).unwrap(); + let term = transform(self, cached_term.term)?; + self.terms.insert( + file_id, + TermEntry { + term, + state: EntryState::Transforming, + ..cached_term + }, + ); + + if let Some(imports) = self.imports.get(&file_id).cloned() { + for f in imports.into_iter() { + self.custom_transform(f, transform)?; + } + } + // TODO: We're setting the state back to whatever it was. + self.update_state(file_id, state); + } + Ok(()) + } + _ => Err(CacheError::NotParsed), + } + } + /// Apply program transformations to all the fields of a record. /// /// Used to transform stdlib modules and other records loaded in the environment, when using @@ -1068,18 +1105,6 @@ impl Cache { self.terms.get(&file_id).map(|TermEntry { term, .. }| term) } - /// Set a new value for a cached term. - pub fn set(&mut self, file_id: FileId, term: RichTerm, state: EntryState) { - self.terms.insert( - file_id, - TermEntry { - term, - state, - parse_errs: Default::default(), - }, - ); - } - /// Returns true if a particular file id represents a Nickel standard library file, false /// otherwise. pub fn is_stdlib_module(&self, file: FileId) -> bool { diff --git a/core/src/eval/mod.rs b/core/src/eval/mod.rs index 027c3d8089..93c0da0441 100644 --- a/core/src/eval/mod.rs +++ b/core/src/eval/mod.rs @@ -124,7 +124,7 @@ pub struct VirtualMachine { // The call stack, for error reporting. call_stack: CallStack, // The interface used to fetch imports. - pub import_resolver: R, + import_resolver: R, // The evaluation cache. pub cache: C, // The initial environment containing stdlib and builtin functions accessible from anywhere diff --git a/core/src/program.rs b/core/src/program.rs index 8f6c6d41b9..07a19cdc35 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -177,7 +177,7 @@ impl FieldOverride { /// code of imported expressions, and a dictionary which stores corresponding parsed terms. pub struct Program { /// The id of the program source in the file database. - pub main_id: FileId, + main_id: FileId, /// The state of the Nickel virtual machine. pub vm: VirtualMachine, /// The color option to use when reporting errors. @@ -386,6 +386,24 @@ impl Program { .clone()) } + /// Applies a custom transformation to the main term, assuming that it has been parsed but not + /// yet transformed. + /// + /// The term is left in the `Parsed` state, so it will be transformed as usual after this custom + /// transformation. + /// + /// This state-management isn't great, as it breaks the usual linear order of state changes. + /// In particular, there's no protection against double-applying the same transformation, and if + /// you accidentally already advanced past `Parsed` then this function will do nothing. + pub fn custom_transform(&mut self, mut transform: F) -> Result<(), CacheError> + where + F: FnMut(&mut Cache, RichTerm) -> Result, + { + self.vm + .import_resolver_mut() + .custom_transform(self.main_id, &mut transform) + } + /// Retrieve the parsed term, typecheck it, and generate a fresh initial environment. If /// `self.overrides` isn't empty, generate the required merge parts and return a merge /// expression including the overrides. Extract the field corresponding to `self.field`, if not diff --git a/doc/manual/typing.md b/doc/manual/typing.md index b42fbd6a82..977973e986 100644 --- a/doc/manual/typing.md +++ b/doc/manual/typing.md @@ -688,9 +688,9 @@ calling to the statically typed `std.array.filter` from dynamically typed code: ```nickel #repl > std.array.filter (fun x => if x % 2 == 0 then x else null) [1,2,3,4,5,6] error: contract broken by the caller of `filter` - ┌─ :427:25 + ┌─ :431:25 │ -427 │ : forall a. (a -> Bool) -> Array a -> Array a +431 │ : forall a. (a -> Bool) -> Array a -> Array a │ ---- expected return type of a function provided by the caller │ ┌─ :1:55