#!/usr/bin/python3

# see apertium-editdist --help for usage

import sys
import struct
import argparse

description_string = "Produce an edit distance transducer in ATT format."

epilog_string = """
For the default case, all the desired transitions are generated with weight 1.0.

The specification file should be in the following format:
* First, an (optional) list of tokens separated by newlines
  All transitions involving these tokens that are otherwise unspecified
  are generated with weight 1.0. Symbol weights can be specified by appending
  a tab and a weight to the symbol. Transitions involving such a symbol
  will have the user-specified weight added to it.
* If you want to exclude symbols that may be induced from a transducer,
  add a leading ~ character to that line.
* If you want to specify transitions, insert a line with the content "@@"
  (without the quotes)
* In the following lines, specified transitions with the form
  FROM <TAB> TO <TAB> WEIGHT
  where FROM is the source token, TO is the destination token and WEIGHT is
  a nonnegative floating point number specifying the weight. By default,
  if only one transition involving FROM and TO is specified, the same WEIGHT
  will be used to specify the transition TO -> FROM (assuming that both are
  listed in the list of tokens).
* If the command line option to generate swaps is set, you can also specify swap
  weights with
  FROM,TO <TAB> TO,FROM <TAB> WEIGHT
  Again, unspecified swaps will be generated automatically with weight 1.0.
* Lines starting with ## are comments.

with d for distance and S for size of alphabet plus one
(for epsilon), expected output is a transducer in ATT format with
* Swapless:
** d + 1 states
** d*(S^2 + S - 1) transitions
* Swapful:
** d*(S^2 - 3S + 3) + 1 states
** d*(3S^2 - 5S + 3) transitions
"""

def maketrans(from_st, to_st, from_sy, to_sy, weight):
    return "{0}\t{1}\t{2}\t{3}\t{4}".format(from_st, to_st, from_sy, to_sy, weight)

