Skip to content

Commit

Permalink
Add a custom_transform method
Browse files Browse the repository at this point in the history
  • Loading branch information
jneem committed Sep 10, 2024
1 parent 4e003d4 commit a0550b6
Show file tree
Hide file tree
Showing 11 changed files with 169 additions and 58 deletions.
5 changes: 1 addition & 4 deletions cli/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ use nickel_lang_core::error::report::ErrorFormat;
use crate::repl::ReplCommand;

#[cfg(feature = "doc")]
use crate::doc::DocCommand;

#[cfg(feature = "doc")]
use crate::doctest::TestCommand;
use crate::{doc::DocCommand, doctest::TestCommand};

#[cfg(feature = "format")]
use crate::format::FormatCommand;
Expand Down
127 changes: 99 additions & 28 deletions cli/src/doctest.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
use std::{collections::HashMap, io::Write as _};
//! The `nickel test` command.
//!
//! Extracts tests from docstrings and evaluates them, printing out any failures.

use std::{collections::HashMap, io::Write as _, path::PathBuf, rc::Rc};

use comrak::{arena_tree::NodeEdge, nodes::AstNode, Arena, ComrakOptions};
use nickel_lang_core::{
cache::{Cache, EntryState, SourcePath},
cache::{Cache, ImportResolver, InputFormat, SourcePath},
error::{Error as CoreError, EvalError, IntoDiagnostics},
eval::{cache::CacheImpl, Closure, Environment},
identifier::{Ident, LocIdent},
label::Label,
match_sharedterm, mk_app, mk_fun,
program::Program,
term::{record::RecordData, RichTerm, Term, Traverse as _, TraverseOrder},
term::{
make, record::RecordData, LabeledType, RichTerm, Term, Traverse as _, TraverseOrder,
TypeAnnotation,
},
typ::{Type, TypeF},
};
use once_cell::sync::Lazy;
use regex::Regex;
Expand All @@ -26,10 +35,18 @@ pub struct TestCommand {
pub input: InputOptions<ExtractFieldOnly>,
}

/// The expected outcome of a test.
#[derive(Debug)]
enum Expected {
/// The test is expected to evaluate (without errors) to a specific value.
///
/// The string here will be parsed into a nickel term, and then wrapped in a `std.contract.Equal`
/// contract to provide a nice error message.
Value(String),
/// The test is expected to raise an error, and the error message is expected to contain
/// this string as a substring.
Error(String),
/// The test is expected to evaluate without errors, but we don't care what it evaluates to.
None,
}

Expand Down Expand Up @@ -114,29 +131,22 @@ impl DocTest {
let expected = Expected::extract(&input);
DocTest { input, expected }
}

fn code(&self) -> String {
match &self.expected {
Expected::Value(v) => format!("(({}\n)| std.contract.Equal ({v}))", self.input),
_ => self.input.clone(),
}
}
}

struct Error {
/// The record path to the field whose doctest triggered this error.
path: Vec<LocIdent>,
/// The field whose doctest triggered this error might have multiple tests in its
/// doc metadata. This is the index of the failing test.
idx: usize,
kind: ErrorKind,
}

enum ErrorKind {
UnexpectedFailure {
error: EvalError,
},
/// A doctest was expected to succeed, but it failed.
UnexpectedFailure { error: EvalError },
/// A doctest was expected to fail, but instead it succeeded.
UnexpectedSuccess {
result: RichTerm,
},
UnexpectedSuccess { result: RichTerm },
/// A doctest failed with an unexpected message.
WrongTestFailure {
messages: Vec<String>,
Expand All @@ -146,9 +156,9 @@ enum ErrorKind {

// Go through the record spine, running tests one-by-one.
//
// `spine` is the already-evaluated record spine, which we use only for finding the path to
// each test. We don't use `spine` for evaluating the test because `RecRecord`s have already
// been evaluated to `Record`s so it's challenging to reconstruct the environment.
// `spine` is the already-evaluated record spine. It was previously transformed
// with [`doctest_transform`], so all the tests are present in the record spine.
// They've already been closurized with the correct environment.
fn run_tests(
path: &mut Vec<LocIdent>,
prog: &mut Program<CacheImpl>,
Expand Down Expand Up @@ -278,14 +288,15 @@ impl TestCommand {
program: &mut Program<CacheImpl>,
) -> Result<(RichTerm, TestRegistry), CoreError> {
let mut registry = TestRegistry::default();
let term = program.parse()?;
let cache = &mut program.vm.import_resolver;
let transformed = doctest_transform(cache, &mut registry, term);
cache.set(program.main_id, transformed.clone()?, EntryState::Parsed);
program.typecheck()?;
program
.custom_transform(|cache, rt| doctest_transform(cache, &mut registry, rt))
.map_err(|e| e.unwrap_error("transforming doctest"))?;
Ok((program.eval_record_spine()?, registry))
}
}

/// Extract all the nickel code blocks from a single doc comment.
fn nickel_code_blocks<'a>(document: &'a AstNode<'a>) -> Vec<DocTest> {
use comrak::arena_tree::Node;
use comrak::nodes::{Ast, NodeCodeBlock, NodeValue};
Expand Down Expand Up @@ -351,6 +362,9 @@ fn nickel_code_blocks<'a>(document: &'a AstNode<'a>) -> Vec<DocTest> {
// field that declares it. We wrap the test in a function so that it doesn't get
// evaluated too soon.
//
// The generated test field ids (i.e. `%0` in the example above) are collected
// in `registry` so that a later pass can go through and evaluate them.
//
// One disadvantage with this traversal approach is that any parse errors in
// the test will be encountered as soon as we explore the record spine. We might
// prefer to delay parsing the tests until it's time to evaluate them.
Expand All @@ -361,6 +375,38 @@ fn doctest_transform(
registry: &mut TestRegistry,
rt: RichTerm,
) -> Result<RichTerm, CoreError> {
// Get the path that of the current term, so we can pretend that test snippets
// came from the same path. This allows imports to work.
let path = rt
.pos
.as_opt_ref()
.and_then(|sp| cache.get_path(sp.src_id))
.map(PathBuf::from);

let source_path = match path {
Some(p) => SourcePath::Snippet(p),
None => SourcePath::Generated("test".to_owned()),
};

// Prepare a test snippet. Skips typechecking and transformations, because
// the returned term will get inserted into a bigger term that will be
// typechecked and transformed.
fn prepare(
cache: &mut Cache,
input: &str,
source_path: &SourcePath,
) -> Result<RichTerm, CoreError> {
let src_id = cache.add_string(source_path.clone(), input.to_owned());
cache.parse(src_id, InputFormat::Nickel)?;
// We could probably skip import resolution here also, but `Cache::get` insists
// that imports be resolved.
cache
.resolve_imports(src_id)
.map_err(|e| e.unwrap_error("test snippet"))?;
// unwrap: we just populated it
Ok(cache.get(src_id).unwrap())
}

let mut record_with_doctests =
|mut record_data: RecordData, dyn_fields, pos| -> Result<_, CoreError> {
let mut doc_fields: Vec<(Ident, RichTerm)> = Vec::new();
Expand All @@ -374,11 +420,36 @@ fn doctest_transform(
));

for (i, snippet) in snippets.iter().enumerate() {
let src_id = cache
.add_string(SourcePath::Generated("test".to_owned()), snippet.code());
let (test_term, errors) = cache.parse_nocache(src_id)?;
if !errors.errors.is_empty() {
return Err(errors.into());
let mut test_term = prepare(cache, &snippet.input, &source_path)?;

if let Expected::Value(s) = &snippet.expected {
// Create the contract `std.contract.Equal <expected>` and apply it to the
// test term.
let expected_term = prepare(cache, s, &source_path)?;
// unwrap: we just parsed it, so it will have a span
let expected_span = expected_term.pos.into_opt().unwrap();

let eq = make::static_access(
RichTerm::from(Term::Var("std".into())),
["contract", "Equal"],
);
let eq = mk_app!(eq, expected_term);
let eq_ty = Type::from(TypeF::Flat(eq));
test_term = Term::Annotated(
TypeAnnotation {
typ: None,
contracts: vec![LabeledType {
typ: eq_ty.clone(),
label: Label {
typ: Rc::new(eq_ty),
span: expected_span,
..Default::default()
},
}],
},
test_term,
)
.into();
}

// Make the test term lazy, so that the tests don't automatically get evaluated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ expression: err
---
error: contract broken by the caller of `at`
invalid array indexing
┌─ <stdlib/std.ncl>:164:9
┌─ <stdlib/std.ncl>:165:9
164| std.contract.unstable.IndexedArrayFun 'Index
165| std.contract.unstable.IndexedArrayFun 'Index
-------------------------------------------- expected type
┌─ [INPUTS_PATH]/errors/array_at_empty_array.ncl:3:16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ expression: err
---
error: contract broken by the caller of `at`
invalid array indexing
┌─ <stdlib/std.ncl>:164:9
┌─ <stdlib/std.ncl>:165:9
164| std.contract.unstable.IndexedArrayFun 'Index
165| std.contract.unstable.IndexedArrayFun 'Index
-------------------------------------------- expected type
┌─ [INPUTS_PATH]/errors/array_at_out_of_bound.ncl:3:16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ expression: err
---
error: contract broken by the caller of `range`
invalid range
┌─ <stdlib/std.ncl>:755:9
┌─ <stdlib/std.ncl>:769:9
755| std.contract.unstable.RangeFun Dyn
769| std.contract.unstable.RangeFun Dyn
---------------------------------- expected type
┌─ [INPUTS_PATH]/errors/array_range_reversed_indices.ncl:3:19
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ expression: err
---
error: contract broken by the caller of `range_step`
invalid range step
┌─ <stdlib/std.ncl>:730:9
┌─ <stdlib/std.ncl>:744:9
730| std.contract.unstable.RangeFun (std.contract.unstable.RangeStep -> Dyn)
744| std.contract.unstable.RangeFun (std.contract.unstable.RangeStep -> Dyn)
----------------------------------------------------------------------- expected type
┌─ [INPUTS_PATH]/errors/array_range_step_negative_step.ncl:3:27
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ expression: err
---
error: contract broken by the caller of `map`
expected an array
┌─ <stdlib/std.ncl>:148:33
┌─ <stdlib/std.ncl>:149:33
148 │ : forall a b. (a -> b) -> Array a -> Array b
149 │ : forall a b. (a -> b) -> Array a -> Array b
------- expected type of the argument provided by the caller
┌─ [INPUTS_PATH]/errors/caller_contract_violation.ncl:3:31
Expand Down
49 changes: 37 additions & 12 deletions core/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,43 @@ impl Cache {
}
}

/// Applies a custom transform to an input and its imports, leaving them
/// in the same state as before. Requires that the input has been parsed.
/// In order for the transform to apply to imports, they need to have been
/// resolved.
pub fn custom_transform<E>(
&mut self,
file_id: FileId,
transform: &mut impl FnMut(&mut Cache, RichTerm) -> Result<RichTerm, E>,
) -> Result<(), CacheError<E>> {
match self.entry_state(file_id) {
Some(state) if state >= EntryState::Parsed => {
if state < EntryState::Transforming {
let cached_term = self.terms.remove(&file_id).unwrap();
let term = transform(self, cached_term.term)?;
self.terms.insert(
file_id,
TermEntry {
term,
state: EntryState::Transforming,
..cached_term
},
);

if let Some(imports) = self.imports.get(&file_id).cloned() {
for f in imports.into_iter() {
self.custom_transform(f, transform)?;
}
}
// TODO: We're setting the state back to whatever it was.
self.update_state(file_id, state);
}
Ok(())
}
_ => Err(CacheError::NotParsed),
}
}

/// Apply program transformations to all the fields of a record.
///
/// Used to transform stdlib modules and other records loaded in the environment, when using
Expand Down Expand Up @@ -1068,18 +1105,6 @@ impl Cache {
self.terms.get(&file_id).map(|TermEntry { term, .. }| term)
}

/// Set a new value for a cached term.
pub fn set(&mut self, file_id: FileId, term: RichTerm, state: EntryState) {
self.terms.insert(
file_id,
TermEntry {
term,
state,
parse_errs: Default::default(),
},
);
}

/// Returns true if a particular file id represents a Nickel standard library file, false
/// otherwise.
pub fn is_stdlib_module(&self, file: FileId) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion core/src/eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ pub struct VirtualMachine<R: ImportResolver, C: Cache> {
// The call stack, for error reporting.
call_stack: CallStack,
// The interface used to fetch imports.
pub import_resolver: R,
import_resolver: R,
// The evaluation cache.
pub cache: C,
// The initial environment containing stdlib and builtin functions accessible from anywhere
Expand Down
20 changes: 19 additions & 1 deletion core/src/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ impl FieldOverride {
/// code of imported expressions, and a dictionary which stores corresponding parsed terms.
pub struct Program<EC: EvalCache> {
/// The id of the program source in the file database.
pub main_id: FileId,
main_id: FileId,
/// The state of the Nickel virtual machine.
pub vm: VirtualMachine<Cache, EC>,
/// The color option to use when reporting errors.
Expand Down Expand Up @@ -386,6 +386,24 @@ impl<EC: EvalCache> Program<EC> {
.clone())
}

/// Applies a custom transformation to the main term, assuming that it has been parsed but not
/// yet transformed.
///
/// The term is left in the `Parsed` state, so it will be transformed as usual after this custom
/// transformation.
///
/// This state-management isn't great, as it breaks the usual linear order of state changes.
/// In particular, there's no protection against double-applying the same transformation, and if
/// you accidentally already advanced past `Parsed` then this function will do nothing.
pub fn custom_transform<E, F>(&mut self, mut transform: F) -> Result<(), CacheError<E>>
where
F: FnMut(&mut Cache, RichTerm) -> Result<RichTerm, E>,
{
self.vm
.import_resolver_mut()
.custom_transform(self.main_id, &mut transform)
}

/// Retrieve the parsed term, typecheck it, and generate a fresh initial environment. If
/// `self.overrides` isn't empty, generate the required merge parts and return a merge
/// expression including the overrides. Extract the field corresponding to `self.field`, if not
Expand Down
4 changes: 2 additions & 2 deletions doc/manual/typing.md
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,9 @@ calling to the statically typed `std.array.filter` from dynamically typed code:
```nickel #repl
> std.array.filter (fun x => if x % 2 == 0 then x else null) [1,2,3,4,5,6]
error: contract broken by the caller of `filter`
┌─ <stdlib/std.ncl>:427:25
┌─ <stdlib/std.ncl>:431:25
427 │ : forall a. (a -> Bool) -> Array a -> Array a
431 │ : forall a. (a -> Bool) -> Array a -> Array a
│ ---- expected return type of a function provided by the caller
┌─ <repl-input-6>:1:55
Expand Down

0 comments on commit a0550b6

Please sign in to comment.