diff --git a/src/main.rs b/src/main.rs index ebad22a..cb76ac0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,11 +2,11 @@ use std::{ collections::HashMap, error::Error, io::{self, Read, Write}, + process, }; use bstr::ByteSlice; use clap::Parser; -use itertools::Itertools; use rand::seq::IndexedRandom; #[derive(clap::Parser)] @@ -14,7 +14,9 @@ struct Args { first_word: String, #[arg(short, long)] count: usize, - // TODO: higher-order chains + #[arg(short = 'n', long, default_value_t = 2, value_parser = clap::value_parser!(u8).range(2..=5))] + order: u8, + // DONE: higher-order chains // TODO: save chains } // long-term TODO: live up to name :P @@ -24,6 +26,7 @@ fn main() -> Result<(), Box<dyn Error>> { let Args { first_word, mut count, + order, } = Args::parse(); // printing the first word immediately gives an intuitive idea of how long freq construction takes @@ -35,26 +38,44 @@ fn main() -> Result<(), Box<dyn Error>> { io::stdin().read_to_end(&mut buf)?; let buf = buf.to_lowercase(); - let mut freq: HashMap<&str, HashMap<&str, usize>> = HashMap::new(); - for (word, next) in buf.words().tuple_windows() { - *freq.entry(word).or_default().entry(next).or_default() += 1; + // it's probably fine to make a big ol vec of the words… + let words: Vec<_> = buf.words().collect(); + + let mut freq: HashMap<&[&str], HashMap<&[&str], usize>> = HashMap::new(); + + for window in words.windows(order as usize) { + let [words @ .., _] = window else { + unreachable!() + }; + *freq + .entry(words) + .or_default() + .entry(&window[1..]) + .or_default() += 1; } let mut rng = rand::rng(); - let mut current_word = first_word.as_str(); - while let Some(nexts) = freq.get(current_word) { + let Some(mut context) = freq.keys().find(|k| k[0] == first_word) else { + eprintln!("well *I* didn't find {first_word} in the corpus"); + process::exit(1); + }; + while let Some(nexts) = freq.get(context) { if count == 0 { break; } - current_word = nexts + let new_context @ [next, ..] = nexts .iter() .collect::<Vec<_>>() .choose_weighted(&mut rng, |(_, count)| **count) .expect("freq maps must be nonempty") - .0; + .0 + else { + unreachable!(); + }; + context = new_context; - print!("{current_word}"); + print!("{next}"); if count > 1 { print!(" "); }