Switch to side-by-side view

--- a/Allura/allura/model/auth.py
+++ b/Allura/allura/model/auth.py
@@ -294,8 +294,8 @@
                 yield project
 
     def role_iter(self):
-        anon_role = ProjectRole.query.get(name='*anonymous')
-        auth_role = ProjectRole.query.get(name='*authenticated')
+        anon_role = ProjectRole.anonymous()
+        auth_role = ProjectRole.authenticated()
         if anon_role:
             yield anon_role
         if self._id and auth_role:
@@ -306,20 +306,11 @@
                 yield role
 
     def project_role(self, project=None):
-        if project is None: project = c.project
+        if project is None: project = c.project.root_project
         with h.push_config(c, project=project, user=self):
             if self._id is None:
-                return ProjectRole.query.get(name='*anonymous')
-            pr = ProjectRole.query.get(user_id=self._id)
-            if pr: return pr
-            try:
-                obj = ProjectRole(user_id=self._id)
-                session(obj).insert_now(obj, state(obj))
-                self.projects.append(c.project._id)
-                return obj
-            except pymongo.errors.DuplicateKeyError:
-                session(obj).expunge(obj)
-                return ProjectRole.query.get(user_id=self._id)
+                return ProjectRole.anonymous(project)
+            return ProjectRole.upsert(user_id=self._id, project_id=project._id)
 
     def set_password(self, new_password):
         return plugin.AuthenticationProvider.get(request).set_password(
@@ -333,12 +324,17 @@
     class __mongometa__:
         session = project_orm_session
         name='user'
-        unique_indexes = [ ('name', 'user_id') ]
+        unique_indexes = [ ('user_id', 'project_id', 'name') ]
     
     _id = FieldProperty(S.ObjectId)
+    user_id = FieldProperty(S.ObjectId, if_missing=None) # if role is a user
+    project_id = FieldProperty(S.ObjectId, if_missing=lambda:c.project._id)
     name = FieldProperty(str)
-    user_id = FieldProperty(S.ObjectId, if_missing=None) # if role is a user
     roles = FieldProperty([S.ObjectId])
+
+    def __init__(self, **kw):
+        assert 'project_id' in kw, 'Project roles must specify a project id'
+        super(ProjectRole, self).__init__(**kw)
 
     def display(self):
         if self.name: return self.name
@@ -351,7 +347,30 @@
         return '**unknown name role: %s' % self._id # pragma no cover
 
     @classmethod
+    def by_user(cls, user=None, project=None):
+        if user is None: user = c.user
+        if project is None: project = c.project.root_project
+        return cls.query.get(user_id=user._id, project_id=project._id)
+
+    @classmethod
+    def by_name(cls, name, project=None):
+        if project is None: project = c.project.root_project
+        return cls.query.get(name=name, project_id=project._id)
+
+    @classmethod
+    def anonymous(cls, project=None):
+        if project is None: project = c.project.root_project
+        return cls.by_name('*anonymous', project)
+
+    @classmethod
+    def authenticated(cls, project=None):
+        if project is None: project = c.project.root_project
+        return cls.by_name('*authenticated', project)
+
+    @classmethod
     def upsert(cls, **kw):
+        obj = cls.query.get(**kw)
+        if obj is not None: return obj
         try:
             obj = cls(**kw)
             session(obj).insert_now(obj, state(obj))
@@ -370,19 +389,35 @@
 
     @property
     def user(self):
+        if self.user_id is None: return None
         return User.query.get(_id=self.user_id)
 
+    @classmethod
+    def roles_reachable_from(cls, *roots):
+        to_visit = list(roots)
+        visited = set()
+        while to_visit:
+            pr = to_visit.pop(0)
+            if pr in visited: continue
+            visited.add(pr)
+            yield pr
+            to_visit += cls.query.find(dict(_id={'$in':pr.roles})).all()
+
+    @classmethod
+    def roles_that_reach(cls, *roots):
+        to_visit = list(roots)
+        visited = set()
+        while to_visit:
+            pr = to_visit.pop(0)
+            if pr in visited: continue
+            visited.add(pr)
+            yield pr
+            to_visit += cls.query.find(dict(roles=pr._id)).all()
+
     def users_with_role(self):
-        return [pr.user for pr in ProjectRole.query.find({'roles':self._id}).all() if pr.user_id]
-
-    def role_iter(self, visited=None):
-        if visited is None: visited = set()
-        if self._id not in visited: 
-            yield self
-            visited.add(self._id)
-            for rid in self.roles:
-                pr = ProjectRole.query.get(_id=rid)
-                if pr is None: continue
-                for rr in pr.role_iter(visited):
-                    yield rr
-
+        return [
+            role.user for role in self.roles_that_reach(self) if role.user_id ]
+
+    def role_iter(self):
+        return self.roles_reachable_from(self)
+