diff --git a/src/solutions/day12.rs b/src/solutions/day12.rs index de630f6..9f1dcb1 100644 --- a/src/solutions/day12.rs +++ b/src/solutions/day12.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use super::Solution; pub struct Day12 {} @@ -18,7 +20,7 @@ impl Solution for Day12 { .map(|n| n.parse::().unwrap()) .collect::>(); - ans += self.solve(&springs.as_bytes(), None, &nums); + ans += self.solve(&springs.as_bytes(), None, &nums, &mut HashMap::new()); } Ok(Box::new(ans)) } @@ -43,7 +45,7 @@ impl Solution for Day12 { .map(|n| n.parse::().unwrap()) .collect::>(); - ans += self.solve(&springs.as_bytes(), None, &nums); + ans += self.solve(&springs.as_bytes(), None, &nums, &mut HashMap::new()); } Ok(Box::new(ans)) } @@ -56,7 +58,13 @@ impl Solution for Day12 { impl Day12 { /// recursively counts the number of permutations of spring we could get from /// the splits specified in 'nums' - fn solve(&self, s: &[u8], in_group: Option, cons: &[usize]) -> usize { + fn solve<'a, 'b>( + &self, + s: &'a [u8], + in_group: Option, + cons: &'b [usize], + map: &mut HashMap<(&'a [u8], Option, &'b [usize]), usize>, + ) -> usize { if s.is_empty() { return match in_group { Some(n) if cons == &[n] => 1, @@ -64,15 +72,36 @@ impl Day12 { _ => 0, }; } - // Resursively match based on the whether we are in a block and/or we have spaces left to fill - match (s[0], in_group, cons) { - (b'.', None, _) | (b'?', None, []) => self.solve(&s[1..], None, cons), - (b'.' | b'?', Some(n), [e, ..]) if n == *e => self.solve(&s[1..], None, &cons[1..]), - (b'#' | b'?', Some(n), [e, ..]) if n < *e => self.solve(&s[1..], Some(n + 1), cons), - (b'#', None, [_, ..]) => self.solve(&s[1..], Some(1), cons), - (b'?', None, _) => self.solve(&s[1..], None, cons) + self.solve(&s[1..], Some(1), cons), - _ => 0, + + // Check for a cache hit + if s[0] == b'?' { + if let Some(result) = map.get(&(s, in_group, cons)) { + return *result; + } } + + // Resursively match based on the whether we are in a block and/or we have spaces left to fill + let ans = match (s[0], in_group, cons) { + (b'.', None, _) | (b'?', None, []) => self.solve(&s[1..], None, cons, map), + (b'.' | b'?', Some(n), [e, ..]) if n == *e => { + self.solve(&s[1..], None, &cons[1..], map) + } + (b'#' | b'?', Some(n), [e, ..]) if n < *e => { + self.solve(&s[1..], Some(n + 1), cons, map) + } + (b'#', None, [_, ..]) => self.solve(&s[1..], Some(1), cons, map), + (b'?', None, _) => { + self.solve(&s[1..], None, cons, map) + self.solve(&s[1..], Some(1), cons, map) + } + _ => 0, + }; + + // Store in cache + if s[0] == b'?' { + map.insert((s, in_group, cons), ans); + } + + ans } }