1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.wss4j.common.util;
20
21 import java.io.ByteArrayInputStream;
22 import java.io.ByteArrayOutputStream;
23 import java.io.IOException;
24 import java.io.OutputStream;
25 import java.nio.charset.StandardCharsets;
26 import java.util.ArrayList;
27 import java.util.Collections;
28 import java.util.List;
29
30 import javax.xml.XMLConstants;
31 import javax.xml.transform.Source;
32 import javax.xml.transform.Transformer;
33 import javax.xml.transform.TransformerException;
34 import javax.xml.transform.TransformerFactory;
35 import javax.xml.transform.dom.DOMSource;
36 import javax.xml.transform.sax.SAXSource;
37 import javax.xml.transform.stream.StreamResult;
38 import javax.xml.transform.stream.StreamSource;
39
40 import org.w3c.dom.Attr;
41 import org.w3c.dom.CDATASection;
42 import org.w3c.dom.Document;
43 import org.w3c.dom.Element;
44 import org.w3c.dom.NamedNodeMap;
45 import org.w3c.dom.Node;
46 import org.w3c.dom.Text;
47 import org.xml.sax.InputSource;
48
49 public final class XMLUtils {
50
51 public static final String XMLNS_NS = "http://www.w3.org/2000/xmlns/";
52 public static final String XML_NS = "http://www.w3.org/XML/1998/namespace";
53 public static final String WSU_NS =
54 "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-utility-1.0.xsd";
55
56 private static final org.slf4j.Logger LOG =
57 org.slf4j.LoggerFactory.getLogger(XMLUtils.class);
58
59 private XMLUtils() {
60
61 }
62
63
64
65
66
67
68
69
70
71 public static Element getDirectChildElement(Node parentNode, String localName, String namespace) {
72 if (parentNode == null) {
73 return null;
74 }
75 for (Node currentChild = parentNode.getFirstChild();
76 currentChild != null;
77 currentChild = currentChild.getNextSibling()
78 ) {
79 if (Node.ELEMENT_NODE == currentChild.getNodeType()
80 && localName.equals(currentChild.getLocalName())
81 && namespace.equals(currentChild.getNamespaceURI())) {
82 return (Element) currentChild;
83 }
84 }
85 return null;
86 }
87
88
89
90
91 public static String getElementText(Element e) {
92 if (e != null) {
93 Node node = e.getFirstChild();
94 StringBuilder builder = new StringBuilder();
95 boolean found = false;
96 while (node != null) {
97 if (Node.TEXT_NODE == node.getNodeType()) {
98 found = true;
99 builder.append(((Text)node).getData());
100 } else if (Node.CDATA_SECTION_NODE == node.getNodeType()) {
101 found = true;
102 builder.append(((CDATASection)node).getData());
103 }
104 node = node.getNextSibling();
105 }
106
107 if (!found) {
108 return null;
109 }
110 return builder.toString();
111 }
112 return null;
113 }
114
115 public static String getNamespace(String prefix, Node e) {
116 while (e != null && e.getNodeType() == Node.ELEMENT_NODE) {
117 Attr attr = null;
118 if (prefix == null) {
119 attr = ((Element) e).getAttributeNode("xmlns");
120 } else {
121 attr = ((Element) e).getAttributeNodeNS(XMLNS_NS, prefix);
122 }
123 if (attr != null) {
124 return attr.getValue();
125 }
126 e = e.getParentNode();
127 }
128 return null;
129 }
130
131 public static String prettyDocumentToString(Document doc) throws IOException, TransformerException {
132 try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
133 elementToStream(doc.getDocumentElement(), baos);
134 return new String(baos.toByteArray(), StandardCharsets.UTF_8);
135 }
136 }
137
138 public static void elementToStream(Element element, OutputStream out)
139 throws TransformerException {
140 DOMSource source = new DOMSource(element);
141 StreamResult result = new StreamResult(out);
142
143 TransformerFactory transFactory = TransformerFactory.newInstance();
144 transFactory.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true);
145 try {
146 transFactory.setAttribute(XMLConstants.ACCESS_EXTERNAL_DTD, "");
147 transFactory.setAttribute(XMLConstants.ACCESS_EXTERNAL_STYLESHEET, "");
148 } catch (IllegalArgumentException ex) {
149
150 }
151
152 Transformer transformer = transFactory.newTransformer();
153 transformer.transform(source, result);
154 }
155
156
157
158
159
160
161 public static InputSource sourceToInputSource(Source source) throws IOException, TransformerException {
162 if (source instanceof SAXSource) {
163 return ((SAXSource) source).getInputSource();
164 } else if (source instanceof DOMSource) {
165 Node node = ((DOMSource) source).getNode();
166 if (node instanceof Document) {
167 node = ((Document) node).getDocumentElement();
168 }
169 Element domElement = (Element) node;
170 try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
171 elementToStream(domElement, baos);
172 InputSource isource = new InputSource(source.getSystemId());
173 isource.setByteStream(new ByteArrayInputStream(baos.toByteArray()));
174 return isource;
175 }
176 } else if (source instanceof StreamSource) {
177 StreamSource ss = (StreamSource) source;
178 InputSource isource = new InputSource(ss.getSystemId());
179 isource.setByteStream(ss.getInputStream());
180 isource.setCharacterStream(ss.getReader());
181 isource.setPublicId(ss.getPublicId());
182 return isource;
183 } else {
184 return getInputSourceFromURI(source.getSystemId());
185 }
186 }
187
188
189
190
191
192
193
194
195 public static InputSource getInputSourceFromURI(String uri) {
196 return new InputSource(uri);
197 }
198
199
200
201
202
203
204
205
206
207
208
209
210 public static String setNamespace(Element element, String namespace, String prefix) {
211 String pre = getPrefixNS(namespace, element);
212 if (pre != null) {
213 return pre;
214 }
215 element.setAttributeNS(XMLNS_NS, "xmlns:" + prefix, namespace);
216 return prefix;
217 }
218
219 public static String getPrefixNS(String uri, Node e) {
220 while (e != null && e.getNodeType() == Element.ELEMENT_NODE) {
221 NamedNodeMap attrs = e.getAttributes();
222 int length = attrs.getLength();
223 for (int n = 0; n < length; n++) {
224 Attr a = (Attr) attrs.item(n);
225 String name = a.getName();
226 if (name.startsWith("xmlns:") && a.getNodeValue().equals(uri)) {
227 return name.substring("xmlns:".length());
228 }
229 }
230 e = e.getParentNode();
231 }
232 return null;
233 }
234
235
236
237
238
239
240
241
242 public static String getIDFromReference(String ref) {
243 if (ref == null) {
244 return null;
245 }
246 String id = ref.trim();
247 if (id.length() == 0) {
248 return null;
249 }
250 if (id.charAt(0) == '#') {
251 id = id.substring(1);
252 }
253 return id;
254 }
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273 public static Element findElementById(
274 Node startNode, String value, boolean checkMultipleElements
275 ) {
276
277
278
279 if (startNode == null) {
280 return null;
281 }
282 Node startParent = startNode.getParentNode();
283 Node processedNode = null;
284 Element foundElement = null;
285 String id = XMLUtils.getIDFromReference(value);
286
287 while (startNode != null && id != null) {
288
289 if (startNode.getNodeType() == Node.ELEMENT_NODE) {
290 Element se = (Element) startNode;
291
292 String attributeNS = se.getAttributeNS(WSU_NS, "Id");
293 if (attributeNS.length() == 0 || !id.equals(attributeNS)) {
294 attributeNS = se.getAttributeNS(null, "Id");
295 }
296 if (attributeNS.length() != 0 && id.equals(attributeNS)) {
297 if (!checkMultipleElements) {
298 return se;
299 } else if (foundElement == null) {
300 foundElement = se;
301 } else {
302 LOG.warn("Multiple elements with the same 'Id' attribute value!");
303 return null;
304 }
305 }
306 }
307
308 processedNode = startNode;
309 startNode = startNode.getFirstChild();
310
311
312 if (startNode == null) {
313
314 startNode = processedNode.getNextSibling();
315 }
316
317
318 while (startNode == null) {
319 processedNode = processedNode.getParentNode();
320 if (processedNode == startParent) {
321 return foundElement;
322 }
323
324 startNode = processedNode.getNextSibling();
325 }
326 }
327 return foundElement;
328 }
329
330
331
332
333
334
335
336
337
338
339
340
341
342 public static Element findElement(Node startNode, String name, String namespace) {
343
344
345
346
347 if (startNode == null) {
348 return null;
349 }
350 Node startParent = startNode.getParentNode();
351 Node processedNode = null;
352
353 while (startNode != null) {
354
355 if (startNode.getNodeType() == Node.ELEMENT_NODE
356 && startNode.getLocalName().equals(name)) {
357 String ns = startNode.getNamespaceURI();
358 if (ns != null && ns.equals(namespace)) {
359 return (Element)startNode;
360 }
361
362 if ((namespace == null || namespace.length() == 0)
363 && (ns == null || ns.length() == 0)) {
364 return (Element)startNode;
365 }
366 }
367 processedNode = startNode;
368 startNode = startNode.getFirstChild();
369
370
371 if (startNode == null) {
372
373 startNode = processedNode.getNextSibling();
374 }
375
376
377 while (startNode == null) {
378 processedNode = processedNode.getParentNode();
379 if (processedNode == startParent) {
380 return null;
381 }
382
383 startNode = processedNode.getNextSibling();
384 }
385 }
386 return null;
387 }
388
389
390
391
392
393
394
395
396
397
398
399
400 public static List<Element> findElements(Node startNode, String name, String namespace) {
401
402
403
404
405 if (startNode == null) {
406 return Collections.emptyList();
407 }
408 Node startParent = startNode.getParentNode();
409 Node processedNode = null;
410
411 List<Element> foundNodes = new ArrayList<>();
412 while (startNode != null) {
413
414 if (startNode.getNodeType() == Node.ELEMENT_NODE
415 && startNode.getLocalName().equals(name)) {
416 String ns = startNode.getNamespaceURI();
417 if (ns != null && ns.equals(namespace)) {
418 foundNodes.add((Element)startNode);
419 }
420
421 if ((namespace == null || namespace.length() == 0)
422 && (ns == null || ns.length() == 0)) {
423 foundNodes.add((Element)startNode);
424 }
425 }
426 processedNode = startNode;
427 startNode = startNode.getFirstChild();
428
429
430 if (startNode == null) {
431
432 startNode = processedNode.getNextSibling();
433 }
434
435
436 while (startNode == null) {
437 processedNode = processedNode.getParentNode();
438 if (processedNode == startParent) {
439 return foundNodes;
440 }
441
442 startNode = processedNode.getNextSibling();
443 }
444 }
445 return foundNodes;
446 }
447
448
449
450
451
452
453
454
455
456
457 public static Element findSAMLAssertionElementById(Node startNode, String value) {
458 Element foundElement = null;
459
460
461
462
463
464 if (startNode == null || value == null) {
465 return null;
466 }
467 Node startParent = startNode.getParentNode();
468 Node processedNode = null;
469
470 while (startNode != null) {
471
472 if (startNode.getNodeType() == Node.ELEMENT_NODE) {
473 Element se = (Element) startNode;
474 if (se.hasAttributeNS(null, "ID") && value.equals(se.getAttributeNS(null, "ID"))
475 || se.hasAttributeNS(null, "AssertionID")
476 && value.equals(se.getAttributeNS(null, "AssertionID"))) {
477 if (foundElement == null) {
478 foundElement = se;
479 } else {
480 LOG.warn("Multiple elements with the same 'ID' attribute value!");
481 return null;
482 }
483 }
484 }
485
486 processedNode = startNode;
487 startNode = startNode.getFirstChild();
488
489
490 if (startNode == null) {
491
492 startNode = processedNode.getNextSibling();
493 }
494
495
496 while (startNode == null) {
497 processedNode = processedNode.getParentNode();
498 if (processedNode == startParent) {
499 return foundElement;
500 }
501
502 startNode = processedNode.getNextSibling();
503 }
504 }
505 return foundElement;
506 }
507
508 }