HIGHER ORDER CHAINS???

This commit is contained in:
mehbark 2025-03-23 00:21:17 -04:00
parent 34d47fe749
commit abbe59d011

View file

@ -2,11 +2,11 @@ use std::{
collections::HashMap,
error::Error,
io::{self, Read, Write},
process,
};
use bstr::ByteSlice;
use clap::Parser;
use itertools::Itertools;
use rand::seq::IndexedRandom;
#[derive(clap::Parser)]
@ -14,7 +14,9 @@ struct Args {
first_word: String,
#[arg(short, long)]
count: usize,
// TODO: higher-order chains
#[arg(short = 'n', long, default_value_t = 2, value_parser = clap::value_parser!(u8).range(2..=5))]
order: u8,
// DONE: higher-order chains
// TODO: save chains
}
// long-term TODO: live up to name :P
@ -24,6 +26,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let Args {
first_word,
mut count,
order,
} = Args::parse();
// printing the first word immediately gives an intuitive idea of how long freq construction takes
@ -35,26 +38,44 @@ fn main() -> Result<(), Box<dyn Error>> {
io::stdin().read_to_end(&mut buf)?;
let buf = buf.to_lowercase();
let mut freq: HashMap<&str, HashMap<&str, usize>> = HashMap::new();
for (word, next) in buf.words().tuple_windows() {
*freq.entry(word).or_default().entry(next).or_default() += 1;
// it's probably fine to make a big ol vec of the words…
let words: Vec<_> = buf.words().collect();
let mut freq: HashMap<&[&str], HashMap<&[&str], usize>> = HashMap::new();
for window in words.windows(order as usize) {
let [words @ .., _] = window else {
unreachable!()
};
*freq
.entry(words)
.or_default()
.entry(&window[1..])
.or_default() += 1;
}
let mut rng = rand::rng();
let mut current_word = first_word.as_str();
while let Some(nexts) = freq.get(current_word) {
let Some(mut context) = freq.keys().find(|k| k[0] == first_word) else {
eprintln!("well *I* didn't find {first_word} in the corpus");
process::exit(1);
};
while let Some(nexts) = freq.get(context) {
if count == 0 {
break;
}
current_word = nexts
let new_context @ [next, ..] = nexts
.iter()
.collect::<Vec<_>>()
.choose_weighted(&mut rng, |(_, count)| **count)
.expect("freq maps must be nonempty")
.0;
.0
else {
unreachable!();
};
context = new_context;
print!("{current_word}");
print!("{next}");
if count > 1 {
print!(" ");
}