/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.naturalli;

import edu.stanford.nlp.ie.machinereading.structure.Span;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasIndex;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.naturalli.ClauseSplitter;
import edu.stanford.nlp.naturalli.NaturalLogicAnnotator;
import edu.stanford.nlp.naturalli.RelationTripleSegmenter;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.process.TSVSentenceProcessor;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations;
import edu.stanford.nlp.semgraph.SemanticGraphEdge;
import edu.stanford.nlp.semgraph.semgrex.SemgrexMatcher;
import edu.stanford.nlp.semgraph.semgrex.SemgrexPattern;
import edu.stanford.nlp.trees.PennTreeReader;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.UniversalEnglishGrammaticalStructureFactory;
import edu.stanford.nlp.util.ArgumentParser;
import edu.stanford.nlp.util.ArrayCoreMap;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

public class CreateClauseDataset
implements TSVSentenceProcessor {
    private static final Redwood.RedwoodChannels log = Redwood.channels(CreateClauseDataset.class);
    @ArgumentParser.Option(name="in", gloss="The input to read from")
    private static InputStream in = System.in;
    private static final Pattern TRACE_TARGET_PATTERN = Pattern.compile("(NP-.*)-([0-9]+)");
    private static final Pattern TRACE_SOURCE_PATTERN = Pattern.compile(".*\\*-([0-9]+)");
    private static final UniversalEnglishGrammaticalStructureFactory parser = new UniversalEnglishGrammaticalStructureFactory();
    private static final RelationTripleSegmenter segmenter = new RelationTripleSegmenter();
    private static final NaturalLogicAnnotator natlog = new NaturalLogicAnnotator();

    private CreateClauseDataset() {
    }

    private static Span toSpan(List<? extends HasIndex> chunk) {
        int min = Integer.MAX_VALUE;
        int max = -1;
        for (HasIndex hasIndex : chunk) {
            min = Math.min(hasIndex.index() - 1, min);
            max = Math.max(hasIndex.index(), max);
        }
        assert (min >= 0);
        assert (max < Integer.MAX_VALUE && max > 0);
        return new Span(min, max);
    }

    @Override
    public void process(long id, Annotation doc) {
        CoreMap sentence = (CoreMap)((List)doc.get(CoreAnnotations.SentencesAnnotation.class)).get(0);
        SemanticGraph depparse = (SemanticGraph)sentence.get(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class);
        log.info("| " + (String)sentence.get(CoreAnnotations.TextAnnotation.class));
        BitSet consumedAsSubjects = new BitSet();
        ArrayList<Span> subjectSpans = new ArrayList<Span>();
        block0: for (IndexedWord head : depparse.topologicalSort()) {
            Optional<List<IndexedWord>> subjectChunk;
            if (!head.tag().startsWith("N") && !head.tag().equals("PRP") || !(subjectChunk = segmenter.getValidChunk(depparse, head, CreateClauseDataset.segmenter.VALID_SUBJECT_ARCS, Optional.empty(), true)).isPresent()) continue;
            for (IndexedWord tok : subjectChunk.get()) {
                if (!consumedAsSubjects.get(tok.index())) continue;
                continue block0;
            }
            for (IndexedWord tok : subjectChunk.get()) {
                consumedAsSubjects.set(tok.index());
            }
            subjectSpans.add(CreateClauseDataset.toSpan(subjectChunk.get()));
        }
    }

    private static SemanticGraph parse(Tree tree) {
        return new SemanticGraph(parser.newGrammaticalStructure(tree).typedDependenciesCollapsed());
    }

    private static Collection<Pair<Span, Span>> subjectObjectPairs(SemanticGraph depparse, List<CoreLabel> tokens, Map<Integer, Span> traceTargets, Map<Integer, Integer> traceSources) {
        IndexedWord object;
        SemgrexMatcher matcher;
        ArrayList<Pair<Span, Span>> data = new ArrayList<Pair<Span, Span>>();
        for (SemgrexPattern vpPattern : CreateClauseDataset.segmenter.VP_PATTERNS) {
            matcher = vpPattern.matcher(depparse);
            while (matcher.find()) {
                IndexedWord verb = matcher.getNode("verb");
                object = matcher.getNode("object");
                if (verb == null || object == null) continue;
                boolean hasSubject = false;
                for (SemanticGraphEdge edge : depparse.outgoingEdgeIterable(verb)) {
                    if (!edge.getRelation().toString().contains("subj")) continue;
                    hasSubject = true;
                }
                for (SemanticGraphEdge edge : depparse.outgoingEdgeIterable(object)) {
                    if (!edge.getRelation().toString().contains("subj")) continue;
                    hasSubject = true;
                }
                if (hasSubject) continue;
                Optional<List<IndexedWord>> verbChunk = segmenter.getValidChunk(depparse, verb, CreateClauseDataset.segmenter.VALID_ADVERB_ARCS, Optional.empty(), true);
                Optional<List<IndexedWord>> objectChunk = segmenter.getValidChunk(depparse, object, CreateClauseDataset.segmenter.VALID_OBJECT_ARCS, Optional.empty(), true);
                if (!verbChunk.isPresent() || !objectChunk.isPresent()) continue;
                verbChunk.get().sort(Comparator.comparingInt(IndexedWord::index));
                objectChunk.get().sort(Comparator.comparingInt(IndexedWord::index));
                int traceId = -1;
                Span verbSpan = CreateClauseDataset.toSpan(verbChunk.get());
                Span traceSpan = Span.fromValues(verbSpan.start() - 1, verbSpan.end() + 1);
                for (Map.Entry<Integer, Integer> entry : traceSources.entrySet()) {
                    if (!traceSpan.contains(entry.getValue())) continue;
                    traceId = entry.getKey();
                }
                if (traceId < 0) continue;
                Span subjectSpan = traceTargets.get(traceId);
                Span objectSpan = CreateClauseDataset.toSpan(objectChunk.get());
                if (subjectSpan == null) continue;
                data.add(Pair.makePair(subjectSpan, objectSpan));
            }
        }
        for (SemgrexPattern vpPattern : CreateClauseDataset.segmenter.VERB_PATTERNS) {
            matcher = vpPattern.matcher(depparse);
            while (matcher.find()) {
                IndexedWord subject = matcher.getNode("subject");
                object = matcher.getNode("object");
                if (subject == null || object == null) continue;
                Optional<List<IndexedWord>> subjectChunk = segmenter.getValidChunk(depparse, subject, CreateClauseDataset.segmenter.VALID_SUBJECT_ARCS, Optional.empty(), true);
                Optional<List<IndexedWord>> objectChunk = segmenter.getValidChunk(depparse, object, CreateClauseDataset.segmenter.VALID_OBJECT_ARCS, Optional.empty(), true);
                if (!subjectChunk.isPresent() || !objectChunk.isPresent()) continue;
                Span subjectSpan = CreateClauseDataset.toSpan(subjectChunk.get());
                Span objectSpan = CreateClauseDataset.toSpan(objectChunk.get());
                data.add(Pair.makePair(subjectSpan, objectSpan));
            }
        }
        return data;
    }

    private static Map<Integer, Span> findTraceTargets(Tree root) {
        HashMap<Integer, Span> spansInTree = new HashMap<Integer, Span>(4);
        Matcher m = TRACE_TARGET_PATTERN.matcher(root.label().value() == null ? "NULL" : root.label().value());
        if (m.matches()) {
            int index = Integer.parseInt(m.group(2));
            spansInTree.put(index, Span.fromPair(root.getSpan()).toExclusive());
        }
        for (Tree child : root.children()) {
            spansInTree.putAll(CreateClauseDataset.findTraceTargets(child));
        }
        return spansInTree;
    }

    private static Map<Integer, Integer> findTraceSources(Tree root) {
        HashMap<Integer, Integer> spansInTree = new HashMap<Integer, Integer>(4);
        Matcher m = TRACE_SOURCE_PATTERN.matcher(root.label().value() == null ? "NULL" : root.label().value());
        if (m.matches()) {
            int index = Integer.parseInt(m.group(1));
            spansInTree.put(index, ((CoreLabel)root.label()).index() - 1);
        }
        for (Tree child : root.children()) {
            spansInTree.putAll(CreateClauseDataset.findTraceSources(child));
        }
        return spansInTree;
    }

    private static int countDatums(List<Pair<CoreMap, Collection<Pair<Span, Span>>>> data) {
        int count = 0;
        for (Pair<CoreMap, Collection<Pair<Span, Span>>> datum : data) {
            count += ((Collection)datum.second).size();
        }
        return count;
    }

    private static List<Pair<CoreMap, Collection<Pair<Span, Span>>>> processDirectory(String name, File directory) throws IOException {
        Redwood.Util.forceTrack("Processing " + name);
        Iterable<File> files = IOUtils.iterFilesRecursive(directory, "mrg");
        int numTreesProcessed = 0;
        ArrayList<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData = new ArrayList<Pair<CoreMap, Collection<Pair<Span, Span>>>>(1024);
        for (File file : files) {
            Tree tree;
            PennTreeReader reader = new PennTreeReader(IOUtils.readerFromFile(file));
            while ((tree = reader.readTree()) != null) {
                try {
                    tree.indexSpans();
                    tree.setSpans();
                    final List<CoreLabel> tokens = tree.getLeaves().stream().map(leaf -> (CoreLabel)leaf.label()).collect(Collectors.toList());
                    final SemanticGraph graph = CreateClauseDataset.parse(tree);
                    Map<Integer, Span> targets = CreateClauseDataset.findTraceTargets(tree);
                    Map<Integer, Integer> sources = CreateClauseDataset.findTraceSources(tree);
                    ArrayCoreMap sentence = new ArrayCoreMap(4){
                        {
                            super(capacity);
                            this.set(CoreAnnotations.TokensAnnotation.class, tokens);
                            this.set(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class, graph);
                            this.set(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class, graph);
                            this.set(SemanticGraphCoreAnnotations.EnhancedPlusPlusDependenciesAnnotation.class, graph);
                        }
                    };
                    natlog.doOneSentence(null, sentence);
                    Collection<Pair<Span, Span>> trainingDataFromSentence = CreateClauseDataset.subjectObjectPairs(graph, tokens, targets, sources);
                    trainingData.add(Pair.makePair(sentence, trainingDataFromSentence));
                    if (++numTreesProcessed % 100 != 0) continue;
                    Redwood.Util.log("[" + new DecimalFormat("00000").format(numTreesProcessed) + "] " + CreateClauseDataset.countDatums(trainingData) + " known extractions");
                }
                catch (Throwable t) {
                    t.printStackTrace();
                }
            }
        }
        Redwood.Util.log("" + numTreesProcessed + " trees processed yielding " + CreateClauseDataset.countDatums(trainingData) + " known extractions");
        Redwood.Util.endTrack("Processing " + name);
        return trainingData;
    }

    public static void main(String[] args) throws IOException {
        Redwood.Util.forceTrack("Processing treebanks");
        ArrayList<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData = new ArrayList<Pair<CoreMap, Collection<Pair<Span, Span>>>>();
        trainingData.addAll(CreateClauseDataset.processDirectory("WSJ", new File("/home/gabor/lib/data/penn_treebank/wsj")));
        trainingData.addAll(CreateClauseDataset.processDirectory("Brown", new File("/home/gabor/lib/data/penn_treebank/brown")));
        Redwood.Util.endTrack("Processing treebanks");
        Redwood.Util.forceTrack("Training");
        Redwood.Util.log("dataset size: " + trainingData.size());
        ClauseSplitter.train(trainingData.stream(), new File("/home/gabor/tmp/clauseSearcher.ser.gz"), new File("/home/gabor/tmp/clauseSearcherData.tab.gz"));
        Redwood.Util.endTrack("Training");
    }
}

