Ver Fonte

Replace `RequestContext` kwarg `tenant` with `project_id`

Daniel Vincze há 4 anos atrás
pai
commit
5639974800
5 ficheiros alterados com 41 adições e 43 exclusões
  1. 3 3
      coriolis/api/middleware/auth.py
  2. 1 1
      coriolis/api/wsgi.py
  3. 4 4
      coriolis/context.py
  4. 32 34
      coriolis/db/api.py
  5. 1 1
      coriolis/keystone.py

+ 3 - 3
coriolis/api/middleware/auth.py

@@ -26,10 +26,10 @@ class CoriolisKeystoneContext(wsgi.Middleware):
         roles = [r.strip() for r in req.headers.get('X_ROLE', '').split(',')]
         if 'X_TENANT_ID' in req.headers:
             # This is the new header since Keystone went to ID/Name
-            tenant = req.headers['X_TENANT_ID']
+            project_id = req.headers['X_TENANT_ID']
         else:
             # This is for legacy compatibility
-            tenant = req.headers['X_TENANT']
+            project_id = req.headers['X_TENANT']
 
         project_name = req.headers.get('X_TENANT_NAME')
         project_domain_name = req.headers.get('X-Project-Domain-Name')
@@ -56,7 +56,7 @@ class CoriolisKeystoneContext(wsgi.Middleware):
                     explanation=_('Invalid service catalog json.'))
 
         ctx = context.RequestContext(user,
-                                     tenant,
+                                     project_id,
                                      project_name=project_name,
                                      project_domain_name=project_domain_name,
                                      user_domain_name=user_domain_name,

+ 1 - 1
coriolis/api/wsgi.py

@@ -907,7 +907,7 @@ class Resource(Application):
 
         project_id = action_args.pop("project_id", None)
         context = request.environ.get('coriolis.context')
-        if (context and project_id and (project_id != context.tenant)):
+        if (context and project_id and (project_id != context.project_id)):
             msg = _("Malformed request url")
             return Fault(webob.exc.HTTPBadRequest(explanation=msg))
 

+ 4 - 4
coriolis/context.py

@@ -13,7 +13,7 @@ from coriolis import policy
 
 @enginefacade.transaction_context_provider
 class RequestContext(context.RequestContext):
-    def __init__(self, user, tenant, is_admin=None,
+    def __init__(self, user, project_id, is_admin=None,
                  roles=None, project_name=None, remote_address=None,
                  timestamp=None, request_id=None, auth_token=None,
                  overwrite=True, domain_name=None, domain_id=None,
@@ -24,7 +24,7 @@ class RequestContext(context.RequestContext):
 
         super(RequestContext, self).__init__(auth_token=auth_token,
                                              user=user,
-                                             tenant=tenant,
+                                             project_id=project_id,
                                              domain_name=domain_name,
                                              domain_id=domain_id,
                                              user_domain_name=user_domain_name,
@@ -51,7 +51,7 @@ class RequestContext(context.RequestContext):
     def to_dict(self):
         result = super(RequestContext, self).to_dict()
         result['user'] = self.user
-        result['tenant'] = self.tenant
+        result['project_id'] = self.project_id
         result['project_name'] = self.project_name
         result['domain_id'] = self.domain_id
         result['domain_name'] = self.domain_name
@@ -102,5 +102,5 @@ class RequestContext(context.RequestContext):
 
 def get_admin_context(trust_id=None):
     return RequestContext(
-        user=None, tenant=None, is_admin=True,
+        user=None, project_id=None, is_admin=True,
         trust_id=trust_id)

+ 32 - 34
coriolis/db/api.py

@@ -105,7 +105,7 @@ def _get_replica_schedules_filter(context, replica_id=None,
     sched_filter = q.filter()
     if is_user_context(context):
         sched_filter = sched_filter.filter(
-            models.Replica.project_id == context.tenant)
+            models.Replica.project_id == context.project_id)
 
     if replica_id:
         sched_filter = sched_filter.filter(
@@ -141,7 +141,7 @@ def get_endpoints(context):
         orm.joinedload('mapped_regions'))
     if is_user_context(context):
         q = q.filter(
-            models.Endpoint.project_id == context.tenant)
+            models.Endpoint.project_id == context.project_id)
     return q.filter().all()
 
 
@@ -151,7 +151,7 @@ def get_endpoint(context, endpoint_id):
         orm.joinedload('mapped_regions'))
     if is_user_context(context):
         q = q.filter(
-            models.Endpoint.project_id == context.tenant)
+            models.Endpoint.project_id == context.project_id)
     return q.filter(
         models.Endpoint.id == endpoint_id).first()
 
@@ -159,7 +159,7 @@ def get_endpoint(context, endpoint_id):
 @enginefacade.writer
 def add_endpoint(context, endpoint):
     endpoint.user_id = context.user
-    endpoint.project_id = context.tenant
+    endpoint.project_id = context.project_id
     _session(context).add(endpoint)
 
 
@@ -169,14 +169,13 @@ def update_endpoint(context, endpoint_id, updated_values):
     if not endpoint:
         raise exception.NotFound("Endpoint with ID '%s' found" % endpoint_id)
 
-
     if not isinstance(updated_values, dict):
         raise exception.InvalidInput(
             "Update payload for endpoints must be a dict. Got the following "
             "(type: %s): %s" % (type(updated_values), updated_values))
 
     def _try_unmap_regions(region_ids):
-         for region_to_unmap in region_ids:
+        for region_to_unmap in region_ids:
             try:
                 LOG.debug(
                     "Attempting to unmap region '%s' from endpoint '%s'",
@@ -237,7 +236,6 @@ def update_endpoint(context, endpoint_id, updated_values):
             _try_unmap_regions(newly_mapped_regions)
             raise
 
-
     updateable_fields = ["name", "description", "connection_info"]
     try:
         _update_sqlalchemy_object_fields(
@@ -262,7 +260,7 @@ def delete_endpoint(context, endpoint_id):
     endpoint = get_endpoint(context, endpoint_id)
     args = {"id": endpoint_id}
     if is_user_context(context):
-        args["project_id"] = context.tenant
+        args["project_id"] = context.project_id
     count = _soft_delete_aware_query(context, models.Endpoint).filter_by(
         **args).soft_delete()
     if count == 0:
@@ -281,7 +279,7 @@ def get_replica_tasks_executions(context, replica_id, include_tasks=False):
     if include_tasks:
         q = _get_tasks_with_details_options(q)
     if is_user_context(context):
-        q = q.filter(models.Replica.project_id == context.tenant)
+        q = q.filter(models.Replica.project_id == context.project_id)
     return q.filter(
         models.Replica.id == replica_id).all()
 
@@ -292,7 +290,7 @@ def get_replica_tasks_execution(context, replica_id, execution_id):
         models.Replica)
     q = _get_tasks_with_details_options(q)
     if is_user_context(context):
-        q = q.filter(models.Replica.project_id == context.tenant)
+        q = q.filter(models.Replica.project_id == context.project_id)
     return q.filter(
         models.Replica.id == replica_id,
         models.TasksExecution.id == execution_id).first()
@@ -301,7 +299,7 @@ def get_replica_tasks_execution(context, replica_id, execution_id):
 @enginefacade.writer
 def add_replica_tasks_execution(context, execution):
     if is_user_context(context):
-        if execution.action.project_id != context.tenant:
+        if execution.action.project_id != context.project_id:
             raise exception.NotAuthorized()
 
     # include deleted records
@@ -309,7 +307,7 @@ def add_replica_tasks_execution(context, execution):
         context,
         func.max(
             models.TasksExecution.number)).filter(
-                models.TasksExecution.action_id==(
+                models.TasksExecution.action_id == (
                     execution.action.id)).first()[0] or 0
     execution.number = max_number + 1
 
@@ -322,7 +320,7 @@ def delete_replica_tasks_execution(context, execution_id):
         models.TasksExecution.id == execution_id)
     if is_user_context(context):
         if not q.join(models.Replica).filter(
-                models.Replica.project_id == context.tenant).first():
+                models.Replica.project_id == context.project_id).first():
             raise exception.NotAuthorized()
     count = q.soft_delete()
     if count == 0:
@@ -381,7 +379,7 @@ def delete_replica_schedule(context, replica_id,
             "No such schedule")
     if is_user_context(context):
         if not q.join(models.Replica).filter(
-                models.Replica.project_id == context.tenant).first():
+                models.Replica.project_id == context.project_id).first():
             raise exception.NotAuthorized()
     if pre_delete_callable:
         pre_delete_callable(context, schedule)
@@ -398,7 +396,7 @@ def add_replica_schedule(context, schedule, post_create_callable=None):
     # two-phase transactions or at least allow running these functions
     # inside a single transaction block.
 
-    if schedule.replica.project_id != context.tenant:
+    if schedule.replica.project_id != context.project_id:
         raise exception.NotAuthorized()
     _session(context).add(schedule)
     if post_create_callable:
@@ -422,7 +420,7 @@ def get_replicas(context,
     q = q.filter()
     if is_user_context(context):
         q = q.filter(
-            models.Replica.project_id == context.tenant)
+            models.Replica.project_id == context.project_id)
     db_result = q.all()
     if to_dict:
         return [
@@ -439,7 +437,7 @@ def get_replica(context, replica_id):
     q = _get_replica_with_tasks_executions_options(q)
     if is_user_context(context):
         q = q.filter(
-            models.Replica.project_id == context.tenant)
+            models.Replica.project_id == context.project_id)
     return q.filter(
         models.Replica.id == replica_id).first()
 
