#!/usr/bin/python3
#
# Copyright (C) The Phosh Developers
#
# SPDX-License-Identifier: GPL-3.0-or-later
#
# Author: Guido Günther <agx@sigxcpu.org>
#
# The extraction is base on the MIT licensed
# https://github.com/mpoyraz/ngram-lm-wiki

import fileinput
import os
import sys
import re
import json
import argparse
import random
import subprocess
import sqlite3
from tqdm import tqdm
from multiprocessing import Pool
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.util import ngrams
from pathlib import Path
from collections import defaultdict

# Tokenize to sentences
tokenize_fn = None
# Lowercase a sentence
lower_fn = None
# Drop an unwanted word when counting n-grams
drop_word_fn = None
# Chars to remove from wiki data
chars_to_remove_regex = r"[#$%&()*+,-./:;<=>?@\[\]^_{|}~!\"\\]"
apostrophes = "[’`´ʹʻʼʽʿˈ‘]"
# Maximum n-gram count
max_ngrams = 3


def parse_sentences_from_wiki_json_file(fpath):
    with open(fpath) as fp:
        texts = [json.loads(line.strip())["text"] for line in fp]

    # Sentences from paragraphs
    sentences = []
    for text in texts:
        for sent in tokenize_fn(text):
            # Lower the sentence
            sent = lower_fn(sent)
            # Remove pre-defined chars
            sent = re.sub(chars_to_remove_regex, "", sent)
            # Unify apostrophes
            sent = re.sub(apostrophes, "'", sent)
            # Remove multiple spaces
            sent = re.sub(r"\s+", " ", sent)
            # Append
            if len(sent) > 0:
                sentences.append(sent)

    return sentences


def extract_wiki_dump(extract_dir, wiki_dump, n_procs):
    extractor = [
        "wikiextractor",
        wiki_dump,
        "-o",
        extract_dir,
        "--no-templates",
        "--json",
        "--processes",
        str(n_procs),
    ]
    subprocess.check_call(extractor)


def parse_sentences(sentence_file, extract_dir, n_files, n_procs):
    # Paths of the extracted wiki files
    dirs = list(extract_dir.glob("[A-Z][A-Z]"))
    filepaths = []
    for i in range(n_files):
        dir = random.choice(dirs)
        subdir = list(dir.glob("wiki_??"))
        f = random.choice(subdir)
        print(f)
        filepaths.append(f)

    with open(sentence_file, "w") as f:
        # Load each wiki files and parse sentences
        with Pool(n_procs) as pool:
            n_sentences = 0
            for sentences in tqdm(
                    pool.imap(parse_sentences_from_wiki_json_file, filepaths), total=n_files):
                for sent in sentences:
                    f.write(f"{sent}\n")
                n_sentences += len(sentences)
            print("Number of extracted sentences: {}".format(n_sentences))


def build_where_clause(words):
    where_clause = "WHERE"

    for i in range(len(words) - 1):
        where_clause += f" word_{len(words) - i - 1} = '{words[i]}' AND"

    where_clause += f" word = '{words[-1]}'"
    return where_clause


def build_ngrams(sentences, ngram_file, db_file):
    if os.path.exists(db_file):
        os.remove(db_file)
    print("Creating database:")
    con = sqlite3.connect(db_file)
    #  con.set_trace_callback(print)
    cur = con.cursor()

    for n in range(0, max_ngrams):
        cols = ", ".join([f"word_{i} TEXT" for i in reversed(range(n + 1))]).replace(
            "_0", ""
        )
        constraints = "UNIQUE({})".format(
            ", ".join([f"word_{i}" for i in range(n + 1)])).replace("_0", "")
        table = f"_{n + 1}_gram"
        cur.execute(f"CREATE TABLE {table}({cols}, count INTEGER, {constraints})")

    print("Filling database tables:")
    i = 0
    for sentence in tqdm(sentences):
        counts = defaultdict(int)
        tokens = word_tokenize(sentence)

        for n in range(1, max_ngrams + 1):
            n_grams = ngrams(tokens, n)
            for n_gram in n_grams:
                for word in n_gram:
                    if drop_word_fn(word):
                        break
                else:
                    counts[n_gram] += 1
        # Insert after each sentence to keep memory usage under control
        for key, count in counts.items():
            table = f"_{len(key)}_gram"
            words = ",".join([f"'{word}'" for word in key])
            where = build_where_clause(key)

            query = f"SELECT count FROM {table} {where}"
            try:
                res = cur.execute(query).fetchone()
            except Exception:
                print("Statement failed: %s", query)
                raise
            if res:
                count = res[-1] + 1
                stmt = f"UPDATE {table} SET count = {count} {where}"
            else:
                stmt = f"INSERT INTO {table} VALUES ({words}, {count})"
            try:
                cur.execute(stmt)
            except Exception:
                print("Statement failed: %s", stmt)
                raise

        if (i % 100000 == 0):
            con.commit()
        i += 1

    con.commit()

    # Create index
    for n in range(0, max_ngrams):
        word_cols = "({})".format(", ".join([f"word_{i}" for i in reversed(range(n + 1))])).replace("_0", "")
        table = f"_{n + 1}_gram"
        index = f"_{n + 1}_index"
        stmt = f"CREATE UNIQUE INDEX {index} ON {table}{word_cols}"
        cur.execute(stmt)
    con.commit()

    cur.execute("pragma optimize")
    con.commit()

    # TODO: drop rare items from tables
    con.execute("VACUUM")
    con.commit()
    con.close()


def main():
    parser = argparse.ArgumentParser(
        description="Build ngram database from Wikipedia dumps"
    )
    parser.add_argument(
        "--dump", type=str, required=True, help="Path to a wikipedia dump"
    )
    parser.add_argument("--output", type=str, required=True, help="Output directory")
    parser.add_argument(
        "--language", type=str, default="de", help="Language of the wikipedia dump"
    )
    parser.add_argument(
        "--processes", type=int, default=8, help="Number of processes to use"
    )
    parser.add_argument(
        "--skip-extract", default=False, action="store_true", help="Extract wiki data"
    )
    parser.add_argument(
        "--files", type=int, default=10, help="Number of wiki files to use to build the DB"
    )
    parser.add_argument(
        "--skip-parse",
        default=False,
        action="store_true",
        help="Parse extraced wiki data into sentences",
    )
    parser.add_argument(
        "--skip-presage-ngrams",
        default=False,
        action="store_true",
        help="Build n-grams of built sentences for presage",
    )
    args = parser.parse_args()

    global tokenize_fn, lower_fn, drop_word_fn
    # Defaults for all languages
    tokenize_fn = sent_tokenize
    lower_fn = lambda x: x.lower()
    drop_word_fn = lambda x: "'" in x

    if args.language in ["de"]:
        tokenize_fn = lambda x: sent_tokenize(x, language="german")

    output_path = Path(args.output)
    extract_dir = output_path / "extract"
    sentence_file = output_path / "sentences.txt"
    ngram_file = output_path / f"n-gram-{args.language}.txt"
    db_file = output_path / f"database_{args.language}.db"

    if not args.skip_extract:
        print("Extracting Wiki source")
        extract_wiki_dump(extract_dir, args.dump, args.processes)

    if not args.skip_parse:
        print("Parsing sentences")
        sentences = parse_sentences(sentence_file, extract_dir, args.files, args.processes)

    sentences = fileinput.input(sentence_file, encoding="utf-8")
    if not args.skip_presage_ngrams:
        print("Building N-grams")
        build_ngrams(sentences, ngram_file, db_file)

    return 0


if __name__ == "__main__":
    sys.exit(main())
