Switch to unified view

a/pyforge/flyway/graph.py b/pyforge/flyway/graph.py
1
import heapq
2
from itertools import product
1
from itertools import product
3
from collections import defaultdict, namedtuple
2
from collections import defaultdict, namedtuple
4
3
5
def shortest_path(nodes, start, end):
4
class MigrationGraph(object):
5
6
    def __init__(self, migrations):
7
        self._build_graph(migrations)
8
9
    def _build_graph(self, migrations):
10
        '''Build a graph where the nodes are possible migration states and the
11
        edges are transitions between states allowed by migrations.
12
        '''
13
        # Generate all states referenced by the given migrations.  Also index
14
        # nodes by state.
15
        versions = defaultdict(lambda: [-1])
16
        for mod,version in migrations:
17
            versions[mod].append(version)
18
        self._State = namedtuple('State', versions)
19
        self._modules = versions.keys()
20
        self._nodes = [ Node(self._State(*ver)) for ver in product(*versions.values()) ]
21
        self.node_by_state = dict((n.state, n) for n in self._nodes)
22
23
        # Index the nodes by (mod,version)
24
        self._index = defaultdict(list)
25
        for n in self._nodes:
26
            for m in self._modules:
27
                v = getattr(n.state, m)
28
                self._index[m,v].append(n)
29
30
        # Add edges for all the migrations
31
        for m in migrations.itervalues():
32
            for direction in 'up', 'down':
33
                ms = MigrateStep(self, m, direction)
34
                for prev, next in ms.transitions():
35
                    prev.succs.append((next, ms))
36
37
    def nodes_with(self, requirements):
38
        '''Return list of nodes that match the requirements listed in
39
        requirements, which is either a dict or list of (mod,version) pairs.'''
40
        if isinstance(requirements, dict):
41
            requirements = requirements.iteritems()
42
        nodes = None
43
        for (mod, ver) in requirements:
44
            if nodes is None: nodes = set(self._index[mod,ver])
45
            else: nodes &= set(self._index[mod,ver])
46
        return nodes
47
48
    def shortest_path(self, start_requirements, end_requirements):
6
    '''Dijkstra's algorithm for shortest path from the start Node to the end Node'''
49
        '''Dijkstra's algorithm for shortest path from the start Node to any end
50
        Node'''
51
        # Find the start node
52
        start = dict((m, -1) for m in self._modules)
53
        start.update(start_requirements)
54
        start_state = self._State(**start)
55
        start = self.node_by_state[start_state]
56
        # Find the end node(s)
57
        end = self.nodes_with(end_requirements)
58
        # Run the algorithm
7
    start.distance = 0
59
        start.distance = 0
8
    unvisited = list(nodes)
60
        nodes = priority_dict(
9
    heapq.heapify(unvisited)
61
            (node, node.distance)
10
    while unvisited:
62
            for node in self._nodes)
11
        cur = heapq.heappop(unvisited)
63
        while nodes:
64
            cur = nodes.pop_smallest()
12
        if cur.distance is None:
65
            if cur.distance is None: # pragma no cover
13
            raise ValueError, 'No migration path exists from %s to %s' % (
66
                raise ValueError, 'No migration path exists from %s to %s' % (
14
                start, end)
67
                    start, end)
15
        if cur in end:
68
            if cur in end:
16
            return list(cur.path())
69
                return list(cur.path())
17
        cur.visit()
70
            cur.visit(nodes)
18
        heapq.heapify(unvisited)
19
71
20
72
    def as_dot(self): # pragma no cover
21
def gen_migration_states(migrations):
73
        yield 'digraph G {'
22
    '''Return all states referenced by the given migrations'''
23
    versions = defaultdict(lambda: [-1])
24
    for mod,version in migrations:
25
        versions[mod].append(version)
26
    State = namedtuple('State', versions)
27
    modules = versions.keys()
28
    states = [ State(*ver) for ver in product(*versions.values()) ]
29
    return State, modules, states
30
31
def index_migration_states(modules, states):
32
    '''Return an index from (module,version) => all states with that mod,version'''
33
    index = defaultdict(list)
34
    for s in states:
35
        for m in modules:
74
        for n in self._nodes:
36
            v = getattr(s, m)
75
            yield 'node_%d[label="%r"];' % (id(n), n.state)
37
            index[m,v].append(s)
76
        for n in self._nodes:
38
    return index
77
            for (next, ms) in n.succs:
39
78
                yield 'node_%d->node_%d[label="%r"];' % (id(n), id(next), ms)
40
def build_graph(states, state_index, migrations):
79
        yield '}'
41
    node_from_state = dict((s, Node(s)) for s in states)
42
    nodes = node_from_state.values()
43
    for m in migrations.itervalues():
44
        for direction in 'up', 'down':
45
            ms = MigrateStep(m, direction)
46
            for cur_state, next_state in ms.transitions(state_index):
47
                n = Node(ms)
48
                nodes.append(n)
49
                prev = node_from_state[cur_state]
50
                next = node_from_state[next_state]
51
                prev.succs.append(n)
52
                n.succs.append(next)
53
    return nodes
54
80
55
class MigrateStep(object):
81
class MigrateStep(object):
82
    '''Object representing a single migration step in a single direction (either
83
    up or down'''
56
84
57
    def __init__(self, migration, direction):
85
    def __init__(self, graph, migration, direction):
86
        self.graph = graph
58
        self.migration = migration
87
        self.migration = migration
59
        self.direction = direction
88
        self.direction = direction
60
89
61
    def transitions(self, state_index):
90
    def transitions(self):
91
        '''Returns all node->node transitions made possible by this migratestep'''
62
        if self.direction == 'up':
92
        if self.direction == 'up':
63
            reqs = self.migration.up_requires()
93
            reqs = self.migration.up_requires()
64
            postcondition = self.migration.up_postcondition()
94
            postcondition = self.migration.up_postcondition()
65
        else:
95
        else:
66
            reqs = self.migration.down_requires()
96
            reqs = self.migration.down_requires()
67
            postcondition = self.migration.down_postcondition()
97
            postcondition = self.migration.down_postcondition()
68
        for prev in states_with(reqs, state_index):
98
        for prev in self.graph.nodes_with(reqs):
69
            next = prev._replace(**postcondition)
99
            next_state = prev.state._replace(**postcondition)
100
            next = self.graph.node_by_state[next_state]
70
            yield prev, next
101
            yield prev, next
71
102
72
    def apply(self, state):
103
    def apply(self, state):
104
        '''Actually run the migration, updating the state passed in'''
73
        if self.direction == 'up':
105
        if self.direction == 'up':
74
            self.migration.up()
106
            self.migration.up()
75
            state.update(self.migration.up_postcondition())
107
            state.update(self.migration.up_postcondition())
76
        else:
108
        else:
77
            self.migration.down()
109
            self.migration.down()
78
            state.update(self.migration.down_postcondition())
110
            state.update(self.migration.down_postcondition())
79
111
80
    def __repr__(self):
112
    def __repr__(self): # pragma no cover
81
        return '<%s.%s %s>' % (
113
        return '<%s.%s %s>' % (
82
            self.migration.module,
114
            self.migration.module,
83
            self.migration.version,
115
            self.migration.version,
84
            self.direction)
116
            self.direction)
85
117
86
class Node(object):
118
class Node(object):
87
119
88
    def __init__(self, data):
120
    def __init__(self, state):
89
        self.data = data
121
        self.state = state
90
        self.visited = False
122
        self.visited = False
91
        self.distance = None
123
        self.distance = 1e9 # effectively inf
92
        self.pred = None
124
        self.pred = None # (state, migrationstep)
93
        self.succs = []
125
        self.succs = [] # list of (state, migrationstep)
94
126
95
    def visit(self):
127
    def visit(self, nodes):
128
        '''The 'visit' step of Dijkstra's shortest-path algorithm'''
96
        self.visited = True
129
        self.visited = True
130
        new_dist = self.distance + 1
97
        for succ in self.succs:
131
        for succ, ms in self.succs:
98
            if succ.visited: continue
132
            if succ.visited: continue
99
            if self < succ:
133
            if new_dist < succ.distance:
100
                succ.distance = self.distance + 1
134
                succ.distance = new_dist
101
                succ.pred = self
135
                succ.pred = (self, ms)
136
                nodes[succ] = new_dist
102
137
103
    def path(self):
