// -*- C++ -*-
#include <iostream>
#include "Rivet/Analysis.hh"
#include "Rivet/Projections/FinalState.hh"
#include "Rivet/Projections/PromptFinalState.hh"
#include "Rivet/Tools/AliceCommon.hh"
#include "Rivet/Projections/AliceCommon.hh"
#include "Rivet/Projections/UnstableParticles.hh"
#include "Rivet/Projections/CentralityProjection.hh"
#include "Rivet/Analysis.hh"
#include "Rivet/Projections/UnstableParticles.hh"
#include "Rivet/Tools/Cuts.hh"
#include "Rivet/Projections/HepMCHeavyIon.hh"

namespace Rivet {

  /// @brief Prompt strange-meson production at 5 TeV
  class ALICE_2021_I1946131 : public Analysis {
  public:

    RIVET_DEFAULT_ANALYSIS_CTOR(ALICE_2021_I1946131);

    void BuildAverage(Scatter2DPtr AvgScatter, Histo1DPtr AvgHisto, const vector<Scatter2DPtr>& Histograms)
    {
        for (auto& pointAvg : AvgHisto->bins())
        {
            double ValWeight = 0.0;  
            double WeightTot = 0.0;
            double ErrWeight = 0.0;  
        
            for (const auto& histo : Histograms)
            {
                
                for (const auto& pointD : histo->points())
                {
                
                    if (pointD.x() >= pointAvg.xMin() && pointD.x() < pointAvg.xMax())
                    {
                        double relErr = pointD.yErrPlus() / pointD.y();
                        double weight = 1 / std::pow(relErr, 2);

                        ValWeight += pointD.y()*weight;
                        ErrWeight += std::pow(pointD.yErrPlus()*weight, 2);

                        WeightTot += weight;
                                                        
                    }
                }
            }
                
            double avg; 
            double AvgErr;

            avg = ValWeight / WeightTot;
            AvgErr = std::sqrt(ErrWeight) / WeightTot;

            AvgScatter->addPoint(pointAvg.xMid(), avg, pointAvg.xWidth()/2., AvgErr);
        }
    }

    void init() {

      declareCentrality(ALICE::V0MMultiplicity(), "ALICE_2015_PBPBCentrality", "V0M","V0M");
   
      const UnstableParticles ufsD0(Cuts::absrap < 0.5 && Cuts::pT > 0.*GeV && Cuts::abspid == 421);
      declare(ufsD0, "ufsD0");

      const UnstableParticles ufsDplus(Cuts::absrap < 0.5 && Cuts::pT > 2.*GeV && Cuts::abspid == 411);
      declare(ufsDplus, "ufsDplus");

      const UnstableParticles ufsDstar(Cuts::absrap < 0.5 && Cuts::pT > 2.*GeV && Cuts::abspid == 413);
      declare(ufsDstar, "ufsDstar");

      book(_h["D0Pt_0010"], 1, 1, 1);                                   // D0 transverse momentum distribution (0-10%)
      book(_h["D0Pt_3050"], 2, 1, 1);                                   // D0 transverse momentum distribution (30-50%)
      
      book(_h["DplusPt_0010"], 3, 1, 1);                                // D+ transverse momentum distribution (0-10%)
      book(_h["DplusPt_3050"], 4, 1, 1);                                // D+ transverse momentum distribution (30-50%)
      
      book(_h["DstarPt_0010"], 5, 1, 1);                                // DS+ transverse momentum distribution (0-10%)
      book(_h["DstarPt_3050"], 6, 1, 1);                                // DS+ transverse momentum distribution (30-50%)
   
      string refname7 = mkAxisCode(7, 1, 1);                            // Ratio D+/D0 (0-10%)
      const Scatter2D& refdata7 = refData(refname7);
      book(_h["Ratio7_Dplus"], "_" + refname7 + "_Dplus", refdata7);
      book(_h["Ratio7_D0"], "_" + refname7 + "_D0", refdata7);
      book(_s["Ratio7_Dplus_D0"], refname7);

      string refname8 = mkAxisCode(8, 1, 1);                            // Ratio D+/D0 (30-50%)
      const Scatter2D& refdata8 = refData(refname8);
      book(_h["Ratio8_Dplus"], "_" + refname8 + "_Dplus", refdata8);
      book(_h["Ratio8_D0"], "_" + refname8 + "_D0", refdata8);
      book(_s["Ratio8_Dplus_D0"], refname8);

      string refname9 = mkAxisCode(9, 1, 1);                            // Ratio DS+/D0 (0-10%)
      const Scatter2D& refdata9 = refData(refname9);
      book(_h["Ratio9_Dstar"], "_" + refname9 + "_Dstar", refdata9);
      book(_h["Ratio9_D0"], "_" + refname9 + "_D0", refdata9);
      book(_s["Ratio9_Dstar_D0"], refname9);

      string refname10 = mkAxisCode(10, 1, 1);                          // Ratio DS+/D0 (30-50%)
      const Scatter2D& refdata10 = refData(refname10);
      book(_h["Ratio10_Dstar"], "_" + refname10 + "_Dstar", refdata10);
      book(_h["Ratio10_D0"], "_" + refname10 + "_D0", refdata10);
      book(_s["Ratio10_Dstar_D0"], refname10);
         
      string refnameRaa11 = mkAxisCode(11, 1, 1);                                   //Raa D0 (0-10%) 
      const Scatter2D& refdataRaa11 = refData(refnameRaa11);
      book(_h["Raa11_D0_PbPb"], "_" + refnameRaa11 + "_D0_PbPb", refdataRaa11);
      book(_h["Raa11_D0_pp"], "_" + refnameRaa11 + "_D0_pp", refdataRaa11);
      book(_s["Raa11_D0"], refnameRaa11);

      string refnameRaa12 = mkAxisCode(12, 1, 1);                                   //Raa D+ (0-10%)
      const Scatter2D& refdataRaa12 = refData(refnameRaa12);
      book(_h["Raa12_Dplus_PbPb"], "_" + refnameRaa12 + "_Dplus_PbPb", refdataRaa12);
      book(_h["Raa12_Dplus_pp"], "_" + refnameRaa12 + "_Dplus_pp", refdataRaa12);
      book(_s["Raa12_Dplus"], refnameRaa12);

      string refnameRaa13 = mkAxisCode(13, 1, 1);                                   //Raa DS+ (0-10%)
      const Scatter2D& refdataRaa13 = refData(refnameRaa13);
      book(_h["Raa13_Dstar_PbPb"], "_" + refnameRaa13 + "_Dstar_PbPb", refdataRaa13);
      book(_h["Raa13_Dstar_pp"], "_" + refnameRaa13 + "_Dstar_pp", refdataRaa13);
      book(_s["Raa13_Dstar"], refnameRaa13);

      string refnameRaa14 = mkAxisCode(14, 1, 1);                                   //Raa D0 (30-50%) 
      const Scatter2D& refdataRaa14 = refData(refnameRaa14);
      book(_h["Raa14_D0_PbPb"], "_" + refnameRaa14 + "_D0_PbPb", refdataRaa14);
      book(_h["Raa14_D0_pp"], "_" + refnameRaa14 + "_D0_pp", refdataRaa14);
      book(_s["Raa14_D0"], refnameRaa14);

      string refnameRaa15 = mkAxisCode(15, 1, 1);                                   //Raa D+ (30-50%) 
      const Scatter2D& refdataRaa15 = refData(refnameRaa15);
      book(_h["Raa15_Dplus_PbPb"], "_" + refnameRaa15 + "_Dplus_PbPb", refdataRaa15);
      book(_h["Raa15_Dplus_pp"], "_" + refnameRaa15 + "_Dplus_pp", refdataRaa15);
      book(_s["Raa15_Dplus"], refnameRaa15);

      string refnameRaa16 = mkAxisCode(16, 1, 1);                                   //Raa DS+ (30-50%) 
      const Scatter2D& refdataRaa16 = refData(refnameRaa16);
      book(_h["Raa16_Dstar_PbPb"], "_" + refnameRaa16 + "_Dstar_PbPb", refdataRaa16);
      book(_h["Raa16_Dstar_pp"], "_" + refnameRaa16 + "_Dstar_pp", refdataRaa16);
      book(_s["Raa16_Dstar"], refnameRaa16);

      string refnameRaa17 = mkAxisCode(17, 1, 1);                                   // Average RAA D0, D+, DS+ (0-10%) 
      const Scatter2D& refdataRaa17 = refData(refnameRaa17);
      book(_h["Raa17_Average_Aux"], "_" + refnameRaa17 + "_Aux", refdataRaa17);
      book(_s["Raa17_Average"],refnameRaa17);
  
      string refnameRaa18 = mkAxisCode(18, 1, 1);                                   // Average RAA D0, D+, DS+ (30-50%) 
      const Scatter2D& refdataRaa18 = refData(refnameRaa18);
      book(_h["Raa18_Average_Aux"], "_" + refnameRaa18 + "_Aux", refdataRaa18);
      book(_s["Raa18_Average"],refnameRaa18);

      book(_c["sow_pp5TeV"], "_sow_pp5TeV");
      book(_c["sow_PbPb5TeV_0010"], "_sow_PbPb5TeV_0010");
      book(_c["sow_PbPb5TeV_3050"], "_sow_PbPb5TeV_3050");

      string refnameRaa19 = mkAxisCode(19, 1, 1);                                   // pt-integrated RAA D0 (0-10%) 
      const Scatter2D& refdataRaa19 = refData(refnameRaa19);
      book(_h["Raa19_D0_PbPb_0010"], "_" + refnameRaa19 + "_D0_PbPb_0010", refdataRaa19);
      book(_h["Raa19_D0_pp_0010"], "_" + refnameRaa19 + "_D0_pp_0010", refdataRaa19);
      book(_s["Raa19_D0_0010"], refnameRaa19);

      string refnameRaa20 = mkAxisCode(20, 1, 1);                                   // pt-integrated RAA D0 (30-50%)
      const Scatter2D& refdataRaa20 = refData(refnameRaa20);
      book(_h["Raa20_D0_PbPb_3050"], "_" + refnameRaa20 + "_D0_PbPb_3050", refdataRaa20);
      book(_h["Raa20_D0_pp_3050"], "_" + refnameRaa20 + "_D0_pp_3050", refdataRaa20);
      book(_s["Raa20_D0_3050"], refnameRaa20);

    }

    void analyze(const Event& event) {

      const ParticlePair& beam = beams();
      string CollSystem = "Empty";
      double NN = 208;

      if (beam.first.pid() == PID::LEAD && beam.second.pid() == PID::LEAD)
      {
          CollSystem = "PBPB";
          if(fuzzyEquals(sqrtS()/GeV, 5020*NN, 1E-3)){
          CollSystem += "5TeV";
          }        
      }
      if (beam.first.pid() == PID::PROTON && beam.second.pid() == PID::PROTON)
      {
          CollSystem = "PP";
          if(fuzzyEquals(sqrtS()/GeV, 5020, 1E-3)) CollSystem += "5TeV";
      }

      Particles particlesD0 = apply<UnstableParticles>(event,"ufsD0").particles();
      Particles particlesDplus = apply<UnstableParticles>(event,"ufsDplus").particles();
      Particles particlesDstar = apply<UnstableParticles>(event,"ufsDstar").particles();

      if(CollSystem == "PP5TeV")
      {
          _c["sow_pp5TeV"]->fill();

          for(const Particle& p : particlesD0)
          {
              if(p.fromBottom()) continue;
              _h["Raa11_D0_pp"]->fill(p.pT()/GeV);
              _h["Raa14_D0_pp"]->fill(p.pT()/GeV);
              _h["Raa19_D0_pp_0010"]->fill(5);
              _h["Raa20_D0_pp_3050"]->fill(40);

          }

          for(const Particle& p : particlesDplus)
          {
              if(p.fromBottom()) continue;
              _h["Raa12_Dplus_pp"]->fill(p.pT()/GeV);
              _h["Raa15_Dplus_pp"]->fill(p.pT()/GeV);

          }

          for(const Particle& p : particlesDstar)
          {
              if(p.fromBottom()) continue;
              _h["Raa13_Dstar_pp"]->fill(p.pT()/GeV);
              _h["Raa16_Dstar_pp"]->fill(p.pT()/GeV);
          }

      }

      const CentralityProjection& centProj = apply<CentralityProjection>(event,"V0M");

      const double cent = centProj();

      if(cent >= 50.) vetoEvent;

      if(CollSystem == "PBPB5TeV")
        {
        if(cent < 10.)
            { 
                _c["sow_PbPb5TeV_0010"]->fill();
                for(const Particle& p : particlesD0)
                {             
                    if(p.fromBottom()) continue;
                    _h["D0Pt_0010"]->fill(p.pT()/GeV);
                    _h["Ratio7_D0"]->fill(p.pT()/GeV);
                    _h["Ratio9_D0"]->fill(p.pT()/GeV);
                    _h["Raa11_D0_PbPb"]->fill(p.pT()/GeV);
                    _h["Raa19_D0_PbPb_0010"]->fill(5);    
                }

                for(const Particle& p : particlesDplus)
                {
                    if(p.fromBottom()) continue;
                    _h["DplusPt_0010"]->fill(p.pT()/GeV);
                    _h["Ratio7_Dplus"]->fill(p.pT()/GeV);
                    _h["Raa12_Dplus_PbPb"]->fill(p.pT()/GeV);
                }

                for(const Particle& p : particlesDstar)
                {
                    if(p.fromBottom()) continue;
                    _h["DstarPt_0010"]->fill(p.pT()/GeV);
                    _h["Ratio9_Dstar"]->fill(p.pT()/GeV);
                    _h["Raa13_Dstar_PbPb"]->fill(p.pT()/GeV);
                }

            }
            else if(cent >= 30. && cent < 50.)
            {
                _c["sow_PbPb5TeV_3050"]->fill();
                for(const Particle& p : particlesD0)
                {
                    if(p.fromBottom()) continue;
                    _h["D0Pt_3050"]->fill(p.pT()/GeV);
                    _h["Ratio8_D0"]->fill(p.pT()/GeV);
                    _h["Ratio10_D0"]->fill(p.pT()/GeV);
                    _h["Raa14_D0_PbPb"]->fill(p.pT()/GeV);
                    _h["Raa20_D0_PbPb_3050"]->fill(40); 
                }

                for(const Particle& p : particlesDplus)
                {
                    if(p.fromBottom()) continue;
                    _h["DplusPt_3050"]->fill(p.pT()/GeV);
                    _h["Ratio8_Dplus"]->fill(p.pT()/GeV);
                    _h["Raa15_Dplus_PbPb"]->fill(p.pT()/GeV);                                    
                }

                for(const Particle& p : particlesDstar)
                {
                    if(p.fromBottom()) continue;
                    _h["DstarPt_3050"]->fill(p.pT()/GeV);
                    _h["Ratio10_Dstar"]->fill(p.pT()/GeV);
                    _h["Raa16_Dstar_PbPb"]->fill(p.pT()/GeV);
                }
            }         
        }   
    }

    void finalize() {

      _h["D0Pt_0010"]->scaleW(1./(2*_c["sow_PbPb5TeV_0010"]->sumW()));
      _h["DplusPt_0010"]->scaleW(1./(2*_c["sow_PbPb5TeV_0010"]->sumW()));
      _h["DstarPt_0010"]->scaleW(1./(2*_c["sow_PbPb5TeV_0010"]->sumW()));

      _h["D0Pt_3050"]->scaleW(1./(2*_c["sow_PbPb5TeV_3050"]->sumW()));
      _h["DplusPt_3050"]->scaleW(1./(2*_c["sow_PbPb5TeV_3050"]->sumW()));
      _h["DstarPt_3050"]->scaleW(1./(2*_c["sow_PbPb5TeV_3050"]->sumW()));

      divide(_h["Ratio7_Dplus"], _h["Ratio7_D0"], _s["Ratio7_Dplus_D0"]);
      divide(_h["Ratio8_Dplus"], _h["Ratio8_D0"], _s["Ratio8_Dplus_D0"]);

      divide(_h["Ratio9_Dstar"], _h["Ratio9_D0"], _s["Ratio9_Dstar_D0"]);
      divide(_h["Ratio10_Dstar"], _h["Ratio10_D0"], _s["Ratio10_Dstar_D0"]);

      _h["Raa11_D0_PbPb"]->scaleW(1./_c["sow_PbPb5TeV_0010"]->sumW());
      _h["Raa11_D0_pp"]->scaleW(Ncoll0010/_c["sow_pp5TeV"]->sumW());
      divide(_h["Raa11_D0_PbPb"], _h["Raa11_D0_pp"], _s["Raa11_D0"]);

      _h["Raa12_Dplus_PbPb"]->scaleW(1./_c["sow_PbPb5TeV_0010"]->sumW());
      _h["Raa12_Dplus_pp"]->scaleW(crossSection()*Ncoll0010/_c["sow_pp5TeV"]->sumW());
      divide(_h["Raa12_Dplus_PbPb"], _h["Raa12_Dplus_pp"], _s["Raa12_Dplus"]);

      _h["Raa13_Dstar_PbPb"]->scaleW(1./_c["sow_PbPb5TeV_0010"]->sumW());
      _h["Raa13_Dstar_pp"]->scaleW(Ncoll0010/_c["sow_pp5TeV"]->sumW());
      divide(_h["Raa13_Dstar_PbPb"], _h["Raa13_Dstar_pp"], _s["Raa13_Dstar"]);

      _h["Raa14_D0_PbPb"]->scaleW(1./_c["sow_PbPb5TeV_3050"]->sumW());
      _h["Raa14_D0_pp"]->scaleW(Ncoll3050/_c["sow_pp5TeV"]->sumW());
      divide(_h["Raa14_D0_PbPb"], _h["Raa14_D0_pp"], _s["Raa14_D0"]);

      _h["Raa15_Dplus_PbPb"]->scaleW(1./_c["sow_PbPb5TeV_3050"]->sumW());
      _h["Raa15_Dplus_pp"]->scaleW(Ncoll3050/_c["sow_pp5TeV"]->sumW());
      divide(_h["Raa15_Dplus_PbPb"], _h["Raa15_Dplus_pp"], _s["Raa15_Dplus"]);

      _h["Raa16_Dstar_PbPb"]->scaleW(1./_c["sow_PbPb5TeV_3050"]->sumW());
      _h["Raa16_Dstar_pp"]->scaleW(Ncoll3050/_c["sow_pp5TeV"]->sumW());
      divide(_h["Raa16_Dstar_PbPb"], _h["Raa16_Dstar_pp"], _s["Raa16_Dstar"]);

      vector<Scatter2DPtr> Histograms17 = {_s["Raa11_D0"], _s["Raa12_Dplus"], _s["Raa13_Dstar"]};
      BuildAverage(_s["Raa17_Average"], _h["Raa17_Average_Aux"], Histograms17);

      vector<Scatter2DPtr> Histograms18 = {_s["Raa14_D0"], _s["Raa15_Dplus"], _s["Raa16_Dstar"]};
      BuildAverage(_s["Raa18_Average"], _h["Raa18_Average_Aux"], Histograms18);

      _h["Raa19_D0_PbPb_0010"]->scaleW(1./_c["sow_PbPb5TeV_0010"]->sumW());
      _h["Raa19_D0_pp_0010"]->scaleW(Ncoll0010/_c["sow_pp5TeV"]->sumW());
      divide(_h["Raa19_D0_PbPb_0010"], _h["Raa19_D0_pp_0010"], _s["Raa19_D0_0010"]);

      _h["Raa20_D0_PbPb_3050"]->scaleW(1./_c["sow_PbPb5TeV_3050"]->sumW());
      _h["Raa20_D0_pp_3050"]->scaleW(Ncoll3050/_c["sow_pp5TeV"]->sumW());
      divide(_h["Raa20_D0_PbPb_3050"], _h["Raa20_D0_pp_3050"], _s["Raa20_D0_3050"]);

    }

    map<string, Histo1DPtr> _h;
    map<string, CounterPtr> _c;
    map<string, Scatter2DPtr> _s;
    const double Ncoll0010 = 1572.;
    const double Ncoll3050 = 264.8;

  };

  RIVET_DECLARE_PLUGIN(ALICE_2021_I1946131);

}
