--- a/pyforge/flyway/graph.py
+++ b/pyforge/flyway/graph.py
@@ -1,75 +1,107 @@
-import heapq
from itertools import product
from collections import defaultdict, namedtuple
-def shortest_path(nodes, start, end):
- '''Dijkstra's algorithm for shortest path from the start Node to the end Node'''
- start.distance = 0
- unvisited = list(nodes)
- heapq.heapify(unvisited)
- while unvisited:
- cur = heapq.heappop(unvisited)
- if cur.distance is None:
- raise ValueError, 'No migration path exists from %s to %s' % (
- start, end)
- if cur in end:
- return list(cur.path())
- cur.visit()
- heapq.heapify(unvisited)
-
-
-def gen_migration_states(migrations):
- '''Return all states referenced by the given migrations'''
- versions = defaultdict(lambda: [-1])
- for mod,version in migrations:
- versions[mod].append(version)
- State = namedtuple('State', versions)
- modules = versions.keys()
- states = [ State(*ver) for ver in product(*versions.values()) ]
- return State, modules, states
-
-def index_migration_states(modules, states):
- '''Return an index from (module,version) => all states with that mod,version'''
- index = defaultdict(list)
- for s in states:
- for m in modules:
- v = getattr(s, m)
- index[m,v].append(s)
- return index
-
-def build_graph(states, state_index, migrations):
- node_from_state = dict((s, Node(s)) for s in states)
- nodes = node_from_state.values()
- for m in migrations.itervalues():
- for direction in 'up', 'down':
- ms = MigrateStep(m, direction)
- for cur_state, next_state in ms.transitions(state_index):
- n = Node(ms)
- nodes.append(n)
- prev = node_from_state[cur_state]
- next = node_from_state[next_state]
- prev.succs.append(n)
- n.succs.append(next)
- return nodes
+class MigrationGraph(object):
+
+ def __init__(self, migrations):
+ self._build_graph(migrations)
+
+ def _build_graph(self, migrations):
+ '''Build a graph where the nodes are possible migration states and the
+ edges are transitions between states allowed by migrations.
+ '''
+ # Generate all states referenced by the given migrations. Also index
+ # nodes by state.
+ versions = defaultdict(lambda: [-1])
+ for mod,version in migrations:
+ versions[mod].append(version)
+ self._State = namedtuple('State', versions)
+ self._modules = versions.keys()
+ self._nodes = [ Node(self._State(*ver)) for ver in product(*versions.values()) ]
+ self.node_by_state = dict((n.state, n) for n in self._nodes)
+
+ # Index the nodes by (mod,version)
+ self._index = defaultdict(list)
+ for n in self._nodes:
+ for m in self._modules:
+ v = getattr(n.state, m)
+ self._index[m,v].append(n)
+
+ # Add edges for all the migrations
+ for m in migrations.itervalues():
+ for direction in 'up', 'down':
+ ms = MigrateStep(self, m, direction)
+ for prev, next in ms.transitions():
+ prev.succs.append((next, ms))
+
+ def nodes_with(self, requirements):
+ '''Return list of nodes that match the requirements listed in
+ requirements, which is either a dict or list of (mod,version) pairs.'''
+ if isinstance(requirements, dict):
+ requirements = requirements.iteritems()
+ nodes = None
+ for (mod, ver) in requirements:
+ if nodes is None: nodes = set(self._index[mod,ver])
+ else: nodes &= set(self._index[mod,ver])
+ return nodes
+
+ def shortest_path(self, start_requirements, end_requirements):
+ '''Dijkstra's algorithm for shortest path from the start Node to any end
+ Node'''
+ # Find the start node
+ start = dict((m, -1) for m in self._modules)
+ start.update(start_requirements)
+ start_state = self._State(**start)
+ start = self.node_by_state[start_state]
+ # Find the end node(s)
+ end = self.nodes_with(end_requirements)
+ # Run the algorithm
+ start.distance = 0
+ nodes = priority_dict(
+ (node, node.distance)
+ for node in self._nodes)
+ while nodes:
+ cur = nodes.pop_smallest()
+ if cur.distance is None: # pragma no cover
+ raise ValueError, 'No migration path exists from %s to %s' % (
+ start, end)
+ if cur in end:
+ return list(cur.path())
+ cur.visit(nodes)
+
+ def as_dot(self): # pragma no cover
+ yield 'digraph G {'
+ for n in self._nodes:
+ yield 'node_%d[label="%r"];' % (id(n), n.state)
+ for n in self._nodes:
+ for (next, ms) in n.succs:
+ yield 'node_%d->node_%d[label="%r"];' % (id(n), id(next), ms)
+ yield '}'
class MigrateStep(object):
-
- def __init__(self, migration, direction):
+ '''Object representing a single migration step in a single direction (either
+ up or down'''
+
+ def __init__(self, graph, migration, direction):
+ self.graph = graph
self.migration = migration
self.direction = direction
- def transitions(self, state_index):
+ def transitions(self):
+ '''Returns all node->node transitions made possible by this migratestep'''
if self.direction == 'up':
reqs = self.migration.up_requires()
postcondition = self.migration.up_postcondition()
else:
reqs = self.migration.down_requires()
postcondition = self.migration.down_postcondition()
- for prev in states_with(reqs, state_index):
- next = prev._replace(**postcondition)
+ for prev in self.graph.nodes_with(reqs):
+ next_state = prev.state._replace(**postcondition)
+ next = self.graph.node_by_state[next_state]
yield prev, next
def apply(self, state):
+ '''Actually run the migration, updating the state passed in'''
if self.direction == 'up':
self.migration.up()
state.update(self.migration.up_postcondition())
@@ -77,7 +109,7 @@
self.migration.down()
state.update(self.migration.down_postcondition())
- def __repr__(self):
+ def __repr__(self): # pragma no cover
return '<%s.%s %s>' % (
self.migration.module,
self.migration.version,
@@ -85,43 +117,124 @@
class Node(object):
- def __init__(self, data):
- self.data = data
+ def __init__(self, state):
+ self.state = state
self.visited = False
- self.distance = None
- self.pred = None
- self.succs = []
-
- def visit(self):
+ self.distance = 1e9 # effectively inf
+ self.pred = None # (state, migrationstep)
+ self.succs = [] # list of (state, migrationstep)
+
+ def visit(self, nodes):
+ '''The 'visit' step of Dijkstra's shortest-path algorithm'''
self.visited = True
- for succ in self.succs:
+ new_dist = self.distance + 1
+ for succ, ms in self.succs:
if succ.visited: continue
- if self < succ:
- succ.distance = self.distance + 1
- succ.pred = self
+ if new_dist < succ.distance:
+ succ.distance = new_dist
+ succ.pred = (self, ms)
+ nodes[succ] = new_dist
def path(self):
+ '''Read back the shortest path from the 'predecessor' field'''
if self.pred:
- for p in self.pred.path():
+ for p in self.pred[0].path():
yield p
- yield self.data
-
- def __lt__(self, other):
- if self.distance is None:
- return False
- if other.distance is None:
- return True
- return self.distance < other.distance
-
- def __repr__(self):
- return '<Node %r (%s)>' % (self.data,self.distance)
-
-def states_with(requirements, state_index):
- states = None
- for (mod, ver) in requirements:
- if states is None: states = set(state_index[mod,ver])
- else: states &= set(state_index[mod,ver])
- return states
-
-
-
+ yield self.pred[1]
+
+ def __repr__(self): # pragma no cover
+ return '<Node %r (%s)>' % (self.state,self.distance)
+
+# priority dictionary recipe copied from
+# http://code.activestate.com/recipes/522995-priority-dict-a-priority-queue-with-updatable-prio/
+# We use this rather than the raw heap because the priority_dict allows us to
+# update the priority of a node, which heapq does not (natively) allow without
+# re-running heapify() each time a priority changes. (And priorities change
+# often in Dijkstra's algorithm.)
+from heapq import heapify, heappush, heappop
+
+class priority_dict(dict):
+ """Dictionary that can be used as a priority queue.
+
+ Keys of the dictionary are items to be put into the queue, and values
+ are their respective priorities. All dictionary methods work as expected.
+ The advantage over a standard heapq-based priority queue is
+ that priorities of items can be efficiently updated (amortized O(1))
+ using code as 'thedict[item] = new_priority.'
+
+ The 'smallest' method can be used to return the object with lowest
+ priority, and 'pop_smallest' also removes it.
+
+ The 'sorted_iter' method provides a destructive sorted iterator.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(priority_dict, self).__init__(*args, **kwargs)
+ self._rebuild_heap()
+
+ def _rebuild_heap(self):
+ self._heap = [(v, k) for k, v in self.iteritems()]
+ heapify(self._heap)
+
+ def smallest(self):
+ """Return the item with the lowest priority.
+
+ Raises IndexError if the object is empty.
+ """
+
+ heap = self._heap
+ v, k = heap[0]
+ while k not in self or self[k] != v:
+ heappop(heap)
+ v, k = heap[0]
+ return k
+
+ def pop_smallest(self):
+ """Return the item with the lowest priority and remove it.
+
+ Raises IndexError if the object is empty.
+ """
+
+ heap = self._heap
+ v, k = heappop(heap)
+ while k not in self or self[k] != v:
+ v, k = heappop(heap)
+ del self[k]
+ return k
+
+ def __setitem__(self, key, val):
+ # We are not going to remove the previous value from the heap,
+ # since this would have a cost O(n).
+
+ super(priority_dict, self).__setitem__(key, val)
+
+ if len(self._heap) < 2 * len(self):
+ heappush(self._heap, (val, key))
+ else:
+ # When the heap grows larger than 2 * len(self), we rebuild it
+ # from scratch to avoid wasting too much memory.
+ self._rebuild_heap()
+
+ def setdefault(self, key, val):
+ if key not in self:
+ self[key] = val
+ return val
+ return self[key]
+
+ def update(self, *args, **kwargs):
+ # Reimplementing dict.update is tricky -- see e.g.
+ # http://mail.python.org/pipermail/python-ideas/2007-May/000744.html
+ # We just rebuild the heap from scratch after passing to super.
+
+ super(priority_dict, self).update(*args, **kwargs)
+ self._rebuild_heap()
+
+ def sorted_iter(self):
+ """Sorted iterator of the priority dictionary items.
+
+ Beware: this will destroy elements as they are returned.
+ """
+
+ while self:
+ yield self.pop_smallest()
+# End recipe