"""
Python-Version: 3.10.9
Packages: Check requirements.txt! (run "pip install -r requirements.txt")

Check README.html for further information!
"""

# Import of Modules
import vlfpy as vp
from vlfpy.simulation import ParameterSet, OpticalSetup, parameter_scan_1D, parameter_scan_2D
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from time import time


# Path to the optical system
OS_PATH = r'resources\TUT.0350_GratingEfficiency.os'

# Constants
MICRO = 0.000_001 # micro [µ]

def single_run():
    # Initialize the optical setup
    optical_setup = OpticalSetup(OS_PATH)
    
    # Start time tracking
    start = time()
    
    # Perform the simulation
    results = optical_setup.perform()
    
    # End time tracking
    end = time()
    
    # Print simulation results to the console
    print(*results, sep='\n')
    
    # Print simulation time
    print(f'Simulation Time (s): {end-start:.2f}')

def parameter_scanning_1D():
    # Initialize the optical setup
    optical_setup = OpticalSetup(OS_PATH)
    
    # Define values for modulation depth, here: 0.1 µm - 10 µm, 100 steps
    values = np.linspace(0.1 * MICRO, 10 * MICRO, 100)
    
    # Define the parameter set (further information in README.html)
    # 1. parameter: Index attribute of the LightPathElement, here: Rectangular Grating
    # 2. parameter: ID tag of the parameter to be scanned
    # 3. parameter: the already defined values
    params = ParameterSet(1, 'Stack2.LayersAsArray[0].Interface.ModulationDepth', values)
    
    # Start time tracking
    start = time()
    
    # Perform the parameter scan
    results = parameter_scan_1D(optical_setup, params)
    
    # End time tracking
    end = time()
    
    # Print results to the console
    print(*results, sep='\n')
    
    # Print simulation time
    print(f'Simulation Time (s): {end-start:.2f}')
    
    # Plotting using matplotlib.pyplot
    
    # Set ax label names
    plt.xlabel('Modulation Depth [µm]')
    plt.ylabel('Efficiencies [%]')
    
    # Plot Eff T[-1;0] in blue
    plt.plot([value * 1_000_000 for value in values], 
             [result[0].value * 100 for result in results[0].data], 
             'b', label='Efficiencies T[-1;0]')
    
    # Plot Eff T[0;0] in red
    plt.plot([value * 1_000_000 for value in values], 
             [result[1].value * 100 for result in results[0].data], 
             'r', label='Efficiencies T[0;0]')
    
    # Show legend in lower left corner
    plt.legend(loc='lower left')
    
    # Save the figure
    plt.savefig(r'resources\Parameter_Scan_1D.png')
    
    # Show the figure
    plt.show()

def parameter_scanning_2D():
    # Initialize the optical setup
    optical_setup = OpticalSetup(OS_PATH)
    
    # Define steps for modulation depth
    md_steps = 31
    
    # Define steps for relative slit width
    sw_steps = 31
    
    # Define values for modulation depth, here: 0.1 µm - 10 µm, 31 steps (+ 0.33 µm per step)
    md_values = np.linspace(0.1 * MICRO, 10 * MICRO, md_steps)
    
    # Define values for relative slit width, here: 20 % - 80 %, 31 steps (+ 2 % per step)
    sw_values = np.linspace(0.2, 0.8, sw_steps)
    
    # Print the number of iterations to the console
    print(f'Number of Iterations = {len(md_values) * len(sw_values)}\n')
    
    # Define the parameter set (further information in README.html)
    # 1. parameter: Index attribute of the LightPathElement, here: Rectangular Grating
    # 2. parameter: ID tag of the parameter to be scanned
    # 3. parameter: the already defined values
    md_params = ParameterSet(1, 'Stack2.LayersAsArray[0].Interface.ModulationDepth', md_values)
    sw_params = ParameterSet(1, 'Stack2.LayersAsArray[0].Interface.RelativeSlitWidth', sw_values)
    
    # Start time tracking
    start = time()
    
    # Perform the parameter scan
    results = parameter_scan_2D(optical_setup, md_params, sw_params)
           
    # End time tracking
    end = time()
    
    # Print results to the console
    print(*results)
    
    # Print simulation time
    print(f'Simulation Time (s): {end-start:.2f}')
    
    # Plotting using matplotlib.pyplot
    
    # Set values of relative slit width for x-axis
    # Multiply w/ 100 to get percentages from floats
    x = np.array([sw_value * 100 for sw_value in sw_values])
    
    # Set values of modulation depth for y-axis
    # Multiply w/ 1_000_000 to get µm from m
    y = np.array([md_value * 1_000_000 for md_value in md_values])
    
    # Transform results to a NumPy array
    # Multiply w/ 100 to get percentages from floats
    np_array_1D = np.array([result[0].value * 100 for result in results[0].data])
    
    # Transform the 1D array to a 2D array to map the values for the plot
    np_array_2D = np.reshape(np_array_1D, (len(y), len(x)))
    
    # Create the figure and axes
    fig, ax = plt.subplots()
    
    # Create a pcolormesh
    c = ax.pcolormesh(x, y, np_array_2D, shading='gouraud', cmap='magma')
    # c = ax.pcolormesh(x, y, np_array_2D, cmap='magma')
    
    # Set title and label names
    ax.set_title('Efficiencies T[-1;0] [%]')
    ax.set_ylabel('Modulation Depth [µm]')
    ax.set_xlabel('Relative Slit Width [%]')
    
    # Create a colorbar
    fig.colorbar(c, ax=ax)
    
    # Save the figure
    plt.savefig(r'resources\Parameter_Scan_2D.png', dpi=300)
    
    # Show the figure
    plt.show()

    
if __name__ == '__main__':
    
    single_run()
    #parameter_scanning_1D()
    #parameter_scanning_2D()