MODULE WAVELET_SUB
    use data

    CONTAINS

    SUBROUTINE WAVE_TRANS_OPT(x_vals, epr_exp, best, wave_exp)

        IMPLICIT NONE

        INTEGER, PARAMETER             :: subscale = 4
        REAL(KIND = 8), PARAMETER      :: dj = 1.0_8/REAL(subscale, 8)

        REAL(KIND = 8), intent(in)     :: epr_exp(:), x_vals(:)
        REAL(KIND = 8),  intent(out)   :: best(:) 

        COMPLEX(KIND = 8), intent(out) :: wave_exp(:,:) 
        COMPLEX(KIND = 8), ALLOCATABLE              :: wave_sim(:,:)

        INTEGER                        :: mother, npad, i, j, param_num_steps, p, n, num_param_cycles, jtot

        REAL(KIND = 8)                 :: pi, dt, RMSD, s0, default_param, param
        REAL(KIND = 8), ALLOCATABLE    :: scales(:), period(:), param_mat(:),recon_epr(:), coi(:)
        REAL(KIND = 8)                 :: ymean_exp, Cdelta, psi0




        Cdelta = 0.0_8
        psi0 = 0.0_8

        n = SIZE(x_vals)

        ALLOCATE(coi(n))
        ALLOCATE(recon_epr(n))


        ymean_exp = 0.0_8

        DO i = 1, n
            ! Sum all y vals
            ymean_exp = ymean_exp + epr_exp(i)
        END DO 
        ! Divide by number of y vals to calculate mean
        ymean_exp = ymean_exp/n

        ! Set dt (change in x value) to 0
        dt = 0.0_8
        ! Total change in x-val = x-val(n) - x-val(1)
        ! Divide by number of data points to calculate average change in x-val
        dt = (x_vals(n) - x_vals(1))/n
        ! Set pi
        pi = 4.0_8*ATAN(1.0_8)    

        ! Pad time series with zeros to reduce edge effects #### This must be greater than or equal to n and greater than 1024####

        ! If epr signal has less than 1024 points...
        IF (n .LT. 1024) THEN
            ! Pad with 1024 zeros
            npad = 1024

            ! If EPR signal has more than 1024 but less than 2048
        ELSE IF ((n .GE. 1024) .AND. (n .LT. 2048)) THEN
            ! Pad with 2048 zeros
            npad = 2048

            ! If signal is greater than 2048 points but less than 4096
        ELSE IF ((n .GE. 2048) .AND. (n .LT. 4096)) THEN
            ! Pad with 4096 zeros 
            npad = 4096

            ! If signal is more than 4096 data points
        ELSE 
            ! Print error message
            WRITE(6, *) 'ERROR: experimental and simulated data must be less than 4096 data points'
            STOP
        END IF    

        ! Best will hold the optimal parameters for transform, set to 99 to compare against RMSD calculated later
        best = 99

        !Number of steps in the optimisation of each parameter
        param_num_steps = 20

        ! Loop over the 3 possible mother wavelets
        ! 0 = Morlet
        ! 1 = Paul
        ! 2 = Derivative of Gaussian (DoG)
        DO MOTHER = 0, 2

            ! Call subroutine which will determine the default parameters for the current mother wavelet, the number of cylces required for parameter optimisation and the allowed params for the mother wavelet
            CALL GET_WAVELET_PARAMS (mother, default_param, s0, dt, param_mat, param_num_steps, num_param_cycles)

            ! Allocate the total number of wavelet scales for the selected mother wavelet
            jtot = 1+(LOG(n*dt/s0)/LOG(2.0_8))/dj

            ! Alloacate size to matrix of wavelet coefs
            !ALLOCATE(wave_exp(n, jtot))
            ! ALLOCATE size to vector of wavelet scales
            ALLOCATE(scales(jtot))
            ! Allocate size to vector of period of wavelets
            ALLOCATE(period(jtot))


            ! Loop over entire param_mat array
            DO p = 1, num_param_cycles

                ! set param to pth term of param_mat array
                param = param_mat(p)

                scales = 0.0_8
                wave_exp = (0.0_8, 0.0_8)
                period = 0.0_8

            ! Call wavelet transform subroutine --- returns matrix of wavelet coeffs
                CALL WAVELET(n,epr_exp,dt,mother,param,s0,dj,jtot,npad,wave_exp(1:n,1:jtot),scales,period,coi)

                ! Get parameters requied for reconstruction of the signal based on mother wavelet and param choice
                CALL GET_RECON_PARAMS(mother, Cdelta, psi0, param, pi)

                ! CAll reconstruction subroutine
                CALL RECON(n, jtot, wave_exp(1:n,1:jtot), scales, Cdelta, psi0, ymean_exp, dj, dt, recon_epr, RMSD, epr_exp)

                ! If the RMDS value is the lowest calculated so far
                IF (RMSD .LT. best(3)) THEN

                    ! Add the params used for this transform to the array bestbest
                    best(1) = mother
                    best(2) = param
                    best(3) = RMSD
                    !WRITE(6, *) 'best: ', best
                    ! Open new file for best reconstruction to be added to
                    ! OPEN(UNIT = 33, FILE = 'EPR-recon-best.res', STATUS = 'UNKNOWN')

                    ! ! Loop over whole recon vector
                    ! DO i = 1, n
                    !     ! Write calculated x values along with reconstruction
                    !     WRITE(33, *) epr_xy(1,1)+(i-1)*dt, recon_epr(i)
                    ! END DO
                    ! ! Close file
                    ! CLOSE(33)
                END IF

            END DO
            ! Deallocate matrices so they can be reallocated for the next loop - when mother wavelet and params change
            DEALLOCATE(param_mat)
            !DEALLOCATE(wave_exp)
            DEALLOCATE(scales)
            DEALLOCATE(period)
        END DO


        !~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        !Carry out final wavelet transform
        !~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

        ! Set the value of mother 
        mother = NINT(best(1))

        ! Call subroutine to return the size of s0 
        CALL GET_WAVELET_PARAMS (mother, default_param, s0, dt, param_mat, param_num_steps, num_param_cycles)

        ! Set param to the best option -- must be done after GET_WAVELET_PARAMS as this sub  
        param = best(2)


        ! calculate jtot - total number of subscales
        jtot = NINT(1.0_8+(LOG(n*dt/s0)/LOG(2.0_8))/dj)



        ! ALLOCATE size to vector of wavelet scales
        ALLOCATE(scales(jtot))
        ! Allocate size to vector of period of wavelets
        ALLOCATE(period(jtot))   
        !Allocate size to vectors which will hold wavlet coeffs for exp and sim data
        !ALLOCATE(wave_exp(n,jtot))
        ALLOCATE(wave_sim(n,jtot))


        ! Zero the wavelet matrix
        wave_exp = (0.0_8, 0.0_8)
        scales = 0.0_8
        period = 0

        ! Call wavelet transform subroutine
        CALL WAVELET(n,epr_exp,dt,mother,param,s0,dj,jtot,npad,wave_exp(1:n,1:jtot),scales,period,coi)

    END SUBROUTINE WAVE_TRANS_OPT

    SUBROUTINE WAVELET_ERROR_SUB(epr_sim, x_vals, wave_exp, best, wavelet_error)    

        IMPLICIT NONE

        INTEGER, PARAMETER             :: subscale = 4
        REAL(KIND = 8), PARAMETER      :: dj = 1.0_8/REAL(subscale, 8)

        REAL(KIND = 8), intent(in)     :: epr_sim(:), x_vals(:), best(:)
        REAL(KIND = 8),  intent(out)   :: wavelet_error 

        INTEGER                        :: mother, npad, i, j, param_num_steps, n, num_param_cycles, jtot

        REAL(KIND = 8)                 :: pi, dt, s0, param
        REAL(KIND = 8), ALLOCATABLE    :: scales(:), period(:), coi(:)
        REAL(KIND = 8)                 :: ymean_exp, Cdelta, psi0

        COMPLEX(KIND = 8), intent(in)  :: wave_exp(:,:)
        COMPLEX(KIND = 8), ALLOCATABLE :: wave_sim(:,:)

        wavelet_error = 0.0_8

        n = size(x_vals)

        dt = 0.0_8
        ! Total change in x-val = x-val(n) - x-val(1)
        ! Divide by number of data points to calculate average change in x-val
        dt = (x_vals(n) - x_vals(1))/n

        ! Set pi
        pi = 4.0_8*ATAN(1.0_8)    

        ! If epr signal has less than 1024 points...
        IF (n .LT. 1024) THEN
            ! Pad with 1024 zeros
            npad = 1024

            ! If EPR signal has more than 1024 but less than 2048
        ELSE IF ((n .GE. 1024) .AND. (n .LT. 2048)) THEN
            ! Pad with 2048 zeros
            npad = 2048

            ! If signal is greater than 2048 points but less than 4096
        ELSE IF ((n .GE. 2048) .AND. (n .LT. 4096)) THEN
            ! Pad with 4096 zeros 
            npad = 4096

            ! If signal is more than 4096 data points
        ELSE 
            ! Print error message
            WRITE(6, *) 'ERROR: experimental and simulated data must be less than 4096 data points'
            STOP
        END IF

        mother = NINT(best(1))
        param = best(2)

        IF (MOTHER .EQ. 0) THEN
            ! Set default param and minimum scale
            s0 = dt
            
            ! If mother wavelet is Paul wavelet
        ELSE IF (MOTHER .EQ. 1) THEN
            ! Set S0 to twice the change in the x value
            s0 = 2*dt    
            
            !If mother wavelet is Derivative of Gaussian
        ELSE IF (MOTHER .EQ. 2) THEN
            ! Set default minimum scale
            s0 = dt/4
        END IF 

        ! calculate jtot - total number of subscales
        jtot = NINT(1.0_8+(LOG(n*dt/s0)/LOG(2.0_8))/dj)

        ! ALLOCATE size to vector of wavelet scales
        ALLOCATE(scales(jtot))
        ! Allocate size to vector of period of wavelets
        ALLOCATE(period(jtot))   
        ! Size of cone of influence array
        ALLOCATE(coi(n))
        ! allocate wave_matrix
        ALLOCATE(wave_sim(n, jtot))

        scales = 0.0_8
        period = 0.0_8
        coi = 0.0_8
        wave_sim = (0.0_8, 0.0_8)
        wavelet_error = 0.0_8

        CALL WAVELET(n,epr_sim,dt,mother,param,s0,dj,jtot,npad,wave_sim,scales,period,coi)

        CALL CALC_WAVE_ERROR(wave_exp(1:n,1:jtot), wave_sim, wavelet_error, n, jtot)         

    END SUBROUTINE WAVELET_ERROR_SUB

    SUBROUTINE CALC_WAVE_ERROR(wave_exp, wave_sim, wave_error, n, jtot)

        INTEGER, intent(in)            :: n, jtot
        INTEGER                        :: i, j

        COMPLEX(KIND = 8), intent(in)  :: wave_exp(:,:), wave_sim(:,:)
        COMPLEX(KIND = 8)              :: wave_diff(n, jtot)

        REAL(KIND = 8), intent(out)    :: wave_error

        wave_diff = (0.0_8, 0.0_8)

        DO i = 1, n
            DO  j = 1, jtot
                ! Calculate the difference between the experimental and simulated coefficients
                wave_diff(i, j) = (wave_exp(i, j) - wave_sim(i, j))
                !WRITE(6, *) wave_diff(i, j)
            END DO
        END DO 
        ! Take the absolute value of eac
        !wave_diff = ABS(wave_diff)
        wave_error = 0.0_8
        ! Sum all wavelet coefficients
        wave_error = (SUM(ABS(wave_diff)))/SIZE(wave_exp)

    END SUBROUTINE CALC_WAVE_ERROR

    SUBROUTINE GET_WAVELET_PARAMS (mother, default_param, s0, dt, param_mat, param_num_steps, num_param_cycles)

        IMPLICIT NONE
        INTEGER, intent(in)                      :: param_num_steps, mother
        INTEGER, intent(out)                     :: num_param_cycles
        INTEGER                                  :: i
        
        REAL(KIND = 8), intent(inout)            :: dt
        REAL(KIND = 8), intent(out)              :: s0
        REAL(KIND = 8), ALLOCATABLE, intent(out) :: param_mat(:)
        REAL(KIND = 8)                           :: default_param

        ! Sets PARAM to default value depending on the mother wavelet
        ! Sets s0 to value which will provide good reconstruction - dependant on the mother wavelet 
        ! Morlet wavelet

        ! If mother wavelet is Morlet
        IF (MOTHER .EQ. 0) THEN
            ! Set default param and minimum scale
            default_param =6.0_8
            s0 = dt
            ALLOCATE(param_mat(1))
            ! Set param to accepted value
            param_mat = 6 
            ! only 1 param step required
            num_param_cycles = SIZE(param_mat)
            
            ! If mother wavelet is Paul wavelet
        ELSE IF (MOTHER .EQ. 1) THEN
            ! Set S0 to twice the change in the x value
            s0 = 2*dt
            ! Allocate matrix to required size
            ALLOCATE(param_mat(1))
            ! Set the defualt parameter
            default_param  = 4.0_8
            param_mat = default_param
            num_param_cycles = SIZE(param_mat)
            !If mother wavelet is Derivative of Gaussian
        ELSE IF (MOTHER .EQ. 2) THEN
            ! Set default minimum scale
            s0 = dt/4
            ! Allocate required size to  cycle through params
            ALLOCATE(param_mat(2))

            ! Set default param for first half of param array
            default_param = 2.0_8

            param_mat(1) = default_param

            !Default param = 6 in second half of mat
            default_param = 6.0_8
            ! Set possible parameter values to be increments of 0.05 over default_param(1)
            param_mat(2) = default_param

            ! Set number of cylces required to loop over all params for this wavelet
            num_param_cycles = 2
        END IF 

        END SUBROUTINE GET_WAVELET_PARAMS

        SUBROUTINE GET_RECON_PARAMS(mother, Cdelta, psi0, param, pi)

            INTEGER, intent(in) :: mother
            INTEGER             :: m

            REAL(KIND = 8), intent(in)  :: param, pi
            REAL(KIND = 8), intent(out) :: Cdelta, psi0

            
            ! If mother wavelet is morlet Wavelet
            IF (mother.EQ.0) THEN
                IF (param.EQ.6.D0) THEN
                    ! These are the params required for reconstruction
                    Cdelta = 0.776D0
                    psi0 = pi**(-0.25D0)
                ELSE
                    ! if the parameter is outside of the allowed range then print and error
                    WRITE(6, *) 'ERROR: Folling condition not met for Morlet wavelet-----param.EQ.6.00000----Program Stopped'
                    STOP
                END IF 

                ! If mother Wavelet is Paul Wavelet
            ELSE IF (mother.EQ.1) THEN
                m = INT(param)
                IF (m.EQ.4) THEN
                    ! These are the params required for reconstruction
                    Cdelta = 1.132D0
                    psi0 = 1.079D0
                ELSE
                    ! if the parameter is outside of the allowed range then print and error
                    WRITE(6, *) 'ERROR: Folling condition not met for Paul wavelet\nINT(param).EQ. 4\nProgram Stopped'
                    STOP
                END IF

                !If mother wavelet is DoG
            ELSE IF (mother.EQ.2) THEN
                m = INT(param)
                IF (m.EQ.2) THEN
                    ! These are the params required for reconstruction
                    Cdelta = 3.541D0
                    psi0 = 0.867D0
                ELSE IF (m.EQ.6) THEN
                    ! These are the params required for reconstruction
                    Cdelta = 1.966D0
                    psi0 = 0.884D0
                ELSE
                    ! If param is not in allowed range print error message
                    WRITE(6, *) 'ERROR: One of the folling condition not met for DoG wavelet INT(param).EQ.2 .OR. INT(param).EQ.6.Program Stopped'
                    STOP
                END IF
            ELSE
                ! If the mother wavelet value is invalid print a wavelet
                WRITE(6, *) 'ERROR: Mother Wavelet value invalid. Must be 0, 1, or 2 for Morlet, Paul, or DoG respectively'
                STOP
            END IF

        END SUBROUTINE GET_RECON_PARAMS

        SUBROUTINE RECON(n, jtot, wave, scales, Cdelta, psi0, ymean, dj, dt, recon_epr, RMSD, epr)
            IMPLICIT NONE

            INTEGER, intent(in) :: n, jtot
            INTEGER             :: i, j

            REAL(KIND = 8), intent(in)    :: scales(:), Cdelta, psi0, ymean, dj, dt, epr(:)
            REAL(KIND = 8), intent(out)   :: recon_epr(n), RMSD

            COMPLEX(KIND = 8), intent(in) :: wave(:,:)

            recon_epr = 0.0_8


            DO i = 1 , n
                DO j=1,jtot

                    ! Sum(ish) allong rows of the wavelet matrix
                    recon_epr(i) = recon_epr(i)+(DBLE(wave(i,j)))/SQRT(scales(j))
                END DO

                ! Sum accross a column of the matrix and somehow return the original data
                ! Must add ymean as this method produces reconstructed data which is shifted by -ymean      
                recon_epr(i) = (dj*SQRT(dt)*recon_epr(i)/(Cdelta*psi0))+ymean
            END DO

            ! Zero the RMSD value 
            RMSD = 0.0_8

            ! loop over signal/recon arrays
            DO i = 1, n
                ! Calculate RMSD between the curves
                RMSD = RMSD + (((epr(i))-recon_epr(i))**2)
            END DO
            ! RMSD calc
            RMSD = RMSD/n    
            RMSD = SQRT(RMSD)

        END SUBROUTINE RECON

    END MODULE WAVELET_SUB
