import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.collections import LineCollection
import cmocean
import os

# Define column names as provided
colnames = ['sim', 'month', 't', 'i', 'basin', 'lat', 'lon', 'min_press', 'wind', 'radius', 'CAT', 'landfall', 'dist_land']

# Read the STORMS data file
file_path = 'STORM/STORM_DATA_IBTRACS_NA_1000_YEARS_0.txt'

# Read the data (space/comma separated)
df = pd.read_csv(file_path, names=colnames, sep=r'[,\s]+', engine='python')

# Convert wind speed from m/s to knots (1 m/s = 1.944 knots)
df['wind_knots'] = df['wind'] * 1.944

# -- Apply fix: convert longitude --
df['lon'] = np.where(df['lon'] > 180, df['lon'] - 360, df['lon'])

# Get the first 5 unique 'sim' values
first_5_sims = df['sim'].unique()[:5]

# Filter the DataFrame to only include the first 5 sims
df_first5 = df[df['sim'].isin(first_5_sims)]

# Set up the colormap
cmap = cmocean.cm.thermal

# Set up the map
fig = plt.figure(figsize=(14, 10))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.set_extent([-100, -10, 5, 50], crs=ccrs.PlateCarree())
ax.add_feature(cfeature.LAND, facecolor='lightgray', edgecolor='k')
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.add_feature(cfeature.OCEAN, facecolor='azure')
ax.gridlines(draw_labels=True, linestyle='--', alpha=0.5)

# Plot each storm as a colored line by windspeed
for (sim, t_id), group in df_first5.groupby(['sim', 't']):
    # Sort by 'i' to ensure proper track ordering
    group = group.sort_values('i')
    
    x = group['lon'].values
    y = group['lat'].values
    c = group['wind_knots'].values
    
    # Skip tracks with less than 2 points
    if len(x) < 2:
        continue
    
    # Build line segments for LineCollection
    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    lc = LineCollection(segments, cmap=cmap, norm=plt.Normalize(30, 100),
                        linewidth=2.5, alpha=0.8, zorder=2)
    lc.set_array(c[:-1])  # Color by windspeed at start of each segment
    ax.add_collection(lc)

    # Add open circle marker at FIRST POINT only (where i=0)
    first_point = group[group['i'] == 0]
    if len(first_point) > 0:
        ax.plot(first_point['lon'].iloc[0], first_point['lat'].iloc[0], 'o', 
                markersize=10, markerfacecolor='none',
                markeredgecolor='black', markeredgewidth=2, zorder=3)

# Colorbar for the track windspeed
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(30, 100))
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, orientation='vertical', pad=0.01, fraction=0.03)
cbar.set_label('Track Wind Speed (kt)')

plt.title('Simulated N. Atlantic Tropical Cyclone Tracks 5 Seasons: Colored by Wind Speed', fontsize=16)

# Save as both TIFF and JPG
plt.savefig('NA_basin_tracks_STORMS.tiff', dpi=300, bbox_inches='tight', format='tiff')
plt.savefig('NA_basin_tracks_STORMS.jpg', dpi=300, bbox_inches='tight', format='jpeg')

plt.show()
plt.close(fig)