class Transducer:
    def __init__(self, dist, other, epsilon, swap, elim):
        self.distance = dist
        self.alphabet = {}
        self.exclusions = set()
        self.substitutions = {}
        self.swaps = {}
        self.should_swap = swap
        self.should_elim = elim
        self.other = other
        self.epsilon = epsilon
        self.swapstate = self.distance + 1
        self.skipstate = self.swapstate + 1
        self.transitions = []

    def read_input_file(self, fname):
        def parse_error(err, line):
            nonlocal fname
            sys.stderr.write('Syntax error on line %s of %s: %s.\n' % (fname, line, err))
            sys.exit(1)
        with open(fname) as fin:
            in_alpha = True
            for i, line_ in enumerate(fin, 1):
                line = line_.strip()
                if not line or line.startswith('##'):
                    continue
                if in_alpha:
                    if line == '@@':
                        in_alpha = False
                    elif len(line) > 1 and line[0] == '~':
                        self.exclusions.add(line[1:].strip())
                        continue
                    elif '\t' in line:
                        if line.count('\t') > 1:
                            parse_error('Too many tabs', i)
                        symbol, weight = line.split('\t')
                        try:
                            self.alphabet[symbol] = float(weight)
                        except:
                            parse_error('Unable to parse weight', i)
                    else:
                        self.alphabet[line] = 0.0
                else:
                    if line.count('\t') > 2 or line.count('\t') == 1:
                        parse_error('Wrong number of tabs, expected 3 tab-separated columns', i)
                    elif line.count('\t') == 0:
                        parse_error('Substitutions and swaps must be tab-separated', i)
                    l, r, w = line.split('\t')
                    weight = 0.0
                    try:
                        weight = float(w)
                    except:
                        parse_error('Unable to parse weight', i)
                    if ',' in line:
                        frompair = tuple(l.split(','))
                        topair = tuple(r.split(','))
                        if not (len(frompair) == len(topair) == 2):
                            parse_error('Swap-specification has wrong number of comma separators', i)
                        self.swaps.setdefault((frompair, topair), weight)
                    else:
                        self.substitutions.setdefault((l, r), weight)

    def read_optimized_lookup_alphabet(self, fname):
        with open(fname, "rb") as fin:
            byt = fin.read(5)
            if byt == b"HFST\0":
                # just ignore any hfst3 header
                header_length = struct.unpack_from("<H", fin.read(3), 0)[0]
                fin.read(header_length)
                # hopefully there's nothing surprising in here
                byt = fin.read(56)
            else:
                byt += fin.read(56 - 5)
            symbol_count = struct.unpack_from("<H", byt, 2)[0]
            for n in range(symbol_count):
                s = fin.read(1)
                while s[-1] != 0:
                    s += fin.read(1)
                sym = s[:-1].decode('utf-8')
                if len(sym) != 1:
                    sys.stderr.write("Ignored symbol " + sym + "\n")
                elif not sym.isspace() and sym not in self.exclusions:
                    self.alphabet.setdefault(sym, 0.0)

    def extend_alphabet(self, alpha):
        for c in alpha:
            if c not in self.exclusions:
                self.alphabet.setdefault(c, 0.0)

    def clean_alphabet(self):
        # depending on what order read_input_file(), extend_alphabet()
        # and read_optimzed_lookup_alphabet() are called in, symbols
        # might get included in the alphabet which should be excluded
        # so remove those
        for s in self.exclusions:
            if s in self.alphabet:
                del self.alphabet[s]

    def generate(self):
        # for substitutions and swaps that weren't defined by the user,
        # generate standard subs and swaps
        self.substitutions.setdefault((self.other, self.epsilon), 1.0)
        for symbol in self.alphabet:
            w = 1.0 + self.alphabet[symbol]
            self.substitutions.setdefault((self.other, symbol), w)
            self.substitutions.setdefault((self.epsilon, symbol), w)
            self.substitutions.setdefault((symbol, self.epsilon), w)
            for symbol2 in self.alphabet:
                if symbol == symbol2: continue
                w += self.alphabet[symbol2]
                p12 = (symbol, symbol2)
                p21 = (symbol2, symbol)
                self.swaps.setdefault((p12, p21), self.swaps.get((p21, p12), w))
                self.substitutions.setdefault(p12, self.substitutions.get(p21, w))

    def next_special(self, state):
        if state == "swap":
            self.swapstate += 1
            while self.swapstate == self.skipstate:
                self.swapstate += 1
        elif state == "skip":
            self.skipstate += 1
            while self.skipstate == self.swapstate:
                self.skipstate += 1
        else:
            raise Exception

    def make_identities(self, state, nextstate = None):
        if nextstate is None:
            nextstate = state
        ret = []
        for symbol in self.alphabet:
            if symbol not in (self.epsilon, self.other):
                ret.append(maketrans(state, nextstate, symbol, symbol, 0.0))
        return ret

    def make_swaps(self, state, nextstate = None):
        if nextstate is None:
            nextstate = state + 1
        ret = []
        if self.should_swap:
            for swap in self.swaps:
                ret.append(maketrans(state, self.swapstate, swap[0][0], swap[0][1], self.swaps[swap]))
                ret.append(maketrans(self.swapstate, nextstate, swap[1][0], swap[1][1], 0.0))
                self.next_special("swap")
        return ret

    # for substitutions, we try to eliminate redundancies by refusing to do
    # deletion right after insertion and insertion right after deletion
    def make_substitutions(self, state, nextstate = None):
        if nextstate is None:
            nextstate = state + 1
        ret = []
        for sub in self.substitutions:
            if (nextstate + 1 >= self.distance) or not self.should_elim:
                ret.append(maketrans(state, nextstate, sub[0], sub[1], self.substitutions[sub]))
            elif sub[1] is self.epsilon: # deletion
                ret.append(maketrans(state, self.skipstate, sub[0], sub[1], self.substitutions[sub]))
                ret += self.make_identities(self.skipstate, nextstate)
                ret += self.make_swaps(self.skipstate, nextstate + 1)
                for sub2 in self.substitutions:
                    # after deletion, refuse to do insertion
                    if sub2[0] != self.epsilon:
                        ret.append(maketrans(self.skipstate, nextstate + 1, sub2[0], sub2[1], self.substitutions[sub2]))
                self.next_special("skip")
            elif sub[0] is self.epsilon: # insertion
                ret.append(maketrans(state, self.skipstate, sub[0], sub[1], self.substitutions[sub]))
                ret += self.make_identities(self.skipstate, nextstate)
                ret += self.make_swaps(self.skipstate, nextstate + 1)
                for sub2 in self.substitutions:
                    # after insertion, refuse to do deletion
                    if sub2[1] != self.epsilon:
                        ret.append(maketrans(self.skipstate, nextstate + 1, sub2[0], sub2[1], self.substitutions[sub2]))
                self.next_special("skip")
            else:
                ret.append(maketrans(state, nextstate, sub[0], sub[1], self.substitutions[sub]))
        return ret

    def make_transitions(self):
        for state in range(self.distance):
            self.transitions.append(str(state + 1) + "\t0.0") # final states
            self.transitions += self.make_identities(state)
            self.transitions += self.make_substitutions(state)
            self.transitions += self.make_swaps(state)
        self.transitions += self.make_identities(self.distance)

    def verbose_report(self):
        template = '''
{states} states and {trans} transitions written for
distance {dist} and base alphabet size {alpha_size}

The alphabet was:
{alpha}
The exclusions were:
{excl}
'''
        alpha = ' '.join('%s:%s' % (s, w) for s, w in self.alphabet.items())
        excl = ' '.join(self.exclusions)
        report = template.format(states=max(self.skipstate, self.swapstate),
                                 trans=len(self.transitions),
                                 dist=self.distance,
                                 alpha_size=len(self.alphabet),
                                 alpha=alpha,
                                 excl=excl)
        sys.stderr.write(report)

def main():
    parser = argparse.ArgumentParser(description=description_string,
                                     epilog=epilog_string,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('-d', '--distance', type=int, metavar='DIST', default=1,
                        help='edit depth, default is 1')
    parser.add_argument('-e', '--epsilon', default='@0@', metavar='EPS',
                        help='epsilon symbol, default is @0@')
    parser.add_argument('--no-elim', action='store_true',
                        help="don't reduce elimination")
    parser.add_argument('-s', '--swap', action='store_true',
                        help='generate swaps in addition to insertions and deletions')
    parser.add_argument('-v', '--verbose', action='store_true',
                        help='print summary to stderr')
    alpha = parser.add_argument_group('Alphabet', 'Specify the alphabet of the transducer (these may be combined and repeated)')
    alpha.add_argument('-i', '--input', dest='inputfile', metavar='INPUT',
                       action='append', default=[],
                       help='specification file in edit-distance syntax')
    alpha.add_argument('-a', '--alphabet', dest='alphabetfile', metavar='TRANS',
                       action='append', default=[],
                       help='optimized-lookup format transducer to read alphabet from (ignoring multi-character symbols)')
    alpha.add_argument('alphabet_string', nargs='*',
                       help='the alphabet as a string on the command line')
    options = parser.parse_args()

    transducer = Transducer(options.distance, '@_UNKNOWN_SYMBOL_@', options.epsilon, options.swap, not options.no_elim)

    if not options.inputfile and not options.alphabet_string and not options.alphabetfile:
        parser.error('Must provide an alphabet')
    for fl in options.inputfile:
        transducer.read_input_file(fl)
    for s in options.alphabet_string:
        transducer.extend_alphabet(s)
    for fl in options.alphabetfile:
        transducer.read_optimized_lookup_alphabet(fl)
    transducer.clean_alphabet()

    transducer.generate()
    transducer.make_transitions()
    for transition in transducer.transitions:
        print(transition)

    if options.verbose:
        transducer.verbose_report()

if __name__ == '__main__':
    main()
