Commit 2e43db90 authored by Lukas Burgey's avatar Lukas Burgey

Rework the userinfo updating

Ignores groups, when entitlements are present. Closes #22.

Closes #20
parent 56651d03
......@@ -267,10 +267,8 @@ class User(AbstractUser):
if state_item.state != 'not_deployed':
sites.append(state_item.site)
# return True if the userinfo contains entitlements (we ignore groups then)
def update_userinfo_entitlements(self, userinfo):
if self.idp.userinfo_field_entitlements is None:
return
local_entitlements = self.vos.instance_of(Entitlement)
remote_entitlements = [
Entitlement.extract_name(name)
......@@ -281,31 +279,28 @@ class User(AbstractUser):
# check if local_entitlements were removed
for loc_ent in local_entitlements:
if loc_ent.name not in remote_entitlements:
self.vos.remove(loc_ent)
self.user_changed_vo_removed(loc_ent)
self._remove_vo(loc_ent)
for rem_ent_name in remote_entitlements:
ent = Entitlement.get_entitlement(name=rem_ent_name, idp=self.idp)
# check if user needs to be in this entitlement
if not self.vos.filter(name=rem_ent_name, idp=self.idp).exists():
self.vos.add(ent)
self.user_changed_vo_added(ent)
self._add_vo(ent)
def update_userinfo_groups(self, userinfo):
if self.idp.userinfo_field_groups is None:
return
return len(remote_entitlements) > 0
def update_userinfo_groups(self, userinfo, ignore_groups=False):
local_groups = self.vos.instance_of(Group)
remote_groups = userinfo.get(self.idp.userinfo_field_groups, [])
remote_groups = []
if not ignore_groups:
remote_groups = userinfo.get(self.idp.userinfo_field_groups, [])
# check if groups were removed
for group in local_groups:
if group.name not in remote_groups:
self.vos.remove(group)
self.user_changed_vo_removed(group)
self._remove_vo(group)
# check if groups were added
for group_name in remote_groups:
......@@ -313,8 +308,7 @@ class User(AbstractUser):
# check if user needs to be in this group
if not self.vos.filter(name=group_name, idp=self.idp).exists():
self.user_changed_vo_added(group)
self.vos.add(group)
self._add_vo(group)
def update_userinfo_ssh_key(self, userinfo):
......@@ -328,29 +322,20 @@ class User(AbstractUser):
# is the idp key still present?
if idp_key_name not in userinfo:
self.user_remove_key(key)
self._remove_key(key)
return True
return
# is the idp key changed?
if key.key != unity_key_value:
self.user_remove_key(key)
self._remove_key(key)
new_key = SSHPublicKey(
name=unity_key_name,
key=unity_key_value,
user=self,
)
new_key.save()
self.user_changed_key_added(new_key)
return True
return False
# causes creation of a new key
raise SSHPublicKey.DoesNotExist
except SSHPublicKey.DoesNotExist:
if idp_key_name not in userinfo:
return False
return
key = SSHPublicKey(
name=unity_key_name,
......@@ -359,9 +344,7 @@ class User(AbstractUser):
)
key.save()
self.user_changed_key_added(key)
return True
self._add_key(key)
def update_userinfo(self, userinfo):
......@@ -376,29 +359,21 @@ class User(AbstractUser):
self.userinfo = userinfo
self.save()
changed = False
self.update_userinfo_groups(userinfo)
self.update_userinfo_entitlements(userinfo)
if self.update_userinfo_ssh_key(userinfo):
changed = True
if changed:
self.user_changed()
ignore_groups = self.update_userinfo_entitlements(userinfo)
self.update_userinfo_groups(userinfo, ignore_groups=ignore_groups)
def user_changed(self):
LOGGER.info('user_changed')
self.update_userinfo_ssh_key(userinfo)
def user_changed_key_added(self, key):
LOGGER.debug(self.msg('Added: {}'.format(key)))
def _add_key(self, key):
LOGGER.debug(self.msg('Key added: {}'.format(key)))
from . import deployments
for dep in deployments.get_deployment(self):
dep.user_credential_added(key)
def user_remove_key(self, key):
LOGGER.debug(self.msg('Remove: {}'.format(key)))
def _remove_key(self, key):
LOGGER.debug(self.msg('Key remove: {}'.format(key)))
if key.delete_key():
return
......@@ -407,25 +382,27 @@ class User(AbstractUser):
for dep in deployments.get_deployment(self):
dep.user_credential_removed(key)
def user_changed_vo_added(self, vo):
LOGGER.debug(self.msg('Added: {}'.format(vo)))
def _add_vo(self, vo):
self.vos.add(vo)
LOGGER.debug(self.msg('VO added: {}'.format(vo)))
# check if the user has deactivated deployments for this exact vo
# if yes: reactivate the deployments
# TODO this does nothing for ServiceDeployments
for dep in self.deployments.filter(vodeployment__vo=vo):
LOGGER.debug('user_changed_vo_added: need to activate deployment %s', dep)
LOGGER.debug('_add_vo: need to activate deployment %s', dep)
def user_changed_vo_removed(self, vo):
LOGGER.debug(self.msg('Removed: {}'.format(vo)))
def _remove_vo(self, vo):
self.vos.remove(vo)
LOGGER.debug(self.msg('VO removed: {}'.format(vo)))
# TODO this does nothing for ServiceDeployments
# check if the user has deployments which need member ship of this vo
# if yes remove them
for dep in self.deployments.filter(vodeployment__vo=vo):
LOGGER.debug('user_changed_vo_removed: need to deactivate deployment %s', dep)
LOGGER.debug('_remove_vo: need to deactivate deployment %s', dep)
class SSHPublicKey(models.Model):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment