#!/usr/bin/env python3

################################################################################
##                                                                            ##
##  This file is part of NCrystal (see https://mctools.github.io/ncrystal/)   ##
##                                                                            ##
##  Copyright 2015-2021 NCrystal developers                                   ##
##                                                                            ##
##  Licensed under the Apache License, Version 2.0 (the "License");           ##
##  you may not use this file except in compliance with the License.          ##
##  You may obtain a copy of the License at                                   ##
##                                                                            ##
##      http://www.apache.org/licenses/LICENSE-2.0                            ##
##                                                                            ##
##  Unless required by applicable law or agreed to in writing, software       ##
##  distributed under the License is distributed on an "AS IS" BASIS,         ##
##  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  ##
##  See the License for the specific language governing permissions and       ##
##  limitations under the License.                                            ##
##                                                                            ##
################################################################################

#########################################################################
########################### System setup ################################
#########################################################################

import sys
pyversion = sys.version_info[0:3]
_minpyversion=(3,5,0)
if pyversion < (3,0,0):
    raise SystemExit('NCrystal no longer supports Python2.')
if pyversion < (3,5,0):
    if not _unittest:
        print('WARNING: Unsupported python version %i.%i.%i detected (recommended is %i.%i.%i or later).'%(pyversion+_minpyversion))

import os
import pathlib
import argparse
import math

#in unit test we dont display interactive images and we reduce cpu consumption
#by watching for special env var:
_unittest = bool(os.environ.get('NCRYSTAL_INSPECTFILE_UNITTESTS',0))

#Function for importing required python modules which may be missing, to provide
#a somewhat more helpful error to the user:
def import_optpymod(name):
    import importlib
    try:
        themod = importlib.import_module(name)
    except ImportError:
        raise SystemExit('ERROR: Could not import a required python module: %s'%name)
    return themod

#import NCrystal. Prefer the one from our own installation (ok to modify
#sys.path since we are in a script!):
_ = pathlib.Path( __file__ ).parent / '../share/NCrystal/python/NCrystal/__init__.py'
if _.exists():
    sys.path.insert(0,str(_.parent.parent.absolute().resolve()))

try:
    import NCrystal
except ImportError:
    raise SystemExit("ERROR: Could not import the NCrystal Python module (perhaps your PYTHONPATH is misconfigured).")

#########################################################################
#########################################################################
#########################################################################

def parse_cmdline():
    descr="""
Load input crystal files (usually .ncmat, .nxs, .laz or .lau) files with
NCrystal (v%s) and plot resulting isotropic cross sections for thermal
neutrons. Specifying more than one file results in a single comparison plot of
the total scattering cross section based on the different materials, whereas
specifying just a single file, results in a more detailed cross section plot as
well as a 2D plot of generated scatter angles."""%NCrystal.__version__

    descr=descr.strip()

    epilog="""
examples:
  $ %(prog)s Al_sg225.ncmat
  plot aluminium cross sections and scatter-angles versus neutron wavelength.
  $ %(prog)s Al_sg225.ncmat Ge_sg227.ncmat --common temp=200
  cross sections for aluminium and germanium at T=200K
  $ %(prog)s "Al_sg225.ncmat;dcutoff=0.1" "Al_sg225.ncmat;dcutoff=0.4" "Al_sg225.ncmat;dcutoff=0.8"
  effect of d-spacing cut-off on aluminium cross sections
  $ %(prog)s "Al_sg225.ncmat;temp=20" "Al_sg225.ncmat;temp=293.15" "Al_sg225.ncmat;temp=600"
  effect of temperature on aluminium cross sections"""

    parser = argparse.ArgumentParser(description=descr,
                                     epilog=epilog,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    default_temp=293.15
    default_dcut=0.4
    parser.add_argument('input_files', metavar='FILE', type=str, nargs='*',
                        help="""Input crystallographic file in NXS, Laz or Lau format. It is possible to
                        specify additional parameters by appending them to each
                        file by using the usual format accepted by NCMatCfg (see
                        also examples below).""")
    parser.add_argument('-d','--dump', action='store_true',
                        help='Print derived crystallographic information rather than displaying plots')
    parser.add_argument('--common','-c', metavar='CFG', type=str, default=[],
                        help='Common configuration items that will be applied to all files',action='append')
    parser.add_argument('--coh_elas','--bragg', action='store_true',
                        help="""Only generate coherent-elastic (Bragg diffraction) component""")
    parser.add_argument('--incoh_elas', action='store_true',
                        help="""Only generate incoherent-elastic component""")
    parser.add_argument('--elastic', action='store_true',
                        help="""Only generate elastic component""")
    parser.add_argument('--inelastic', action='store_true',
                        help="""Only generate inelastic components""")
    parser.add_argument('-a','--absorption', action='store_true',
                        help="""Include absorption in cross section plots""")
    parser.add_argument('-p','--pdf', action='store_true',
                        help="""Generate PDF file rather than launching an interactive plot viewer.""")
    parser.add_argument('--test', action='store_true',
                        help="""Perform quick validation of NCrystal installation.""")
    dpi_default=200
    parser.add_argument('--dpi', default=-1,type=int,
                        help="""Change plot resolution. Set to 0 to leave matplotlib defaults alone.
                        (default value is %i, or whatever the NCRYSTAL_DPI env var is set to)."""%dpi_default)

    parser.add_argument('--plugins', action='store_true',
                        help='List currently enabled loaded plugins.')
    parser.add_argument('-b','--browse', action='store_true',
                        help='List .ncmat files available in standard locations (current directory, compiled-in, or in NCrystal search path)')
    parser.add_argument('--extract', type=str, default=None, metavar="FILE",
                        help='''Extract contents of FILE using the same lookup mechanism as used for files
                        specified in NCrystal cfg strings, and thus can even be used to inspect
                        in-memory files''')

    args=parser.parse_args()

    if args.extract:
        return args

    if args.plugins:
        return args

    if args.browse:
        return args

    if args.dpi>3000:
        parser.error('Too high DPI value requested.')

    if args.dpi==-1:
        _=os.environ.get('NCRYSTAL_DPI',None)
        if _:
            try:
                _=int(_)
                if _<0:
                    raise ValueError
            except ValueError:
                print("ERROR: NCRYSTAL_DPI environment variable must be set to integer >=0")
                raise SystemExit
            if _>3000:
                parser.error('Too high DPI value requested via NCRYSTAL_DPI environment variable.')
            args.dpi=_
        else:
            args.dpi=dpi_default

    if args.test:
        if any((args.input_files,args.dump,args.coh_elas,args.incoh_elas,args.elastic,args.inelastic,args.absorption,args.pdf)):
            parser.error('Do not specify other arguments with --test.')
    elif not args.input_files:
        parser.error('Missing input file arguments')

    ncomp_select = sum((1 if _ else 0) for _ in (args.coh_elas,args.incoh_elas,args.elastic,args.inelastic))
    if ncomp_select > 1:
        parser.error('Do not specify more than one of: --coh_elas/--bragg, --incoh_elas, --elastic or --inelastic.')

    if args.coh_elas: args.comp = 'coh_elas'
    elif args.incoh_elas: args.comp = 'incoh_elas'
    elif args.elastic: args.comp = 'elastic'
    elif args.inelastic: args.comp = 'inelastic'
    else: args.comp = 'all'

    if args.absorption and ncomp_select>0:
        parser.error('Do not specify --absorption with either of: --coh_elas/--bragg, --incoh_elas, --elastic or --inelastic.')

    if args.dump and (ncomp_select>0 or args.absorption):
        parser.error('Do not specify --dump with either of: --coh_elas/--bragg, --incoh_elas, --elastic, --inelastic or --absorption.')

    if args.dump and len(args.input_files)>1:
        parser.error('Do not specify more than one input file with --dump [-d].')

    args.common=';'.join(args.common)
    return args

def create_wavelengths(np,cfgs,npoints,logspace=False):
    pcbraggs = [_f for _f in (c.get_scatter('coh_elas',allowfail=True) for c in cfgs) if _f]
    fallback = 10.0#materials with no bragg threshold
    bragg_thresholds = [(NCrystal.ekin2wl(c.domain()[0]) or fallback) for c in pcbraggs if not c._nullprocess]
    max_bragg_threshold = max(bragg_thresholds) if bragg_thresholds else None
    max_bragg_threshold = max_bragg_threshold or fallback
    wls_max = float(int(max_bragg_threshold*1.01+1.0)) if not math.isinf(max_bragg_threshold) else 10.0
    wls_min = 0.001
    wls_max *= 1.2
    return (np.logspace if logspace else np.linspace)(wls_min,wls_max,npoints)

_mpldpi=[None]
_pdffilename='ncrystal.pdf'
_npplt = None
def import_npplt(pdf=False):
    global _npplt
    if _npplt:
        #pdf par must be same as last call:
        assert bool(pdf)==bool(_npplt[2] is not None)
        return _npplt
    np = import_optpymod('numpy')
    mpl = import_optpymod('matplotlib')
    ##Commenting checks below since mpl.compare_versions is deprecated, and I am anyway not sure exactly which versions we support:
    ##if not mpl.compare_versions(mpl.__version__, '0.99.1.1'):
    ##    raise SystemExit("ERROR: Your version of matplotlib (%s) is too ancient to work - aborting plotting!"%mpl.__version__)
    ##if not mpl.compare_versions(mpl.__version__, '1.3'):
    ##    if not _unittest:
    ##        print("WARNING: Your version of matplotlib (%s) is unsupported - expect trouble! (needs at least version 1.3)."%mpl.__version__)

    if _mpldpi[0]:
        mpl.rcParams['figure.dpi']=_mpldpi[0]

    #ability to quit plot windows with Q:
    if 'keymap.quit' in mpl.rcParams and not 'q' in mpl.rcParams['keymap.quit']:
        mpl.rcParams['keymap.quit'] = tuple(list(mpl.rcParams['keymap.quit'])+['q','Q'])

    if _unittest or pdf:
        mpl.use('agg')
    if pdf:
        try:
            from matplotlib.backends.backend_pdf import PdfPages
        except ImportError:
            raise SystemExit("ERROR: Your installation of matplotlib does not have the required support for PDF output.")
    plt = import_optpymod('matplotlib.pyplot')

    _npplt = (np,plt,PdfPages(_pdffilename) if pdf else None)
    return _npplt

def getfields(cfgstr):
    return set([_f for _f in [e.split('=',1)[0] for e in cfgstr.strip().split(';') if '=' in e] if _f])

def combine_cfgs(cfg1,cfg2,prefercfg1=True):
    #assumes cfg1 starts with "filename;" and possibly the ignorefilecfg
    #keyword. Will append cfg2 items after the filename (& ignorefilecfg kw) and
    #before other cfg1 items, in order to give preference to cfg1 items.
    c1=cfg1.split(';')
    if not prefercfg1:
        return ';'.join(c1+cfg2.split(';'))
    filename = [c1.pop(0)]
    if c1 and c1[0].strip()=='ignorefilecfg':
        filename += [c1.pop(0).strip()]
    return ';'.join([_f for _f in (e.strip() for e in filename + cfg2.split(';')+c1) if _f])

#functions for creating labels and title:

def _remove_common_keyvals(dicts):
    """remove any key from the passed dicts which exists with identical value in all
    the dicts. Returns a single dictionary with entries thus removed."""
    sets=[set((k,v) for k,v in list(d.items())) for d in dicts]
    common = dict(set.intersection(*sets)) if sets else {}
    for k in list(common.keys()):
        for d in dicts:
            d.pop(k,None)
    return common

def _classify_differences(parts):
    l=[]
    for p in parts:
        cfg = 'FILENAME=%s;COMPNAME=%s' % (p._cfg.cfgstr,p._compname)
        raw = ['ignorefilecfg=1' if e.strip()=='ignorefilecfg' else e for e in cfg.split(';')]
        l += [ dict([s.strip() for s in e.split('=',1)] for e in raw if '=' in e) ]
    common = _remove_common_keyvals(l)
    return common,l

def _cfgdict_to_str(cfgdict):
    fn = cfgdict.pop('FILENAME','')
    o = [fn] if fn else []
    cn = cfgdict.pop('COMPNAME','')
    if cfgdict:
        o += [', '.join(('%s=%s'%(k,v) if k!='ignorefilecfg' else k) for k,v in sorted(cfgdict.items()))]
    if cn:
        o += [ { 'coh_elas':'Coherent elastic',
                 'incoh_elas':'Incoherent elastic',
                 'elastic':'Elastic',
                 'inelastic':'Inelastic',
                 'absorption':'Absorption',
                 'all':'Total scattering',
                 'elastic+inelastic+absorption':'Total scattering+Absorption'}[cn] ]
    return ' '.join(b%a for a,b in zip(o,['%s','[%s]','(%s)']))

def create_title_and_labels(parts):
    partscfg_common,partscfg_unique = _classify_differences(parts)
    return _cfgdict_to_str(partscfg_common),[_cfgdict_to_str(uc) or 'default' for uc in partscfg_unique]

def _end_plot(plt,pdf):
    if _unittest:
        plt.savefig(open(os.devnull,'wb'),format='raw',dpi=10)
        plt.close()
    elif pdf:
        pdf.savefig()
        plt.close()
    else:
        plt.show()

def plot_xsect(cfgs,comp,absorption,pdf=False):
    assert comp in ('coh_elas','incoh_elas','elastic','inelastic','all')
    assert not absorption or comp=='all'
    np,plt,pdf = import_npplt(pdf)
    plt.xlabel('Neutron wavelength [angstrom]')
    plt.ylabel('Cross section [barn]')

    if len(cfgs)==1 and comp in ('all','elastic'):
        if comp=='all':
            parts=[cfgs[0].get_scatter('coh_elas'),cfgs[0].get_scatter('incoh_elas'),cfgs[0].get_scatter('inelastic')]
        else:
            parts=[cfgs[0].get_scatter('coh_elas'),cfgs[0].get_scatter('incoh_elas')]
        if absorption:
            assert comp=='all'
            parts += [cfgs[0].get_absorption()]
    else:
        if absorption:
            assert comp=='all'
            parts=[c.get_totalxsect() for c in cfgs]
        else:
            parts=[c.get_scatter(comp) for c in cfgs]
    parts = [p for p in parts if not p._nullprocess]
    title,labels = create_title_and_labels(parts)
    if len(set(labels))!=len(labels):
        print("WARNING: Comparing identical setups?")

    wavelengths = create_wavelengths(np,cfgs,3000)
    ekins = NCrystal.wl2ekin(wavelengths)
    need_tot = (len(cfgs)==1 and len(parts)>1)
    xsects_tot = None
    max_len_label = 0

    #colors inspired by http://www.mulinblog.com/a-color-palette-optimized-for-data-visualization/
    col_red = "#F15854"
    partcols = [
        "#5DA5DA", # (blue)
        "#FAA43A", # (orange)
        "#60BD68", # (green)
        #"#B2912F", # (brown)
        "#B276B2", # (purple)
        #"#DECF3F", # (yellow)
        #"#F17CB0", # (pink)
        "#4D4D4D", # (gray)
        ]
    if not need_tot:
        partcols = [col_red]+partcols

    linewidth = 2.0

    for ipart,part in enumerate(parts):
        cfg = part._cfg
        compname = part._compname
        xsects = part.crossSectionNonOriented(ekins)
        if need_tot:
            if xsects_tot is None:
                xsects_tot = np.zeros(len(xsects))
            xsects_tot += xsects
        label=labels[ipart]
        max_len_label = max(max_len_label,len(label))
        ls={0:'-',1:'--',2:':'}.get(ipart//len(partcols),'-.')
        plt.plot(wavelengths,xsects,label=label,color=partcols[ipart%len(partcols)],lw=linewidth,ls=ls)
    if need_tot:
        plt.plot(wavelengths,xsects_tot,
                 label={'all':'Total','elastic':'Total elastic'}[comp],
                 color="#F15854",lw=linewidth)#red-ish colour (see above)
    leg_fsize = 'large'
    if max_len_label > 40: leg_fsize = 'medium'
    if max_len_label > 60: leg_fsize = 'small'
    if max_len_label > 80: leg_fsize = 'smaller'
    try:
        if len(parts)>1:
            leg=plt.legend(loc='best',fontsize=leg_fsize)
            if hasattr(leg,'set_draggable'):
                leg.set_draggable(True)
            else:
                leg.draggable(True)
    except TypeError:
        plt.legend(loc='best')
    plt.grid()
    plt.gca().set_xlim(0.0,wavelengths[-1])
    plt.gca().set_ylim(0.0,None)
    if title:
        plt.title(title)
    try:
        plt.tight_layout()
    except (ValueError, AttributeError):
        plt.subplots_adjust(bottom=0.15, right=0.9, top=0.9, left = 0.15)
        pass
    _end_plot(plt,pdf)

def plot_2d_scatangle(cfg,comp,pdf=False):
    assert comp in ('coh_elas','incoh_elas','elastic','inelastic','all')
    part=cfg.get_scatter(comp)

    np,plt,pdf = import_npplt(pdf)

    #reproducible plots:
    import random
    random.seed(123456)

    #higher granularity wavelengths than for 1D plot to avoid artifacts:
    wavelengths = create_wavelengths(np,[cfg],100 if _unittest else 30000 )

    #get title (label should be uninteresting for a single part):
    title,labels = create_title_and_labels([part])

    #First figure out how many points to put at each wavelength
    if not part._nullprocess:
        xsects = part.crossSectionNonOriented(NCrystal.wl2ekin(wavelengths))
        n2d = 100 if _unittest else 25000
        sumxs = np.sum(xsects)
        if sumxs:
            n_at_wavelength = np.random.poisson(xsects*n2d/np.sum(xsects))
        else:
            n_at_wavelength = np.zeros(len(xsects))
    else:
        n_at_wavelength = np.zeros(len(wavelengths))

    n2d=int(np.sum(n_at_wavelength))#correction for random fluctuations
    if n2d>0:
        plot_angles = np.zeros(n2d)
        plot_wls = np.zeros(n2d)
        j = 0
        for i,n in enumerate(n_at_wavelength):
            i,n = int(i),int(n)
            ang,de = part.generateScatteringNonOriented(NCrystal.wl2ekin(wavelengths[i]),repeat=int(n))
            plot_angles[j:j+n] = ang
            plot_wls[j:j+n].fill(wavelengths[i])
            j+=n
        plot_angles *= (180./np.pi)
    else:
        plot_angles = None
        plot_wls = None

    if plot_wls is not None:
        try:
            plt.scatter(plot_wls,plot_angles,alpha=0.2,marker='.',edgecolor=None,color='black',s=2,zorder=1)
        except ValueError:
            plt.scatter(plot_wls,plot_angles,alpha=0.2,edgecolor=None,color='black',s=2,zorder=1)

    plt.gca().set_xlim(0.0,wavelengths[-1])
    plt.gca().set_ylim(0,180)
    plt.xlabel('Neutron wavelength [angstrom]')
    plt.ylabel('Scattering angle [degrees]')
    if title:
        plt.title(title)
    plt.grid()
    try:
        plt.tight_layout()
    except (ValueError, AttributeError):
        plt.subplots_adjust(bottom=0.12, right=0.95, top=0.9, left = 0.12)
        pass
    _end_plot(plt,pdf)

def dump_info(info):
    NCrystal.dump(info)

class XSSum:
    def __init__(self,*processes):
        self._p = processes[:]
    def crossSectionNonOriented(self,ekin):
        return sum(p.crossSectionNonOriented(ekin) for p in self._p)

class Cfg:

    def __init__(self,cfgstr, commoncfgstr):
        #very crude normalisation, we should really let the C++ code NCMatCfg.cc do this splitting (todo):
        while '= ' in cfgstr:
            cfgstr=cfgstr.replace('= ','=')
        while ' ;' in cfgstr:
            cfgstr=cfgstr.replace(' ;',';')
        while cfgstr.endswith(' '):
            cfgstr=cfgstr[0:-1]
        self._cfgstr = combine_cfgs(cfgstr,commoncfgstr)
        f=getfields(self._cfgstr)
        forbidden=['bragg','coh_elas','incoh_elas','elas','bkgd']
        em="when using this script, use --coh_elas/--bragg/--incoh_elas/--elastic/--inelastic to enable/disable components rather than putting these variables in the cfg string: %s, inelas. Only exception is that inelas can be used to select alternative inelastic models."%(', '.join(forbidden))
        for _ in forbidden:
            assert not _ in f,em
        if 'inelas' in f:
            for _ in ['none','0','sterile','false']:
                assert not ';inelas=%s'%_ in self._cfgstr,em
        self._sc = {}
        self._abs = None
        self._totxs = None
        self._info = None

    def get_scatter(self,comp = 'all', allowfail = False):
        assert comp in ('all','coh_elas','incoh_elas','elastic','inelastic')
        if not comp in self._sc:
            cfgstr = combine_cfgs(self._cfgstr,{'coh_elas':'coh_elas=1;incoh_elas=0;inelas=0',
                                                'incoh_elas':'incoh_elas=1;coh_elas=0;inelas=0',
                                                'elastic':'elas=1;inelas=0',
                                                'inelastic':'elas=0'}.get(comp,''),False)
            try:
                sc = NCrystal.createScatter(cfgstr)
            except NCrystal.NCException:
                if allowfail:
                    return None
                else:
                    raise
            sc._nullprocess = bool(sc.domain()[0]>1e99)
            sc._cfg = self
            sc._compname = comp
            self._sc[comp] = sc
        return self._sc[comp]

    def get_absorption(self):
        if not self._abs:
            a = NCrystal.createAbsorption(self._cfgstr)
            a._nullprocess = bool(a.domain()[0]>1e99)
            a._cfg = self
            a._compname = 'absorption'
            self._abs = a
        return self._abs

    def get_totalxsect(self):
        if not self._totxs:
            a,s = self.get_absorption(),self.get_scatter('all')
            t = XSSum(a,s)
            t._nullprocess = a._nullprocess and s._nullprocess
            t._cfg = self
            t._compname = 'elastic+inelastic+absorption'
            self._totxs = t
        return self._totxs

    def get_info(self):
        if not self._info:
            self._info = NCrystal.createInfo(self._cfgstr)
        return self._info

    @property
    def cfgstr(self):
        return self._cfgstr

def main():
    args = parse_cmdline()
    if args.extract:
        s=NCrystal.getFileContents(args.extract)
        if s is None:
            raise SystemExit('Error: unknown file "%s"'%args.extract)
        print(s,end='')
        sys.exit(0)
    if args.plugins:
        NCrystal.browsePlugins(dump=True)
        sys.exit(0)
    if args.browse:
        NCrystal.browseFiles(dump=True)
        sys.exit(0)
    if args.test:
        NCrystal.test()
        sys.exit(0)


    _mpldpi[0] = args.dpi

    cfgs=[Cfg(e,args.common) for e in args.input_files]
    if args.dump:
        assert len(cfgs)==1
        cfgs[0].get_info().dump()
        return
    plot_xsect( cfgs, comp  = args.comp, absorption = args.absorption, pdf=args.pdf )
    if len(cfgs)==1 and not bool(os.environ.get('NCRYSTAL_INSPECTFILE_NO2DSCATTER',0)):
        plot_2d_scatangle( cfgs[0], comp = args.comp, pdf=args.pdf )
    if args.pdf:
        _,_,pdf = import_npplt(True)
        import datetime
        try:
            d = pdf.infodict()
        except AttributeError:
            d={}
        d['Title'] = 'Plots made with NCrystal-inspectfile from file%s %s'%('' if len(args.input_files)==1 else 's',
                                                                            ','.join(os.path.basename(f) for f in args.input_files))
        d['Author'] = 'NCrystal %s (via inspectfile)'%NCrystal.__version__
        d['Subject'] = 'NCrystal plots'
        d['Keywords'] = 'NCrystal'
        d['CreationDate'] = datetime.datetime.today()
        d['ModDate'] = datetime.datetime.today()
        pdf.close()
        print("created %s"%_pdffilename)

if __name__=='__main__':
    main()

# TODO:
#   - allow tuning of n2d and alpha pars for 2D plot?
#   - show hkl values in 1d and 2d plot?
#   - option to show energy & mfp rather than wavelength & xsect?
#   - deltaE/wl_out plots as well (todo). Perhaps --wl=... option, resulting in plots for that wl.
