from __future__ import print_function
import logging
import time
import datetime
import re
import os
from .data import InvalidData


class BatchError(Exception):

    def __init__(self, e):
        super(BatchError, self).__init__(e.args[0])

class LoadingError(BatchError): pass
class AnalysisError(BatchError): pass
class SaveError(BatchError): pass


class BatchRun(object):
    """
    configure a batch run from a dataset, analysis and saving protocol
    :param dataiter: iterable that returns DataBlock instances
    :param analysefn: function that analyses a dataset from <DataBlock instance>.load()
    :param outputwriter: object that writes the output
    :param progresslog: filename of a log file that registers the progress per block
    :param continue_previous: only run the blocks that have not been completed (checks the progress log)
    :param analysissettings: additional keyword arguments for analysis
    :param savesettings: additional keyword arguments for saving
    """

    def __init__(self, dataiter, analysefn, outputwriter,
                 progresslog=None, continue_previous=False,
                 analysissettings=None, savesettings=None):
        self.dataiter = dataiter
        self.analysefn = analysefn
        self.outputwriter = outputwriter
        self.progresslog = progresslog
        self.settings = dict(analysis=analysissettings or {},
                             save=savesettings or {},
                             continue_previous=continue_previous)

    def init_progresslog(self):
        completed_blocks = []
        if not os.path.isfile(self.progresslog):
            headermode = 'w'
        elif self.settings['continue_previous']:
            with open(self.progresslog, 'r') as f:
                header = f.readline()
                p = re.compile(r'^(..) (.*)$')
                for line in f:
                    m = p.match(line)
                    if not m:
                        continue
                    status, name = m.groups()
                    if status == 'OK':
                        completed_blocks.append(name)
            headermode = 'a'
        else:
            headermode = 'w'

        with open(self.progresslog, headermode) as f:
            f.write('progress {}\n'.format(datetime.datetime.now()))

        return set(completed_blocks)

    def run(self):
        """start the batch run"""
        t0 = time.time()
        logger = logging.getLogger('{}.{}'.format(__name__, self.__class__.__name__))
        exception_count = 0

        # write header to new progress log
        previously_completed = self.init_progresslog() if self.progresslog else set()

        # iterate over datasets
        for i, block in enumerate(self.dataiter):
            iterlogger = logging.getLogger('{}.{}:{}'.format(__name__, self.__class__.__name__, block.name))

            if block.name in previously_completed:
                iterlogger.debug('skipping previously completed block {}'.format(block.name))
                continue

            iterlogger.debug('item started')
            t1 = time.time()

            try:
                # load the data
                try:
                    data = block.load()
                    if not isinstance(data, tuple):
                        data = data,
                except InvalidData as e:
                    iterlogger.warning(str(e))
                    if self.progresslog:
                        safewrite(self.progresslog, 'IV {}\n'.format(block.name))
                    continue
                except Exception as e:
                    iterlogger.exception('error while loading {}'.format(block.name))
                    if self.progresslog:
                        safewrite(self.progresslog, '-- {}\n'.format(block.name))
                    raise LoadingError(e)

                # analyse the data
                try:
                    out = self.analysefn(*data, **self.settings['analysis'])
                    listed_output = _validate_output(out)
                except Exception as e:
                    iterlogger.exception('error in analysis for {}'.format(block.name))
                    if self.progresslog:
                        safewrite(self.progresslog, 'ER {}\n'.format(block.name))
                    raise AnalysisError(e)

                # save the data
                try:
                    for v in listed_output:
                        self.save_item(v, block, iterlogger)
                except Exception as e:
                    iterlogger.exception('could not save result from {}'.format(block.name))
                    if self.progresslog:
                        safewrite(self.progresslog, 'ER {}\n'.format(block.name))
                    raise SaveError(e)
            except BatchError:
                exception_count += 1
                continue

            if self.progresslog:
                safewrite(self.progresslog, 'OK {}\n'.format(block.name))

            # log loop completion
            iterlogger.info('completed in {:.2f}s'.format(time.time() - t1))

        # log run completion
        logger.warning('analysis completed in {:.2f}s ({} errors)'.format(time.time() - t0, exception_count))

    def save_item(self, out, block, logger):
        masked_data = block.mask_points(*out)
        if not masked_data or len(masked_data[0]) == 0:
            logger.warning('no datapoins for {}'.format(block.name))
            if self.progresslog:
                safewrite(self.progresslog, 'OK {}\n'.format(block.name))
        self.outputwriter.save(type(out)(*masked_data), block.name, **self.settings['save'])


def safewrite(fname, s, mode='a'):
    with open(fname, mode) as f:
        f.write(s)


def _validate_output(v):
    import types
    if isinstance(v, types.GeneratorType):
        listed_output = list(v)
    else:
        listed_output = [v]

    for item in listed_output:
        if not isinstance(item, tuple) or not hasattr(item, '_fields'):
            raise TypeError('analysis output {!r} not a named tuple'.format(item.__class__.__name__))
    return listed_output

