HIGHER ORDER CHAINS???
This commit is contained in:
parent
34d47fe749
commit
abbe59d011
1 changed files with 31 additions and 10 deletions
41
src/main.rs
41
src/main.rs
|
@ -2,11 +2,11 @@ use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
error::Error,
|
error::Error,
|
||||||
io::{self, Read, Write},
|
io::{self, Read, Write},
|
||||||
|
process,
|
||||||
};
|
};
|
||||||
|
|
||||||
use bstr::ByteSlice;
|
use bstr::ByteSlice;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use itertools::Itertools;
|
|
||||||
use rand::seq::IndexedRandom;
|
use rand::seq::IndexedRandom;
|
||||||
|
|
||||||
#[derive(clap::Parser)]
|
#[derive(clap::Parser)]
|
||||||
|
@ -14,7 +14,9 @@ struct Args {
|
||||||
first_word: String,
|
first_word: String,
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
count: usize,
|
count: usize,
|
||||||
// TODO: higher-order chains
|
#[arg(short = 'n', long, default_value_t = 2, value_parser = clap::value_parser!(u8).range(2..=5))]
|
||||||
|
order: u8,
|
||||||
|
// DONE: higher-order chains
|
||||||
// TODO: save chains
|
// TODO: save chains
|
||||||
}
|
}
|
||||||
// long-term TODO: live up to name :P
|
// long-term TODO: live up to name :P
|
||||||
|
@ -24,6 +26,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
let Args {
|
let Args {
|
||||||
first_word,
|
first_word,
|
||||||
mut count,
|
mut count,
|
||||||
|
order,
|
||||||
} = Args::parse();
|
} = Args::parse();
|
||||||
|
|
||||||
// printing the first word immediately gives an intuitive idea of how long freq construction takes
|
// printing the first word immediately gives an intuitive idea of how long freq construction takes
|
||||||
|
@ -35,26 +38,44 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
io::stdin().read_to_end(&mut buf)?;
|
io::stdin().read_to_end(&mut buf)?;
|
||||||
let buf = buf.to_lowercase();
|
let buf = buf.to_lowercase();
|
||||||
|
|
||||||
let mut freq: HashMap<&str, HashMap<&str, usize>> = HashMap::new();
|
// it's probably fine to make a big ol vec of the words…
|
||||||
for (word, next) in buf.words().tuple_windows() {
|
let words: Vec<_> = buf.words().collect();
|
||||||
*freq.entry(word).or_default().entry(next).or_default() += 1;
|
|
||||||
|
let mut freq: HashMap<&[&str], HashMap<&[&str], usize>> = HashMap::new();
|
||||||
|
|
||||||
|
for window in words.windows(order as usize) {
|
||||||
|
let [words @ .., _] = window else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
*freq
|
||||||
|
.entry(words)
|
||||||
|
.or_default()
|
||||||
|
.entry(&window[1..])
|
||||||
|
.or_default() += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut rng = rand::rng();
|
let mut rng = rand::rng();
|
||||||
let mut current_word = first_word.as_str();
|
let Some(mut context) = freq.keys().find(|k| k[0] == first_word) else {
|
||||||
while let Some(nexts) = freq.get(current_word) {
|
eprintln!("well *I* didn't find {first_word} in the corpus");
|
||||||
|
process::exit(1);
|
||||||
|
};
|
||||||
|
while let Some(nexts) = freq.get(context) {
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
current_word = nexts
|
let new_context @ [next, ..] = nexts
|
||||||
.iter()
|
.iter()
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.choose_weighted(&mut rng, |(_, count)| **count)
|
.choose_weighted(&mut rng, |(_, count)| **count)
|
||||||
.expect("freq maps must be nonempty")
|
.expect("freq maps must be nonempty")
|
||||||
.0;
|
.0
|
||||||
|
else {
|
||||||
|
unreachable!();
|
||||||
|
};
|
||||||
|
context = new_context;
|
||||||
|
|
||||||
print!("{current_word}");
|
print!("{next}");
|
||||||
if count > 1 {
|
if count > 1 {
|
||||||
print!(" ");
|
print!(" ");
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue