use core::fmt;
use std::{
    collections::{HashMap, HashSet},
    env,
    io::{self, Read},
    iter,
    sync::Mutex,
};

use intaglio::{Symbol, SymbolTable};
use itertools::Itertools;
use once_cell::sync::Lazy;
use rand::{distributions::Alphanumeric, Rng};

static TABLE: Lazy<Mutex<SymbolTable>> = Lazy::new(|| Mutex::new(SymbolTable::new()));
static FORALL: Lazy<Symbol> = Lazy::new(|| TABLE.lock().unwrap().intern("forall").unwrap());
/// Note that it is actually the string "def". This is for the user's convenience.
static CONCRETE: Lazy<Symbol> = Lazy::new(|| TABLE.lock().unwrap().intern("def").unwrap());

const DEFAULT_COMPLEXITY_LIMIT: usize = 20_000;
const DEFAULT_STEP_LIMIT: usize = 2_000;

fn main() {
    match do_it() {
        Some(out) => println!("{out}"),
        None => println!("error :("),
    }
}

fn do_it() -> Option<String> {
    let buf = if let Some(buf) = env::args().nth(1) {
        buf
    } else {
        let mut buf = String::new();
        io::stdin().read_to_string(&mut buf).ok()?;
        buf
    };
    //eprintln!("{:#?}", lex(&buf));
    Some(format!(
        "{}",
        run_program(&buf, DEFAULT_STEP_LIMIT, DEFAULT_COMPLEXITY_LIMIT)?
    ))
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum Sexp {
    Atom(Symbol),
    List(Vec<Sexp>),
}

impl Sexp {
    fn nil() -> Self {
        Self::List(Vec::new())
    }

    /// Returns `true` if the sexp is [`Atom`].
    ///
    /// [`Atom`]: Sexp::Atom
    #[must_use]
    fn is_atom(&self) -> bool {
        matches!(self, Self::Atom(..))
    }

    /// Returns `true` if the sexp is [`List`].
    ///
    /// [`List`]: Sexp::List
    #[must_use]
    fn is_list(&self) -> bool {
        matches!(self, Self::List(..))
    }

    #[must_use]
    fn atom(&self) -> Option<&Symbol> {
        match self {
            Sexp::Atom(at) => Some(at),
            Sexp::List(_) => None,
        }
    }

    #[must_use]
    fn list(&self) -> Option<&[Sexp]> {
        match self {
            Sexp::Atom(_) => None,
            Sexp::List(xs) => Some(&xs),
        }
    }

    #[must_use]
    fn unwrap_atom(&self) -> Symbol {
        match self {
            Sexp::Atom(sym) => *sym,
            Sexp::List(_) => unreachable!(),
        }
    }

    /// Push a sexp to the end of a sexp. Does nothing if the sexp is an atom.
    fn push(&mut self, item: Self) {
        match self {
            Sexp::Atom(_) => (),
            Sexp::List(xs) => xs.push(item),
        }
    }

    fn rw(&self, rule: &Rule) -> Sexp {
        rw(rule, self)
    }

    #[must_use]
    fn pertinent(&self) -> HashSet<Symbol> {
        match self {
            Sexp::Atom(sym) => {
                let mut set = HashSet::with_capacity(1);
                set.insert(*sym);
                set
            }
            Sexp::List(xs) => xs
                .iter()
                .map(Sexp::pertinent)
                .fold(HashSet::with_capacity(0), |a, b| {
                    a.union(&b).copied().collect::<HashSet<Symbol>>()
                }),
        }
    }

    fn concretion_targets(&self) -> HashSet<&Sexp> {
        match self {
            at @ Sexp::Atom(_) => {
                let mut out = HashSet::with_capacity(1);
                out.insert(at);
                out
            }
            Sexp::List(xs) => {
                let mut out = HashSet::from_iter(xs);
                out.insert(self);
                out.extend(xs.iter().flat_map(Sexp::concretion_targets));
                out
            }
        }
    }

    #[must_use]
    fn contains_sym(&self, sym: Symbol) -> bool {
        match self {
            Sexp::Atom(at) => *at == sym,
            Sexp::List(xs) => xs.iter().any(|s| s.contains_sym(sym)),
        }
    }

    #[must_use]
    fn complexity(&self) -> usize {
        match self {
            Sexp::Atom(_) => 1,
            Sexp::List(xs) => 10 + xs.iter().map(Sexp::complexity).sum::<usize>(),
        }
    }

    // #[must_use]
    // fn apply_lets(&self) -> Sexp {
    //     match self {
    //         Sexp::Atom(_) => self.clone(),
    //         Sexp::List(xs) => match Let::from_sexp(self) {
    //             Some(let_) => let_.eval(),
    //             None => Sexp::List(xs.iter().map(Sexp::apply_lets).collect()),
    //         },
    //     }
    // }

    #[must_use]
    fn apply_special_form<'src, F: SpecialForm<'src>>(&'src self) -> Sexp {
        match self {
            Sexp::Atom(_) => match F::from_sexp(self) {
                Some(form) => form.eval(),
                None => self.clone(),
            },
            Sexp::List(xs) => match F::from_sexp(self) {
                Some(form) => form.eval(),
                None => Sexp::List(xs.iter().map(Sexp::apply_special_form::<F>).collect()),
            },
        }
    }
}

impl fmt::Display for Sexp {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Sexp::Atom(id) => write!(f, "{}", TABLE.lock().unwrap().get(*id).unwrap()),
            Sexp::List(xs) => write!(f, "({})", xs.iter().map(ToString::to_string).join(" ")),
        }
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum Token<'src> {
    LParen,
    RParen,
    Atom(&'src str),
}

fn lex(src: &str) -> Vec<Token> {
    // maybe a bad idea?
    let mut out = Vec::with_capacity(src.len());

    let mut atom = None;
    let mut in_comment = false;

    let indices: Vec<_> = src.char_indices().map(|(a, _)| a).collect();
    let slice_src = |start: usize, end: usize| -> &str {
        let end = indices.get(end + 1).unwrap_or(&src.len()) - 1;
        &src[indices[start]..=end]
    };

    for (i, c) in src.chars().enumerate() {
        match c {
            ';' | '`' => {
                in_comment = true;
            }
            '\n' => {
                in_comment = false;
            }
            _ => (),
        }
        if !in_comment {
            match c {
                '(' => {
                    if let Some((start, end)) = atom {
                        out.push(Token::Atom(slice_src(start, end)));
                        atom = None;
                    }
                    out.push(Token::LParen);
                }

                ')' => {
                    if let Some((start, end)) = atom {
                        out.push(Token::Atom(slice_src(start, end)));
                        atom = None;
                    }
                    out.push(Token::RParen);
                }

                c if c.is_whitespace() => {
                    if let Some((start, end)) = atom {
                        out.push(Token::Atom(slice_src(start, end)));
                        atom = None;
                    }
                }

                _ => {
                    atom = match atom {
                        Some((start, end)) => Some((start, end + 1)),
                        None => Some((i, i)),
                    }
                }
            }
        }
    }

    if let Some((start, end)) = atom {
        out.push(Token::Atom(slice_src(start, end)));
    }

    out
}

fn at_depth(depth: usize, sexp: &mut Sexp) -> &mut Sexp {
    if depth == 0 {
        sexp
    } else {
        match sexp {
            at @ Sexp::Atom(_) => at,
            Sexp::List(xs) => at_depth(depth - 1, xs.last_mut().unwrap()),
        }
    }
}

fn parse<'src>(tokens: &'src [Token<'src>]) -> Option<Vec<Sexp>> {
    let mut out = Sexp::nil();
    let mut depth = 0;

    for token in tokens {
        match token {
            Token::LParen => {
                at_depth(depth, &mut out).push(Sexp::nil());
                depth += 1;
            }
            Token::RParen => {
                if depth == 0 {
                    return None;
                }
                depth -= 1;
            }
            Token::Atom(at) => {
                let id = TABLE.lock().unwrap().intern((*at).to_string()).unwrap();
                at_depth(depth, &mut out).push(Sexp::Atom(id));
            }
        }
    }

    match out {
        Sexp::List(xs) => Some(xs),
        Sexp::Atom(_) => unreachable!(),
    }
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum Rule {
    Forall {
        vars: Vec<Symbol>,
        lhs: Sexp,
        rhs: Sexp,
    },
    Concrete {
        lhs: Sexp,
        rhs: Sexp,
    },
}

impl Rule {
    fn forall(vars: Vec<Symbol>, lhs: Sexp, rhs: Sexp) -> Rule {
        Rule::Forall { vars, lhs, rhs }
    }

    fn concrete(lhs: Sexp, rhs: Sexp) -> Rule {
        Rule::Concrete { lhs, rhs }
    }

    fn concrify(&self, sexp: &Sexp) -> Rule {
        match self {
            Rule::Forall { vars, lhs, rhs } => {
                if vars.is_empty() {
                    Rule::Concrete {
                        lhs: lhs.clone(),
                        rhs: rhs.clone(),
                    }
                } else {
                    let rule = Rule::var_replace(vars[0], sexp);
                    let lhs = rw(&rule, lhs);
                    let rhs = rw(&rule, rhs);

                    if vars.len() == 1 {
                        Rule::Concrete { lhs, rhs }
                    } else {
                        Rule::Forall {
                            vars: vars[1..].to_vec(),
                            lhs,
                            rhs,
                        }
                    }
                }
            }
            Rule::Concrete { .. } => self.clone(),
        }
    }

    fn concretions(&self, targets: &HashSet<&Sexp>) -> HashSet<Rule> {
        if self.is_concrete() {
            let mut out = HashSet::with_capacity(1);
            out.insert(self.clone());
            out
        } else {
            let mut out = HashSet::with_capacity(targets.len() * self.num_vars());

            for target in targets {
                let concred = self.concrify(target);
                if concred.is_concrete() {
                    out.insert(concred);
                    break;
                }

                for concred in concred.concretions(targets) {
                    out.insert(concred);
                }
            }

            out
        }
    }

    // fn matches(&self, expr: &Sexp) -> bool {
    //     could_match(self.vars(), self.lhs(), expr)
    // }

    #[must_use]
    fn lhs(&self) -> &Sexp {
        match self {
            Rule::Concrete { lhs, .. } | Rule::Forall { lhs, .. } => lhs,
        }
    }

    #[must_use]
    fn rhs(&self) -> &Sexp {
        match self {
            Rule::Forall { rhs, .. } | Rule::Concrete { rhs, .. } => rhs,
        }
    }

    #[must_use]
    fn vars(&self) -> Option<&Vec<Symbol>> {
        match self {
            Rule::Forall { vars, .. } => Some(vars),
            Rule::Concrete { .. } => None,
        }
    }

    #[must_use]
    fn num_vars(&self) -> usize {
        match self.vars() {
            Some(xs) => xs.len(),
            None => 0,
        }
    }

    #[must_use]
    fn rw(&self, sexp: &Sexp) -> Sexp {
        rw(self, sexp)
    }

    /// Returns `true` if the rule is [`Forall`].
    ///
    /// [`Forall`]: Rule::Forall
    #[must_use]
    fn is_forall(&self) -> bool {
        matches!(self, Self::Forall { .. })
    }

    /// Returns `true` if the rule is [`Concrete`].
    ///
    /// [`Concrete`]: Rule::Concrete
    #[must_use]
    fn is_concrete(&self) -> bool {
        matches!(self, Self::Concrete { .. })
    }

    fn var_replace(from: Symbol, to: &Sexp) -> Rule {
        Rule::Concrete {
            lhs: Sexp::Atom(from),
            rhs: to.clone(),
        }
    }

    #[must_use]
    fn concrete_with_matches(&self, matches: &Matches<'_>) -> Rule {
        match self {
            Rule::Forall { vars, .. } => {
                let mut out = self.clone();

                for var in vars {
                    let expr = matches.get(var).copied().cloned().unwrap_or_else(Sexp::nil);
                    out = out.concrify(&expr);
                }

                out
            }
            Rule::Concrete { .. } => self.clone(),
        }
    }
}

impl fmt::Display for Rule {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Rule::Forall { vars, lhs, rhs } => {
                let vars = vars
                    .iter()
                    .map(|v| TABLE.lock().unwrap().get(*v).unwrap().to_string())
                    .join(" ");
                let abs = format!("∀{vars},");
                let lhs = format!("{abs:>15} {lhs}");
                write!(f, "{lhs:<50} ==> {rhs}")
            }
            Rule::Concrete { lhs, rhs } => write!(f, "{:>50} ==> {rhs}", lhs.to_string()),
        }
    }
}

fn could_match(vars: Option<&[Symbol]>, lhs: &Sexp, expr: &Sexp) -> bool {
    let is_var = |var: &Symbol| -> bool {
        match vars {
            Some(vars) => vars.contains(var),
            None => false,
        }
    };

    match (lhs, expr) {
        (Sexp::Atom(a), Sexp::Atom(b)) => a == b || is_var(a),
        (Sexp::Atom(a), Sexp::List(_)) => is_var(a),
        (Sexp::List(a), Sexp::List(b)) => {
            a.len() == b.len() && a.iter().zip(b).all(|(a, b)| could_match(vars, a, b))
        }
        (Sexp::List(_), Sexp::Atom(_)) => false,
    }
}

type Matches<'src> = HashMap<Symbol, &'src Sexp>;

// DONE?: there. can. be. at most. one. match.
// i'm happy that this is faster, but it might be worth going back to the per-variable thing
fn matches<'src>(vars: &[Symbol], lhs: &Sexp, expr: &'src Sexp) -> Option<Matches<'src>> {
    match (lhs, expr) {
        (Sexp::Atom(a), Sexp::Atom(b)) => {
            if a == b {
                Some(HashMap::with_capacity(0))
            } else if vars.contains(a) {
                let mut out = HashMap::with_capacity(1);
                out.insert(*a, expr);
                Some(out)
            } else {
                None
            }
        }
        (Sexp::Atom(a), Sexp::List(_)) => {
            if vars.contains(a) {
                let mut out = HashMap::with_capacity(1);
                out.insert(*a, expr);
                Some(out)
            } else {
                None
            }
        }
        (Sexp::List(_), Sexp::Atom(_)) => None,
        (Sexp::List(xs), Sexp::List(ys)) => {
            if xs.len() == ys.len() {
                xs.iter()
                    .zip(ys)
                    .map(|(lhs, expr)| matches(vars, lhs, expr))
                    .reduce(|a, b| match (a, b) {
                        (None, _) => None,
                        (_, None) => None,
                        (Some(a), Some(b)) => merge_matches(vars, a, b),
                    })
                    .unwrap_or(None)
            } else {
                None
            }
        }
    }
}

