diff --git a/askama_shared/src/generator.rs b/askama_shared/src/generator.rs index 437331ca0..5c1a42e66 100644 --- a/askama_shared/src/generator.rs +++ b/askama_shared/src/generator.rs @@ -542,7 +542,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { ctx: &'a Context<'_>, buf: &mut Buffer, ws1: Ws, - expr: &Expr<'_>, + expr: &Expr<'a>, arms: &'a [When<'_>], ws2: Ws, ) -> Result { @@ -659,7 +659,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { ws: Ws, scope: Option<&str>, name: &str, - args: &[Expr<'_>], + args: &[Expr<'a>], ) -> Result { if name == "super" { return self.write_block(buf, None, ws); @@ -832,7 +832,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { buf: &mut Buffer, ws: Ws, var: &'a Target<'_>, - val: &Expr<'_>, + val: &Expr<'a>, ) -> Result<(), CompileError> { self.handle_ws(ws); let mut expr_buf = Buffer::new(0); @@ -1032,7 +1032,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { /* Visitor methods for expression types */ - fn visit_expr_root(&mut self, expr: &Expr<'_>) -> Result { + fn visit_expr_root(&mut self, expr: &Expr<'a>) -> Result { let mut buf = Buffer::new(0); self.visit_expr(&mut buf, expr)?; Ok(buf.buf) @@ -1041,7 +1041,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { fn visit_expr( &mut self, buf: &mut Buffer, - expr: &Expr<'_>, + expr: &Expr<'a>, ) -> Result { Ok(match *expr { Expr::BoolLit(s) => self.visit_bool_lit(buf, s), @@ -1063,6 +1063,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { Expr::MethodCall(ref obj, method, ref args) => { self.visit_method_call(buf, obj, method, args)? } + Expr::Closure(ref params, ref body) => self.visit_closure(buf, params, body)?, Expr::RustMacro(name, args) => self.visit_rust_macro(buf, name, args), }) } @@ -1080,7 +1081,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { &mut self, buf: &mut Buffer, mut name: &str, - args: &[Expr<'_>], + args: &[Expr<'a>], ) -> Result { if matches!(name, "escape" | "e") { self._visit_escape_filter(buf, args)?; @@ -1132,7 +1133,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { fn _visit_escape_filter( &mut self, buf: &mut Buffer, - args: &[Expr<'_>], + args: &[Expr<'a>], ) -> Result<(), CompileError> { if args.len() > 2 { return Err("only two arguments allowed to escape filter".into()); @@ -1163,7 +1164,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { fn _visit_format_filter( &mut self, buf: &mut Buffer, - args: &[Expr<'_>], + args: &[Expr<'a>], ) -> Result<(), CompileError> { buf.write("format!("); if let Some(Expr::StrLit(v)) = args.first() { @@ -1182,7 +1183,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { fn _visit_fmt_filter( &mut self, buf: &mut Buffer, - args: &[Expr<'_>], + args: &[Expr<'a>], ) -> Result<(), CompileError> { buf.write("format!("); if let Some(Expr::StrLit(v)) = args.get(1) { @@ -1203,7 +1204,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { fn _visit_join_filter( &mut self, buf: &mut Buffer, - args: &[Expr<'_>], + args: &[Expr<'a>], ) -> Result<(), CompileError> { buf.write("::askama::filters::join((&"); for (i, arg) in args.iter().enumerate() { @@ -1219,7 +1220,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { Ok(()) } - fn _visit_args(&mut self, buf: &mut Buffer, args: &[Expr<'_>]) -> Result<(), CompileError> { + fn _visit_args(&mut self, buf: &mut Buffer, args: &[Expr<'a>]) -> Result<(), CompileError> { if args.is_empty() { return Ok(()); } @@ -1260,7 +1261,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { fn visit_attr( &mut self, buf: &mut Buffer, - obj: &Expr<'_>, + obj: &Expr<'a>, attr: &str, ) -> Result { if let Expr::Var(name) = *obj { @@ -1290,8 +1291,8 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { fn visit_index( &mut self, buf: &mut Buffer, - obj: &Expr<'_>, - key: &Expr<'_>, + obj: &Expr<'a>, + key: &Expr<'a>, ) -> Result { buf.write("&"); self.visit_expr(buf, obj)?; @@ -1304,9 +1305,9 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { fn visit_method_call( &mut self, buf: &mut Buffer, - obj: &Expr<'_>, + obj: &Expr<'a>, method: &str, - args: &[Expr<'_>], + args: &[Expr<'a>], ) -> Result { if matches!(obj, Expr::Var("loop")) { match method { @@ -1343,11 +1344,43 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { Ok(DisplayWrap::Unwrapped) } + fn visit_closure( + &mut self, + buf: &mut Buffer, + params: &[&'a str], + body: &Expr<'a>, + ) -> Result { + self.locals.push(); + + buf.write("|"); + for (i, param) in params.iter().enumerate() { + if i > 0 { + buf.write(", "); + } + buf.write(param); + self.locals.insert(param, LocalMeta::initialized()) + } + buf.write("|"); + + let borrow = !body.is_copyable(); + if borrow { + buf.write("&("); + } + self.visit_expr(buf, body)?; + if borrow { + buf.write(")"); + } + + self.locals.pop(); + + Ok(DisplayWrap::Unwrapped) + } + fn visit_unary( &mut self, buf: &mut Buffer, op: &str, - inner: &Expr<'_>, + inner: &Expr<'a>, ) -> Result { buf.write(op); self.visit_expr(buf, inner)?; @@ -1358,8 +1391,8 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { &mut self, buf: &mut Buffer, op: &str, - left: &Option>>, - right: &Option>>, + left: &Option>>, + right: &Option>>, ) -> Result { if let Some(left) = left { self.visit_expr(buf, left)?; @@ -1375,8 +1408,8 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { &mut self, buf: &mut Buffer, op: &str, - left: &Expr<'_>, - right: &Expr<'_>, + left: &Expr<'a>, + right: &Expr<'a>, ) -> Result { self.visit_expr(buf, left)?; buf.write(&format!(" {} ", op)); @@ -1387,7 +1420,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { fn visit_group( &mut self, buf: &mut Buffer, - inner: &Expr<'_>, + inner: &Expr<'a>, ) -> Result { buf.write("("); self.visit_expr(buf, inner)?; @@ -1398,7 +1431,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { fn visit_array( &mut self, buf: &mut Buffer, - elements: &[Expr<'_>], + elements: &[Expr<'a>], ) -> Result { buf.write("["); for (i, el) in elements.iter().enumerate() { @@ -1425,7 +1458,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { &mut self, buf: &mut Buffer, path: &[&str], - args: &[Expr<'_>], + args: &[Expr<'a>], ) -> Result { for (i, part) in path.iter().enumerate() { if i > 0 { @@ -1453,7 +1486,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { &mut self, buf: &mut Buffer, s: &str, - args: &[Expr<'_>], + args: &[Expr<'a>], ) -> Result { buf.write("("); let s = normalize_identifier(s); diff --git a/askama_shared/src/parser.rs b/askama_shared/src/parser.rs index f5685b9fd..d5144a000 100644 --- a/askama_shared/src/parser.rs +++ b/askama_shared/src/parser.rs @@ -64,6 +64,7 @@ pub enum Expr<'a> { Range(&'a str, Option>>, Option>>), Group(Box>), MethodCall(Box>, &'a str, Vec>), + Closure(Vec<&'a str>, Box>), RustMacro(&'a str, &'a str), } @@ -87,6 +88,9 @@ impl Expr<'_> { // as in that case the call is more likely to return a // reference in the first place then. VarCall(..) | Path(..) | PathCall(..) | MethodCall(..) => true, + // Closures likely don't need to be borrowed, + // as they are usually used in place. + Closure(..) => true, // If the `expr` is within a `Unary` or `BinOp` then // an assumption can be made that the operand is copy. // If not, then the value is moved and adding `.clone()` @@ -640,7 +644,7 @@ expr_prec_layer!(expr_compare, expr_bor, "==", "!=", ">=", ">", "<=", "<"); expr_prec_layer!(expr_and, expr_compare, "&&"); expr_prec_layer!(expr_or, expr_and, "||"); -fn expr_any(i: &str) -> IResult<&str, Expr<'_>> { +fn expr_range(i: &str) -> IResult<&str, Expr<'_>> { let range_right = |i| pair(ws(alt((tag("..="), tag("..")))), opt(expr_or))(i); alt(( map(range_right, |(op, right)| { @@ -656,6 +660,27 @@ fn expr_any(i: &str) -> IResult<&str, Expr<'_>> { ))(i) } +fn expr_closure(i: &str) -> IResult<&str, Expr<'_>> { + let parameters = delimited( + ws(char('|')), + separated_list0(char(','), ws(identifier)), + ws(char('|')), + ); + + let (i, (parameters, expr)) = tuple((opt(parameters), expr_range))(i)?; + Ok(( + i, + match parameters { + Some(parameters) => Expr::Closure(parameters, Box::new(expr)), + None => expr, + }, + )) +} + +fn expr_any(i: &str) -> IResult<&str, Expr<'_>> { + expr_closure(i) +} + fn expr_node<'a>(i: &'a str, s: &State<'_>) -> IResult<&'a str, Node<'a>> { let mut p = tuple(( |i| tag_expr_start(i, s), @@ -1429,6 +1454,50 @@ mod tests { ); } + #[test] + fn test_parse_closure() { + let syntax = Syntax::default(); + assert_eq!( + super::parse("{{ || 12 }}", &syntax).unwrap(), + vec![Node::Expr( + Ws(false, false), + Expr::Closure(vec![], Expr::NumLit("12").into()), + )], + ); + assert_eq!( + super::parse("{{ |a| a.b }}", &syntax).unwrap(), + vec![Node::Expr( + Ws(false, false), + Expr::Closure(vec!["a"], Expr::Attr(Expr::Var("a").into(), "b").into()), + )], + ); + assert_eq!( + super::parse("{{ |a, b, c| a + b }}", &Syntax::default()).unwrap(), + vec![Node::Expr( + Ws(false, false), + Expr::Closure( + vec!["a", "b", "c"], + Expr::BinOp("+", Expr::Var("a").into(), Expr::Var("b").into()).into(), + ), + )], + ); + + assert_eq!( + super::parse("{{ user_opt.map(|user| user.name) }}", &syntax).unwrap(), + vec![Node::Expr( + Ws(false, false), + Expr::MethodCall( + Expr::Var("user_opt").into(), + "map", + vec![Expr::Closure( + vec!["user"], + Expr::Attr(Expr::Var("user").into(), "name").into(), + )] + ), + )], + ); + } + #[test] fn change_delimiters_parse_filter() { let syntax = Syntax { diff --git a/testing/tests/closures.rs b/testing/tests/closures.rs new file mode 100644 index 000000000..c7b0737c0 --- /dev/null +++ b/testing/tests/closures.rs @@ -0,0 +1,77 @@ +use askama::Template; + +#[derive(Debug, Clone)] +struct User { + name: String, + flag: bool, +} + +impl User { + fn ferris() -> Self { + Self { + name: "Ferris".to_string(), + flag: true, + } + } +} + +#[derive(Template)] +#[template( + source = r#"Hello {{ user_opt.map(|user| user.name.as_str()).unwrap_or("World") }}"#, + ext = "txt" +)] +struct ClosureTemplate<'a> { + user_opt: Option<&'a User>, +} + +#[test] +fn test_closure() { + let user = User::ferris(); + let t = ClosureTemplate { + user_opt: Some(&user), + }; + assert_eq!(t.render().unwrap(), "Hello Ferris"); + + let t = ClosureTemplate { user_opt: None }; + assert_eq!(t.render().unwrap(), "Hello World"); +} + +#[derive(Template)] +#[template( + source = r#"Hello {{ user.map(|user| user.name.as_str()).unwrap_or("World") }}"#, + ext = "txt" +)] +struct ClosureShadowTemplate<'a> { + user: Option<&'a User>, +} + +#[test] +fn test_closure_shadow() { + let user = User::ferris(); + let t = ClosureShadowTemplate { user: Some(&user) }; + assert_eq!(t.render().unwrap(), "Hello Ferris"); + + let t = ClosureShadowTemplate { user: None }; + assert_eq!(t.render().unwrap(), "Hello World"); +} + +#[derive(Template)] +#[template( + source = r#"{{ user_opt.map(|user| user.flag).copied().unwrap_or(false) }}"#, + ext = "txt" +)] +struct ClosureBorrowTemplate<'a> { + user_opt: Option<&'a User>, +} + +#[test] +fn test_closure_borrow() { + let user = User::ferris(); + let t = ClosureBorrowTemplate { + user_opt: Some(&user), + }; + assert_eq!(t.render().unwrap(), "true"); + + let t = ClosureBorrowTemplate { user_opt: None }; + assert_eq!(t.render().unwrap(), "false"); +}