@@ -460,7 +458,7 @@ def get_endpoint_replicas_count(context, endpoint_id):
 @enginefacade.writer
 def add_replica(context, replica):
     replica.user_id = context.user
-    replica.project_id = context.tenant
+    replica.project_id = context.project_id
     _session(context).add(replica)
 
 
@@ -468,7 +466,7 @@ def add_replica(context, replica):
 def _delete_transfer_action(context, cls, id):
     args = {"base_id": id}
     if is_user_context(context):
-        args["project_id"] = context.tenant
+        args["project_id"] = context.project_id
     count = _soft_delete_aware_query(context, cls).filter_by(
         **args).soft_delete()
     if count == 0:
@@ -490,7 +488,7 @@ def get_replica_migrations(context, replica_id):
     q = q.options(orm.joinedload("executions"))
     if is_user_context(context):
         q = q.filter(
-            models.Migration.project_id == context.tenant)
+            models.Migration.project_id == context.project_id)
     return q.filter(
         models.Replica.id == replica_id).all()
 
@@ -508,7 +506,7 @@ def get_migrations(context, include_tasks=False,
 
     args = {}
     if is_user_context(context):
-        args["project_id"] = context.tenant
+        args["project_id"] = context.project_id
     result = q.filter_by(**args).all()
     if to_dict:
         return [i.to_dict(
@@ -544,14 +542,14 @@ def get_migration(context, migration_id):
     q = _get_migration_task_query_options(q)
     args = {"id": migration_id}
     if is_user_context(context):
-        args["project_id"] = context.tenant
+        args["project_id"] = context.project_id
     return q.filter_by(**args).first()
 
 
 @enginefacade.writer
 def add_migration(context, migration):
     migration.user_id = context.user
-    migration.project_id = context.tenant
+    migration.project_id = context.project_id
     _session(context).add(migration)
 
 
@@ -568,7 +566,7 @@ def set_execution_status(
             models.TasksExecution.action)
     if is_user_context(context):
         execution = execution.filter(
-            models.BaseTransferAction.project_id == context.tenant)
+            models.BaseTransferAction.project_id == context.project_id)
     execution = execution.filter(
         models.TasksExecution.id == execution_id).first()
     if not execution:
@@ -588,7 +586,7 @@ def get_action(context, action_id):
         context, models.BaseTransferAction)
     if is_user_context(context):
         action = action.filter(
-            models.BaseTransferAction.project_id == context.tenant)
+            models.BaseTransferAction.project_id == context.project_id)
     action = action.filter(
         models.BaseTransferAction.base_id == action_id).first()
     if not action:
@@ -672,7 +670,7 @@ def get_tasks_execution(context, execution_id):
     q = q.options(orm.joinedload("tasks"))
     if is_user_context(context):
         q = q.filter(
-            models.BaseTransferAction.project_id == context.tenant)
+            models.BaseTransferAction.project_id == context.project_id)
     execution = q.filter(
         models.TasksExecution.id == execution_id).first()
     if not execution:
@@ -961,6 +959,7 @@ def delete_region(context, region_id):
     for svc in region.mapped_services:
         delete_service_region_mapping(context, svc.id, region_id)
 
+
 @enginefacade.writer
 def add_endpoint_region_mapping(context, endpoint_region_mapping):
     region_id = endpoint_region_mapping.region_id
@@ -1052,7 +1051,7 @@ def find_service(context, host, binary, topic=None):
     if topic:
         args["topic"] = topic
     q = _soft_delete_aware_query(context, models.Service).options(
-         orm.joinedload('mapped_regions')).filter_by(**args)
+        orm.joinedload('mapped_regions')).filter_by(**args)
     return q.first()
 
 
@@ -1072,7 +1071,7 @@ def update_service(context, service_id, updated_values):
             "(type: %s): %s" % (type(updated_values), updated_values))
 
     def _try_unmap_regions(region_ids):
-         for region_to_unmap in region_ids:
+        for region_to_unmap in region_ids:
             try:
                 LOG.debug(
                     "Attempting to unmap region '%s' from service '%s'",
@@ -1133,7 +1132,6 @@ def update_service(context, service_id, updated_values):
             _try_unmap_regions(newly_mapped_regions)
             raise
 
-
     updateable_fields = ["enabled", "status", "providers", "specs"]
     try:
         _update_sqlalchemy_object_fields(
@@ -1232,7 +1230,7 @@ def get_mapped_services_for_region(context, region_id):
 @enginefacade.writer
 def add_minion_machine(context, minion_machine):
     minion_machine.user_id = context.user
-    minion_machine.project_id = context.tenant
+    minion_machine.project_id = context.project_id
     # inherit pool user/tenant if none are given:
     if None in [minion_machine.user_id, minion_machine.project_id]:
         pool = get_minion_pool(context, minion_machine.pool_id)
@@ -1333,7 +1331,7 @@ def delete_minion_machine(context, minion_machine_id):
 @enginefacade.writer
 def add_minion_pool(context, minion_pool):
     minion_pool.user_id = context.user
-    minion_pool.project_id = context.tenant
+    minion_pool.project_id = context.project_id
     _session(context).add(minion_pool)
 
 
@@ -1341,7 +1339,7 @@ def add_minion_pool(context, minion_pool):
 def delete_minion_pool(context, minion_pool_id):
     args = {"id": minion_pool_id}
     if is_user_context(context):
-        args["project_id"] = context.tenant
+        args["project_id"] = context.project_id
     count = _soft_delete_aware_query(context, models.MinionPool).filter_by(
         **args).soft_delete()
     if count == 0:
@@ -1361,7 +1359,7 @@ def get_minion_pool(
         q = q.options(orm.joinedload('progress_updates'))
     if is_user_context(context):
         q = q.filter(
-            models.MinionPool.project_id == context.tenant)
+            models.MinionPool.project_id == context.project_id)
     return q.filter(
         models.MinionPool.id == minion_pool_id).first()
 
@@ -1374,7 +1372,7 @@ def get_minion_pools(
     q = q.filter()
     if is_user_context(context):
         q = q.filter(
-            models.MinionPool.project_id == context.tenant)
+            models.MinionPool.project_id == context.project_id)
     if include_machines:
         q = q.options(orm.joinedload('minion_machines'))
     if include_events:

+ 1 - 1
coriolis/keystone.py

@@ -62,7 +62,7 @@ def create_trust(ctxt):
         raise exception.NotAuthorized("Trustee authentication failed")
 
     trustor_user_id = ctxt.user
-    trustor_proj_id = ctxt.tenant
+    trustor_proj_id = ctxt.project_id
     roles = ctxt.roles
 
     LOG.debug("Granting Keystone trust. Trustor: %(trustor_user_id)s, trustee:"