138
    def path(self):
139
        '''Read back the shortest path from the 'predecessor' field'''
104
        if self.pred:
140
        if self.pred:
105
            for p in self.pred.path():
141
            for p in self.pred[0].path():
106
                yield p
142
                yield p
107
        yield self.data
143
            yield self.pred[1]
108
144
109
    def __lt__(self, other):
145
    def __repr__(self): # pragma no cover
110
        if self.distance is None:
111
            return False
112
        if other.distance is None:
113
            return True
114
        return self.distance < other.distance
115
116
    def __repr__(self):
117
        return '<Node %r (%s)>' % (self.data,self.distance)
146
        return '<Node %r (%s)>' % (self.state,self.distance)
118
147
119
def states_with(requirements, state_index):
148
# priority dictionary recipe copied from 
120
    states = None
149
# http://code.activestate.com/recipes/522995-priority-dict-a-priority-queue-with-updatable-prio/
121
    for (mod, ver) in requirements:
150
# We use this rather than the raw heap because the priority_dict allows us to
122
        if states is None: states = set(state_index[mod,ver])
151
# update the priority of a node, which heapq does not (natively) allow without
123
        else: states &= set(state_index[mod,ver])
152
# re-running heapify() each time a priority changes.  (And priorities change
124
    return states
153
# often in Dijkstra's algorithm.)
154
from heapq import heapify, heappush, heappop
125
155
156
class priority_dict(dict):
157
    """Dictionary that can be used as a priority queue.
126
158
159
    Keys of the dictionary are items to be put into the queue, and values
160
    are their respective priorities. All dictionary methods work as expected.
161
    The advantage over a standard heapq-based priority queue is
162
    that priorities of items can be efficiently updated (amortized O(1))
163
    using code as 'thedict[item] = new_priority.'
127
164
165
    The 'smallest' method can be used to return the object with lowest
166
    priority, and 'pop_smallest' also removes it.
167
168
    The 'sorted_iter' method provides a destructive sorted iterator.
169
    """
170
    
171
    def __init__(self, *args, **kwargs):
172
        super(priority_dict, self).__init__(*args, **kwargs)
173
        self._rebuild_heap()
174
175
    def _rebuild_heap(self):
176
        self._heap = [(v, k) for k, v in self.iteritems()]
177
        heapify(self._heap)
178
179
    def smallest(self):
180
        """Return the item with the lowest priority.
181
182
        Raises IndexError if the object is empty.
183
        """
184
        
185
        heap = self._heap
186
        v, k = heap[0]
187
        while k not in self or self[k] != v:
188
            heappop(heap)
189
            v, k = heap[0]
190
        return k
191
192
    def pop_smallest(self):
193
        """Return the item with the lowest priority and remove it.
194
195
        Raises IndexError if the object is empty.
196
        """
197
        
198
        heap = self._heap
199
        v, k = heappop(heap)
200
        while k not in self or self[k] != v:
201
            v, k = heappop(heap)
202
        del self[k]
203
        return k
204
205
    def __setitem__(self, key, val):
206
        # We are not going to remove the previous value from the heap,
207
        # since this would have a cost O(n).
208
        
209
        super(priority_dict, self).__setitem__(key, val)
210
        
211
        if len(self._heap) < 2 * len(self):
212
            heappush(self._heap, (val, key))
213
        else:
214
            # When the heap grows larger than 2 * len(self), we rebuild it
215
            # from scratch to avoid wasting too much memory.
216
            self._rebuild_heap()
217
218
    def setdefault(self, key, val):
219
        if key not in self:
220
            self[key] = val
221
            return val
222
        return self[key]
223
224
    def update(self, *args, **kwargs):
225
        # Reimplementing dict.update is tricky -- see e.g.
226
        # http://mail.python.org/pipermail/python-ideas/2007-May/000744.html
227
        # We just rebuild the heap from scratch after passing to super.
228
        
229
        super(priority_dict, self).update(*args, **kwargs)
230
        self._rebuild_heap()
231
232
    def sorted_iter(self):
233
        """Sorted iterator of the priority dictionary items.
234
235
        Beware: this will destroy elements as they are returned.
236
        """
237
        
238
        while self:
239
            yield self.pop_smallest()
240
# End recipe