Ver Fonte

Made non-existent record handling consistent for AWS

Nuwan Goonasekera há 8 anos atrás
pai
commit
b0c423fb64
2 ficheiros alterados com 87 adições e 49 exclusões
  1. 86 48
      cloudbridge/cloud/providers/aws/services.py
  2. 1 1
      test/helpers/__init__.py

+ 86 - 48
cloudbridge/cloud/providers/aws/services.py

@@ -100,8 +100,13 @@ class AWSKeyPairService(BaseKeyPairService):
             kps = self.provider.ec2_conn.get_all_key_pairs(
                 keynames=[key_pair_id])
             return AWSKeyPair(self.provider, kps[0])
-        except EC2ResponseError:
-            return None
+        except EC2ResponseError as ec2e:
+            if ec2e.code == 'InvalidKeyPair.NotFound':
+                return None
+            elif ec2e.code == 'InvalidParameterValue':
+                return None
+            else:
+                raise ec2e
 
     def list(self, limit=None, marker=None):
         """
@@ -123,10 +128,15 @@ class AWSKeyPairService(BaseKeyPairService):
             key_pairs = [
                 AWSKeyPair(self.provider, kp) for kp in
                 self.provider.ec2_conn.get_all_key_pairs(keynames=[name])]
-        except EC2ResponseError:
-            key_pairs = []
-        return ClientPagedResultList(self.provider, key_pairs,
-                                     limit=limit, marker=marker)
+            return ClientPagedResultList(self.provider, key_pairs,
+                                         limit=limit, marker=marker)
+        except EC2ResponseError as ec2e:
+            if ec2e.code == 'InvalidKeyPair.NotFound':
+                return []
+            elif ec2e.code == 'InvalidParameterValue':
+                return []
+            else:
+                raise ec2e
 
     def create(self, name):
         """
@@ -157,8 +167,13 @@ class AWSSecurityGroupService(BaseSecurityGroupService):
             sgs = self.provider.ec2_conn.get_all_security_groups(
                 group_ids=[sg_id])
             return AWSSecurityGroup(self.provider, sgs[0]) if sgs else None
-        except EC2ResponseError:
-            return None
+        except EC2ResponseError as ec2e:
+            if ec2e.code == 'InvalidGroup.NotFound':
+                return None
+            elif ec2e.code == 'InvalidGroupId.Malformed':
+                return None
+            else:
+                raise ec2e
 
     def list(self, limit=None, marker=None):
         """
@@ -200,13 +215,12 @@ class AWSSecurityGroupService(BaseSecurityGroupService):
         """
         Get all security groups associated with your account.
         """
-        try:
-            flters = {'group-name': name}
-            security_groups = self.provider.ec2_conn.get_all_security_groups(
-                filters=flters)
-        except EC2ResponseError:
-            security_groups = []
-        return [AWSSecurityGroup(self.provider, sg) for sg in security_groups]
+        flters = {'group-name': name}
+        ec2_sgs = self.provider.ec2_conn.get_all_security_groups(
+            filters=flters)
+        sgs = [AWSSecurityGroup(self.provider, sg) for sg in ec2_sgs]
+        return ClientPagedResultList(self.provider, sgs,
+                                     limit=limit, marker=marker)
 
     def delete(self, group_id):
         """
@@ -221,16 +235,11 @@ class AWSSecurityGroupService(BaseSecurityGroupService):
                   been deleted by this method but instead has not existed in
                   the first place.
         """
-        try:
-            for sg in self.provider.ec2_conn.get_all_security_groups(
-                    group_ids=[group_id]):
-                try:
-                    sg.delete()
-                except EC2ResponseError:
-                    return False
-        except EC2ResponseError:
-            pass
-        return True
+        sg = self.get(group_id)
+        if sg:
+            sg.delete()
+            return True
+        return False
 
 
 class AWSBlockStoreService(BaseBlockStoreService):
@@ -260,8 +269,17 @@ class AWSVolumeService(BaseVolumeService):
         """
         Returns a volume given its id.
         """
-        vols = self.provider.ec2_conn.get_all_volumes(volume_ids=[volume_id])
-        return AWSVolume(self.provider, vols[0]) if vols else None
+        try:
+            vols = self.provider.ec2_conn.get_all_volumes(
+                volume_ids=[volume_id])
+            return AWSVolume(self.provider, vols[0]) if vols else None
+        except EC2ResponseError as ec2e:
+            if ec2e.code == 'InvalidVolume.NotFound':
+                return None
+            elif ec2e.code == 'InvalidParameterValue':
+                # Occurs if volume_id does not start with 'vol-...'
+                return None
+            raise ec2e
 
     def find(self, name, limit=None, marker=None):
         """
