1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.wss4j.policy.stax.test;
20
21 import org.apache.neethi.builders.AssertionBuilder;
22 import org.apache.wss4j.common.crypto.WSProviderConfig;
23 import org.apache.wss4j.common.ext.WSSecurityException;
24 import org.apache.wss4j.common.saml.SAMLCallback;
25 import org.apache.wss4j.common.saml.SamlAssertionWrapper;
26
27 import org.apache.wss4j.policy.SPConstants;
28 import org.apache.wss4j.policy.stax.enforcer.PolicyEnforcer;
29 import org.apache.wss4j.policy.stax.enforcer.PolicyEnforcerFactory;
30 import org.apache.wss4j.common.WSSPolicyException;
31 import org.apache.wss4j.stax.securityToken.WSSecurityTokenConstants;
32 import org.apache.wss4j.stax.setup.WSSec;
33 import org.apache.wss4j.stax.impl.securityToken.*;
34 import org.apache.wss4j.stax.test.AbstractTestBase;
35 import org.apache.xml.security.binding.xmldsig11.ECKeyValueType;
36 import org.apache.xml.security.binding.xmldsig11.NamedCurveType;
37 import org.apache.xml.security.exceptions.XMLSecurityException;
38 import org.apache.xml.security.stax.config.Init;
39 import org.apache.xml.security.stax.impl.util.IDGenerator;
40
41 import org.junit.jupiter.api.BeforeAll;
42 import org.w3c.dom.Document;
43 import org.w3c.dom.Element;
44 import org.w3c.dom.Node;
45 import org.w3c.dom.NodeList;
46 import org.xml.sax.SAXException;
47
48 import javax.xml.namespace.QName;
49 import javax.xml.parsers.DocumentBuilder;
50 import javax.xml.parsers.DocumentBuilderFactory;
51 import javax.xml.parsers.ParserConfigurationException;
52
53 import java.io.ByteArrayInputStream;
54 import java.io.IOException;
55 import java.io.InputStreamReader;
56 import java.nio.charset.Charset;
57 import java.nio.charset.StandardCharsets;
58 import java.security.KeyStore;
59 import java.security.cert.Certificate;
60 import java.security.cert.X509Certificate;
61 import java.util.Collections;
62 import java.util.List;
63
64 public class AbstractPolicyTestBase extends AbstractTestBase {
65
66 @BeforeAll
67 public static void setUp() throws Exception {
68 WSProviderConfig.init();
69 Init.init(WSSec.class.getClassLoader().getResource("wss/wss-config.xml").toURI(), WSSec.class);
70 }
71
72 protected PolicyEnforcer buildAndStartPolicyEngine(String policyString)
73 throws ParserConfigurationException, SAXException, IOException, WSSPolicyException {
74 return this.buildAndStartPolicyEngine(policyString, false);
75 }
76
77 protected PolicyEnforcer buildAndStartPolicyEngine(String policyString, boolean replacePolicyElement)
78 throws ParserConfigurationException, SAXException, IOException, WSSPolicyException {
79 return buildAndStartPolicyEngine(policyString, replacePolicyElement, null);
80 }
81
82 protected PolicyEnforcer buildAndStartPolicyEngine(
83 String policyString, boolean replacePolicyElement, List<AssertionBuilder<Element>> customAssertionBuilders)
84 throws ParserConfigurationException, SAXException, IOException, WSSPolicyException {
85 DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance();
86 documentBuilderFactory.setNamespaceAware(true);
87 documentBuilderFactory.setValidating(false);
88 DocumentBuilder documentBuilder = documentBuilderFactory.newDocumentBuilder();
89 Document document = documentBuilder.parse(
90 this.getClass().getClassLoader().getResourceAsStream("testdata/wsdl/wsdl-template.wsdl"));
91 NodeList nodeList = document.getElementsByTagNameNS("*", SPConstants.P_LOCALNAME);
92
93 Document policyDocument = documentBuilder.parse(new ByteArrayInputStream(policyString.getBytes(StandardCharsets.UTF_8)));
94 Node policyNode = document.importNode(policyDocument.getDocumentElement(), true);
95 Element element = (Element) nodeList.item(0);
96 if (replacePolicyElement) {
97 element.getParentNode().replaceChild(element, policyNode);
98 } else {
99 element.appendChild(policyNode);
100 }
101 PolicyEnforcerFactory policyEnforcerFactory = PolicyEnforcerFactory.newInstance(document, customAssertionBuilders);
102 PolicyEnforcer policyEnforcer = policyEnforcerFactory.newPolicyEnforcer("", false, null, 0, false);
103
104 return policyEnforcer;
105 }
106
107 public X509SecurityTokenImpl getX509Token(WSSecurityTokenConstants.TokenType tokenType) throws Exception {
108 return getX509Token(tokenType, "transmitter");
109 }
110
111 public X509SecurityTokenImpl getX509Token(WSSecurityTokenConstants.TokenType tokenType, final String keyAlias) throws Exception {
112
113 final KeyStore keyStore = KeyStore.getInstance("jks");
114 keyStore.load(this.getClass().getClassLoader().getResourceAsStream("transmitter.jks"), "default".toCharArray());
115
116 X509SecurityTokenImpl x509SecurityToken =
117 new X509SecurityTokenImpl(
118 tokenType, null, null, null, IDGenerator.generateID(null),
119 WSSecurityTokenConstants.KEYIDENTIFIER_THUMBPRINT_IDENTIFIER, null, true) {
120 @Override
121 protected String getAlias() throws XMLSecurityException {
122 return keyAlias;
123 }
124
125 @Override
126 public List<QName> getElementPath() {
127 List<QName> elementPath = super.getElementPath();
128 if (elementPath != null) {
129 return elementPath;
130 }
131 return Collections.emptyList();
132 }
133 };
134 x509SecurityToken.setSecretKey("", keyStore.getKey(keyAlias, "default".toCharArray()));
135 x509SecurityToken.setPublicKey(keyStore.getCertificate(keyAlias).getPublicKey());
136
137 Certificate[] certificates;
138 try {
139 certificates = keyStore.getCertificateChain(keyAlias);
140 } catch (Exception e) {
141 throw new XMLSecurityException(e);
142 }
143
144 X509Certificate[] x509Certificates = new X509Certificate[certificates.length];
145 for (int i = 0; i < certificates.length; i++) {
146 Certificate certificate = certificates[i];
147 x509Certificates[i] = (X509Certificate) certificate;
148 }
149 x509SecurityToken.setX509Certificates(x509Certificates);
150 return x509SecurityToken;
151 }
152
153 public KerberosServiceSecurityTokenImpl getKerberosServiceSecurityToken(WSSecurityTokenConstants.TokenType tokenType) throws Exception {
154 return new KerberosServiceSecurityTokenImpl(
155 null, null, null, null, IDGenerator.generateID(null),
156 WSSecurityTokenConstants.KEYIDENTIFIER_SECURITY_TOKEN_DIRECT_REFERENCE);
157 }
158
159 public HttpsSecurityTokenImpl getHttpsSecurityToken(WSSecurityTokenConstants.TokenType tokenType) throws Exception {
160 return new HttpsSecurityTokenImpl(getX509Token(tokenType).getX509Certificates()[0]);
161 }
162
163 public RsaKeyValueSecurityTokenImpl getRsaKeyValueSecurityToken() throws Exception {
164 return new RsaKeyValueSecurityTokenImpl(null, null, null, null, null);
165 }
166
167 public DsaKeyValueSecurityTokenImpl getDsaKeyValueSecurityToken() throws Exception {
168 return new DsaKeyValueSecurityTokenImpl(null, null, null, null, null);
169 }
170
171 public ECKeyValueSecurityTokenImpl getECKeyValueSecurityToken() throws Exception {
172 ECKeyValueType ecKeyValueType = new ECKeyValueType();
173 ecKeyValueType.setNamedCurve(new NamedCurveType());
174 return new ECKeyValueSecurityTokenImpl(ecKeyValueType, null, null, null, null);
175 }
176
177 protected String loadResourceAsString(String resource, Charset encoding) throws IOException {
178 InputStreamReader inputStreamReader = new InputStreamReader(this.getClass().getClassLoader().getResourceAsStream(resource), encoding);
179 StringBuilder stringBuilder = new StringBuilder();
180 int read = 0;
181 char[] buffer = new char[1024];
182 while ((read = inputStreamReader.read(buffer)) != -1) {
183 stringBuilder.append(buffer, 0, read);
184 }
185 return stringBuilder.toString();
186 }
187
188 public static SamlAssertionWrapper createSamlAssertionWrapper(SAMLCallback samlCallback) throws WSSecurityException {
189 return new SamlAssertionWrapper(samlCallback);
190 }
191 }