#!/usr/bin/env python3
"""
GAA OptiTrace DEM Analysis v1.0
Reads TINITALY 10m DEM + IFFI landslide data
Cross-references with engineer's KML route
Generates optimized route with real elevations

Usage: python3.12 gaa_dem_analysis.py
Output: /home/bangherangstudio/gaa.wpeitalia.eu/geodata/gaa_analysis.json
"""

import json, math, os, sys, glob
import numpy as np

# Paths
DEM_DIR = '/home/bangherangstudio/gaa.wpeitalia.eu/geodata/dem'
FRANE_DIR = '/home/bangherangstudio/gaa.wpeitalia.eu/geodata/frane'
KML_FILE = '/home/bangherangstudio/gaa.wpeitalia.eu/geodata/GAA_v6_0__MOD2_.kml'
OUTPUT_FILE = '/home/bangherangstudio/gaa.wpeitalia.eu/geodata/gaa_analysis.json'

print("="*70)
print("GAA OptiTrace DEM Analysis v1.0")
print("="*70)

# ============================================================
# 1. LOAD DEM TILES
# ============================================================
print("\n[1/5] Loading DEM tiles...")
import rasterio

dem_files = sorted(glob.glob(os.path.join(DEM_DIR, '*.tif')))
print(f"  Found {len(dem_files)} TIF files")

# Open all DEM files and store metadata
dem_datasets = []
for f in dem_files:
    try:
        ds = rasterio.open(f)
        dem_datasets.append({
            'path': f,
            'name': os.path.basename(f),
            'ds': ds,
            'bounds': ds.bounds,
            'crs': str(ds.crs)
        })
        print(f"  ✓ {os.path.basename(f)}: {ds.bounds.left/1000:.0f}-{ds.bounds.right/1000:.0f}E, {ds.bounds.bottom/1000:.0f}-{ds.bounds.top/1000:.0f}N")
    except Exception as e:
        print(f"  ✗ {os.path.basename(f)}: {e}")

def get_elevation(utm_e, utm_n):
    """Get elevation from DEM at UTM32N coordinates"""
    for dem in dem_datasets:
        b = dem['bounds']
        if b.left <= utm_e <= b.right and b.bottom <= utm_n <= b.top:
            try:
                row, col = dem['ds'].index(utm_e, utm_n)
                band = dem['ds'].read(1, window=rasterio.windows.Window(col, row, 1, 1))
                val = float(band[0, 0])
                if val > -100 and val < 5000:  # valid range
                    return val
            except:
                pass
    return None

def ll_to_utm32(lat, lon):
    """Convert WGS84 lat/lon to UTM zone 32N"""
    lat_r = math.radians(lat)
    lon0 = math.radians(9)
    k0 = 0.9996; a = 6378137.0; e2 = 0.00669438
    ep2 = e2 / (1 - e2)
    N = a / math.sqrt(1 - e2 * math.sin(lat_r)**2)
    T = math.tan(lat_r)**2
    C = ep2 * math.cos(lat_r)**2
    A = (math.radians(lon) - lon0) * math.cos(lat_r)
    M = a * ((1 - e2/4 - 3*e2**2/64 - 5*e2**3/256) * lat_r
           - (3*e2/8 + 3*e2**2/32 + 45*e2**3/1024) * math.sin(2*lat_r)
           + (15*e2**2/256 + 45*e2**3/1024) * math.sin(4*lat_r)
           - (35*e2**3/3072) * math.sin(6*lat_r))
    easting = k0 * N * (A + (1-T+C)*A**3/6 + (5-18*T+T**2+72*C-58*ep2)*A**5/120) + 500000
    northing = k0 * (M + N * math.tan(lat_r) * (A**2/2 + (5-T+9*C+4*C**2)*A**4/24 + (61-58*T+T**2+600*C-330*ep2)*A**6/720))
    return easting, northing

def haversine(lat1, lon1, lat2, lon2):
    """Distance in meters between two WGS84 points"""
    dlat = math.radians(lat2 - lat1)
    dlon = math.radians(lon2 - lon1)
    a = math.sin(dlat/2)**2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon/2)**2
    return 6371000 * 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))

# ============================================================
# 2. PARSE KML
# ============================================================
print("\n[2/5] Parsing KML...")
import re

with open(KML_FILE, 'r') as f:
    kml = f.read()

# Extract LineString coordinates (DN2500 + DN2200)
lines = re.findall(r'<LineString>.*?<coordinates>\s*(.*?)\s*</coordinates>', kml, re.DOTALL)

waypoints = []
for line_coords in lines:
    for p in line_coords.strip().split():
        parts = p.split(',')
        wp = {
            'lon': float(parts[0]),
            'lat': float(parts[1]),
            'kml_alt': float(parts[2])
        }
        wp['utm_e'], wp['utm_n'] = ll_to_utm32(wp['lat'], wp['lon'])
        waypoints.append(wp)

# Calculate cumulative distances
for i, wp in enumerate(waypoints):
    if i == 0:
        wp['cum_dist'] = 0
    else:
        wp['cum_dist'] = waypoints[i-1]['cum_dist'] + haversine(
            waypoints[i-1]['lat'], waypoints[i-1]['lon'],
            wp['lat'], wp['lon']
        )

# Extract torrini
torrini = []
for match in re.finditer(r'<Placemark[^>]*>(.*?)</Placemark>', kml, re.DOTALL):
    pm = match.group(1)
    point_match = re.search(r'<Point>\s*<coordinates>(.*?)</coordinates>', pm, re.DOTALL)
    if point_match:
        parts = point_match.group(1).strip().split(',')
        torrini.append({
            'lon': float(parts[0]),
            'lat': float(parts[1]),
            'utm_e': ll_to_utm32(float(parts[1]), float(parts[0]))[0],
            'utm_n': ll_to_utm32(float(parts[1]), float(parts[0]))[1]
        })

print(f"  Waypoints: {len(waypoints)}")
print(f"  Torrini: {len(torrini)}")
print(f"  Distanza totale: {waypoints[-1]['cum_dist']/1000:.1f} km")

# ============================================================
# 3. SAMPLE DEM ALONG ROUTE (every 100m)
# ============================================================
print("\n[3/5] Sampling DEM every 100m along route...")

sampled = []
SAMPLE_INTERVAL = 100  # meters

for i in range(len(waypoints) - 1):
    wp1 = waypoints[i]
    wp2 = waypoints[i+1]
    seg_dist = haversine(wp1['lat'], wp1['lon'], wp2['lat'], wp2['lon'])
    n_samples = max(1, int(seg_dist / SAMPLE_INTERVAL))
    
    for s in range(n_samples):
        t = s / n_samples
        lat = wp1['lat'] + t * (wp2['lat'] - wp1['lat'])
        lon = wp1['lon'] + t * (wp2['lon'] - wp1['lon'])
        kml_alt = wp1['kml_alt'] + t * (wp2['kml_alt'] - wp1['kml_alt'])
        cum_dist = wp1['cum_dist'] + t * seg_dist
        
        utm_e, utm_n = ll_to_utm32(lat, lon)
        dem_alt = get_elevation(utm_e, utm_n)
        
        sampled.append({
            'lat': lat,
            'lon': lon,
            'kml_alt': round(kml_alt, 1),
            'dem_alt': round(dem_alt, 1) if dem_alt is not None else None,
            'cum_dist': round(cum_dist, 1),
            'utm_e': round(utm_e, 1),
            'utm_n': round(utm_n, 1)
        })

# Add last point
wp_last = waypoints[-1]
dem_last = get_elevation(wp_last['utm_e'], wp_last['utm_n'])
sampled.append({
    'lat': wp_last['lat'],
    'lon': wp_last['lon'],
    'kml_alt': wp_last['kml_alt'],
    'dem_alt': round(dem_last, 1) if dem_last is not None else None,
    'cum_dist': round(wp_last['cum_dist'], 1),
    'utm_e': round(wp_last['utm_e'], 1),
    'utm_n': round(wp_last['utm_n'], 1)
})

n_with_dem = sum(1 for s in sampled if s['dem_alt'] is not None)
n_without = sum(1 for s in sampled if s['dem_alt'] is None)
print(f"  Sampled {len(sampled)} points")
print(f"  With DEM: {n_with_dem} ({n_with_dem*100/len(sampled):.0f}%)")
print(f"  Without DEM (no tile coverage): {n_without} ({n_without*100/len(sampled):.0f}%)")

if n_with_dem > 0:
    diffs = [s['dem_alt'] - s['kml_alt'] for s in sampled if s['dem_alt'] is not None]
    print(f"  DEM - KML difference: min={min(diffs):.1f}m, max={max(diffs):.1f}m, avg={sum(diffs)/len(diffs):.1f}m")
    
    # Trincee: DEM > KML (tubo sotto terra)
    trincee = [d for d in diffs if d > 10]
    print(f"  Trincee (DEM > KML+10m): {len(trincee)} punti")
    if trincee:
        print(f"    Max profondità trincea: {max(trincee):.1f}m")
    
    # Tubo in aria: KML > DEM (serve rilevato/viadotto)
    viadotti = [d for d in diffs if d < -10]
    print(f"  Viadotti/rilevati (KML > DEM+10m): {len(viadotti)} punti")
    if viadotti:
        print(f"    Max altezza rilevato: {-min(viadotti):.1f}m")

# ============================================================
# 4. LOAD AND CHECK LANDSLIDES
# ============================================================
print("\n[4/5] Loading landslide data...")

frane_files = sorted(glob.glob(os.path.join(FRANE_DIR, '*.json')))
print(f"  Found {len(frane_files)} GeoJSON files")

all_frane_points = []  # PIFF points
all_frane_areas = []   # Area polygons

for ff in frane_files:
    try:
        with open(ff, 'r') as f:
            gj = json.load(f)
        features = gj.get('features', [])
        fname = os.path.basename(ff)
        
        if 'piff' in fname:
            for feat in features:
                geom = feat.get('geometry', {})
                if geom and geom.get('type') == 'Point':
                    coords = geom['coordinates']
                    props = feat.get('properties', {})
                    all_frane_points.append({
                        'lon': coords[0],
                        'lat': coords[1],
                        'tipo': props.get('tipologia', props.get('tipo_movim', '?')),
                        'file': fname
                    })
        elif 'aree' in fname:
            for feat in features:
                geom = feat.get('geometry', {})
                props = feat.get('properties', {})
                if geom:
                    gtype = geom.get('type', '')
                    if gtype == 'Polygon':
                        all_frane_areas.append({
                            'coords': geom['coordinates'],
                            'tipo': props.get('tipologia', props.get('tipo_movim', '?')),
                            'file': fname
                        })
                    elif gtype == 'MultiPolygon':
                        for poly in geom['coordinates']:
                            all_frane_areas.append({
                                'coords': poly,
                                'tipo': props.get('tipologia', props.get('tipo_movim', '?')),
                                'file': fname
                            })
        
        print(f"  ✓ {fname}: {len(features)} features")
    except Exception as e:
        print(f"  ✗ {fname}: {e}")

print(f"  Total landslide points (PIFF): {len(all_frane_points)}")
print(f"  Total landslide areas: {len(all_frane_areas)}")

# Check proximity of route to landslide points
FRANA_BUFFER = 500  # meters
frane_nearby = []
print(f"\n  Checking route proximity to landslides (buffer {FRANA_BUFFER}m)...")

# Sample every 10th point for speed
check_points = sampled[::10]

for fp in all_frane_points:
    for sp in check_points:
        d = haversine(sp['lat'], sp['lon'], fp['lat'], fp['lon'])
        if d < FRANA_BUFFER:
            frane_nearby.append({
                'frana_lat': fp['lat'],
                'frana_lon': fp['lon'],
                'frana_tipo': fp['tipo'],
                'route_km': sp['cum_dist'] / 1000,
                'distance_m': round(d, 0),
                'route_lat': sp['lat'],
                'route_lon': sp['lon']
            })
            break  # One match per frana is enough

print(f"  Landslides within {FRANA_BUFFER}m of route: {len(frane_nearby)}")

# Check if route passes through landslide areas
def point_in_polygon(lat, lon, polygon_coords):
    """Ray casting algorithm"""
    ring = polygon_coords[0]  # outer ring
    n = len(ring)
    inside = False
    x, y = lon, lat
    j = n - 1
    for i in range(n):
        xi, yi = ring[i][0], ring[i][1]
        xj, yj = ring[j][0], ring[j][1]
        if ((yi > y) != (yj > y)) and (x < (xj - xi) * (y - yi) / (yj - yi) + xi):
            inside = not inside
        j = i
    return inside

frane_through = []
print(f"  Checking if route passes through landslide areas...")
check_points_area = sampled[::5]  # every 500m

for sp in check_points_area:
    for area in all_frane_areas:
        if point_in_polygon(sp['lat'], sp['lon'], area['coords']):
            frane_through.append({
                'km': sp['cum_dist'] / 1000,
                'lat': sp['lat'],
                'lon': sp['lon'],
                'tipo': area['tipo'],
                'file': area['file']
            })
            break  # one area match per point

print(f"  Route points inside landslide areas: {len(frane_through)}")

# ============================================================
# 5. PIEZOMETRIC ANALYSIS WITH TORRINI
# ============================================================
print("\n[5/5] Piezometric analysis...")

LOSS_PER_M = 0.28 / 1000  # m drop per m of pipe
TORRINO_HEIGHT = 6  # m above pipe

# Match torrini to nearest waypoint
for t in torrini:
    min_dist = float('inf')
    min_idx = 0
    for i, wp in enumerate(waypoints):
        d = haversine(wp['lat'], wp['lon'], t['lat'], t['lon'])
        if d < min_dist:
            min_dist = d
            min_idx = i
    t['wp_idx'] = min_idx
    t['km'] = waypoints[min_idx]['cum_dist'] / 1000
    t['kml_alt'] = waypoints[min_idx]['kml_alt']
    t['dem_alt'] = get_elevation(t['utm_e'], t['utm_n'])

print(f"  Torrini matched:")
for i, t in enumerate(torrini):
    dem_str = f"{t['dem_alt']:.0f}m DEM" if t['dem_alt'] else "no DEM"
    print(f"    T{i+1}: km {t['km']:.1f}, livelletta {t['kml_alt']:.0f}m, {dem_str}")

# Without torrini
piez = waypoints[0]['kml_alt']
stats_no_torrini = {'violations': 0, 'max_deficit': 0, 'max_pressure': 0}
for i in range(1, len(waypoints)):
    d = haversine(waypoints[i-1]['lat'], waypoints[i-1]['lon'], waypoints[i]['lat'], waypoints[i]['lon'])
    piez -= d * LOSS_PER_M
    margin = piez - waypoints[i]['kml_alt']
    if margin < 0:
        stats_no_torrini['violations'] += 1
        stats_no_torrini['max_deficit'] = max(stats_no_torrini['max_deficit'], -margin)
    stats_no_torrini['max_pressure'] = max(stats_no_torrini['max_pressure'], margin / 10.2 if margin > 0 else 0)
stats_no_torrini['piez_arrivo'] = round(piez, 1)
stats_no_torrini['margine_arrivo'] = round(piez - waypoints[-1]['kml_alt'], 1)

# With torrini
piez = waypoints[0]['kml_alt']
torrini_sorted = sorted(torrini, key=lambda t: t['wp_idx'])
next_torrino = 0
stats_torrini = {'violations': 0, 'max_deficit': 0, 'max_pressure': 0, 'used': []}
for i in range(1, len(waypoints)):
    d = haversine(waypoints[i-1]['lat'], waypoints[i-1]['lon'], waypoints[i]['lat'], waypoints[i]['lon'])
    piez -= d * LOSS_PER_M
    
    # Check torrino
    if next_torrino < len(torrini_sorted):
        t = torrini_sorted[next_torrino]
        if i >= t['wp_idx']:
            reset_alt = waypoints[i]['kml_alt'] + TORRINO_HEIGHT
            piez = reset_alt
            stats_torrini['used'].append({
                'idx': next_torrino + 1,
                'km': t['km'],
                'reset_to': reset_alt
            })
            next_torrino += 1
    
    margin = piez - waypoints[i]['kml_alt']
    if margin < 0:
        stats_torrini['violations'] += 1
        stats_torrini['max_deficit'] = max(stats_torrini['max_deficit'], -margin)
    stats_torrini['max_pressure'] = max(stats_torrini['max_pressure'], margin / 10.2 if margin > 0 else 0)

stats_torrini['piez_arrivo'] = round(piez, 1)
stats_torrini['margine_arrivo'] = round(piez - waypoints[-1]['kml_alt'], 1)

print(f"\n  SENZA TORRINI:")
print(f"    Piez arrivo: {stats_no_torrini['piez_arrivo']}m")
print(f"    Margine: {stats_no_torrini['margine_arrivo']:+.0f}m")
print(f"    Violazioni: {stats_no_torrini['violations']}")
print(f"    Max pressione: {stats_no_torrini['max_pressure']:.1f} bar")

print(f"\n  CON {len(stats_torrini['used'])} TORRINI:")
print(f"    Piez arrivo: {stats_torrini['piez_arrivo']}m")
print(f"    Margine: {stats_torrini['margine_arrivo']:+.0f}m")
print(f"    Violazioni: {stats_torrini['violations']}")
print(f"    Max pressione: {stats_torrini['max_pressure']:.1f} bar")

# ============================================================
# 6. IDENTIFY CRITICAL ZONES
# ============================================================
print("\n" + "="*70)
print("ZONE CRITICHE")
print("="*70)

# Trincee profonde (DEM > KML > 20m)
deep_trincee = []
for s in sampled:
    if s['dem_alt'] is not None:
        diff = s['dem_alt'] - s['kml_alt']
        if diff > 20:
            deep_trincee.append({
                'km': s['cum_dist'] / 1000,
                'lat': s['lat'],
                'lon': s['lon'],
                'dem': s['dem_alt'],
                'kml': s['kml_alt'],
                'depth': round(diff, 1)
            })

# Group consecutive trincee into zones
trincea_zones = []
if deep_trincee:
    zone = [deep_trincee[0]]
    for i in range(1, len(deep_trincee)):
        if deep_trincee[i]['km'] - deep_trincee[i-1]['km'] < 0.5:
            zone.append(deep_trincee[i])
        else:
            trincea_zones.append(zone)
            zone = [deep_trincee[i]]
    trincea_zones.append(zone)

print(f"\nTrincee profonde (>20m): {len(trincea_zones)} zone")
for z in trincea_zones:
    km_start = z[0]['km']
    km_end = z[-1]['km']
    max_depth = max(p['depth'] for p in z)
    print(f"  km {km_start:.1f}-{km_end:.1f}: max {max_depth:.0f}m ({len(z)} punti)")

# Risalite pesanti (livelletta sale >15m)
print(f"\nRisalite livelletta >15m:")
for i in range(1, len(waypoints)):
    delta = waypoints[i]['kml_alt'] - waypoints[i-1]['kml_alt']
    if delta > 15:
        km = waypoints[i]['cum_dist'] / 1000
        print(f"  km {km:.1f}: {waypoints[i-1]['kml_alt']:.0f}→{waypoints[i]['kml_alt']:.0f}m (+{delta:.0f}m) [{waypoints[i]['lat']:.4f},{waypoints[i]['lon']:.4f}]")

# Frane zone
if frane_nearby:
    print(f"\nFrane entro 500m dal tracciato: {len(frane_nearby)}")
    # Group by km
    frane_by_km = {}
    for fn in frane_nearby:
        km_key = round(fn['route_km'])
        if km_key not in frane_by_km:
            frane_by_km[km_key] = []
        frane_by_km[km_key].append(fn)
    for km_key in sorted(frane_by_km.keys()):
        flist = frane_by_km[km_key]
        print(f"  km ~{km_key}: {len(flist)} frane (min dist: {min(f['distance_m'] for f in flist):.0f}m)")

if frane_through:
    print(f"\nTracciato ATTRAVERSA aree frana: {len(frane_through)} punti")
    for ft in frane_through:
        print(f"  km {ft['km']:.1f}: tipo={ft['tipo']} [{ft['lat']:.4f},{ft['lon']:.4f}]")

# ============================================================
# 7. SAVE OUTPUT
# ============================================================
print(f"\n{'='*70}")
print("Saving analysis...")

# Reduce sampled for output (every 5th point = ~500m)
sampled_reduced = sampled[::5]

output = {
    'version': 'GAA_v6.0_MOD2_analysis',
    'date': '2026-02-16',
    'summary': {
        'total_waypoints': len(waypoints),
        'total_distance_km': round(waypoints[-1]['cum_dist']/1000, 1),
        'start_alt': waypoints[0]['kml_alt'],
        'end_alt': waypoints[-1]['kml_alt'],
        'sampled_points': len(sampled),
        'dem_coverage_pct': round(n_with_dem * 100 / len(sampled)),
        'torrini': len(torrini)
    },
    'piezometric': {
        'without_torrini': stats_no_torrini,
        'with_torrini': stats_torrini
    },
    'landslides': {
        'nearby_500m': len(frane_nearby),
        'through_areas': len(frane_through),
        'details_nearby': frane_nearby[:50],
        'details_through': frane_through
    },
    'critical_zones': {
        'deep_trincee': [{
            'km_start': z[0]['km'],
            'km_end': z[-1]['km'],
            'max_depth': max(p['depth'] for p in z),
            'points': len(z)
        } for z in trincea_zones],
        'risalite_15m': [{
            'km': waypoints[i]['cum_dist']/1000,
            'from': waypoints[i-1]['kml_alt'],
            'to': waypoints[i]['kml_alt'],
            'delta': waypoints[i]['kml_alt'] - waypoints[i-1]['kml_alt']
        } for i in range(1, len(waypoints)) if waypoints[i]['kml_alt'] - waypoints[i-1]['kml_alt'] > 15]
    },
    'waypoints': [{
        'idx': i,
        'lat': round(w['lat'], 6),
        'lon': round(w['lon'], 6),
        'kml_alt': w['kml_alt'],
        'dem_alt': round(get_elevation(w['utm_e'], w['utm_n']), 1) if get_elevation(w['utm_e'], w['utm_n']) else None,
        'cum_dist_km': round(w['cum_dist']/1000, 2)
    } for i, w in enumerate(waypoints)],
    'profile_500m': [{
        'km': round(s['cum_dist']/1000, 2),
        'lat': round(s['lat'], 6),
        'lon': round(s['lon'], 6),
        'kml': s['kml_alt'],
        'dem': s['dem_alt']
    } for s in sampled_reduced],
    'torrini': [{
        'idx': i+1,
        'lat': round(t['lat'], 6),
        'lon': round(t['lon'], 6),
        'km': round(t['km'], 1),
        'kml_alt': t['kml_alt'],
        'dem_alt': round(t['dem_alt'], 1) if t['dem_alt'] else None
    } for i, t in enumerate(torrini)]
}

with open(OUTPUT_FILE, 'w') as f:
    json.dump(output, f, indent=2)

print(f"  Saved to {OUTPUT_FILE}")
print(f"  File size: {os.path.getsize(OUTPUT_FILE)/1024:.0f} KB")

# Close DEM datasets
for dem in dem_datasets:
    dem['ds'].close()

print(f"\n{'='*70}")
print("ANALYSIS COMPLETE")
print(f"{'='*70}")