Switch to unified view

a b/pyforge/flyway/runner.py
1
import logging
2
3
from .model import MigrationInfo
4
from .migrate import Migration
5
6
log = logging.getLogger(__name__)
7
8
def run_migration(datastore, target_versions, dry_run):
9
    '''Attempt to migrate the database to a specific set of required
10
    modules  & versions.'''
11
    # Get the migration status of the db
12
    session = MigrationInfo.__mongometa__.session
13
    session.bind = datastore
14
    info = MigrationInfo.m.get()
15
    if info is None:
16
        info = MigrationInfo.make({})
17
    latest_versions = Migration.latest_versions()
18
    for k,v in target_versions.iteritems():
19
        cur = info.versions.get(k, -1)
20
        islatest = ' (LATEST)' if v == latest_versions[k] else ''
21
        log.info('Target %s=%s%s (current=%s)', k, v, islatest, cur)
22
    # Create a migration plan
23
    plan = list(plan_migration(session, info, target_versions))
24
    # Execute (or print) the plan
25
    for step in plan:
26
        log.info('Migrate %r', step)
27
        if dry_run: continue
28
        step.apply(info.versions)
29
        info.m.save()
30
31
def reset_migration(datastore, dry_run):
32
    '''Reset the state of the database to non-version-controlled WITHOUT migrating
33
34
    This is equivalent to setting all the versions to -1.'''
35
    session = MigrationInfo.__mongometa__.session
36
    session.bind = datastore
37
    log.info('Reset migrations')
38
    if not dry_run:
39
        MigrationInfo.m.remove()
40
41
class MigrationStep(object):
42
43
    def __init__(self, session, module, version, direction):
44
        self.module = module
45
        self.version = version
46
        self.direction = direction
47
        self.session = session
48
        self.migration = Migration.get(module, version)(session)
49
        if direction == 'up':
50
            self.requires = dict(self.migration.requires())
51
            self.postcondition = {module:version}
52
        else:
53
            self.requires = {module:version}
54
            self.postcondition = {module:version-1}
55
56
    def __repr__(self):
57
        return '<%s on %s.%s>' % (self.direction, self.module, self.version)
58
59
    def apply(self, state):
60
        state.update(self.postcondition)
61
        if self.direction == 'up':
62
            self.migration.up()
63
        else:
64
            self.migration.down()
65
66
    def unmet_requirements(self, state):
67
        result = {}
68
        for k,v in self.requires.iteritems():
69
            if state.get(k, -1) != v: result[k] = v
70
        return result
71
72
    def precluded_by(self, other_step):
73
        for k,v in self.requires.iteritems():
74
            if other_step.postcondition.get(k, v) != v:
75
                return True
76
        return False
77
78
    def add_requirements(self, steps, req):
79
        if self.direction == 'down':
80
            mod,ver = self.module, self.version-1
81
            if (mod,ver) in steps: return
82
            if req[mod] == ver: return
83
            step = MigrationStep(self.session, mod, ver, 'down')
84
            steps[mod,ver] = step
85
            step.add_requirements(steps, req)
86
        for mod, ver in self.requires.iteritems():
87
            if (mod,ver) in steps: continue
88
            if ver != -1:
89
                step = MigrationStep(self.session, mod, ver, self.direction)
90
                steps[mod,ver] = step
91
                step.add_requirements(steps, req)
92
93
def plan_migration(session, info, target_versions):
94
    '''Create a migration plan based on the current DB state and the
95
    target version set'''
96
    # Determine all the (final) migrations that need to be run
97
    steps = {}
98
    for mod,req_ver in target_versions.iteritems():
99
        cur_ver = info.versions.get(mod, -1)
100
        if cur_ver < req_ver:
101
            steps[mod,req_ver] = MigrationStep(session, mod, req_ver, 'up')
102
        elif cur_ver > req_ver:
103
            steps[mod,cur_ver] = MigrationStep(session, mod, cur_ver, 'down')
104
    # Add the dependencies of all the migrations
105
    current = dict(info.versions)
106
    for step in steps.values():
107
        step.add_requirements(steps, target_versions)
108
    # Schedule migrations to be run
109
    steps = sorted(steps.values(), key=lambda s:(s.version, s.module))
110
    log.debug('Migrations to be run: %r', steps)
111
    while steps:
112
        step = _pop_step(steps, current)
113
        log.info('State %s, step %s', current, step)
114
        yield step
115
        current.update(step.postcondition)
116
117
def _pop_step(steps, current):
118
    '''This method looks at all the available migration steps and the current
119
    current versioning state and chooses a migration step to run next, removing it
120
    from the list of available migration steps and returning it.
121
    '''
122
    # Find all "valid" migrations, i.e. migrations whose requirements() are met
123
    valid = []
124
    invalid = []
125
    for s in steps:
126
        if s.unmet_requirements(current):
127
            invalid.append(s)
128
        else:
129
            valid.append(s)
130
    # If there's only one valid migration, then it's the next one we'll run
131
    if len(valid) == 1:
132
        steps[:] = invalid
133
        return valid[0]
134
    # Find a migration that does not preclude other valid migrations
135
    #   from running
136
    for i, step_a in enumerate(valid):
137
        for j, step_b in enumerate(valid):
138
            if i == j: continue # don't check against self
139
            if step_b.precluded_by(step_a): break # conflict, step_a is not next
140
        else:
141
            # No conflicts found, step_a is the next step
142
            # Remove step_a from the list of steps
143
            steps[:] = [ s for s in steps if s is not step_a ]
144
            return step_a
145
    # No next step found, could be circular dependency.  Log the error and raise
146
    # a ValueError
147
    log.error('Cannot find valid step at state %s', current)
148
    for v in valid:
149
        log.error('  %r', v)
150
    raise ValueError, "Plan stuck"