Commit 9c6d53d2 authored by benjamin.ertl's avatar benjamin.ertl
Browse files

oidc + scim user info update

parent 3bceb177
...@@ -20,16 +20,18 @@ import org.springframework.stereotype.Component; ...@@ -20,16 +20,18 @@ import org.springframework.stereotype.Component;
import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.oauth2.sdk.token.AccessToken; import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.nimbusds.openid.connect.sdk.claims.UserInfo;
import com.nimbusds.openid.connect.sdk.token.OIDCTokens; import com.nimbusds.openid.connect.sdk.token.OIDCTokens;
import edu.kit.scc.oidc.OidcClient; import edu.kit.scc.oidc.OidcClient;
import edu.kit.scc.scim.ScimClient; import edu.kit.scc.scim.ScimClient;
import edu.kit.scc.scim.ScimUser; import edu.kit.scc.scim.ScimUser;
import edu.kit.scc.scim.ScimUserAttributeMapper;
@Component @Component
public class Harmonizer { public class IdentityHarmonizer {
private static final Logger log = LoggerFactory.getLogger(Harmonizer.class); private static final Logger log = LoggerFactory.getLogger(IdentityHarmonizer.class);
@Autowired @Autowired
private ScimClient scimClient; private ScimClient scimClient;
...@@ -39,12 +41,11 @@ public class Harmonizer { ...@@ -39,12 +41,11 @@ public class Harmonizer {
public ScimUser harmonizeIdentities(String subject, OIDCTokens tokens) { public ScimUser harmonizeIdentities(String subject, OIDCTokens tokens) {
ScimUser scimUser = new ScimUser(); ScimUser scimUser = new ScimUser();
scimUser.setSchemas(Arrays.asList(scimUser.USER_SCHEMA));
scimUser.setUserName(subject); scimUser.setUserName(subject);
// OIDC // OIDC
log.debug("Try to get OIDC user information"); log.debug("Try to get OIDC user information");
JSONObject userInfo = null; UserInfo userInfo = null;
if (tokens != null) { if (tokens != null) {
try { try {
JWT jwt = tokens.getIDToken(); JWT jwt = tokens.getIDToken();
...@@ -54,9 +55,16 @@ public class Harmonizer { ...@@ -54,9 +55,16 @@ public class Harmonizer {
AccessToken accessToken = tokens.getAccessToken(); AccessToken accessToken = tokens.getAccessToken();
userInfo = oidcClient.requestUserInfo(accessToken.getValue()); userInfo = oidcClient.requestUserInfo(accessToken.getValue(), claimsSet);
log.debug("User info {}", userInfo.toString()); if (userInfo != null) {
log.debug("User info {}", userInfo.toJSONObject().toJSONString());
ScimUserAttributeMapper attributeMapper = new ScimUserAttributeMapper();
scimUser = attributeMapper.mapFromUserInfo(userInfo);
scimUser.setSchemas(Arrays.asList(scimUser.USER_SCHEMA));
}
} catch (ParseException e) { } catch (ParseException e) {
log.error(e.getMessage()); log.error(e.getMessage());
...@@ -65,12 +73,19 @@ public class Harmonizer { ...@@ -65,12 +73,19 @@ public class Harmonizer {
// SCIM // SCIM
log.debug("Try to get SCIM user information"); log.debug("Try to get SCIM user information");
JSONObject userJson = scimClient.getUser(subject); JSONObject userJson = scimClient.getUser(scimUser.getUserName());
log.debug("SCIM user info {}", userJson.toString()); if (userJson != null) {
log.debug("SCIM user info {}", userJson.toString());
// TODO merge with SCIM user
}
// LDAP // LDAP
// TODO // TODO
// REGAPP
// TODO
log.debug("Aggregated SCIM user information {}", scimUser.toString()); log.debug("Aggregated SCIM user information {}", scimUser.toString());
return scimUser; return scimUser;
} }
......
...@@ -50,7 +50,7 @@ public class RestServiceController { ...@@ -50,7 +50,7 @@ public class RestServiceController {
private OidcClient oidcClient; private OidcClient oidcClient;
@Autowired @Autowired
private Harmonizer identityHarmonizer; private IdentityHarmonizer identityHarmonizer;
// expected body e.g. // expected body e.g.
// password=password // password=password
......
...@@ -185,24 +185,24 @@ public class HttpClient { ...@@ -185,24 +185,24 @@ public class HttpClient {
} catch (IOException e) { } catch (IOException e) {
// e.printStackTrace(); // e.printStackTrace();
log.error(e.getMessage()); log.error("ERROR {}", e.getMessage());
} catch (Exception e) { } catch (Exception e) {
// e.printStackTrace(); // e.printStackTrace();
log.error(e.getMessage()); log.error("ERROR {}", e.getMessage());
} finally { } finally {
if (in != null) if (in != null)
try { try {
in.close(); in.close();
} catch (IOException e) { } catch (IOException e) {
// e.printStackTrace(); // e.printStackTrace();
log.error(e.getMessage()); log.error("ERROR {}", e.getMessage());
} }
if (out != null) if (out != null)
try { try {
out.close(); out.close();
} catch (IOException e) { } catch (IOException e) {
// e.printStackTrace(); // e.printStackTrace();
log.error(e.getMessage()); log.error("ERROR {}", e.getMessage());
} }
} }
return response; return response;
......
...@@ -18,7 +18,7 @@ import org.springframework.ldap.core.AttributesMapper; ...@@ -18,7 +18,7 @@ import org.springframework.ldap.core.AttributesMapper;
import edu.kit.scc.dto.GroupDTO; import edu.kit.scc.dto.GroupDTO;
public class GroupAttributeMapper implements AttributesMapper<GroupDTO> { public class LdapGroupAttributeMapper implements AttributesMapper<GroupDTO> {
@Override @Override
public GroupDTO mapFromAttributes(Attributes attributes) throws NamingException { public GroupDTO mapFromAttributes(Attributes attributes) throws NamingException {
......
...@@ -45,7 +45,7 @@ public class LdapGroupDAO implements GroupDAO { ...@@ -45,7 +45,7 @@ public class LdapGroupDAO implements GroupDAO {
@Override @Override
public List<GroupDTO> getAllGroups() { public List<GroupDTO> getAllGroups() {
return ldapTemplate.search(groupBase, "(objectclass=posixGroup)", new GroupAttributeMapper()); return ldapTemplate.search(groupBase, "(objectclass=posixGroup)", new LdapGroupAttributeMapper());
} }
...@@ -55,7 +55,7 @@ public class LdapGroupDAO implements GroupDAO { ...@@ -55,7 +55,7 @@ public class LdapGroupDAO implements GroupDAO {
andFilter.and(new EqualsFilter("objectclass", "posixGroup")).and(new EqualsFilter("cn", commonName)); andFilter.and(new EqualsFilter("objectclass", "posixGroup")).and(new EqualsFilter("cn", commonName));
log.debug("LDAP query {}", andFilter.encode()); log.debug("LDAP query {}", andFilter.encode());
return ldapTemplate.search("", andFilter.encode(), new GroupAttributeMapper()); return ldapTemplate.search("", andFilter.encode(), new LdapGroupAttributeMapper());
} }
@Override @Override
......
...@@ -16,7 +16,7 @@ import org.springframework.ldap.core.AttributesMapper; ...@@ -16,7 +16,7 @@ import org.springframework.ldap.core.AttributesMapper;
import edu.kit.scc.dto.UserDTO; import edu.kit.scc.dto.UserDTO;
public class UserAttributeMapper implements AttributesMapper<UserDTO> { public class LdapUserAttributeMapper implements AttributesMapper<UserDTO> {
@Override @Override
public UserDTO mapFromAttributes(Attributes attributes) throws NamingException { public UserDTO mapFromAttributes(Attributes attributes) throws NamingException {
......
...@@ -44,7 +44,7 @@ public class LdapUserDAO implements UserDAO { ...@@ -44,7 +44,7 @@ public class LdapUserDAO implements UserDAO {
@Override @Override
public List<UserDTO> getAllUsers() { public List<UserDTO> getAllUsers() {
return ldapTemplate.search(userBase, "(objectclass=posixAccount)", new UserAttributeMapper()); return ldapTemplate.search(userBase, "(objectclass=posixAccount)", new LdapUserAttributeMapper());
} }
@Override @Override
...@@ -53,7 +53,7 @@ public class LdapUserDAO implements UserDAO { ...@@ -53,7 +53,7 @@ public class LdapUserDAO implements UserDAO {
andFilter.and(new EqualsFilter("objectclass", "posixAccount")).and(new EqualsFilter("uid", uid)); andFilter.and(new EqualsFilter("objectclass", "posixAccount")).and(new EqualsFilter("uid", uid));
log.debug("LDAP query {}", andFilter.encode()); log.debug("LDAP query {}", andFilter.encode());
return ldapTemplate.search("", andFilter.encode(), new UserAttributeMapper()); return ldapTemplate.search("", andFilter.encode(), new LdapUserAttributeMapper());
} }
@Override @Override
......
...@@ -21,6 +21,7 @@ import org.slf4j.LoggerFactory; ...@@ -21,6 +21,7 @@ import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.oauth2.sdk.AuthorizationCode; import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant; import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.AuthorizationGrant; import com.nimbusds.oauth2.sdk.AuthorizationGrant;
...@@ -40,7 +41,11 @@ import com.nimbusds.oauth2.sdk.token.BearerAccessToken; ...@@ -40,7 +41,11 @@ import com.nimbusds.oauth2.sdk.token.BearerAccessToken;
import com.nimbusds.oauth2.sdk.token.Tokens; import com.nimbusds.oauth2.sdk.token.Tokens;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponse; import com.nimbusds.openid.connect.sdk.OIDCTokenResponse;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponseParser; import com.nimbusds.openid.connect.sdk.OIDCTokenResponseParser;
import com.nimbusds.openid.connect.sdk.UserInfoErrorResponse;
import com.nimbusds.openid.connect.sdk.UserInfoRequest; import com.nimbusds.openid.connect.sdk.UserInfoRequest;
import com.nimbusds.openid.connect.sdk.UserInfoResponse;
import com.nimbusds.openid.connect.sdk.UserInfoSuccessResponse;
import com.nimbusds.openid.connect.sdk.claims.UserInfo;
import com.nimbusds.openid.connect.sdk.token.OIDCTokens; import com.nimbusds.openid.connect.sdk.token.OIDCTokens;
import edu.kit.scc.http.CustomSSLContext; import edu.kit.scc.http.CustomSSLContext;
...@@ -82,9 +87,8 @@ public class OidcClient { ...@@ -82,9 +87,8 @@ public class OidcClient {
* @return a {@link JSONObject} with the OIDC user information * @return a {@link JSONObject} with the OIDC user information
*/ */
@SuppressWarnings("static-access") @SuppressWarnings("static-access")
public JSONObject requestUserInfo(String accessToken) { public UserInfo requestUserInfo(String accessToken, JWTClaimsSet claimsSet) {
JSONObject userInfoResponse = null; UserInfo userInfo = null;
try { try {
AccessToken token = AccessToken.parse("Bearer " + accessToken); AccessToken token = AccessToken.parse("Bearer " + accessToken);
...@@ -96,38 +100,47 @@ public class OidcClient { ...@@ -96,38 +100,47 @@ public class OidcClient {
httpRequest.setDefaultSSLSocketFactory(sslContext.getSocketFactory()); httpRequest.setDefaultSSLSocketFactory(sslContext.getSocketFactory());
HTTPResponse response = null; HTTPResponse response = null;
// DEBUG
logHttpRequest(httpRequest);
response = request.toHTTPRequest().send(); response = request.toHTTPRequest().send();
log.debug(response.getContentAsJSONObject().toJSONString()); // DEBUG
logHttpResponse(response);
return new JSONObject(response.getContentAsJSONObject().toJSONString());
net.minidev.json.JSONObject jsonResponse = response.getContentAsJSONObject();
// userInfoResponse = UserInfoResponse.parse(response); jsonResponse.put("sub", claimsSet.getSubject());
//
// if (userInfoResponse instanceof UserInfoErrorResponse) { response.setContent(jsonResponse.toJSONString());
// ErrorObject error = ((UserInfoErrorResponse)
// userInfoResponse).getErrorObject(); UserInfoResponse userInfoResponse = UserInfoResponse.parse(response);
// System.out.println("ERROR " + error.getDescription());
// return null; if (userInfoResponse instanceof UserInfoErrorResponse) {
// } UserInfoErrorResponse errorResponse = (UserInfoErrorResponse) userInfoResponse;
//
// UserInfoSuccessResponse successResponse = ErrorObject error = ((UserInfoErrorResponse) errorResponse).getErrorObject();
// (UserInfoSuccessResponse) userInfoResponse; log.warn("ERROR HTTP {} code {}", error.getHTTPStatusCode(), error.getCode());
// String claims = log.warn("ERROR " + error.getDescription());
// successResponse.getUserInfo().toJSONObject().toJSONString(); return null;
// }
// System.out.println(claims);
// UserInfoSuccessResponse successResponse = (UserInfoSuccessResponse) userInfoResponse;
// return successResponse;
userInfo = successResponse.getUserInfo();
log.debug(userInfo.toJSONObject().toJSONString());
return userInfo;
} catch (ParseException e) { } catch (ParseException e) {
e.printStackTrace(); log.error("ERROR {}", e.getMessage());
} catch (URISyntaxException e) { } catch (URISyntaxException e) {
e.printStackTrace(); log.error("ERROR {}", e.getMessage());
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("ERROR {}", e.getMessage());
} }
return userInfoResponse;
return userInfo;
} }
/** /**
...@@ -169,19 +182,17 @@ public class OidcClient { ...@@ -169,19 +182,17 @@ public class OidcClient {
httpRequest.setDefaultHostnameVerifier(new NullHostNameVerifier()); httpRequest.setDefaultHostnameVerifier(new NullHostNameVerifier());
httpRequest.setDefaultSSLSocketFactory(sslContext.getSocketFactory()); httpRequest.setDefaultSSLSocketFactory(sslContext.getSocketFactory());
log.debug("------HTTP REQUEST DEBUG------"); // DEBUG
for (Entry<String, String> e : httpRequest.getHeaders().entrySet()) logHttpRequest(httpRequest);
log.debug("{} {}", e.getKey(), e.getValue());
log.debug("Method {}", httpRequest.getMethod());
log.debug("Query {}", httpRequest.getQuery());
log.debug("Url {}", httpRequest.getURL());
log.debug("------HTTP REQUEST DEBUG------");
httpResponse = httpRequest.send(); httpResponse = httpRequest.send();
TokenResponse response = null; TokenResponse response = null;
response = OIDCTokenResponseParser.parse(httpResponse); response = OIDCTokenResponseParser.parse(httpResponse);
// DEBUG
logHttpResponse(httpResponse);
if (response instanceof TokenErrorResponse) { if (response instanceof TokenErrorResponse) {
TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) response; TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) response;
...@@ -198,15 +209,35 @@ public class OidcClient { ...@@ -198,15 +209,35 @@ public class OidcClient {
tokens = oidcTokenResponse.getOIDCTokens(); tokens = oidcTokenResponse.getOIDCTokens();
log.debug(oidcTokenResponse.getOIDCTokens().toJSONObject().toJSONString()); log.debug(tokens.toJSONObject().toJSONString());
return tokens; return tokens;
} catch (ParseException e) { } catch (ParseException e) {
e.printStackTrace(); log.error("ERROR {}", e.getMessage());
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("ERROR {}", e.getMessage());
} catch (URISyntaxException e) { } catch (URISyntaxException e) {
e.printStackTrace(); log.error("ERROR {}", e.getMessage());
} }
return tokens; return tokens;
} }
private void logHttpRequest(HTTPRequest httpRequest) {
log.debug("------HTTP REQUEST DEBUG------");
for (Entry<String, String> e : httpRequest.getHeaders().entrySet())
log.debug("{} {}", e.getKey(), e.getValue());
log.debug("Method {}", httpRequest.getMethod());
log.debug("Query {}", httpRequest.getQuery());
log.debug("Url {}", httpRequest.getURL());
log.debug("------HTTP REQUEST DEBUG------");
}
private void logHttpResponse(HTTPResponse httpResponse) {
log.debug("------HTTP RESPONSE DEBUG------");
for (Entry<String, String> e : httpResponse.getHeaders().entrySet())
log.debug("{} {}", e.getKey(), e.getValue());
log.debug("Status code {}", httpResponse.getStatusCode());
log.debug("Content {}", httpResponse.getContent());
log.debug("------HTTP RESPONSE DEBUG------");
}
} }
...@@ -52,7 +52,9 @@ public class ScimClient { ...@@ -52,7 +52,9 @@ public class ScimClient {
public JSONObject getUser(String name) { public JSONObject getUser(String name) {
JSONObject json = null; JSONObject json = null;
HttpClient client = new HttpClient(); HttpClient client = new HttpClient();
String url = userEndpoint + "?userNameEq" + name; String url = userEndpoint.replaceAll("/$", "");
url += "?filter=userNameEq" + name;
HttpResponse response = client.makeHttpsGetRequest(user, password, url); HttpResponse response = client.makeHttpsGetRequest(user, password, url);
if (response != null) { if (response != null) {
...@@ -72,7 +74,8 @@ public class ScimClient { ...@@ -72,7 +74,8 @@ public class ScimClient {
public JSONObject getUsers() { public JSONObject getUsers() {
JSONObject json = null; JSONObject json = null;
HttpClient client = new HttpClient(); HttpClient client = new HttpClient();
HttpResponse response = client.makeHttpsGetRequest(user, password, userEndpoint); String url = userEndpoint.replaceAll("/$", "");
HttpResponse response = client.makeHttpsGetRequest(user, password, url);
if (response != null) { if (response != null) {
log.debug(response.toString()); log.debug(response.toString());
...@@ -91,7 +94,8 @@ public class ScimClient { ...@@ -91,7 +94,8 @@ public class ScimClient {
public JSONObject getGroups(String user, String password) { public JSONObject getGroups(String user, String password) {
JSONObject json = null; JSONObject json = null;
HttpClient client = new HttpClient(); HttpClient client = new HttpClient();
HttpResponse response = client.makeHttpsGetRequest(user, password, groupEndpoint); String url = groupEndpoint.replaceAll("/$", "");
HttpResponse response = client.makeHttpsGetRequest(user, password, url);
if (response != null) { if (response != null) {
log.debug(response.toString()); log.debug(response.toString());
......
...@@ -17,6 +17,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; ...@@ -17,6 +17,7 @@ import com.fasterxml.jackson.annotation.JsonInclude;
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
public class ScimUser { public class ScimUser {
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class Name { public static class Name {
private String formatted, familyName, givenName, middleName, honorificPrefix, honorificSufix; private String formatted, familyName, givenName, middleName, honorificPrefix, honorificSufix;
...@@ -79,6 +80,7 @@ public class ScimUser { ...@@ -79,6 +80,7 @@ public class ScimUser {
} }
} }
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class Email { public static class Email {
private String value, type; private String value, type;
private boolean primary; private boolean primary;
...@@ -114,6 +116,7 @@ public class ScimUser { ...@@ -114,6 +116,7 @@ public class ScimUser {
} }
} }
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class Address { public static class Address {
private String type, streetAddress, locality, region, postalCode, country, formatted; private String type, streetAddress, locality, region, postalCode, country, formatted;
private boolean primary; private boolean primary;
...@@ -194,6 +197,7 @@ public class ScimUser { ...@@ -194,6 +197,7 @@ public class ScimUser {
} }
} }
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class PhoneNumber { public static class PhoneNumber {
private String value, type; private String value, type;
...@@ -220,6 +224,7 @@ public class ScimUser { ...@@ -220,6 +224,7 @@ public class ScimUser {
} }
} }
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class Ims { public static class Ims {
private String value, type; private String value, type;
...@@ -246,6 +251,7 @@ public class ScimUser { ...@@ -246,6 +251,7 @@ public class ScimUser {
} }
} }
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class Photo { public static class Photo {
private String value, type; private String value, type;
...@@ -272,6 +278,7 @@ public class ScimUser { ...@@ -272,6 +278,7 @@ public class ScimUser {
} }
} }
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class Group { public static class Group {
private String value, $ref, display; private String value, $ref, display;
...@@ -306,6 +313,7 @@ public class ScimUser { ...@@ -306,6 +313,7 @@ public class ScimUser {
} }
} }