Ver código fonte

- added initial implementation for the find method.

madhugilla 9 anos atrás
pai
commit
91faaccdf0

+ 15 - 6
cloudbridge/cloud/providers/azure/azure_client.py

@@ -77,8 +77,11 @@ class AzureClient(object):
     def list_locations(self):
         return self.subscription_client.subscriptions.list_locations(self.subscription_id)
 
-    def list_security_group(self):
-        return self.network_management_client.network_security_groups.list(self.resource_group_name)
+    def list_security_group(self, filters=None):
+        security_groups = FilterList(
+            self.network_management_client.network_security_groups.list(self.resource_group_name))
+        security_groups.filter(filters)
+        return security_groups
 
     def create_security_group(self, name, parameters):
         sg_create = self.network_management_client.network_security_groups.create_or_update(self.resource_group_name,
@@ -118,7 +121,7 @@ class AzureClient(object):
     def delete_container(self, container_name):
         self.blob_service.delete_container(container_name)
         return None
-    
+
     def list_blobs(self, container_name):
         return self.blob_service.list_blobs(container_name)
 
@@ -168,7 +171,7 @@ class AzureClient(object):
                 'location': region or self.region_name,
                 'creation_data': {
                     'create_option': 'copy',
-                    'source_uri':snapshot_id
+                    'source_uri': snapshot_id
                 }
             }
         )
@@ -179,11 +182,17 @@ class AzureClient(object):
         return self.compute_client.disks.get(self.resource_group_name, disk_name)
 
 
+# TODO: find out a better way.
 class FilterList(list):
     def filter(self, filters):
+        filtered_list = []
         if filters:
             for obj in self:
                 for key in filters:
-                    print('original value' + str(getattr(obj, key)) + 'key value' + filters[key])
+                    print('original value ' + str(getattr(obj, key)) + ' key value ' + filters[key])
                     if filters[key] not in str(getattr(obj, key)):
-                        self.remove(obj)
+                        print("removing " + str(getattr(obj, key)))
+                        filtered_list.append(obj)
+                        # self.remove(obj)
+            for s in filtered_list:
+                self.remove(s)

+ 1 - 4
cloudbridge/cloud/providers/azure/mock_azure_client.py

@@ -101,11 +101,8 @@ class MockAzureClient:
         self.security_groups.append(sg_create)
         return sg_create
 
-    # def list_security_group(self):
-    #     return self.security_groups
-
     def list_security_group(self, filters=None):
-        security_groups= FilterList(self.security_groups)
+        security_groups = FilterList(self.security_groups)
         security_groups.filter(filters)
         return security_groups
 

+ 14 - 10
cloudbridge/cloud/providers/azure/services.py

@@ -62,7 +62,7 @@ class AzureSecurityGroupService(BaseSecurityGroupService):
 
     def list(self, limit=None, marker=None):
         sgs = [AzureSecurityGroup(self.provider, sg)
-                                  for sg in self.provider.azure_client.list_security_group()]
+               for sg in self.provider.azure_client.list_security_group()]
         return ClientPagedResultList(self.provider, sgs, limit, marker)
 
     def create(self, name, description, network_id):
@@ -72,9 +72,14 @@ class AzureSecurityGroupService(BaseSecurityGroupService):
             return AzureSecurityGroup(self.provider, sg)
         return None
 
-    def find(self, name, limit=None, marker=None):
-        raise NotImplementedError(
-            "AzureSecurityGroupService does not implement this method")
+    def find(self, name: object, limit: object = None, marker: object = None) -> object:
+        """
+        Searches for a security group by a given list of attributes.
+        """
+        security_groups = [AzureSecurityGroup(self.provider, security_group)
+                           for security_group in self.provider.azure_client.list_security_group({'name': name})]
+        return ClientPagedResultList(self.provider, security_groups,
+                                     limit=limit, marker=marker)
 
     def delete(self, group_id):
         params = TemplateUrlParser.parse(NETWORK_SECURITY_GROUP_RESOURCE_ID, group_id)
@@ -99,12 +104,12 @@ class AzureObjectStoreService(BaseObjectStoreService):
         except AzureMissingResourceHttpError:
             return None
 
-    def find(self, name, limit=None, marker=None):
+    def find(self, name: object, limit: object = None, marker: object = None) -> object:
         """
         Searches for a bucket by a given list of attributes.
         """
         buckets = [AzureBucket(self.provider, bucket)
-                   for bucket in self.provider.azure_client.list_containers({'name':name})]
+                   for bucket in self.provider.azure_client.list_containers({'name': name})]
         return ClientPagedResultList(self.provider, buckets,
                                      limit=limit, marker=marker)
 
@@ -142,7 +147,6 @@ class AzureBlockStoreService(BaseBlockStoreService):
 
 
 class AzureVolumeService(BaseVolumeService):
-
     def __init__(self, provider):
         super(AzureVolumeService, self).__init__(provider)
 
@@ -150,7 +154,7 @@ class AzureVolumeService(BaseVolumeService):
         volume = self.provider.azure_client.get_disk(volume_id)
         return AzureVolume(self.provider, volume)
 
-    def find(self, name, limit=None, marker=None):
+    def find(self, name: object, limit: object = None, marker: object = None) -> object:
         raise NotImplementedError('AzureVolumeService not imeplemented this method')
 
     def list(self, limit=None, marker=None):
@@ -159,9 +163,9 @@ class AzureVolumeService(BaseVolumeService):
     def create(self, name, size, zone=None, snapshot=None, description=None):
         zone_id = zone.id if isinstance(zone, PlacementZone) else zone
         snapshot_id = snapshot.id if isinstance(snapshot, Snapshot) and snapshot else snapshot
-        azure_vol =  self.provider.azure_client.create_empty_disk(name, size, zone_id, snapshot_id)
+        azure_vol = self.provider.azure_client.create_empty_disk(name, size, zone_id, snapshot_id)
 
-        cb_vol=AzureVolume(self.provider, azure_vol)
+        cb_vol = AzureVolume(self.provider, azure_vol)
         if description:
             cb_vol.description = description
 

+ 18 - 7
test/test_azure_security_service.py

@@ -24,9 +24,16 @@ class AzureSecurityServiceTestCase(ProviderTestBase):
         self.assertEqual(name, sg.name)
 
     @helpers.skipIfNoService(['security.security_groups'])
-    def test_azure_security_group_find(self):
-        with self.assertRaises(NotImplementedError):
-            sgs = self.security_groups.find("mygroup")
+    def test_azure_security_group_find_exists(self):
+        sgl = self.provider.security.security_groups.find("sg")
+        for sg in sgl:
+            self.assertTrue("sg" in sg.name)
+        self.assertTrue(sgl.total_results > 1)
+
+    @helpers.skipIfNoService(['security.security_groups'])
+    def test_azure_security_group_find_not_exists(self):
+        sgl = self.provider.security.security_groups.find('dontfindme')
+        self.assertTrue(sgl.total_results == 0)
 
     @helpers.skipIfNoService(['security.security_groups'])
     def test_azure_security_group_list(self):
@@ -40,7 +47,8 @@ class AzureSecurityServiceTestCase(ProviderTestBase):
 
     @helpers.skipIfNoService(['security.security_groups'])
     def test_azure_security_group_get_found(self):
-        sgl = self.security_groups.get("/subscriptions/7904d702-e01c-4826-8519-f5a25c866a96/resourceGroups/CloudBridge-Azure/providers/Microsoft.Network/networkSecurityGroups/sg3")
+        sgl = self.security_groups.get(
+            "/subscriptions/7904d702-e01c-4826-8519-f5a25c866a96/resourceGroups/CloudBridge-Azure/providers/Microsoft.Network/networkSecurityGroups/sg3")
         print("Get ( " + "Name - " + sgl.name + "  Id - " + sgl.id + " )")
         self.assertTrue(
             sgl.name == "sg3",
@@ -48,7 +56,8 @@ class AzureSecurityServiceTestCase(ProviderTestBase):
 
     @helpers.skipIfNoService(['security.security_groups'])
     def test_azure_security_group_get_not_found(self):
-        sgl = self.security_groups.get("/subscriptions/7904d702-e01c-4826-8519-f5a25c866a96/resourceGroups/CloudBridge-Azure/providers/Microsoft.Network/networkSecurityGroups/sg4")
+        sgl = self.security_groups.get(
+            "/subscriptions/7904d702-e01c-4826-8519-f5a25c866a96/resourceGroups/CloudBridge-Azure/providers/Microsoft.Network/networkSecurityGroups/sg4")
         print(str(sgl))
         self.assertTrue(
             sgl == None,
@@ -56,13 +65,15 @@ class AzureSecurityServiceTestCase(ProviderTestBase):
 
     @helpers.skipIfNoService(['security.security_groups'])
     def test_azure_security_group_delete_IdExists(self):
-        sg = self.security_groups.delete("/subscriptions/7904d702-e01c-4826-8519-f5a25c866a96/resourceGroups/CloudBridge-Azure/providers/Microsoft.Network/networkSecurityGroups/sg2")
+        sg = self.security_groups.delete(
+            "/subscriptions/7904d702-e01c-4826-8519-f5a25c866a96/resourceGroups/CloudBridge-Azure/providers/Microsoft.Network/networkSecurityGroups/sg2")
         print("Delete - ")
         self.assertEqual(sg, True)
 
     @helpers.skipIfNoService(['security.security_groups'])
     def test_azure_security_group_delete_IdNotExist(self):
-        sg = self.security_groups.delete("/subscriptions/7904d702-e01c-4826-8519-f5a25c866a96/resourceGroups/CloudBridge-Azure/providers/Microsoft.Network/networkSecurityGroups/sg5")
+        sg = self.security_groups.delete(
+            "/subscriptions/7904d702-e01c-4826-8519-f5a25c866a96/resourceGroups/CloudBridge-Azure/providers/Microsoft.Network/networkSecurityGroups/sg5")
         self.assertEqual(sg, False)
 
     @helpers.skipIfNoService(['security.security_groups'])