fn merge_matches<'src>(
    vars: &[Symbol],
    a: Matches<'src>,
    b: Matches<'src>,
) -> Option<Matches<'src>> {
    let mut out = HashMap::with_capacity(vars.len());

    for var in vars {
        match (a.get(var), b.get(var)) {
            (None, None) => {}
            (Some(m), None) | (None, Some(m)) => {
                out.insert(*var, *m);
            }
            (Some(a), Some(b)) => {
                if a != b {
                    return None;
                }
                out.insert(*var, *a);
            }
        }
    }

    Some(out)
}

// fn absorb_matches<'src>(from: Matches<'src>, into: &mut Matches<'src>) {
//     for (var, matches) in from {
//         let insert_into = into.entry(var).or_default();
//         for matc in matches {
//             insert_into.insert(matc);
//         }
//     }
// }

fn rw(rule: &Rule, sexp: &Sexp) -> Sexp {
    if rule.lhs() == rule.rhs() {
        return sexp.clone();
    }

    match rule {
        Rule::Forall { vars, lhs, .. } => {
            if let Some(matc) = matches(vars, lhs, sexp) {
                rule.concrete_with_matches(&matc).rw(sexp)
            } else {
                match sexp {
                    Sexp::Atom(_) => sexp.clone(),
                    Sexp::List(xs) => Sexp::List(xs.iter().map(|s| s.rw(rule)).collect()),
                }
            }
        }
        Rule::Concrete { lhs, rhs } => {
            if sexp == lhs {
                rhs.clone()
            } else {
                match sexp {
                    Sexp::Atom(_) => sexp.clone(),
                    Sexp::List(xs) => Sexp::List(xs.iter().map(|s| s.rw(rule)).collect()),
                }
            }
        }
    }
}

fn simp(
    rules: &[Rule],
    mut expr: Sexp,
    step_limit: usize,
    complexity_limit: usize,
) -> (Sexp, usize) {
    let mut max = 0;
    let mut grace = step_limit / 4;
    for i in 0..step_limit {
        expr = expr
            .apply_special_form::<Genslop>()
            .apply_special_form::<Log>()
            .apply_special_form::<Let>()
            .apply_special_form::<Eq>();

        let complexity = expr.complexity();
        // eprintln!("{}/{step_limit} {complexity}", i + 1);
        if complexity > complexity_limit {
            break;
        }

        let last = expr.clone();

        for rule in rules {
            expr = rw(rule, &expr);
        }

        if expr == last {
            grace -= 1;
        }

        if grace == 0 {
            break;
        }

        max = i;
    }

    (expr, max)
}

trait SpecialForm<'src>: Sized {
    fn from_sexp(sexp: &'src Sexp) -> Option<Self>;

    fn eval(self) -> Sexp;
}

static LET: Lazy<Symbol> = Lazy::new(|| TABLE.lock().unwrap().intern("let").unwrap());
#[derive(Debug, Clone)]
struct Let<'src> {
    from: Symbol,
    to: &'src Sexp,
    body: &'src Sexp,
}

impl<'src> SpecialForm<'src> for Let<'src> {
    fn from_sexp(sexp: &'src Sexp) -> Option<Self> {
        match sexp {
            Sexp::Atom(_) => None,
            Sexp::List(xs) => match &xs[..] {
                [Sexp::Atom(let_), Sexp::Atom(from), to, body] if *let_ == *LET => Some(Self {
                    from: *from,
                    to,
                    body,
                }),
                _ => None,
            },
        }
    }

    fn eval(self) -> Sexp {
        let Self { from, to, body } = self;

        rw(&Rule::var_replace(from, to), body)
    }
}

