From 96eb5e2c6d17999656602e9924d4411e089910ef Mon Sep 17 00:00:00 2001 From: Dhruv Bansal Date: Fri, 30 Sep 2016 16:02:45 +0000 Subject: [PATCH] Some bug squashing and mild refactoring. --- ldap_attr | 77 +++++++++++++++++++++++++---------------------------- ldap_upsert | 70 +++++++++++++++++++++--------------------------- 2 files changed, 66 insertions(+), 81 deletions(-) diff --git a/ldap_attr b/ldap_attr index ff0693f..ee0f37d 100755 --- a/ldap_attr +++ b/ldap_attr @@ -146,24 +146,36 @@ class LdapAttr(object): def __init__(self, module): self.module = module - # python-ldap doesn't understand unicode strings. Parameters that are - # just going to get passed to python-ldap APIs are stored as utf-8. - self.dn = self._utf8_param('dn') - self.name = self._utf8_param('name') - self.values = self._normalized_values() - self.state = self.module.params['state'] + # Parameters that we have to directly pass to python-ldap need + # to converted to UTF-8 first, as python-ldap doesn't + # understand unicode strings. + + # Server parameters self.server_uri = self.module.params['server_uri'] - self.start_tls = self.module.boolean(self.module.params['start_tls']) - self.bind_dn = self._utf8_param('bind_dn') - self.bind_pw = self._utf8_param('bind_pw') + self.start_tls = self.module.boolean(self.module.params['start_tls']) + self.bind_dn = self._utf8_param('bind_dn') + self.bind_pw = self._utf8_param('bind_pw') + + # Attribute parameters + self.dn = self._utf8_param('dn') + self.name = self._utf8_param('name') + self.values = self._normalized_values() + self.state = self.module.params['state'] self._connection = None + def _force_utf8(self, value): + """If value is Unicode, encode to UTF-8.""" + if isinstance(value, unicode): + return value.encode('utf-8') + return value + def _utf8_param(self, name): + """Extract a parameter as UTF-8.""" return self._force_utf8(self.module.params[name]) def _normalized_values(self): - """ Parses the value parameter into a list of utf-8 strings. """ + """Parses the 'values' parameter into a list of UTF-8 strings.""" values = self.module.params['values'] if isinstance(values, basestring): @@ -177,13 +189,6 @@ class LdapAttr(object): return map(self._force_utf8, values) - def _force_utf8(self, value): - """ If value is unicode, encode to utf-8. """ - if isinstance(value, unicode): - value = value.encode('utf-8') - - return value - def main(self): if self.state == 'present': modlist = self.handle_present() @@ -210,57 +215,47 @@ class LdapAttr(object): def handle_present(self): values_to_add = filter(self.is_value_absent, self.values) if len(values_to_add) > 0: - modlist = [(ldap.MOD_ADD, self.name, values_to_add)] + return [(ldap.MOD_ADD, self.name, values_to_add)] else: - modlist = [] - - return modlist + return [] def handle_absent(self): values_to_delete = filter(self.is_value_present, self.values) if len(values_to_delete) > 0: - modlist = [(ldap.MOD_DELETE, self.name, values_to_delete)] + return [(ldap.MOD_DELETE, self.name, values_to_delete)] else: - modlist = [] - - return modlist + return [] def handle_exact(self): - modlist = [] - current = self.current_values() if frozenset(self.values) != frozenset(current): if len(current) == 0: - modlist = [(ldap.MOD_ADD, self.name, self.values)] + return [(ldap.MOD_ADD, self.name, self.values)] elif len(self.values) == 0: - modlist = [(ldap.MOD_DELETE, self.name, None)] + return [(ldap.MOD_DELETE, self.name, None)] else: - modlist = [(ldap.MOD_REPLACE, self.name, self.values)] - - return modlist + return [(ldap.MOD_REPLACE, self.name, self.values)] + return [] # # Util # def is_value_present(self, value): - """ True if the target attribute has the given value. """ + """True if the target attribute has the given value.""" try: - is_present = bool(self.connection.compare_s(self.dn, self.name, value)) + return bool(self.connection.compare_s(self.dn, self.name, value)) except ldap.NO_SUCH_ATTRIBUTE: - is_present = False - - return is_present + return False def is_value_absent(self, value): - """ True if the target attribute does not have the given value. """ + """True if the target attribute does not have the given value.""" return (not self.is_value_present(value)) def current_values(self): - """ Returns the full list of values on the target attribute. """ + """Returns the full list of values on the target attribute.""" results = self.connection.search_s(self.dn, ldap.SCOPE_BASE, attrlist=[self.name]) - values = results[0][1].get(self.name, []) - + values = results[0][1].get(self.name, []) return values # diff --git a/ldap_upsert b/ldap_upsert index 7806073..1b72978 100755 --- a/ldap_upsert +++ b/ldap_upsert @@ -104,23 +104,39 @@ class LdapUpsert(object): def __init__(self, module): self.module = module + # Parameters that we have to directly pass to python-ldap need + # to converted to UTF-8 first, as python-ldap doesn't + # understand unicode strings. + + # Server parameters self.server_uri = self.module.params['server_uri'] - self.start_tls = self._boolean_param('start_tls') + self.start_tls = self.module.boolean(self.module.params['start_tls']) self.bind_dn = self._utf8_param('bind_dn') self.bind_pw = self._utf8_param('bind_pw') - + + # Entry parameters self.dn = self._utf8_param('dn') self._load_attrs() if 'objectClass' not in self.attrs: self.module.fail_json(msg="At least one objectClass must be provided") - def _boolean_param(self, name): - return self.module.boolean(self.module.params[name]) - + def _force_utf8(self, value): + """If value is Unicode, encode to UTF-8.""" + if isinstance(value, unicode): + return value.encode('utf-8') + return value + def _utf8_param(self, name): + """Extract a parameter as UTF-8.""" return self._force_utf8(self.module.params[name]) + def _load_attrs(self): + self.attrs = {} + for name, raw in self.module.params.iteritems(): + if name not in self.module.argument_spec: + self.attrs[name] = self._load_attr_values(name, raw) + def _load_attr_values(self, name, raw): if isinstance(raw, basestring): values = raw.split(',') @@ -132,19 +148,7 @@ class LdapUpsert(object): return map(self._force_utf8, values) - def _force_utf8(self, value): - """ If value is unicode, encode to utf-8. """ - if isinstance(value, unicode): - value = value.encode('utf-8') - - return value - def _load_attrs(self): - self.attrs = {} - for name, raw in self.module.params.iteritems(): - if name not in self.module.argument_spec: - self.attrs[name] = self._load_attr_values(name, raw) - def main(self): if self.entry_exists(): results = self.update_entry() @@ -169,10 +173,9 @@ class LdapUpsert(object): def update_entry(self): results = [] - for attr, value in self.attrs.iteritems(): + for attr, values in self.attrs.iteritems(): if attr == 'objectClass': continue - value = self._extract_value(value) - check = self._attribute_value_check(attr, value) + check = self._attribute_values_check(attr, values) if check is False: op = ldap.MOD_REPLACE elif check is None: @@ -180,34 +183,21 @@ class LdapUpsert(object): else: op = None # Nothing to see here... if op is not None: - result = self.connection.modify_s(self.dn, [(op, attr, value)]) + result = self.connection.modify_s(self.dn, [(op, attr, values)]) results.append(result) if len(results) == 0: return dict(changed=False) else: return dict(changed=True, results=results) - def _attribute_value_check(self, attr, value): + def _attribute_values_check(self, attr, values): try: - return bool(self.connection.compare_s(self.dn, attr, value)) - except ldap.NO_SUCH_ATTRIBUTE, ldap.UNDEFINED_TYPE: + return all(self._attribute_value_check(attr, value) for value in values) + except ldap.NO_SUCH_ATTRIBUTE: return None - - def _extract_value(self, values): - if isinstance(values, basestring): - if values == '': - values = [] - else: - values = [values] - - if not (isinstance(values, list) and all(isinstance(value, basestring) for value in values)): - self.module.fail_json(msg="Attribute values must be strings or lists of strings.") - - values = map(self._force_utf8, values) - if len(values) == 1: - return values[0] - else: - return values + + def _attribute_value_check(self, attr, value): + return bool(self.connection.compare_s(self.dn, attr, value)) # # LDAP Connection