package machine;

import java.util.*;
import java.util.List;

import metalexer.ast.*;

public class Machine {
    private static final int BOF_SYM_NUM = 0;

    //NB: really pairs
    //NB: order matters
    public static ENFA buildENFA(List<MetaPattern> mpats, List<Integer> actions, int numSymbols) {
        List<SubMachine> machines = new ArrayList<SubMachine>(mpats.size());
        List<Integer[]> transitions = new ArrayList<Integer[]>();
        List<Set<Integer>> epsilons = new ArrayList<Set<Integer>>();
        int start = addState(transitions, epsilons, numSymbols);
        Integer[] startTransitions = transitions.get(start);
        Set<Integer> startEpsilons = epsilons.get(start);
        for(int sym = 0; sym < numSymbols; sym++) {
            startTransitions[sym] = start; //self-loop on all symbols
        }
        for(MetaPattern mpat : mpats) {
            SubMachine sub = buildSubMachine(mpat, transitions, epsilons, numSymbols);
            startEpsilons.add(sub.startState);
            machines.add(sub);
        }

        return buildENFA(transitions, epsilons, actions, numSymbols);
    }

    private static ENFA buildENFA(List<Integer[]> mTransitions, List<Set<Integer>> mEpsilons, List<Integer> mActions, int numSymbols) {
        final int numStates = mTransitions.size();
        
        BitSet[/*state*/][/*symbol*/] transitions = new BitSet[numStates][numSymbols];
        BitSet[/*state*/] epsilons = new BitSet[numStates];
        Integer[/*state*/] actions = new Integer[numStates];
        
        for(int state = 0; state < numStates; state++) {
            for(int sym = 0; sym < numSymbols; sym++) {
                transitions[state][sym] = new BitSet(numStates);
                Integer destination = mTransitions.get(state)[sym];
                if(destination != null) {
                    transitions[state][sym].set(destination);
                }
            }
            epsilons[state] = new BitSet(numStates);
            for(Integer st : mEpsilons.get(state)) {
                epsilons[state].set(st);
            }
            actions[state] = mActions.get(state);
        }
        
        return new ENFA(transitions, epsilons, actions);
    }

    //TODO: split into separate methods using aspect
    public static SubMachine buildSubMachine(MetaPattern mpat, List<Integer[]> transitions, List<Set<Integer>> epsilons, int numSymbols) {
        int start = addState(transitions, epsilons, numSymbols);
        int end = addState(transitions, epsilons, numSymbols);

        if(mpat instanceof AltMetaPattern) {
            SubMachine sub1 = buildSubMachine(((AltMetaPattern) mpat).getAlt1(), transitions, epsilons, numSymbols);
            SubMachine sub2 = buildSubMachine(((AltMetaPattern) mpat).getAlt2(), transitions, epsilons, numSymbols);

            epsilons.get(start).add(sub1.startState);
            epsilons.get(start).add(sub2.startState);
            epsilons.get(sub1.endState).add(end);
            epsilons.get(sub2.endState).add(end);
        } else if(mpat instanceof SeqMetaPattern) {
            SubMachine sub1 = buildSubMachine(((SeqMetaPattern) mpat).getSeq1(), transitions, epsilons, numSymbols);
            SubMachine sub2 = buildSubMachine(((SeqMetaPattern) mpat).getSeq2(), transitions, epsilons, numSymbols);

            epsilons.get(start).add(sub1.startState);
            epsilons.get(sub1.endState).add(sub2.startState);
            epsilons.get(sub1.endState).add(end);
        } else if(mpat instanceof PlusMetaPattern) {
            SubMachine sub = buildSubMachine(((PlusMetaPattern) mpat).getMetaPattern(), transitions, epsilons, numSymbols);

            epsilons.get(start).add(sub.startState);
            epsilons.get(sub.startState).add(end);
            epsilons.get(sub.endState).add(sub.startState);
        } else if(mpat instanceof StarMetaPattern) {
            SubMachine sub = buildSubMachine(((StarMetaPattern) mpat).getMetaPattern(), transitions, epsilons, numSymbols);

            epsilons.get(start).add(sub.startState);
            epsilons.get(sub.startState).add(end);
            epsilons.get(sub.endState).add(sub.startState);
            epsilons.get(start).add(end);
        } else if(mpat instanceof OptMetaPattern) {
            SubMachine sub = buildSubMachine(((OptMetaPattern) mpat).getMetaPattern(), transitions, epsilons, numSymbols);

            epsilons.get(start).add(sub.startState);
            epsilons.get(sub.startState).add(end);
            epsilons.get(start).add(end);

            //// Base Cases (i.e. actual transitions) //////////////////////////
        } else if(mpat instanceof ClassMetaPattern) {
            Integer[] startStateRow = transitions.get(start);

            ClassMetaPattern cmp = (ClassMetaPattern) mpat;
            if(cmp.getNegated()) {
                Set<Integer> negated = new HashSet<Integer>();
                for(BaseMetaPattern bmp : cmp.getBaseMetaPatterns()) {
                    negated.add(getValue(bmp));
                }
                for(int sym = 0; sym < numSymbols; sym++) {
                    if(!negated.contains(sym)) {
                        startStateRow[sym] = end; //edge from start to end
                    }
                }
            } else {
                for(BaseMetaPattern bmp : cmp.getBaseMetaPatterns()) {
                    startStateRow[getValue(bmp)] = end; //edge from start to end
                }
            }
        } else if(mpat instanceof AnyMetaPattern) {
            Integer[] startStateRow = transitions.get(start);
            for(int sym = 0; sym < numSymbols; sym++) {
                if(sym != BOF_SYM_NUM) {
                    startStateRow[sym] = end; //edge from start to end
                }
            }
        } else if(mpat instanceof BOFMetaPattern) {
            transitions.get(start)[BOF_SYM_NUM] = end; //edge from start to end
        } else if(mpat instanceof BaseMetaPattern) {
            int sym = getValue((BaseMetaPattern) mpat);
            transitions.get(start)[sym] = end; //edge from start to end
        }
        return new SubMachine(start, end);
    }

    public static int getValue(BaseMetaPattern bmp) {
        if(bmp instanceof SymMetaPattern) {
            SymMetaPattern smp = (SymMetaPattern) bmp;
            return smp.getSym().getSymbolValue();
        } else if(bmp instanceof RegionMetaPattern) {
            RegionMetaPattern rmp = (RegionMetaPattern) bmp;
            return rmp.getRegion().getDecl().getComponent().getRegionSymbolValue();
        } else {
            throw new RuntimeException();
        }
    }

    private static int addState(List<Integer[]> transitions, List<Set<Integer>> epsilons, int numSymbols) {
        int stateNum = transitions.size();
        transitions.add(new Integer[numSymbols]);
        epsilons.add(new HashSet<Integer>());
        return stateNum;
    }

    private static class SubMachine {
        int startState;
        int endState;
        public SubMachine(int startState, int endState) {
            this.startState = startState;
            this.endState = endState;
        }
    }
}
