#!/usr/bin/env python
import argparse
import os
import sys
from pprint import pprint
from collections import defaultdict
from collections.abc import Iterable

from meshroom.core import graph as pg


def addPlots(curves, title, fileObj):
    if not curves:
        return

    import matplotlib.pyplot as plt, mpld3

    fig = plt.figure()
    ax = fig.add_subplot(111, facecolor='#EEEEEE')
    ax.grid(color='white', linestyle='solid')

    for curveName, curve in curves:
        if not isinstance(curve[0], str):
            ax.plot(curve, label=curveName)
    ax.legend()
    # plt.ylim(0, 100)
    plt.title(title)

    mpld3.save_html(fig, fileObj)
    plt.close(fig)


parser = argparse.ArgumentParser(description='Query the status of nodes in a Graph of processes.')
parser.add_argument('graphFile', metavar='GRAPHFILE.mg', type=str,
                    help='Filepath to a graph file.')
parser.add_argument('--node', metavar='NODE_NAME', type=str,
                    help='Process the node alone.')
parser.add_argument('--graph', metavar='NODE_NAME', type=str,
                    help='Process the node and all previous nodes needed.')
parser.add_argument('--exportHtml', metavar='FILE', type=str,
                    help='Filepath to the output html file.')
parser.add_argument("--verbose", help="Print full status information",
                    action="store_true")

args = parser.parse_args()

if not os.path.exists(args.graphFile):
    print(f'ERROR: No graph file "{args.graphFile}".')
    sys.exit(-1)

graph = pg.loadGraph(args.graphFile)

graph.update()
graph.updateStatisticsFromCache()

nodes = []
if args.node:
    nodes = [graph.findNode(args.node)]
else:
    startNodes = None
    if args.graph:
        startNodes = [graph.node(args.graph)]
    nodes, edges = graph.dfsOnFinish(startNodes=startNodes)

for node in nodes:
    for chunk in node.chunks:
        print(f'{chunk.name}: {chunk.statistics.toDict()}\n')

if args.exportHtml:
    with open(args.exportHtml, 'w') as fileObj:
        for node in nodes:
            for chunk in node.chunks:
                for curves in (chunk.statistics.computer.curves, chunk.statistics.process.curves):
                    exportCurves = defaultdict(list)
                    for name, curve in curves.items():
                        s = name.split('.')
                        figName = s[0]
                        curveName = ''.join(s[1:])
                        exportCurves[figName].append((curveName, curve))

                    for name, curves in exportCurves.items():
                        addPlots(curves, name, fileObj)
