1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.wss4j.common.token;
21
22 import java.math.BigInteger;
23 import java.security.cert.CertificateEncodingException;
24 import java.security.cert.X509Certificate;
25 import java.util.Arrays;
26
27 import javax.xml.namespace.QName;
28
29 import org.w3c.dom.Document;
30 import org.w3c.dom.Element;
31 import org.w3c.dom.Node;
32 import org.w3c.dom.Text;
33 import org.apache.wss4j.common.WSS4JConstants;
34 import org.apache.wss4j.common.bsp.BSPEnforcer;
35 import org.apache.wss4j.common.bsp.BSPRule;
36 import org.apache.wss4j.common.crypto.Crypto;
37 import org.apache.wss4j.common.crypto.CryptoType;
38 import org.apache.wss4j.common.crypto.Merlin;
39 import org.apache.wss4j.common.ext.WSSecurityException;
40 import org.apache.wss4j.common.util.DOM2Writer;
41 import org.apache.wss4j.common.util.KeyUtils;
42 import org.apache.wss4j.common.util.XMLUtils;
43
44
45
46
47 public class SecurityTokenReference {
48 public static final String SECURITY_TOKEN_REFERENCE = "SecurityTokenReference";
49 public static final QName STR_QNAME =
50 new QName(WSS4JConstants.WSSE_NS, SECURITY_TOKEN_REFERENCE);
51 public static final String SKI_URI =
52 WSS4JConstants.X509TOKEN_NS + "#X509SubjectKeyIdentifier";
53 public static final String THUMB_URI =
54 WSS4JConstants.SOAPMESSAGE_NS11 + "#" + WSS4JConstants.THUMBPRINT;
55 public static final String ENC_KEY_SHA1_URI =
56 WSS4JConstants.SOAPMESSAGE_NS11 + "#" + WSS4JConstants.ENC_KEY_SHA1_URI;
57 public static final String X509_V3_TYPE = WSS4JConstants.X509TOKEN_NS + "#X509v3";
58
59 private static final org.slf4j.Logger LOG =
60 org.slf4j.LoggerFactory.getLogger(SecurityTokenReference.class);
61
62 private Element element;
63 private DOMX509IssuerSerial issuerSerial;
64 private byte[] skiBytes;
65 private Reference reference;
66
67
68
69
70
71
72
73
74 public SecurityTokenReference(Element elem, BSPEnforcer bspEnforcer) throws WSSecurityException {
75 element = elem;
76 QName el = new QName(element.getNamespaceURI(), element.getLocalName());
77 if (!STR_QNAME.equals(el)) {
78 throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "badElement",
79 new Object[] {STR_QNAME, el});
80 }
81
82 checkBSPCompliance(bspEnforcer);
83
84 if (containsReference()) {
85 Node node = element.getFirstChild();
86 while (node != null) {
87 if (Node.ELEMENT_NODE == node.getNodeType()
88 && WSS4JConstants.WSSE_NS.equals(node.getNamespaceURI())
89 && "Reference".equals(node.getLocalName())) {
90 reference = new Reference((Element)node);
91 break;
92 }
93 node = node.getNextSibling();
94 }
95 }
96 }
97
98
99
100
101
102
103 public SecurityTokenReference(Document doc) {
104 element = doc.createElementNS(WSS4JConstants.WSSE_NS, "wsse:SecurityTokenReference");
105 }
106
107
108
109
110
111 public void addWSSENamespace() {
112 XMLUtils.setNamespace(element, WSS4JConstants.WSSE_NS, WSS4JConstants.WSSE_PREFIX);
113 }
114
115
116
117
118
119 public void addWSUNamespace() {
120 element.setAttributeNS(XMLUtils.XMLNS_NS, "xmlns:" + WSS4JConstants.WSU_PREFIX, WSS4JConstants.WSU_NS);
121 }
122
123
124
125
126
127 public void addTokenType(String tokenType) {
128 if (tokenType != null) {
129 XMLUtils.setNamespace(element, WSS4JConstants.WSSE11_NS, WSS4JConstants.WSSE11_PREFIX);
130 element.setAttributeNS(
131 WSS4JConstants.WSSE11_NS,
132 WSS4JConstants.WSSE11_PREFIX + ":" + WSS4JConstants.TOKEN_TYPE,
133 tokenType
134 );
135 }
136 }
137
138
139
140
141
142 public String getTokenType() {
143 return element.getAttributeNS(
144 WSS4JConstants.WSSE11_NS, WSS4JConstants.TOKEN_TYPE
145 );
146 }
147
148
149
150
151
152
153 public void setReference(Reference ref) {
154 Element elem = getFirstElement();
155 if (elem != null) {
156 element.replaceChild(ref.getElement(), elem);
157 } else {
158 element.appendChild(ref.getElement());
159 }
160 this.reference = ref;
161 }
162
163
164
165
166
167
168
169
170 public Reference getReference() throws WSSecurityException {
171 return reference;
172 }
173
174
175
176
177
178
179
180
181
182 public void setKeyIdentifier(X509Certificate cert)
183 throws WSSecurityException {
184 Document doc = element.getOwnerDocument();
185 byte[] data = null;
186 try {
187 data = cert.getEncoded();
188 } catch (CertificateEncodingException e) {
189 throw new WSSecurityException(
190 WSSecurityException.ErrorCode.SECURITY_TOKEN_UNAVAILABLE, e, "encodeError"
191 );
192 }
193 Text text = doc.createTextNode(org.apache.xml.security.utils.XMLUtils.encodeToString(data));
194
195 createKeyIdentifier(doc, X509_V3_TYPE, text, true);
196 }
197
198
199
200
201
202
203
204
205
206
207 public void setKeyIdentifierSKI(X509Certificate cert, Crypto crypto)
208 throws WSSecurityException {
209
210
211
212 if (cert.getVersion() != 3) {
213 throw new WSSecurityException(
214 WSSecurityException.ErrorCode.UNSUPPORTED_SECURITY_TOKEN,
215 "invalidCertForSKI", new Object[] {cert.getVersion()});
216 }
217
218 Document doc = element.getOwnerDocument();
219
220 Crypto skiCrypto = crypto;
221 if (skiCrypto == null) {
222 skiCrypto = new Merlin();
223 }
224 byte[] data = skiCrypto.getSKIBytesFromCert(cert);
225
226 Text text = doc.createTextNode(org.apache.xml.security.utils.XMLUtils.encodeToString(data));
227 createKeyIdentifier(doc, SKI_URI, text, true);
228 }
229
230
231
232
233
234
235
236
237
238
239
240 public void setKeyIdentifierThumb(X509Certificate cert) throws WSSecurityException {
241 Document doc = element.getOwnerDocument();
242 byte[] encodedCert = null;
243 try {
244 encodedCert = cert.getEncoded();
245 } catch (CertificateEncodingException e1) {
246 throw new WSSecurityException(
247 WSSecurityException.ErrorCode.SECURITY_TOKEN_UNAVAILABLE, e1, "encodeError"
248 );
249 }
250 try {
251 byte[] encodedBytes = KeyUtils.generateDigest(encodedCert);
252 Text text = doc.createTextNode(org.apache.xml.security.utils.XMLUtils.encodeToString(encodedBytes));
253 createKeyIdentifier(doc, THUMB_URI, text, true);
254 } catch (WSSecurityException e1) {
255 throw new WSSecurityException(
256 WSSecurityException.ErrorCode.FAILURE, e1, "decoding.general"
257 );
258 }
259 }
260
261 public void setKeyIdentifierEncKeySHA1(String value) throws WSSecurityException {
262 Document doc = element.getOwnerDocument();
263 Text text = doc.createTextNode(value);
264 createKeyIdentifier(doc, ENC_KEY_SHA1_URI, text, true);
265 }
266
267 public void setKeyIdentifier(String valueType, String keyIdVal) throws WSSecurityException {
268 setKeyIdentifier(valueType, keyIdVal, false);
269 }
270
271 public void setKeyIdentifier(String valueType, String keyIdVal, boolean base64)
272 throws WSSecurityException {
273 Document doc = element.getOwnerDocument();
274 createKeyIdentifier(doc, valueType, doc.createTextNode(keyIdVal), base64);
275 }
276
277 private void createKeyIdentifier(Document doc, String uri, Node node, boolean base64) {
278 Element keyId = doc.createElementNS(WSS4JConstants.WSSE_NS, "wsse:KeyIdentifier");
279 keyId.setAttributeNS(null, "ValueType", uri);
280 if (base64) {
281 keyId.setAttributeNS(null, "EncodingType", WSS4JConstants.BASE64_ENCODING);
282 }
283
284 keyId.appendChild(node);
285 Element elem = getFirstElement();
286 if (elem != null) {
287 element.replaceChild(keyId, elem);
288 } else {
289 element.appendChild(keyId);
290 }
291 }
292
293
294
295
296
297
298 public Element getFirstElement() {
299 for (Node currentChild = element.getFirstChild();
300 currentChild != null;
301 currentChild = currentChild.getNextSibling()
302 ) {
303 if (Node.ELEMENT_NODE == currentChild.getNodeType()) {
304 return (Element) currentChild;
305 }
306 }
307 return null;
308 }
309
310
311
312
313
314
315
316 public X509Certificate[] getKeyIdentifier(Crypto crypto) throws WSSecurityException {
317 if (crypto == null) {
318 return new X509Certificate[0];
319 }
320
321 Element elem = getFirstElement();
322 String value = elem.getAttributeNS(null, "ValueType");
323
324 if (X509_V3_TYPE.equals(value)) {
325 X509Security token = new X509Security(elem, new BSPEnforcer(true));
326 X509Certificate cert = token.getX509Certificate(crypto);
327 return new X509Certificate[]{cert};
328 } else if (SKI_URI.equals(value)) {
329 X509Certificate cert = getX509SKIAlias(crypto);
330 if (cert != null) {
331 return new X509Certificate[]{cert};
332 }
333 } else if (THUMB_URI.equals(value)) {
334 String text = XMLUtils.getElementText(getFirstElement());
335 if (text != null) {
336 byte[] thumb = org.apache.xml.security.utils.XMLUtils.decode(text);
337
338 CryptoType cryptoType = new CryptoType(CryptoType.TYPE.THUMBPRINT_SHA1);
339 cryptoType.setBytes(thumb);
340 X509Certificate[] certs = crypto.getX509Certificates(cryptoType);
341 if (certs != null && certs.length > 0) {
342 return new X509Certificate[]{certs[0]};
343 }
344 }
345 }
346
347 return new X509Certificate[0];
348 }
349
350 public String getKeyIdentifierValue() {
351 if (containsKeyIdentifier()) {
352 return XMLUtils.getElementText(getFirstElement());
353 }
354 return null;
355 }
356
357 public String getKeyIdentifierValueType() {
358 if (containsKeyIdentifier()) {
359 Element elem = getFirstElement();
360 return elem.getAttributeNS(null, "ValueType");
361 }
362 return null;
363 }
364
365 public String getKeyIdentifierEncodingType() {
366 if (containsKeyIdentifier()) {
367 Element elem = getFirstElement();
368 return elem.getAttributeNS(null, "EncodingType");
369 }
370 return null;
371 }
372
373 public X509Certificate getX509SKIAlias(Crypto crypto) throws WSSecurityException {
374 if (crypto == null) {
375 return null;
376 }
377
378 if (skiBytes == null) {
379 skiBytes = getSKIBytes();
380 if (skiBytes == null) {
381 return null;
382 }
383 }
384 CryptoType cryptoType = new CryptoType(CryptoType.TYPE.SKI_BYTES);
385 cryptoType.setBytes(skiBytes);
386 X509Certificate[] certs = crypto.getX509Certificates(cryptoType);
387 if (certs != null && certs.length > 0) {
388 return certs[0];
389 }
390 return null;
391 }
392
393 public byte[] getSKIBytes() {
394 if (skiBytes != null) {
395 return skiBytes;
396 }
397 String text = XMLUtils.getElementText(getFirstElement());
398 if (text != null) {
399 skiBytes = org.apache.xml.security.utils.XMLUtils.decode(text);
400 }
401 return skiBytes;
402 }
403
404
405
406
407
408
409
410
411 public void setUnknownElement(Element unknownElement) {
412 Element elem = getFirstElement();
413 if (elem != null) {
414 element.replaceChild(unknownElement, elem);
415 } else {
416 element.appendChild(unknownElement);
417 }
418 }
419
420
421
422
423
424
425 public X509Certificate[] getX509IssuerSerial(Crypto crypto) throws WSSecurityException {
426 if (crypto == null) {
427 return new X509Certificate[0];
428 }
429
430 if (issuerSerial == null) {
431 issuerSerial = getIssuerSerial();
432 if (issuerSerial == null) {
433 return new X509Certificate[0];
434 }
435 }
436 CryptoType cryptoType = new CryptoType(CryptoType.TYPE.ISSUER_SERIAL);
437 cryptoType.setIssuerSerial(issuerSerial.getIssuer(), issuerSerial.getSerialNumber());
438 return crypto.getX509Certificates(cryptoType);
439 }
440
441 private DOMX509IssuerSerial getIssuerSerial() throws WSSecurityException {
442 if (issuerSerial != null) {
443 return issuerSerial;
444 }
445 Element elem = getFirstElement();
446 if (elem == null) {
447 return null;
448 }
449 if (WSS4JConstants.X509_DATA_LN.equals(elem.getLocalName())) {
450 elem =
451 XMLUtils.findElement(
452 elem, WSS4JConstants.X509_ISSUER_SERIAL_LN, WSS4JConstants.SIG_NS
453 );
454 }
455 issuerSerial = new DOMX509IssuerSerial(elem);
456
457 return issuerSerial;
458 }
459
460
461
462
463
464
465
466 public boolean containsReference() {
467 return containsElement(WSS4JConstants.WSSE_NS, "Reference");
468 }
469
470
471
472
473
474
475
476 public boolean containsX509IssuerSerial() {
477 return containsElement(WSS4JConstants.SIG_NS, WSS4JConstants.X509_ISSUER_SERIAL_LN);
478 }
479
480
481
482
483
484
485
486 public boolean containsX509Data() {
487 return containsElement(WSS4JConstants.SIG_NS, WSS4JConstants.X509_DATA_LN);
488 }
489
490
491
492
493
494
495
496 public boolean containsKeyIdentifier() {
497 return containsElement(WSS4JConstants.WSSE_NS, "KeyIdentifier");
498 }
499
500 private boolean containsElement(String namespace, String localname) {
501 Node node = element.getFirstChild();
502 while (node != null) {
503 if (Node.ELEMENT_NODE == node.getNodeType()) {
504 String ns = node.getNamespaceURI();
505 String name = node.getLocalName();
506 if ((namespace != null && namespace.equals(ns)
507 || namespace == null && ns == null)
508 && localname.equals(name)
509 ) {
510 return true;
511 }
512 }
513 node = node.getNextSibling();
514 }
515 return false;
516 }
517
518
519
520
521
522
523 public Element getElement() {
524 return element;
525 }
526
527
528
529
530
531
532 public void setID(String id) {
533 element.setAttributeNS(WSS4JConstants.WSU_NS, WSS4JConstants.WSU_PREFIX + ":Id", id);
534 }
535
536
537
538
539
540 public String getID() {
541 return element.getAttributeNS(WSS4JConstants.WSU_NS, "Id");
542 }
543
544
545
546
547
548
549 public String toString() {
550 return DOM2Writer.nodeToString(element);
551 }
552
553
554
555
556
557 private void checkBSPCompliance(BSPEnforcer bspEnforcer) throws WSSecurityException {
558
559 int result = 0;
560 Node node = element.getFirstChild();
561 Element child = null;
562 while (node != null) {
563 if (Node.ELEMENT_NODE == node.getNodeType()) {
564 result++;
565 child = (Element)node;
566 }
567 node = node.getNextSibling();
568 }
569 if (result != 1) {
570 bspEnforcer.handleBSPRule(BSPRule.R3061);
571 }
572 if ("KeyIdentifier".equals(child.getLocalName())
573 && WSS4JConstants.WSSE_NS.equals(child.getNamespaceURI())) {
574
575 String valueType = getKeyIdentifierValueType();
576
577 if (valueType == null || valueType.length() == 0) {
578 bspEnforcer.handleBSPRule(BSPRule.R3054);
579 }
580 String encodingType = getFirstElement().getAttributeNS(null, "EncodingType");
581
582 if (encodingType.length() != 0 && !WSS4JConstants.BASE64_ENCODING.equals(encodingType)) {
583 bspEnforcer.handleBSPRule(BSPRule.R3071);
584 }
585
586
587 if (!WSS4JConstants.WSS_SAML_KI_VALUE_TYPE.equals(valueType)
588 && !WSS4JConstants.WSS_SAML2_KI_VALUE_TYPE.equals(valueType)
589 && encodingType.length() == 0) {
590 bspEnforcer.handleBSPRule(BSPRule.R3070);
591 }
592 } else if ("Embedded".equals(child.getLocalName())) {
593 result = 0;
594 node = child.getFirstChild();
595 while (node != null) {
596 if (Node.ELEMENT_NODE == node.getNodeType()) {
597 result++;
598
599 if ("SecurityTokenReference".equals(node.getLocalName())
600 && WSS4JConstants.WSSE_NS.equals(node.getNamespaceURI())) {
601 bspEnforcer.handleBSPRule(BSPRule.R3056);
602 }
603 }
604 node = node.getNextSibling();
605 }
606
607 if (result != 1) {
608 bspEnforcer.handleBSPRule(BSPRule.R3060);
609 }
610 }
611 }
612
613 @Override
614 public int hashCode() {
615 int result = 17;
616 try {
617 Reference reference = getReference();
618 if (reference != null) {
619 result = 31 * result + reference.hashCode();
620 }
621 } catch (WSSecurityException e) {
622 LOG.error(e.getMessage(), e);
623 }
624 String keyIdentifierEncodingType = getKeyIdentifierEncodingType();
625 if (keyIdentifierEncodingType != null) {
626 result = 31 * result + keyIdentifierEncodingType.hashCode();
627 }
628 String keyIdentifierValueType = getKeyIdentifierValueType();
629 if (keyIdentifierValueType != null) {
630 result = 31 * result + keyIdentifierValueType.hashCode();
631 }
632 String keyIdentifierValue = getKeyIdentifierValue();
633 if (keyIdentifierValue != null) {
634 result = 31 * result + keyIdentifierValue.hashCode();
635 }
636 String tokenType = getTokenType();
637 if (tokenType != null) {
638 result = 31 * result + tokenType.hashCode();
639 }
640 byte[] skiBytes = getSKIBytes();
641 if (skiBytes != null) {
642 result = 31 * result + Arrays.hashCode(skiBytes);
643 }
644 String issuer = null;
645 BigInteger serialNumber = null;
646
647 try {
648 issuer = getIssuerSerial().getIssuer();
649 serialNumber = getIssuerSerial().getSerialNumber();
650 } catch (WSSecurityException e) {
651 LOG.error(e.getMessage(), e);
652 }
653 if (issuer != null) {
654 result = 31 * result + issuer.hashCode();
655 }
656 if (serialNumber != null) {
657 result = 31 * result + serialNumber.hashCode();
658 }
659 return result;
660 }
661
662 @Override
663 public boolean equals(Object object) {
664 if (!(object instanceof SecurityTokenReference)) {
665 return false;
666 }
667 SecurityTokenReference tokenReference = (SecurityTokenReference)object;
668 try {
669 if (!getReference().equals(tokenReference.getReference())) {
670 return false;
671 }
672 } catch (WSSecurityException e) {
673 LOG.error(e.getMessage(), e);
674 return false;
675 }
676 if (!compare(getKeyIdentifierEncodingType(), tokenReference.getKeyIdentifierEncodingType())) {
677 return false;
678 }
679 if (!compare(getKeyIdentifierValueType(), tokenReference.getKeyIdentifierValueType())) {
680 return false;
681 }
682 if (!compare(getKeyIdentifierValue(), tokenReference.getKeyIdentifierValue())) {
683 return false;
684 }
685 if (!compare(getTokenType(), tokenReference.getTokenType())) {
686 return false;
687 }
688 if (!Arrays.equals(getSKIBytes(), tokenReference.getSKIBytes())) {
689 return false;
690 }
691 try {
692 if (getIssuerSerial() != null && tokenReference.getIssuerSerial() != null) {
693 if (!compare(getIssuerSerial().getIssuer(), tokenReference.getIssuerSerial().getIssuer())) {
694 return false;
695 }
696 if (!compare(getIssuerSerial().getSerialNumber(), tokenReference.getIssuerSerial().getSerialNumber())) {
697 return false;
698 }
699 }
700 } catch (WSSecurityException e) {
701 LOG.error(e.getMessage(), e);
702 return false;
703 }
704
705 return true;
706 }
707
708 private boolean compare(String item1, String item2) {
709 if (item1 == null && item2 != null) {
710 return false;
711 } else if (item1 != null && !item1.equals(item2)) {
712 return false;
713 }
714 return true;
715 }
716
717 private boolean compare(BigInteger item1, BigInteger item2) {
718 if (item1 == null && item2 != null) {
719 return false;
720 } else if (item1 != null && !item1.equals(item2)) {
721 return false;
722 }
723 return true;
724 }
725 }