static EQ: Lazy<Symbol> = Lazy::new(|| TABLE.lock().unwrap().intern("eq").unwrap());
#[derive(Debug, Clone)]
struct Eq<'src> {
    a: &'src Sexp,
    b: &'src Sexp,
    then: &'src Sexp,
    els: &'src Sexp,
}

impl<'src> SpecialForm<'src> for Eq<'src> {
    fn from_sexp(sexp: &'src Sexp) -> Option<Self> {
        match sexp {
            Sexp::Atom(_) => None,
            Sexp::List(xs) => match &xs[..] {
                [Sexp::Atom(eq), a, b, then, els] if *eq == *EQ => Some(Self { a, b, then, els }),
                _ => None,
            },
        }
    }

    fn eval(self) -> Sexp {
        let Self { a, b, then, els } = self;

        if a == b {
            then.clone()
        } else {
            els.clone()
        }
    }
}

static GENSLOP: Lazy<Symbol> = Lazy::new(|| TABLE.lock().unwrap().intern("genslop").unwrap());

#[derive(Debug, Clone)]
struct Genslop;

impl Genslop {
    fn gen() -> String {
        format!(
            "slop-{}",
            rand::thread_rng()
                .sample_iter(&Alphanumeric)
                .take(8)
                .map(char::from)
                .collect::<String>()
        )
    }
}

impl SpecialForm<'_> for Genslop {
    fn from_sexp(sexp: &Sexp) -> Option<Self> {
        if let Some(genslop) = sexp.atom() {
            if *genslop == *GENSLOP {
                Some(Genslop)
            } else {
                None
            }
        } else {
            None
        }
    }

    fn eval(self) -> Sexp {
        let ident = Genslop::gen();
        let symbol = TABLE.lock().unwrap().intern(ident).unwrap();
        Sexp::Atom(symbol)
    }
}

static LOG: Lazy<Symbol> = Lazy::new(|| TABLE.lock().unwrap().intern("log").unwrap());

#[derive(Debug, Clone)]
struct Log<'src>(&'src Sexp);

impl<'src> SpecialForm<'src> for Log<'src> {
    fn from_sexp(sexp: &'src Sexp) -> Option<Self> {
        match sexp {
            Sexp::List(xs) if xs.len() == 2 && *xs[0].atom()? == *LOG => Some(Self(&xs[1])),
            _ => None,
        }
    }

    fn eval(self) -> Sexp {
        eprintln!("LOG: {}", self.0);
        self.0.clone()
    }
}

fn make_rule(sexp: &Sexp) -> Option<Rule> {
    match sexp {
        Sexp::List(xs) => match &xs[..] {
            [Sexp::Atom(forall), Sexp::List(vars), lhs, rhs]
                if *forall == FORALL.to_owned() && vars.iter().all(Sexp::is_atom) =>
            {
                Some(Rule::forall(
                    vars.iter().map(Sexp::unwrap_atom).collect(),
                    lhs.clone(),
                    rhs.clone(),
                ))
            }
            [Sexp::Atom(concrete), lhs, rhs] if *concrete == CONCRETE.to_owned() => {
                Some(Rule::concrete(lhs.clone(), rhs.clone()))
            }
            _ => None,
        },
        Sexp::Atom(_) => None,
    }
}

fn run_program(src: &str, step_limit: usize, complexity_limit: usize) -> Option<Sexp> {
    let mut parsed = parse(&lex(src))?;

    let expr = parsed.pop()?;
    let mut rules = parsed.iter().filter_map(make_rule).collect_vec();
    rules.sort_by_key(Rule::num_vars);

    for (i, rule) in rules.iter().enumerate() {
        eprintln!("{}{rule}", if i % 2 == 0 { "\x1b[0m" } else { "\x1b[1m" });
    }
    //eprintln!("{:<50}\n{:<50}\n{:^50}", "|", "▾", expr.to_string());

    let (simped, steps) = simp(&rules, expr, step_limit, complexity_limit);
    eprintln!("\x1b[0m");
    eprintln!("Complexity: {}", simped.complexity());
    eprintln!("Steps:      {}", steps + 1);
    Some(simped)
}

// (forall (a b) (+ a (succ b)) (succ (+ a b))) (forall (a) (+ a 0) a) (forall (a) (= a a) true) (= (+ (succ 0) (succ 0)) (succ (succ 0)))