Просмотр исходного кода

Removed forward slash from aws zone ids and relaxed test

Nuwan Goonasekera 6 лет назад
Родитель
Сommit
2e8f1d79c3

+ 13 - 0
cloudbridge/providers/aws/resources.py

@@ -1152,8 +1152,21 @@ class AWSDnsZone(BaseDnsZone):
 
 
     @property
     @property
     def id(self):
     def id(self):
+        # The ID contains a slash, do not allow this
+        return self.escape_zone_id(self.aws_id)
+
+    @property
+    def aws_id(self):
         return self._dns_zone.get('Id')
         return self._dns_zone.get('Id')
 
 
+    @staticmethod
+    def escape_zone_id(value):
+        return value.replace("/", "-") if value else None
+
+    @staticmethod
+    def unescape_zone_id(value):
+        return value.replace("-", "/") if value else None
+
     @property
     @property
     def name(self):
     def name(self):
         return self._dns_zone.get('Name')
         return self._dns_zone.get('Name')

+ 7 - 6
cloudbridge/providers/aws/services.py

@@ -1338,7 +1338,8 @@ class AWSDnsZoneService(BaseDnsZoneService):
               priority=BaseDnsZoneService.STANDARD_EVENT_PRIORITY)
               priority=BaseDnsZoneService.STANDARD_EVENT_PRIORITY)
     def get(self, dns_zone_id):
     def get(self, dns_zone_id):
         try:
         try:
-            dns_zone = self.provider.dns.client.get_hosted_zone(Id=dns_zone_id)
+            dns_zone = self.provider.dns.client.get_hosted_zone(
+                Id=AWSDnsZone.unescape_zone_id(dns_zone_id))
             return AWSDnsZone(self.provider, dns_zone.get('HostedZone'))
             return AWSDnsZone(self.provider, dns_zone.get('HostedZone'))
         except self.provider.dns.client.exceptions.NoSuchHostedZone:
         except self.provider.dns.client.exceptions.NoSuchHostedZone:
             return None
             return None
@@ -1382,7 +1383,7 @@ class AWSDnsZoneService(BaseDnsZoneService):
         dns_zone = (dns_zone if isinstance(dns_zone, AWSDnsZone)
         dns_zone = (dns_zone if isinstance(dns_zone, AWSDnsZone)
                     else self.get(dns_zone))
                     else self.get(dns_zone))
         if dns_zone:
         if dns_zone:
-            self.provider.dns.client.delete_hosted_zone(Id=dns_zone.id)
+            self.provider.dns.client.delete_hosted_zone(Id=dns_zone.aws_id)
 
 
 
 
 class AWSDnsRecordService(BaseDnsRecordService):
 class AWSDnsRecordService(BaseDnsRecordService):
@@ -1395,7 +1396,7 @@ class AWSDnsRecordService(BaseDnsRecordService):
             if rec_id and ":" in rec_id:
             if rec_id and ":" in rec_id:
                 rec_name, rec_type = rec_id.split(":")
                 rec_name, rec_type = rec_id.split(":")
                 response = self.provider.dns.client.list_resource_record_sets(
                 response = self.provider.dns.client.list_resource_record_sets(
-                    HostedZoneId=dns_zone.id,
+                    HostedZoneId=dns_zone.aws_id,
                     StartRecordName=rec_name,
                     StartRecordName=rec_name,
                     StartRecordType=rec_type)
                     StartRecordType=rec_type)
                 return AWSDnsRecord(self.provider, dns_zone,
                 return AWSDnsRecord(self.provider, dns_zone,
@@ -1414,7 +1415,7 @@ class AWSDnsRecordService(BaseDnsRecordService):
     def list(self, dns_zone, limit=None, marker=None):
     def list(self, dns_zone, limit=None, marker=None):
         response = self.provider.dns.client.list_resource_record_sets(
         response = self.provider.dns.client.list_resource_record_sets(
             **trim_empty_params({
             **trim_empty_params({
-                'HostedZoneId': dns_zone.id,
+                'HostedZoneId': dns_zone.aws_id,
                 'MaxItems': limit,
                 'MaxItems': limit,
                 'StartRecordIdentifier': marker
                 'StartRecordIdentifier': marker
             })
             })
@@ -1444,7 +1445,7 @@ class AWSDnsRecordService(BaseDnsRecordService):
         AWSDnsRecord.assert_valid_resource_name(name)
         AWSDnsRecord.assert_valid_resource_name(name)
 
 
         response = self.provider.dns.client.change_resource_record_sets(
         response = self.provider.dns.client.change_resource_record_sets(
-            HostedZoneId=dns_zone.id,
+            HostedZoneId=dns_zone.aws_id,
             ChangeBatch={
             ChangeBatch={
                 'Changes': [{
                 'Changes': [{
                     'Action': 'CREATE',
                     'Action': 'CREATE',
@@ -1468,7 +1469,7 @@ class AWSDnsRecordService(BaseDnsRecordService):
 
 
         rec_name, rec_type = rec_id.split(":")
         rec_name, rec_type = rec_id.split(":")
         response = self.provider.dns.client.change_resource_record_sets(
         response = self.provider.dns.client.change_resource_record_sets(
-            HostedZoneId=dns_zone.id,
+            HostedZoneId=dns_zone.aws_id,
             ChangeBatch={
             ChangeBatch={
                 'Changes': [{
                 'Changes': [{
                     'Action': 'DELETE',
                     'Action': 'DELETE',

+ 4 - 7
tests/helpers/standard_interface_tests.py

@@ -7,8 +7,6 @@ This includes:
 """
 """
 import uuid
 import uuid
 
 
-from six.moves.urllib.parse import quote_plus
-
 import tenacity
 import tenacity
 
 
 from cloudbridge.base import helpers as cb_helpers
 from cloudbridge.base import helpers as cb_helpers
@@ -149,11 +147,10 @@ def check_obj_id(test, obj):
     id_property = getattr(type(obj), 'id', None)
     id_property = getattr(type(obj), 'id', None)
     test.assertIsInstance(id_property, property)
     test.assertIsInstance(id_property, property)
     test.assertIsNone(id_property.fset, "Id should not have a setter")
     test.assertIsNone(id_property.fset, "Id should not have a setter")
-    # Non-url safe characters trip up djcloudbridge or anything that needs to
-    # use the ID in a url so make sure ids do not contain them
-    test.assertEqual(quote_plus(obj.id), obj.id,
-                     "IDs should only contain URL friendly chars that do not "
-                     "require encoding but contains: %s" % (obj.id,))
+    # Some delimiter characters can trip up djcloudbridge url reversing
+    # so make sure ids do not contain them
+    test.assertTrue("/" not in obj.id,
+                    "IDs should not contain slash but is: %s" % (obj.id,))
 
 
 
 
 def check_obj_name(test, obj):
 def check_obj_name(test, obj):