merge with 1305

This commit is contained in:
vladimir.p
2011-07-22 16:05:14 -07:00
5 changed files with 102 additions and 31 deletions

View File

@@ -315,9 +315,9 @@ def migration_get(context, migration_id):
return IMPL.migration_get(context, migration_id)
def migration_get_by_instance_and_status(context, instance_id, status):
"""Finds a migration by the instance id its migrating."""
return IMPL.migration_get_by_instance_and_status(context, instance_id,
def migration_get_by_instance_and_status(context, instance_uuid, status):
"""Finds a migration by the instance uuid its migrating."""
return IMPL.migration_get_by_instance_and_status(context, instance_uuid,
status)
@@ -333,13 +333,14 @@ def fixed_ip_associate(context, address, instance_id):
return IMPL.fixed_ip_associate(context, address, instance_id)
def fixed_ip_associate_pool(context, network_id, instance_id):
"""Find free ip in network and associate it to instance.
def fixed_ip_associate_pool(context, network_id, instance_id=None, host=None):
"""Find free ip in network and associate it to instance or host.
Raises if one is not available.
"""
return IMPL.fixed_ip_associate_pool(context, network_id, instance_id)
return IMPL.fixed_ip_associate_pool(context, network_id,
instance_id, host)
def fixed_ip_create(context, values):
@@ -362,9 +363,9 @@ def fixed_ip_get_all(context):
return IMPL.fixed_ip_get_all(context)
def fixed_ip_get_all_by_host(context, host):
"""Get all defined fixed ips used by a host."""
return IMPL.fixed_ip_get_all_by_host(context, host)
def fixed_ip_get_all_by_instance_host(context, host):
"""Get all allocated fixed ips filtered by instance host."""
return IMPL.fixed_ip_get_all_instance_by_host(context, host)
def fixed_ip_get_by_address(context, address):
@@ -377,6 +378,11 @@ def fixed_ip_get_by_instance(context, instance_id):
return IMPL.fixed_ip_get_by_instance(context, instance_id)
def fixed_ip_get_by_network_host(context, network_id, host):
"""Get fixed ip for a host in a network."""
return IMPL.fixed_ip_get_by_network_host(context, network_id, host)
def fixed_ip_get_by_virtual_interface(context, vif_id):
"""Get fixed ips by virtual interface or raise if none exist."""
return IMPL.fixed_ip_get_by_virtual_interface(context, vif_id)
@@ -1012,10 +1018,16 @@ def block_device_mapping_create(context, values):
def block_device_mapping_update(context, bdm_id, values):
"""Create an entry of block device mapping"""
"""Update an entry of block device mapping"""
return IMPL.block_device_mapping_update(context, bdm_id, values)
def block_device_mapping_update_or_create(context, values):
"""Update an entry of block device mapping.
If not existed, create a new entry"""
return IMPL.block_device_mapping_update_or_create(context, values)
def block_device_mapping_get_all_by_instance(context, instance_id):
"""Get all block device mapping belonging to a instance"""
return IMPL.block_device_mapping_get_all_by_instance(context, instance_id)
@@ -1322,9 +1334,9 @@ def instance_type_get_all(context, inactive=False):
return IMPL.instance_type_get_all(context, inactive)
def instance_type_get_by_id(context, id):
def instance_type_get(context, id):
"""Get instance type by id."""
return IMPL.instance_type_get_by_id(context, id)
return IMPL.instance_type_get(context, id)
def instance_type_get_by_name(context, name):

View File

@@ -18,7 +18,6 @@
"""
Implementation of SQLAlchemy backend.
"""
import traceback
import warnings
from nova import db
@@ -33,7 +32,6 @@ from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import joinedload_all
from sqlalchemy.sql import exists
from sqlalchemy.sql import func
from sqlalchemy.sql.expression import literal_column
@@ -672,7 +670,7 @@ def fixed_ip_associate(context, address, instance_id):
@require_admin_context
def fixed_ip_associate_pool(context, network_id, instance_id):
def fixed_ip_associate_pool(context, network_id, instance_id=None, host=None):
session = get_session()
with session.begin():
network_or_none = or_(models.FixedIp.network_id == network_id,
@@ -682,6 +680,7 @@ def fixed_ip_associate_pool(context, network_id, instance_id):
filter_by(reserved=False).\
filter_by(deleted=False).\
filter_by(instance=None).\
filter_by(host=None).\
with_lockmode('update').\
first()
# NOTE(vish): if with_lockmode isn't supported, as in sqlite,
@@ -692,9 +691,12 @@ def fixed_ip_associate_pool(context, network_id, instance_id):
fixed_ip_ref.network = network_get(context,
network_id,
session=session)
fixed_ip_ref.instance = instance_get(context,
instance_id,
session=session)
if instance_id:
fixed_ip_ref.instance = instance_get(context,
instance_id,
session=session)
if host:
fixed_ip_ref.host = host
session.add(fixed_ip_ref)
return fixed_ip_ref['address']
@@ -750,7 +752,7 @@ def fixed_ip_get_all(context, session=None):
@require_admin_context
def fixed_ip_get_all_by_host(context, host=None):
def fixed_ip_get_all_by_instance_host(context, host=None):
session = get_session()
result = session.query(models.FixedIp).\
@@ -799,6 +801,20 @@ def fixed_ip_get_by_instance(context, instance_id):
return rv
@require_context
def fixed_ip_get_by_network_host(context, network_id, host):
session = get_session()
rv = session.query(models.FixedIp).\
filter_by(network_id=network_id).\
filter_by(host=host).\
filter_by(deleted=False).\
first()
if not rv:
raise exception.FixedIpNotFoundForNetworkHost(network_id=network_id,
host=host)
return rv
@require_context
def fixed_ip_get_by_virtual_interface(context, vif_id):
session = get_session()
@@ -1362,7 +1378,11 @@ def instance_update(context, instance_id, values):
instance_metadata_update_or_create(context, instance_id,
values.pop('metadata'))
with session.begin():
instance_ref = instance_get(context, instance_id, session=session)
if utils.is_uuid_like(instance_id):
instance_ref = instance_get_by_uuid(context, instance_id,
session=session)
else:
instance_ref = instance_get(context, instance_id, session=session)
instance_ref.update(values)
instance_ref.save(session=session)
return instance_ref
@@ -1509,8 +1529,6 @@ def network_associate(context, project_id, force=False):
called by project_get_networks under certain conditions
and network manager add_network_to_project()
only associates projects with networks that have configured hosts
only associate if the project doesn't already have a network
or if force is True
@@ -1526,7 +1544,6 @@ def network_associate(context, project_id, force=False):
def network_query(project_filter):
return session.query(models.Network).\
filter_by(deleted=False).\
filter(models.Network.host != None).\
filter_by(project_id=project_filter).\
with_lockmode('update').\
first()
@@ -1733,9 +1750,16 @@ def network_get_all_by_instance(_context, instance_id):
def network_get_all_by_host(context, host):
session = get_session()
with session.begin():
# NOTE(vish): return networks that have host set
# or that have a fixed ip with host set
host_filter = or_(models.Network.host == host,
models.FixedIp.host == host)
return session.query(models.Network).\
filter_by(deleted=False).\
filter_by(host=host).\
join(models.Network.fixed_ips).\
filter(host_filter).\
filter_by(deleted=False).\
all()
@@ -1767,6 +1791,7 @@ def network_update(context, network_id, values):
network_ref = network_get(context, network_id, session=session)
network_ref.update(values)
network_ref.save(session=session)
return network_ref
###################
@@ -2281,6 +2306,23 @@ def block_device_mapping_update(context, bdm_id, values):
update(values)
@require_context
def block_device_mapping_update_or_create(context, values):
session = get_session()
with session.begin():
result = session.query(models.BlockDeviceMapping).\
filter_by(instance_id=values['instance_id']).\
filter_by(device_name=values['device_name']).\
filter_by(deleted=False).\
first()
if not result:
bdm_ref = models.BlockDeviceMapping()
bdm_ref.update(values)
bdm_ref.save(session=session)
else:
result.update(values)
@require_context
def block_device_mapping_get_all_by_instance(context, instance_id):
session = get_session()
@@ -2839,13 +2881,13 @@ def migration_get(context, id, session=None):
@require_admin_context
def migration_get_by_instance_and_status(context, instance_id, status):
def migration_get_by_instance_and_status(context, instance_uuid, status):
session = get_session()
result = session.query(models.Migration).\
filter_by(instance_id=instance_id).\
filter_by(instance_uuid=instance_uuid).\
filter_by(status=status).first()
if not result:
raise exception.MigrationNotFoundByStatus(instance_id=instance_id,
raise exception.MigrationNotFoundByStatus(instance_id=instance_uuid,
status=status)
return result
@@ -3026,7 +3068,7 @@ def instance_type_get_all(context, inactive=False):
@require_context
def instance_type_get_by_id(context, id):
def instance_type_get(context, id):
"""Returns a dict describing specific instance_type"""
session = get_session()
inst_type = session.query(models.InstanceTypes).\

View File

@@ -64,8 +64,8 @@ def db_version():
'users', 'user_project_association',
'user_project_role_association',
'user_role_association',
'volumes',
'virtual_storage_arrays', 'drive_types'):
'virtual_storage_arrays', 'drive_types',
'volumes'):
assert table in meta.tables
return db_version_control(1)
except AssertionError:

View File

@@ -31,6 +31,7 @@ import unittest
import mox
import nose.plugins.skip
import nova.image.fake
import shutil
import stubout
from eventlet import greenthread
@@ -119,6 +120,9 @@ class TestCase(unittest.TestCase):
if hasattr(fake.FakeConnection, '_instance'):
del fake.FakeConnection._instance
if FLAGS.image_service == 'nova.image.fake.FakeImageService':
nova.image.fake.FakeImageService_reset()
# Reset any overriden flags
self.reset_flags()
@@ -248,3 +252,15 @@ class TestCase(unittest.TestCase):
for d1, d2 in zip(L1, L2):
self.assertDictMatch(d1, d2, approx_equal=approx_equal,
tolerance=tolerance)
def assertSubDictMatch(self, sub_dict, super_dict):
"""Assert a sub_dict is subset of super_dict."""
self.assertTrue(set(sub_dict.keys()).issubset(set(super_dict.keys())))
for k, sub_value in sub_dict.items():
super_value = super_dict[k]
if isinstance(sub_value, dict):
self.assertSubDictMatch(sub_value, super_value)
elif 'DONTCARE' in (sub_value, super_value):
continue
else:
self.assertEqual(sub_value, super_value)

View File

@@ -59,6 +59,7 @@ def setup():
network.create_networks(ctxt,
label='test',
cidr=FLAGS.fixed_range,
multi_host=FLAGS.multi_host,
num_networks=FLAGS.num_networks,
network_size=FLAGS.network_size,
cidr_v6=FLAGS.fixed_range_v6,
@@ -68,7 +69,7 @@ def setup():
vpn_start=FLAGS.vpn_start,
vlan_start=FLAGS.vlan_start)
for net in db.network_get_all(ctxt):
network.set_network_host(ctxt, net['id'])
network.set_network_host(ctxt, net)
cleandb = os.path.join(FLAGS.state_path, FLAGS.sqlite_clean_db)
shutil.copyfile(testdb, cleandb)