@@ -313,11 +331,14 @@ class AWSSnapshotService(BaseSnapshotService):
         try:
             snaps = self.provider.ec2_conn.get_all_snapshots(
                 snapshot_ids=[snapshot_id])
+            return AWSSnapshot(self.provider, snaps[0]) if snaps else None
         except EC2ResponseError as ec2e:
             if ec2e.code == 'InvalidSnapshot.NotFound':
                 return None
+            elif ec2e.code == 'InvalidParameterValue':
+                # Occurs if snapshot_id does not start with 'snap-...'
+                return None
             raise ec2e
-        return AWSSnapshot(self.provider, snaps[0]) if snaps else None
 
     def find(self, name, limit=None, marker=None):
         """
@@ -426,12 +447,14 @@ class AWSImageService(BaseImageService):
         """
         try:
             image = self.provider.ec2_conn.get_image(image_id)
-            if image:
-                return AWSMachineImage(self.provider, image)
-        except EC2ResponseError:
-            pass
-
-        return None
+            return AWSMachineImage(self.provider, image) if image else None
+        except EC2ResponseError as ec2e:
+            if ec2e.code == 'InvalidAMIID.NotFound':
+                return None
+            elif ec2e.code == 'InvalidAMIID.Malformed':
+                # Occurs if image_id does not start with 'ami-...'
+                return None
+            raise ec2e
 
     def find(self, name, limit=None, marker=None):
         """
@@ -613,12 +636,15 @@ class AWSInstanceService(BaseInstanceService):
         try:
             reservation = self.provider.ec2_conn.get_all_reservations(
                 instance_ids=[instance_id])
-        except EC2ResponseError:
-            return None
-        if reservation:
-            return AWSInstance(self.provider, reservation[0].instances[0])
-        else:
-            return None
+            return (AWSInstance(self.provider, reservation[0].instances[0])
+                    if reservation else None)
+        except EC2ResponseError as ec2e:
+            if ec2e.code == 'InvalidInstanceID.NotFound':
+                return None
+            elif ec2e.code == 'InvalidParameterValue':
+                # Occurs if id does not start with 'inst-...'
+                return None
+            raise ec2e
 
     def find(self, name, limit=None, marker=None):
         """
@@ -715,10 +741,16 @@ class AWSNetworkService(BaseNetworkService):
         self._subnet_svc = AWSSubnetService(self.provider)
 
     def get(self, network_id):
-        network = self.provider.vpc_conn.get_all_vpcs(vpc_ids=[network_id])
-        if network:
-            return AWSNetwork(self.provider, network[0])
-        return None
+        try:
+            network = self.provider.vpc_conn.get_all_vpcs(vpc_ids=[network_id])
+            return AWSNetwork(self.provider, network[0]) if network else None
+        except EC2ResponseError as ec2e:
+            if ec2e.code == 'InvalidVpcID.NotFound':
+                return None
+            elif ec2e.code == 'InvalidParameterValue':
+                # Occurs if id does not start with 'vpc-...'
+                return None
+            raise ec2e
 
     def list(self, limit=None, marker=None):
         networks = [AWSNetwork(self.provider, network)
@@ -781,10 +813,16 @@ class AWSSubnetService(BaseSubnetService):
         super(AWSSubnetService, self).__init__(provider)
 
     def get(self, subnet_id):
-        subnets = self.provider.vpc_conn.get_all_subnets([subnet_id])
-        if subnets:
-            return AWSSubnet(self.provider, subnets[0])
-        return None
+        try:
+            subnets = self.provider.vpc_conn.get_all_subnets([subnet_id])
+            return AWSSubnet(self.provider, subnets[0]) if subnets else None
+        except EC2ResponseError as ec2e:
+            if ec2e.code == 'InvalidSubnetID.NotFound':
+                return None
+            elif ec2e.code == 'InvalidParameterValue':
+                # Occurs if id does not start with 'subnet-...'
+                return None
+            raise ec2e
 
     def list(self, network=None, limit=None, marker=None):
         fltr = None

+ 1 - 1
test/helpers/__init__.py

@@ -169,7 +169,7 @@ def cleanup_test_resources(instance=None, network=None, security_group=None,
 
 
 def get_uuid():
-    return str(uuid.uuid4()).replace("-", "")
+    return str(uuid.uuid4())
 
 
 class ProviderTestBase(unittest.TestCase):