Commit 3d4e7e43 authored by Lukas Burgey's avatar Lukas Burgey
Browse files

Create deployments for user during the userinfo update

parent 9d5c3389
...@@ -157,7 +157,7 @@ class NewDeployment(models.Model): ...@@ -157,7 +157,7 @@ class NewDeployment(models.Model):
state_target = models.CharField( state_target = models.CharField(
max_length=50, max_length=50,
choices=STATE_CHOICES, choices=STATE_CHOICES,
default=DEPLOYED, default=NOT_DEPLOYED,
) )
is_active = models.BooleanField( is_active = models.BooleanField(
...@@ -274,7 +274,8 @@ class NewDeployment(models.Model): ...@@ -274,7 +274,8 @@ class NewDeployment(models.Model):
def user_remove(self): def user_remove(self):
self._set_target('not_deployed') self._set_target('not_deployed')
for item in self.state_items.all(): for item in self.state_items.all():
item.user_remove() if item.state != NOT_DEPLOYED:
item.user_remove()
self.publish_to_client() self.publish_to_client()
# each state item publishes its state to the user # each state item publishes its state to the user
...@@ -447,15 +448,21 @@ class NewDeploymentStateItem(models.Model): ...@@ -447,15 +448,21 @@ class NewDeploymentStateItem(models.Model):
# user: removal requested # user: removal requested
def user_remove(self): def user_remove(self):
if self.state == 'not_deployed': if self.parent.state_target == DEPLOYED and (self.state == FAILED or self.state == REJECTED):
self._reset()
self._set_state(NOT_DEPLOYED)
return
if self.state == NOT_DEPLOYED:
LOGGER.info(self.msg('ignoring invalid state transition user_remove')) LOGGER.info(self.msg('ignoring invalid state transition user_remove'))
return return
# FIXME this will break if the client 'finishes' the deployment, after user_remove
if ( if (
self.state == 'deployment_pending' self.state == DEPLOYMENT_PENDING
or self.state == 'questionnaire' or self.state == QUESTIONNAIRE
): ):
self._set_state('not_deployed') self._set_state(NOT_DEPLOYED)
return return
self._set_state( self._set_state(
...@@ -471,33 +478,35 @@ class NewDeploymentStateItem(models.Model): ...@@ -471,33 +478,35 @@ class NewDeploymentStateItem(models.Model):
# returns None on success, or a string describing an error # returns None on success, or a string describing an error
def client_response(self, output): def client_response(self, output):
status = output.get('state', 'undefined') state = output.get('state', 'undefined')
self.message = output.get('message', '') self.message = output.get('message', '')
self.save() self.save()
if status != 'undefined': if state != 'undefined':
# update values # update values
if status == 'deployed': if state == 'deployed':
self.credentials = output.get('credentials', {}) if self.parent.state_target == NOT_DEPLOYED:
self.save() self.user_remove()
elif status == 'not_deployed': else:
self.credentials = output.get('credentials', {})
self.save()
elif state == 'not_deployed':
# reset credentials and questi # reset credentials and questi
self._reset() self._reset()
self.save() self.save()
elif status == 'questionnaire': elif state == 'questionnaire':
self.questionnaire = output.get('questionnaire', {}) self.questionnaire = output.get('questionnaire', {})
self.save() self.save()
elif status == 'rejected': elif state == 'rejected':
pass pass
elif status == 'failed': elif state == 'failed':
pass pass
else: else:
return 'unknown state \''+status+'\'' return 'unknown state \''+state+'\''
self._set_state(status) self._set_state(state)
return None return None
return 'missing status in output' return 'missing state in output'
# resets all client sent values # resets all client sent values
......
...@@ -150,7 +150,7 @@ class DeploymentTest(TestCase): ...@@ -150,7 +150,7 @@ class DeploymentTest(TestCase):
self.assertIsNotNone(deployment) self.assertIsNotNone(deployment)
self.assertEqual(len(deployment.services), service_count) self.assertEqual(len(deployment.services), service_count)
self.assertEqual(deployment.state_items.count(), service_count) self.assertEqual(deployment.state_items.count(), service_count)
self.assertEqual(deployment.state_target, models.DEPLOYED) self.assertEqual(deployment.state_target, models.NOT_DEPLOYED)
if service_count > 0: if service_count > 0:
self.assertEqual(deployment.state, models.NOT_DEPLOYED) self.assertEqual(deployment.state, models.NOT_DEPLOYED)
...@@ -160,6 +160,8 @@ class DeploymentTest(TestCase): ...@@ -160,6 +160,8 @@ class DeploymentTest(TestCase):
if service_count > 0: if service_count > 0:
self.assertEqual(deployment.state, models.DEPLOYMENT_PENDING) self.assertEqual(deployment.state, models.DEPLOYMENT_PENDING)
self.assertFalse(deployment.target_reached) self.assertFalse(deployment.target_reached)
for item in deployment.state_items.all():
self.assertEqual(item.state, models.DEPLOYMENT_PENDING)
else: else:
self.assertEqual(deployment.state, models.DEPLOYED) self.assertEqual(deployment.state, models.DEPLOYED)
self.assertTrue(deployment.target_reached) self.assertTrue(deployment.target_reached)
...@@ -168,6 +170,8 @@ class DeploymentTest(TestCase): ...@@ -168,6 +170,8 @@ class DeploymentTest(TestCase):
LOGGER.debug('deployment_run: %s state items', deployment.state_items.count()) LOGGER.debug('deployment_run: %s state items', deployment.state_items.count())
for item in deployment.state_items.all(): for item in deployment.state_items.all():
item.client_response({'state': models.DEPLOYED}) item.client_response({'state': models.DEPLOYED})
self.assertEqual(item.state, models.DEPLOYED)
self.assertEqual(deployment.state, models.DEPLOYED) self.assertEqual(deployment.state, models.DEPLOYED)
...@@ -179,6 +183,8 @@ class DeploymentTest(TestCase): ...@@ -179,6 +183,8 @@ class DeploymentTest(TestCase):
if service_count > 0: if service_count > 0:
self.assertEqual(deployment.state, models.REMOVAL_PENDING) self.assertEqual(deployment.state, models.REMOVAL_PENDING)
self.assertFalse(deployment.target_reached) self.assertFalse(deployment.target_reached)
for item in deployment.state_items.all():
self.assertEqual(item.state, models.REMOVAL_PENDING)
else: else:
self.assertEqual(deployment.state, models.NOT_DEPLOYED) self.assertEqual(deployment.state, models.NOT_DEPLOYED)
self.assertTrue(deployment.target_reached) self.assertTrue(deployment.target_reached)
...@@ -186,6 +192,7 @@ class DeploymentTest(TestCase): ...@@ -186,6 +192,7 @@ class DeploymentTest(TestCase):
# execute removals # execute removals
for item in deployment.state_items.all(): for item in deployment.state_items.all():
item.client_response({'state': models.NOT_DEPLOYED}) item.client_response({'state': models.NOT_DEPLOYED})
self.assertEqual(item.state, models.NOT_DEPLOYED)
self.assertEqual(deployment.state, models.NOT_DEPLOYED) self.assertEqual(deployment.state, models.NOT_DEPLOYED)
self.assertTrue(deployment.target_reached) self.assertTrue(deployment.target_reached)
...@@ -194,21 +201,25 @@ class DeploymentTest(TestCase): ...@@ -194,21 +201,25 @@ class DeploymentTest(TestCase):
def deployment_run_delayed_service(self, deployment, group, service_count): def deployment_run_delayed_service(self, deployment, group, service_count):
self.assertIsNotNone(deployment) self.assertIsNotNone(deployment)
self.assertEqual(deployment.state, models.NOT_DEPLOYED) self.assertEqual(deployment.state, models.NOT_DEPLOYED)
self.assertEqual(deployment.state_target, models.NOT_DEPLOYED)
self.assertEqual(len(deployment.services), service_count) self.assertEqual(len(deployment.services), service_count)
self.assertEqual(deployment.state_items.count(), service_count) self.assertEqual(deployment.state_items.count(), service_count)
# start deployment # start deployment
deployment.user_deploy() deployment.user_deploy()
self.assertFalse(deployment.state_items.exists()) if service_count == 0:
self.assertTrue(deployment.target_reached) self.assertFalse(deployment.state_items.exists())
self.assertTrue(deployment.target_reached)
# add service
delayed_service = models.Service.objects.get(name='DELAYED_SERVICE') delayed_service = models.Service.objects.get(name='DELAYED_SERVICE')
delayed_service.groups.add(group) delayed_service.groups.add(group)
# add service
deployment.service_added(delayed_service) deployment.service_added(delayed_service)
for item in deployment.state_items.all():
self.assertEqual(item.state, models.DEPLOYMENT_PENDING)
self.assertEqual(deployment.state_items.count(), service_count + 1) self.assertEqual(deployment.state_items.count(), service_count + 1)
self.assertEqual(deployment.state, models.DEPLOYMENT_PENDING) self.assertEqual(deployment.state, models.DEPLOYMENT_PENDING)
......
...@@ -5,7 +5,7 @@ from django.test import TestCase ...@@ -5,7 +5,7 @@ from django.test import TestCase
from feudal.backend import models from feudal.backend import models
from feudal.backend.auth.v1 import models as auth_models from feudal.backend.auth.v1 import models as auth_models
from feudal.backend.models.test_models import setup_fixture, teardown_fixture, TEST_NAME, TEST_USERINFO, TEST_PASSWORD from feudal.backend.models.test_models import setup_fixture, TEST_NAME, TEST_PASSWORD, TEST_USERINFO
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
...@@ -13,9 +13,6 @@ class UserTest(TestCase): ...@@ -13,9 +13,6 @@ class UserTest(TestCase):
def setUp(self): def setUp(self):
setup_fixture() setup_fixture()
def tearDown(self):
teardown_fixture()
def test_user(self): def test_user(self):
user = auth.authenticate( user = auth.authenticate(
username=TEST_NAME, username=TEST_NAME,
......
...@@ -212,7 +212,7 @@ class User(AbstractUser): ...@@ -212,7 +212,7 @@ class User(AbstractUser):
if self.user_type == 'oidcuser': if self.user_type == 'oidcuser':
for dep in self.deployments.all(): for dep in self.deployments.all():
dep.deactivate() dep.user_remove()
def update_userinfo(self, userinfo): def update_userinfo(self, userinfo):
groups = userinfo.get('groups', []) groups = userinfo.get('groups', [])
...@@ -232,18 +232,23 @@ class User(AbstractUser): ...@@ -232,18 +232,23 @@ class User(AbstractUser):
for group_name in groups: for group_name in groups:
group = None
try: try:
group = Group.objects.get(name=group_name) group = Group.objects.get(name=group_name)
self.groups.add(group)
except Group.DoesNotExist: except Group.DoesNotExist:
LOGGER.info('Adding group %s', group_name) LOGGER.info('Adding group %s', group_name)
group = Group(name=group_name) group = Group(name=group_name)
group.save() group.save()
self.groups.add(group)
for dep in self.deployments.filter(group=group): for dep in self.deployments.filter(group=group):
dep.activate() dep.activate()
self.groups.add(group)
from . import NewDeployment
dep = NewDeployment.get_deployment(self, group=group)
dep.save()
# include the ssh key from unity # include the ssh key from unity
unity_key_value = userinfo.get('ssh_key', '') unity_key_value = userinfo.get('ssh_key', '')
unity_key_name = 'unity_key' unity_key_name = 'unity_key'
...@@ -288,19 +293,13 @@ class SSHPublicKey(models.Model): ...@@ -288,19 +293,13 @@ class SSHPublicKey(models.Model):
# somewhere # somewhere
# the receiver 'delete_removen_ssh_key' does the actual deletion # the receiver 'delete_removen_ssh_key' does the actual deletion
def delete_key(self): def delete_key(self):
# if this key is not deployed anywhere we delete it now
if not self.deployed_anywhere:
LOGGER.info(self.msg('Direct deletion of key'))
self.delete()
return
LOGGER.info(self.msg('Deletion of key started')) LOGGER.info(self.msg('Deletion of key started'))
self.deleted = True self.deleted = True
self.save() self.save()
# delete implies removeing the key from all clients # delete implies removeing the key from all clients
for deployment in self.deployments.all(): for dep in self.user.deployments.all():
deployment.remove_key(self) dep.remove_key(self)
# when a key is removed by a client we try to finally delete it # when a key is removed by a client we try to finally delete it
def try_final_deletion(self): def try_final_deletion(self):
...@@ -326,6 +325,7 @@ class SSHPublicKey(models.Model): ...@@ -326,6 +325,7 @@ class SSHPublicKey(models.Model):
def msg(self, msg): def msg(self, msg):
return '[SSHKey:{}] {}'.format(self, msg) return '[SSHKey:{}] {}'.format(self, msg)
@receiver(post_save, sender=User) @receiver(post_save, sender=User)
def deactivate_user(sender, instance=None, created=False, **kwargs): def deactivate_user(sender, instance=None, created=False, **kwargs):
if created: if created:
...@@ -334,7 +334,6 @@ def deactivate_user(sender, instance=None, created=False, **kwargs): ...@@ -334,7 +334,6 @@ def deactivate_user(sender, instance=None, created=False, **kwargs):
if not instance.is_active and instance.is_active_at_clients: if not instance.is_active and instance.is_active_at_clients:
instance.deactivate() instance.deactivate()
@receiver(post_save, sender=User) @receiver(post_save, sender=User)
def activate_user(sender, instance=None, created=False, **kwargs): def activate_user(sender, instance=None, created=False, **kwargs):
if created: if created:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment