HIGHER ORDER CHAINS???
This commit is contained in:
parent
34d47fe749
commit
abbe59d011
1 changed files with 31 additions and 10 deletions
41
src/main.rs
41
src/main.rs
|
@ -2,11 +2,11 @@ use std::{
|
|||
collections::HashMap,
|
||||
error::Error,
|
||||
io::{self, Read, Write},
|
||||
process,
|
||||
};
|
||||
|
||||
use bstr::ByteSlice;
|
||||
use clap::Parser;
|
||||
use itertools::Itertools;
|
||||
use rand::seq::IndexedRandom;
|
||||
|
||||
#[derive(clap::Parser)]
|
||||
|
@ -14,7 +14,9 @@ struct Args {
|
|||
first_word: String,
|
||||
#[arg(short, long)]
|
||||
count: usize,
|
||||
// TODO: higher-order chains
|
||||
#[arg(short = 'n', long, default_value_t = 2, value_parser = clap::value_parser!(u8).range(2..=5))]
|
||||
order: u8,
|
||||
// DONE: higher-order chains
|
||||
// TODO: save chains
|
||||
}
|
||||
// long-term TODO: live up to name :P
|
||||
|
@ -24,6 +26,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||
let Args {
|
||||
first_word,
|
||||
mut count,
|
||||
order,
|
||||
} = Args::parse();
|
||||
|
||||
// printing the first word immediately gives an intuitive idea of how long freq construction takes
|
||||
|
@ -35,26 +38,44 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||
io::stdin().read_to_end(&mut buf)?;
|
||||
let buf = buf.to_lowercase();
|
||||
|
||||
let mut freq: HashMap<&str, HashMap<&str, usize>> = HashMap::new();
|
||||
for (word, next) in buf.words().tuple_windows() {
|
||||
*freq.entry(word).or_default().entry(next).or_default() += 1;
|
||||
// it's probably fine to make a big ol vec of the words…
|
||||
let words: Vec<_> = buf.words().collect();
|
||||
|
||||
let mut freq: HashMap<&[&str], HashMap<&[&str], usize>> = HashMap::new();
|
||||
|
||||
for window in words.windows(order as usize) {
|
||||
let [words @ .., _] = window else {
|
||||
unreachable!()
|
||||
};
|
||||
*freq
|
||||
.entry(words)
|
||||
.or_default()
|
||||
.entry(&window[1..])
|
||||
.or_default() += 1;
|
||||
}
|
||||
|
||||
let mut rng = rand::rng();
|
||||
let mut current_word = first_word.as_str();
|
||||
while let Some(nexts) = freq.get(current_word) {
|
||||
let Some(mut context) = freq.keys().find(|k| k[0] == first_word) else {
|
||||
eprintln!("well *I* didn't find {first_word} in the corpus");
|
||||
process::exit(1);
|
||||
};
|
||||
while let Some(nexts) = freq.get(context) {
|
||||
if count == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
current_word = nexts
|
||||
let new_context @ [next, ..] = nexts
|
||||
.iter()
|
||||
.collect::<Vec<_>>()
|
||||
.choose_weighted(&mut rng, |(_, count)| **count)
|
||||
.expect("freq maps must be nonempty")
|
||||
.0;
|
||||
.0
|
||||
else {
|
||||
unreachable!();
|
||||
};
|
||||
context = new_context;
|
||||
|
||||
print!("{current_word}");
|
||||
print!("{next}");
|
||||
if count > 1 {
|
||||
print!(" ");
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue