import networkx as nx
import matplotlib.pyplot as plt

def genere_positions():
    pos = {
        "Entrepot": (0, 0), "A": (1, 2), "B": (2, -1.5),
        "C": (3, 1), "D": (4, -2), "E": (5, 1.5),
        "Client": (6, 0)
    }
    return pos    

def construit_reseau_drones():
    G = nx.Graph()
    G.add_nodes_from(['Entrepot', 'A', 'B', 'C', 'D', 'E', 'Client'])    
    aretes = [
        ('Entrepot', 'A', 30),
        ('Entrepot', 'B', 40),
        ('A', 'C', 20),
        ('A', 'B', 50),
        ('B', 'D', 10),
        ('C', 'E', 60),
        ('C', 'D', 30),
        ('D', 'E', 20),
        ('E', 'Client', 40),
        ('D', 'Client', 70)
    ]
    
    for u, v, poids in aretes:
        G.add_edge(u, v, weight=poids)
    
    return G

def affiche_reseau_drones(G):

    pos = genere_positions()
    
    plt.figure(figsize=(10, 6))
    
    node_colors = ['lightblue' if node == 'Entrepot' else 
                   'salmon' if node == 'Client' else 
                   'lightgray' for node in G.nodes()]
    
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=800)
    nx.draw_networkx_edges(G, pos, width=2, alpha=0.7)
    nx.draw_networkx_labels(G, pos, font_size=10, font_weight='bold')
    
    edge_labels = nx.get_edge_attributes(G, 'weight')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9)
    
    plt.title("Réseau de drones - Chemins arête-disjoints Entrepot → Client")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

def trouve_chemins_disjoints(G):
    source, target = 'Entrepot', 'Client'
    
    chemin1 = nx.dijkstra_path(G, source, target)
    longueur1 = nx.dijkstra_path_length(G, source, target)
    print(f"Chemin 1 (minimal) : {' → '.join(chemin1)} = {longueur1}s")
    
    aretes1 = set(tuple(sorted([chemin1[i], chemin1[i+1]])) for i in range(len(chemin1)-1))
    
    G_temp = G.copy()
    G_temp.remove_edges_from(aretes1)
    
    chemin2 = nx.dijkstra_path(G_temp, source, target)
    longueur2 = nx.dijkstra_path_length(G_temp, source, target)
    print(f"Chemin 2 (disjoint) : {' → '.join(chemin2)} = {longueur2}s")
    
    return chemin1, longueur1, chemin2, longueur2

def affiche_chemins_disjoints(G, chemin1, chemin2):
    pos = genere_positions()
    
    plt.figure(figsize=(12, 6))
    
    node_colors = ['lightblue' if node == 'Entrepot' else 
                   'salmon' if node == 'Client' else 
                   'lightgray' for node in G.nodes()]
    
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=800)
    nx.draw_networkx_labels(G, pos, font_size=10, font_weight='bold')
    
    aretes1 = list(zip(chemin1, chemin1[1:]))
    aretes2 = list(zip(chemin2, chemin2[1:]))
    
    nx.draw_networkx_edges(G, pos, edgelist=aretes1, 
                          edge_color='blue', width=4, alpha=0.9)

    nx.draw_networkx_edges(G, pos, edgelist=aretes2, 
                          edge_color='orange', width=4, alpha=0.9)
    
    toutes_aretes = set(G.edges())
    aretes_utilisees = set(aretes1 + aretes2)
    autres_aretes = list(toutes_aretes - aretes_utilisees)
    nx.draw_networkx_edges(G, pos, edgelist=autres_aretes, 
                          edge_color='gray', width=1, alpha=0.5)
    
    edge_labels1 = {edge: G[edge[0]][edge[1]]['weight'] for edge in aretes1}
    edge_labels2 = {edge: G[edge[0]][edge[1]]['weight'] for edge in aretes2}
    edge_labels = {**edge_labels1, **edge_labels2}
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9)
    
    plt.title("2 Chemins arête-disjoints : Bleu (1er) + Orange (2e)")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

def calcule_temps_total(longueur1, longueur2):

    total = longueur1 + longueur2
    print(f"\n TEMPS TOTAL des 2 drones : {total} secondes")
    print(f"   (Drone 1: {longueur1}s + Drone 2: {longueur2}s)")
    return total

def main():
    print("--- Exercice 02 ---\n")
    
    # (1) Construction
    print("(1) Construction du réseau...")
    G = construit_reseau_drones()
    print(f"Graphe: {len(G.nodes())} sommets, {len(G.edges())} arêtes\n")
    
    # (2) Affichage initial
    print("(2) Affichage du réseau")
    affiche_reseau_drones(G)
    
    # (3) Chemins disjoints
    print("\n(3) Recherche de 2 chemins arête-disjoints")
    chemin1, l1, chemin2, l2 = trouve_chemins_disjoints(G)
    
    # (4) Affichage coloré
    print("\n(4) Affichage des 2 chemins")
    affiche_chemins_disjoints(G, chemin1, chemin2)
    
    # (5) Temps total
    print("\n(5) Calcul temps total")
    calcule_temps_total(l1, l2)
    

if __name__ == "__main__":
    main()
