more help, unify error handling

This commit is contained in:
mehbark 2025-03-23 13:22:50 -04:00
parent 64ccb0132f
commit 6c8ba9a9ac
2 changed files with 20 additions and 7 deletions

View file

@ -2,6 +2,7 @@
name = "fast-markov-chain"
version = "0.1.0"
edition = "2024"
description = "Markov Chain generator and querier."
[dependencies]
bstr = "1.11.3"

View file

@ -1,7 +1,6 @@
use std::{
error::Error,
io::{self, Read, Write},
process,
};
use bstr::ByteSlice;
@ -10,12 +9,24 @@ use rand::seq::IndexedRandom;
use rustc_hash::FxHashMap;
#[derive(clap::Parser)]
#[command(
version,
about,
long_about = "Markov Chain generator and querier.
Reads all of STDIN, tokenizes with order ORDER, and outputs *up to* COUNT tokens, starting with FIRST_WORD or a random word from the corpus.
Tokens are currently lower-cased words."
)]
struct Args {
// TODO: multi-word prefix
first_word: String,
#[arg(short, long)]
#[arg(short, long, help = "maximum number of tokens to output")]
count: usize,
#[arg(short = 'n', long, default_value_t = 2, value_parser = clap::value_parser!(u8).range(2..=5))]
#[arg(
short = 'n', long,
default_value_t = 2,
value_parser = clap::value_parser!(u8).range(2..=8),
help = "the length of sample windows; higher orders increase coherence but also repetition",
)]
order: u8,
// DONE: higher-order chains
// TODO: save chains
@ -29,6 +40,7 @@ fn main() -> Result<(), Box<dyn Error>> {
mut count,
order,
} = Args::parse();
let first_word = first_word.to_lowercase();
// printing the first word immediately gives an intuitive idea of how long freq construction takes
print!("{first_word} ");
@ -53,10 +65,10 @@ fn main() -> Result<(), Box<dyn Error>> {
}
let mut rng = rand::rng();
let Some(mut context) = freq.keys().find(|k| k[0] == first_word) else {
eprintln!("well *I* didn't find {first_word} in the corpus");
process::exit(1);
};
let mut context = freq
.keys()
.find(|k| k[0] == first_word)
.ok_or(format!("well *I* didn't find {first_word} in the corpus"))?;
while let Some(nexts) = freq.get(context) {
if count == 0 {
break;