#!/usr/bin/python

import heapq
import numpy
import pylab
from matplotlib.collections import LineCollection

# This is useful when playing around in the python interpreter:
# it causes plots to be displayed and updated immediately.
#pylab.ion()

class Node(object):
    def __init__(self, loc):
        self.loc = loc
        self.adj = []
        self.dists = []

    def __repr__(self):
        return '<Node(%s) with %s edges>' % (self.loc, len(self.adj))

def euclidean(a, b):
    """Return the Euclidean (ell_2) distance between two vectors."""
    return sum((x - y)**2 for x, y in zip(a, b))**.5

def manhattan(a, b):
    """Return the Manhattan (ell_1) distance between two vectors."""
    return sum(abs(x - y) for x, y in zip(a, b))

################################################################################
# Dijkstra / A*
################################################################################

def search(s, t, potential = lambda u: 0):
    """Find the shortest path from s to t using Dijkstra/A*.

    Given:
      - s, a Node to start from
      - t, a Node to finish at (or None)
      - potential, an admissible heuristic for the distance to t.

    Returns a tuple containing:
       - path, a list containing nodes along the shortest path from s to t.
       - visited, the set of nodes that were visited by the time t was reached.
       - prev, a dict giving back pointers on the shortest path
         tree starting at s.  prev[v] is correct for all v in visited.
       - dists, a dict giving the distance to nodes.  This is correct
         for visited nodes.

    If t is None, it returns the full shortest path tree from s, and
    'path' is None.
    """
    dists = {}
    visited = set()
    dists[s] = 0
    heap = [(dists[s] + potential(s), s)]
    prev = {}
    while heap and t not in visited:
        _, u = heapq.heappop(heap)
        if u in visited:
            continue
        visited.add(u)
        for v, length in zip(u.adj, u.dists):
            if dists[u] + length < dists.get(v, 1e9):
                dists[v] = dists[u] + length
                prev[v] = u
                heapq.heappush(heap, (dists[v] + potential(v), v))

    if t is None:
        path = None
    else: # Find the s-t path.
        if t not in prev:
            print 'Warning: no path from s to t'
            path = []
        else:
            v = t
            path = [v]
            while v != s:
                v = prev[v]
                path.append(v)
            path.reverse()
    return (path, visited, prev, dists)

################################################################################
# Graph generation
################################################################################

def add_edge(u, v, dist):
    u.adj.append(v)
    v.adj.append(u)
    u.dists.append(dist)
    v.dists.append(dist)

def generate_graph(n, d):
    """Create a size n graph where each vertex is neighbor to the nearest d.

    Return a list of Node instances.
    """

    # Compute the positions of the n points
    locs = numpy.random.rand(n, 2)
    # Compute the pairwise difference vectors (for an n x n x 2 matrix)
    gaps = locs.reshape(n, 1, 2) - locs.reshape(1, n, 2)
    # Compute the pairwist differences
    distances = numpy.linalg.norm(gaps, axis=2)
    # Locate the closest d in each row
    nearest = numpy.argpartition(distances, d)[:,:d]
    # Convert into a list of pairs
    edge_list = [(i, j) for i, lst in enumerate(nearest) for j in lst if i != j]

    # Convert into an edge per node
    nodes = [Node(loc) for loc in locs]
    for i, j in edge_list:
        add_edge(nodes[i], nodes[j], distances[i, j])

    return nodes

################################################################################
# Graph display
################################################################################

def draw_edges(edges, **kws):
    edgelocs = [(u.loc, v.loc) for u, v in edges]
    default_kws = dict(lw=0.5, color='black', zorder=-3)
    default_kws.update(kws)
    pylab.gca().add_collection(LineCollection(edgelocs, **default_kws))

def draw_nodes(nodes, **kws):
    default_kws = dict(s=10, lw=0.01)
    default_kws.update(kws)
    pylab.scatter([n.loc[0] for n in nodes],
                  [n.loc[1] for n in nodes], **default_kws)
    pylab.xlim(0, 1)
    pylab.ylim(0, 1)


def draw_search_result(path, prev, visited, s, t):
    seen_edges = [(u, v) for u in visited for v in u.adj]
    tree_edges = [(x, prev[x]) for x in prev]
    path_edges = zip(path, path[1:])
    draw_edges(seen_edges, lw=0.1, color='gray')
    draw_edges(tree_edges)
    draw_nodes(visited)
    draw_edges(path_edges, lw=4)
    draw_nodes([s], s=60, color='green')
    draw_nodes([t], s=60, color='red')

def draw_graph(nodes, s, t):
    draw_nodes(nodes)
    draw_edges([(u, v) for u in nodes for v in u.adj])
    draw_nodes([s], s=60, color='green')
    draw_nodes([t], s=60, color='red')

################################################################################
# Putting it together
################################################################################

def try_method(s, t, potential = lambda u: 0, potential_label=''):
    path, visited, prev, dists = search(s, t, potential)
    pylab.title('Heuristic = %s' % potential_label)
    draw_search_result(path, prev, visited, s, t)
    print 'Searched %s nodes, %s edges using heuristic %s: distance %s' % (len(visited), sum(len(u.adj) for u in visited),
                                                                           potential_label, dists.get(t))

def run_experiment(n, d):
    # Create the graph
    n = 2000
    d = 4
    nodes = generate_graph(n, d)
    s, t = nodes[0], nodes[-1]

    # Look at all of them
    pylab.figure()
    pylab.subplot(2, 2, 1)
    pylab.title('All nodes in graph')
    draw_graph(nodes, s, t)

    pylab.subplot(2, 2, 2)
    try_method(s, t, lambda u: 0, 'None (Dijkstra)')
    pylab.subplot(2, 2, 3)
    try_method(s, t, lambda u: euclidean(u.loc, t.loc), 'Euclidean')
    pylab.subplot(2, 2, 4)
    try_method(s, t, lambda u: manhattan(u.loc, t.loc), 'Manhattan')


if __name__ == '__main__':
    run_experiment(5000, 4)
    pylab.show()
