001/*
002 * $HeadURL: http://juliusdavies.ca/svn/not-yet-commons-ssl/tags/commons-ssl-0.3.9/src/java/org/apache/commons/ssl/PKCS8Key.java $
003 * $Revision: 121 $
004 * $Date: 2007-11-13 21:26:57 -0800 (Tue, 13 Nov 2007) $
005 *
006 * ====================================================================
007 * Licensed to the Apache Software Foundation (ASF) under one
008 * or more contributor license agreements.  See the NOTICE file
009 * distributed with this work for additional information
010 * regarding copyright ownership.  The ASF licenses this file
011 * to you under the Apache License, Version 2.0 (the
012 * "License"); you may not use this file except in compliance
013 * with the License.  You may obtain a copy of the License at
014 *
015 *   http://www.apache.org/licenses/LICENSE-2.0
016 *
017 * Unless required by applicable law or agreed to in writing,
018 * software distributed under the License is distributed on an
019 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
020 * KIND, either express or implied.  See the License for the
021 * specific language governing permissions and limitations
022 * under the License.
023 * ====================================================================
024 *
025 * This software consists of voluntary contributions made by many
026 * individuals on behalf of the Apache Software Foundation.  For more
027 * information on the Apache Software Foundation, please see
028 * <http://www.apache.org/>.
029 *
030 */
031
032package org.apache.commons.ssl;
033
034import org.apache.commons.ssl.asn1.ASN1EncodableVector;
035import org.apache.commons.ssl.asn1.ASN1OutputStream;
036import org.apache.commons.ssl.asn1.DEREncodable;
037import org.apache.commons.ssl.asn1.DERInteger;
038import org.apache.commons.ssl.asn1.DERNull;
039import org.apache.commons.ssl.asn1.DERObjectIdentifier;
040import org.apache.commons.ssl.asn1.DEROctetString;
041import org.apache.commons.ssl.asn1.DERSequence;
042
043import javax.crypto.BadPaddingException;
044import javax.crypto.Cipher;
045import javax.crypto.IllegalBlockSizeException;
046import javax.crypto.Mac;
047import javax.crypto.NoSuchPaddingException;
048import javax.crypto.SecretKey;
049import javax.crypto.spec.IvParameterSpec;
050import javax.crypto.spec.RC2ParameterSpec;
051import javax.crypto.spec.RC5ParameterSpec;
052import javax.crypto.spec.SecretKeySpec;
053import java.io.ByteArrayInputStream;
054import java.io.ByteArrayOutputStream;
055import java.io.File;
056import java.io.FileInputStream;
057import java.io.IOException;
058import java.io.InputStream;
059import java.math.BigInteger;
060import java.security.GeneralSecurityException;
061import java.security.InvalidAlgorithmParameterException;
062import java.security.InvalidKeyException;
063import java.security.KeyFactory;
064import java.security.MessageDigest;
065import java.security.NoSuchAlgorithmException;
066import java.security.PrivateKey;
067import java.security.spec.KeySpec;
068import java.security.spec.PKCS8EncodedKeySpec;
069import java.util.Arrays;
070import java.util.Collections;
071import java.util.Iterator;
072import java.util.List;
073
074/**
075 * Utility for decrypting PKCS8 private keys.  Way easier to use than
076 * javax.crypto.EncryptedPrivateKeyInfo since all you need is the byte[] array
077 * and the password.  You don't need to know anything else about the PKCS8
078 * key you pass in.
079 * </p><p>
080 * Can handle base64 PEM, or raw DER.
081 * Can handle PKCS8 Version 1.5 and 2.0.
082 * Can also handle OpenSSL encrypted or unencrypted private keys (DSA or RSA).
083 * </p><p>
084 * The PKCS12 key derivation (the "pkcs12()" method) comes from BouncyCastle.
085 * </p>
086 *
087 * @author Credit Union Central of British Columbia
088 * @author <a href="http://www.cucbc.com/">www.cucbc.com</a>
089 * @author <a href="mailto:juliusdavies@cucbc.com">juliusdavies@cucbc.com</a>
090 * @author <a href="bouncycastle.org">bouncycastle.org</a>
091 * @since 7-Nov-2006
092 */
093public class PKCS8Key {
094    public final static String RSA_OID = "1.2.840.113549.1.1.1";
095    public final static String DSA_OID = "1.2.840.10040.4.1";
096
097    public final static String PKCS8_UNENCRYPTED = "PRIVATE KEY";
098    public final static String PKCS8_ENCRYPTED = "ENCRYPTED PRIVATE KEY";
099    public final static String OPENSSL_RSA = "RSA PRIVATE KEY";
100    public final static String OPENSSL_DSA = "DSA PRIVATE KEY";
101
102    private final PrivateKey privateKey;
103    private final byte[] decryptedBytes;
104    private final String transformation;
105    private final int keySize;
106    private final boolean isDSA;
107    private final boolean isRSA;
108
109    static {
110        JavaImpl.load();
111    }
112
113    /**
114     * @param in       pkcs8 file to parse (pem or der, encrypted or unencrypted)
115     * @param password password to decrypt the pkcs8 file.  Ignored if the
116     *                 supplied pkcs8 is already unencrypted.
117     * @throws GeneralSecurityException If a parsing or decryption problem
118     *                                  occured.
119     * @throws IOException              If the supplied InputStream could not be read.
120     */
121    public PKCS8Key(final InputStream in, char[] password)
122        throws GeneralSecurityException, IOException {
123        this(Util.streamToBytes(in), password);
124    }
125
126    /**
127     * @param in       pkcs8 file to parse (pem or der, encrypted or unencrypted)
128     * @param password password to decrypt the pkcs8 file.  Ignored if the
129     *                 supplied pkcs8 is already unencrypted.
130     * @throws GeneralSecurityException If a parsing or decryption problem
131     *                                  occured.
132     */
133    public PKCS8Key(final ByteArrayInputStream in, char[] password)
134        throws GeneralSecurityException {
135        this(Util.streamToBytes(in), password);
136    }
137
138    /**
139     * @param encoded  pkcs8 file to parse (pem or der, encrypted or unencrypted)
140     * @param password password to decrypt the pkcs8 file.  Ignored if the
141     *                 supplied pkcs8 is already unencrypted.
142     * @throws GeneralSecurityException If a parsing or decryption problem
143     *                                  occured.
144     */
145    public PKCS8Key(final byte[] encoded, char[] password)
146        throws GeneralSecurityException {
147        DecryptResult decryptResult =
148            new DecryptResult("UNENCRYPTED", 0, encoded);
149
150        List pemItems = PEMUtil.decode(encoded);
151        PEMItem keyItem = null;
152        byte[] derBytes = null;
153        if (pemItems.isEmpty()) {
154            // must be DER encoded - PEMUtil wasn't able to extract anything.
155            derBytes = encoded;
156        } else {
157            Iterator it = pemItems.iterator();
158            boolean opensslRSA = false;
159            boolean opensslDSA = false;
160
161            while (it.hasNext()) {
162                PEMItem item = (PEMItem) it.next();
163                String type = item.pemType.trim().toUpperCase();
164                boolean plainPKCS8 = type.startsWith(PKCS8_UNENCRYPTED);
165                boolean encryptedPKCS8 = type.startsWith(PKCS8_ENCRYPTED);
166                boolean rsa = type.startsWith(OPENSSL_RSA);
167                boolean dsa = type.startsWith(OPENSSL_DSA);
168                if (plainPKCS8 || encryptedPKCS8 || rsa || dsa) {
169                    opensslRSA = opensslRSA || rsa;
170                    opensslDSA = opensslDSA || dsa;
171                    if (derBytes != null) {
172                        throw new ProbablyNotPKCS8Exception("More than one pkcs8 or OpenSSL key found in the supplied PEM Base64 stream");
173                    }
174                    derBytes = item.getDerBytes();
175                    keyItem = item;
176                    decryptResult = new DecryptResult("UNENCRYPTED", 0, derBytes);
177                }
178            }
179            // after the loop is finished, did we find anything?
180            if (derBytes == null) {
181                throw new ProbablyNotPKCS8Exception("No pkcs8 or OpenSSL key found in the supplied PEM Base64 stream");
182            }
183
184            if (opensslDSA || opensslRSA) {
185                String c = keyItem.cipher.trim();
186                boolean encrypted = !"UNKNOWN".equals(c) && !"".equals(c);
187                if (encrypted) {
188                    decryptResult = opensslDecrypt(keyItem, password);
189                }
190
191                String oid = RSA_OID;
192                if (opensslDSA) {
193                    oid = DSA_OID;
194                }
195                derBytes = formatAsPKCS8(decryptResult.bytes, oid, null);
196
197                String tf = decryptResult.transformation;
198                int ks = decryptResult.keySize;
199                decryptResult = new DecryptResult(tf, ks, derBytes);
200            }
201        }
202
203        ASN1Structure pkcs8;
204        try {
205            pkcs8 = ASN1Util.analyze(derBytes);
206        }
207        catch (Exception e) {
208            throw new ProbablyNotPKCS8Exception("asn1 parse failure: " + e);
209        }
210
211        String oid = RSA_OID;
212        // With the OpenSSL unencrypted private keys in DER format, the only way
213        // to even have a hope of guessing what we've got (DSA or RSA?) is to
214        // count the number of DERIntegers occurring in the first DERSequence.
215        int derIntegerCount = -1;
216        if (pkcs8.derIntegers != null) {
217            derIntegerCount = pkcs8.derIntegers.size();
218        }
219        switch (derIntegerCount) {
220            case 6:
221                oid = DSA_OID;
222            case 9:
223                derBytes = formatAsPKCS8(derBytes, oid, pkcs8);
224                pkcs8.oid1 = oid;
225
226                String tf = decryptResult.transformation;
227                int ks = decryptResult.keySize;
228                decryptResult = new DecryptResult(tf, ks, derBytes);
229                break;
230            default:
231                break;
232        }
233
234        oid = pkcs8.oid1;
235        if (!oid.startsWith("1.2.840.113549.1")) {
236            boolean isOkay = false;
237            if (oid.startsWith("1.2.840.10040.4.")) {
238                String s = oid.substring("1.2.840.10040.4.".length());
239                // 1.2.840.10040.4.1 -- id-dsa
240                // 1.2.840.10040.4.3 -- id-dsa-with-sha1
241                isOkay = s.equals("1") || s.startsWith("1.") ||
242                         s.equals("3") || s.startsWith("3.");
243            }
244            if (!isOkay) {
245                throw new ProbablyNotPKCS8Exception("Valid ASN.1, but not PKCS8 or OpenSSL format.  OID=" + oid);
246            }
247        }
248
249        boolean isRSA = RSA_OID.equals(oid);
250        boolean isDSA = DSA_OID.equals(oid);
251        boolean encrypted = !isRSA && !isDSA;
252        byte[] decryptedPKCS8 = encrypted ? null : derBytes;
253
254        if (encrypted) {
255            decryptResult = decryptPKCS8(pkcs8, password);
256            decryptedPKCS8 = decryptResult.bytes;
257        }
258        if (encrypted) {
259            try {
260                pkcs8 = ASN1Util.analyze(decryptedPKCS8);
261            }
262            catch (Exception e) {
263                throw new ProbablyBadPasswordException("Decrypted stream not ASN.1.  Probably bad decryption password.");
264            }
265            oid = pkcs8.oid1;
266            isDSA = DSA_OID.equals(oid);
267        }
268
269        KeySpec spec = new PKCS8EncodedKeySpec(decryptedPKCS8);
270        String type = "RSA";
271        PrivateKey pk;
272        try {
273            KeyFactory KF;
274            if (isDSA) {
275                type = "DSA";
276                KF = KeyFactory.getInstance("DSA");
277            } else {
278                KF = KeyFactory.getInstance("RSA");
279            }
280            pk = KF.generatePrivate(spec);
281        }
282        catch (Exception e) {
283            throw new ProbablyBadPasswordException("Cannot create " + type + " private key from decrypted stream.  Probably bad decryption password. " + e);
284        }
285        if (pk != null) {
286            this.privateKey = pk;
287            this.isDSA = isDSA;
288            this.isRSA = !isDSA;
289            this.decryptedBytes = decryptedPKCS8;
290            this.transformation = decryptResult.transformation;
291            this.keySize = decryptResult.keySize;
292        } else {
293            throw new GeneralSecurityException("KeyFactory.generatePrivate() returned null and didn't throw exception!");
294        }
295    }
296
297    public boolean isRSA() {
298        return isRSA;
299    }
300
301    public boolean isDSA() {
302        return isDSA;
303    }
304
305    public String getTransformation() {
306        return transformation;
307    }
308
309    public int getKeySize() {
310        return keySize;
311    }
312
313    public byte[] getDecryptedBytes() {
314        return decryptedBytes;
315    }
316
317    public PrivateKey getPrivateKey() {
318        return privateKey;
319    }
320
321    public static class DecryptResult {
322        public final String transformation;
323        public final int keySize;
324        public final byte[] bytes;
325
326        protected DecryptResult(String transformation, int keySize,
327                                byte[] decryptedBytes) {
328            this.transformation = transformation;
329            this.keySize = keySize;
330            this.bytes = decryptedBytes;
331        }
332    }
333
334    private static DecryptResult opensslDecrypt(final PEMItem item,
335                                                final char[] password)
336        throws GeneralSecurityException {
337        final String cipher = item.cipher;
338        final String mode = item.mode;
339        final int keySize = item.keySizeInBits;
340        final byte[] salt = item.iv;
341        final boolean des2 = item.des2;
342        final DerivedKey dk = OpenSSL.deriveKey(password, salt, keySize, des2);
343        return decrypt(cipher, mode, dk, des2, null, item.getDerBytes());
344    }
345
346    public static Cipher generateCipher(String cipher, String mode,
347                                        final DerivedKey dk,
348                                        final boolean des2,
349                                        final byte[] iv,
350                                        final boolean decryptMode)
351        throws NoSuchAlgorithmException, NoSuchPaddingException,
352        InvalidKeyException, InvalidAlgorithmParameterException {
353        if (des2 && dk.key.length >= 24) {
354            // copy first 8 bytes into last 8 bytes to create 2DES key.
355            System.arraycopy(dk.key, 0, dk.key, 16, 8);
356        }
357
358        final int keySize = dk.key.length * 8;
359        cipher = cipher.trim();
360        String cipherUpper = cipher.toUpperCase();
361        mode = mode.trim().toUpperCase();
362        // Is the cipher even available?
363        Cipher.getInstance(cipher);
364        String padding = "PKCS5Padding";
365        if (mode.startsWith("CFB") || mode.startsWith("OFB")) {
366            padding = "NoPadding";
367        }
368
369        String transformation = cipher + "/" + mode + "/" + padding;
370        if (cipherUpper.startsWith("RC4")) {
371            // RC4 does not take mode or padding.
372            transformation = cipher;
373        }
374
375        SecretKey secret = new SecretKeySpec(dk.key, cipher);
376        IvParameterSpec ivParams;
377        if (iv != null) {
378            ivParams = new IvParameterSpec(iv);
379        } else {
380            ivParams = dk.iv != null ? new IvParameterSpec(dk.iv) : null;
381        }
382
383        Cipher c = Cipher.getInstance(transformation);
384        int cipherMode = Cipher.ENCRYPT_MODE;
385        if (decryptMode) {
386            cipherMode = Cipher.DECRYPT_MODE;
387        }
388
389        // RC2 requires special params to inform engine of keysize.
390        if (cipherUpper.startsWith("RC2")) {
391            RC2ParameterSpec rcParams;
392            if (mode.startsWith("ECB") || ivParams == null) {
393                // ECB doesn't take an IV.
394                rcParams = new RC2ParameterSpec(keySize);
395            } else {
396                rcParams = new RC2ParameterSpec(keySize, ivParams.getIV());
397            }
398            c.init(cipherMode, secret, rcParams);
399        } else if (cipherUpper.startsWith("RC5")) {
400            RC5ParameterSpec rcParams;
401            if (mode.startsWith("ECB") || ivParams == null) {
402                // ECB doesn't take an IV.
403                rcParams = new RC5ParameterSpec(16, 12, 32);
404            } else {
405                rcParams = new RC5ParameterSpec(16, 12, 32, ivParams.getIV());
406            }
407            c.init(cipherMode, secret, rcParams);
408        } else if (mode.startsWith("ECB") || cipherUpper.startsWith("RC4")) {
409            // RC4 doesn't require any params.
410            // Any cipher using ECB does not require an IV.
411            c.init(cipherMode, secret);
412        } else {
413            // DES, DESede, AES, BlowFish require IVParams (when in CBC, CFB,
414            // or OFB mode).  (In ECB mode they don't require IVParams).
415            c.init(cipherMode, secret, ivParams);
416        }
417        return c;
418    }
419
420    public static DecryptResult decrypt(String cipher, String mode,
421                                        final DerivedKey dk,
422                                        final boolean des2,
423                                        final byte[] iv,
424                                        final byte[] encryptedBytes)
425
426        throws NoSuchAlgorithmException, NoSuchPaddingException,
427        InvalidKeyException, InvalidAlgorithmParameterException,
428        IllegalBlockSizeException, BadPaddingException {
429        Cipher c = generateCipher(cipher, mode, dk, des2, iv, true);
430        final String transformation = c.getAlgorithm();
431        final int keySize = dk.key.length * 8;
432        byte[] decryptedBytes = c.doFinal(encryptedBytes);
433        return new DecryptResult(transformation, keySize, decryptedBytes);
434    }
435
436    private static DecryptResult decryptPKCS8(ASN1Structure pkcs8,
437                                              char[] password)
438        throws GeneralSecurityException {
439        boolean isVersion1 = true;
440        boolean isVersion2 = false;
441        boolean usePKCS12PasswordPadding = false;
442        boolean use2DES = false;
443        String cipher = null;
444        String hash = null;
445        int keySize = -1;
446        // Almost all PKCS8 encrypted keys use CBC.  Looks like the AES OID's can
447        // support different modes, and RC4 doesn't use any mode at all!
448        String mode = "CBC";
449
450        // In PKCS8 Version 2 the IV is stored in the ASN.1 structure for
451        // us, so we don't need to derive it.  Just leave "ivSize" set to 0 for
452        // those ones.
453        int ivSize = 0;
454
455        String oid = pkcs8.oid1;
456        if (oid.startsWith("1.2.840.113549.1.12."))  // PKCS12 key derivation!
457        {
458            usePKCS12PasswordPadding = true;
459
460            // Let's trim this OID to make life a little easier.
461            oid = oid.substring("1.2.840.113549.1.12.".length());
462
463            if (oid.equals("1.1") || oid.startsWith("1.1.")) {
464                // 1.2.840.113549.1.12.1.1
465                hash = "SHA1";
466                cipher = "RC4";
467                keySize = 128;
468            } else if (oid.equals("1.2") || oid.startsWith("1.2.")) {
469                // 1.2.840.113549.1.12.1.2
470                hash = "SHA1";
471                cipher = "RC4";
472                keySize = 40;
473            } else if (oid.equals("1.3") || oid.startsWith("1.3.")) {
474                // 1.2.840.113549.1.12.1.3
475                hash = "SHA1";
476                cipher = "DESede";
477                keySize = 192;
478            } else if (oid.equals("1.4") || oid.startsWith("1.4.")) {
479                // DES2 !!!
480
481                // 1.2.840.113549.1.12.1.4
482                hash = "SHA1";
483                cipher = "DESede";
484                keySize = 192;
485                use2DES = true;
486                // later on we'll copy the first 8 bytes of the 24 byte DESede key
487                // over top the last 8 bytes, making the key look like K1-K2-K1
488                // instead of the usual K1-K2-K3.
489            } else if (oid.equals("1.5") || oid.startsWith("1.5.")) {
490                // 1.2.840.113549.1.12.1.5
491                hash = "SHA1";
492                cipher = "RC2";
493                keySize = 128;
494            } else if (oid.equals("1.6") || oid.startsWith("1.6.")) {
495                // 1.2.840.113549.1.12.1.6
496                hash = "SHA1";
497                cipher = "RC2";
498                keySize = 40;
499            }
500        } else if (oid.startsWith("1.2.840.113549.1.5.")) {
501            // Let's trim this OID to make life a little easier.
502            oid = oid.substring("1.2.840.113549.1.5.".length());
503
504            if (oid.equals("1") || oid.startsWith("1.")) {
505                // 1.2.840.113549.1.5.1 -- pbeWithMD2AndDES-CBC
506                hash = "MD2";
507                cipher = "DES";
508                keySize = 64;
509            } else if (oid.equals("3") || oid.startsWith("3.")) {
510                // 1.2.840.113549.1.5.3 -- pbeWithMD5AndDES-CBC
511                hash = "MD5";
512                cipher = "DES";
513                keySize = 64;
514            } else if (oid.equals("4") || oid.startsWith("4.")) {
515                // 1.2.840.113549.1.5.4 -- pbeWithMD2AndRC2_CBC
516                hash = "MD2";
517                cipher = "RC2";
518                keySize = 64;
519            } else if (oid.equals("6") || oid.startsWith("6.")) {
520                // 1.2.840.113549.1.5.6 -- pbeWithMD5AndRC2_CBC
521                hash = "MD5";
522                cipher = "RC2";
523                keySize = 64;
524            } else if (oid.equals("10") || oid.startsWith("10.")) {
525                // 1.2.840.113549.1.5.10 -- pbeWithSHA1AndDES-CBC
526                hash = "SHA1";
527                cipher = "DES";
528                keySize = 64;
529            } else if (oid.equals("11") || oid.startsWith("11.")) {
530                // 1.2.840.113549.1.5.11 -- pbeWithSHA1AndRC2_CBC
531                hash = "SHA1";
532                cipher = "RC2";
533                keySize = 64;
534            } else if (oid.equals("12") || oid.startsWith("12.")) {
535                // 1.2.840.113549.1.5.12 - id-PBKDF2 - Key Derivation Function
536                isVersion2 = true;
537            } else if (oid.equals("13") || oid.startsWith("13.")) {
538                // 1.2.840.113549.1.5.13 - id-PBES2: PBES2 encryption scheme
539                isVersion2 = true;
540            } else if (oid.equals("14") || oid.startsWith("14.")) {
541                // 1.2.840.113549.1.5.14 - id-PBMAC1 message authentication scheme
542                isVersion2 = true;
543            }
544        }
545        if (isVersion2) {
546            isVersion1 = false;
547            hash = "HmacSHA1";
548            oid = pkcs8.oid2;
549
550            // really ought to be:
551            //
552            // if ( oid.startsWith( "1.2.840.113549.1.5.12" ) )
553            //
554            // but all my tests still pass, and I figure this to be more robust:
555            if (pkcs8.oid3 != null) {
556                oid = pkcs8.oid3;
557            }
558            if (oid.startsWith("1.3.6.1.4.1.3029.1.2")) {
559                // 1.3.6.1.4.1.3029.1.2 - Blowfish
560                cipher = "Blowfish";
561                mode = "CBC";
562                keySize = 128;
563            } else if (oid.startsWith("1.3.14.3.2.")) {
564                oid = oid.substring("1.3.14.3.2.".length());
565                if (oid.equals("6") || oid.startsWith("6.")) {
566                    // 1.3.14.3.2.6 - desECB
567                    cipher = "DES";
568                    mode = "ECB";
569                    keySize = 64;
570                } else if (oid.equals("7") || oid.startsWith("7.")) {
571                    // 1.3.14.3.2.7 - desCBC
572                    cipher = "DES";
573                    mode = "CBC";
574                    keySize = 64;
575                } else if (oid.equals("8") || oid.startsWith("8.")) {
576                    // 1.3.14.3.2.8 - desOFB
577                    cipher = "DES";
578                    mode = "OFB";
579                    keySize = 64;
580                } else if (oid.equals("9") || oid.startsWith("9.")) {
581                    // 1.3.14.3.2.9 - desCFB
582                    cipher = "DES";
583                    mode = "CFB";
584                    keySize = 64;
585                } else if (oid.equals("17") || oid.startsWith("17.")) {
586                    // 1.3.14.3.2.17 - desEDE
587                    cipher = "DESede";
588                    mode = "CBC";
589                    keySize = 192;
590
591                    // If the supplied IV is all zeroes, then this is DES2
592                    // (Well, that's what happened when I played with OpenSSL!)
593                    if (allZeroes(pkcs8.iv)) {
594                        mode = "ECB";
595                        use2DES = true;
596                        pkcs8.iv = null;
597                    }
598                }
599            }
600
601            // AES
602            // 2.16.840.1.101.3.4.1.1  - id-aes128-ECB
603            // 2.16.840.1.101.3.4.1.2  - id-aes128-CBC
604            // 2.16.840.1.101.3.4.1.3  - id-aes128-OFB
605            // 2.16.840.1.101.3.4.1.4  - id-aes128-CFB
606            // 2.16.840.1.101.3.4.1.21 - id-aes192-ECB
607            // 2.16.840.1.101.3.4.1.22 - id-aes192-CBC
608            // 2.16.840.1.101.3.4.1.23 - id-aes192-OFB
609            // 2.16.840.1.101.3.4.1.24 - id-aes192-CFB
610            // 2.16.840.1.101.3.4.1.41 - id-aes256-ECB
611            // 2.16.840.1.101.3.4.1.42 - id-aes256-CBC
612            // 2.16.840.1.101.3.4.1.43 - id-aes256-OFB
613            // 2.16.840.1.101.3.4.1.44 - id-aes256-CFB
614            else if (oid.startsWith("2.16.840.1.101.3.4.1.")) {
615                cipher = "AES";
616                if (pkcs8.iv == null) {
617                    ivSize = 128;
618                }
619                oid = oid.substring("2.16.840.1.101.3.4.1.".length());
620                int x = oid.indexOf('.');
621                int finalDigit;
622                if (x >= 0) {
623                    finalDigit = Integer.parseInt(oid.substring(0, x));
624                } else {
625                    finalDigit = Integer.parseInt(oid);
626                }
627                switch (finalDigit % 10) {
628                    case 1:
629                        mode = "ECB";
630                        break;
631                    case 2:
632                        mode = "CBC";
633                        break;
634                    case 3:
635                        mode = "OFB";
636                        break;
637                    case 4:
638                        mode = "CFB";
639                        break;
640                    default:
641                        throw new RuntimeException("Unknown AES final digit: " + finalDigit);
642                }
643                switch (finalDigit / 10) {
644                    case 0:
645                        keySize = 128;
646                        break;
647                    case 2:
648                        keySize = 192;
649                        break;
650                    case 4:
651                        keySize = 256;
652                        break;
653                    default:
654                        throw new RuntimeException("Unknown AES final digit: " + finalDigit);
655                }
656            } else if (oid.startsWith("1.2.840.113549.3.")) {
657                // Let's trim this OID to make life a little easier.
658                oid = oid.substring("1.2.840.113549.3.".length());
659
660                if (oid.equals("2") || oid.startsWith("2.")) {
661                    // 1.2.840.113549.3.2 - RC2-CBC
662                    // Note:  keysize determined in PKCS8 Version 2.0 ASN.1 field.
663                    cipher = "RC2";
664                    keySize = pkcs8.keySize * 8;
665                } else if (oid.equals("4") || oid.startsWith("4.")) {
666                    // 1.2.840.113549.3.4 - RC4
667                    // Note:  keysize determined in PKCS8 Version 2.0 ASN.1 field.
668                    cipher = "RC4";
669                    keySize = pkcs8.keySize * 8;
670                } else if (oid.equals("7") || oid.startsWith("7.")) {
671                    // 1.2.840.113549.3.7 - DES-EDE3-CBC
672                    cipher = "DESede";
673                    keySize = 192;
674                } else if (oid.equals("9") || oid.startsWith("9.")) {
675                    // 1.2.840.113549.3.9 - RC5 CBC Pad
676                    // Note:  keysize determined in PKCS8 Version 2.0 ASN.1 field.
677                    keySize = pkcs8.keySize * 8;
678                    cipher = "RC5";
679
680                    // Need to find out more about RC5.
681                    // How do I create the RC5ParameterSpec?
682                    // (int version, int rounds, int wordSize, byte[] iv)
683                }
684            }
685        }
686
687        // The pkcs8 structure has been thoroughly examined.  If we don't have
688        // a cipher or hash at this point, then we don't support the file we
689        // were given.
690        if (cipher == null || hash == null) {
691            throw new ProbablyNotPKCS8Exception("Unsupported PKCS8 format. oid1=[" + pkcs8.oid1 + "], oid2=[" + pkcs8.oid2 + "]");
692        }
693
694        // In PKCS8 Version 1.5 we need to derive an 8 byte IV.  In those cases
695        // the ASN.1 structure doesn't have the IV, anyway, so I can use that
696        // to decide whether to derive one or not.
697        //
698        // Note:  if AES, then IV has to be 16 bytes.
699        if (pkcs8.iv == null) {
700            ivSize = 64;
701        }
702
703        byte[] salt = pkcs8.salt;
704        int ic = pkcs8.iterationCount;
705
706        // PKCS8 converts the password to a byte[] array using a simple
707        // cast.  This byte[] array is ignored if we're using the PKCS12
708        // key derivation, since that employs a different technique.
709        byte[] pwd = new byte[password.length];
710        for (int i = 0; i < pwd.length; i++) {
711            pwd[i] = (byte) password[i];
712        }
713
714        DerivedKey dk;
715        if (usePKCS12PasswordPadding) {
716            MessageDigest md = MessageDigest.getInstance(hash);
717            dk = deriveKeyPKCS12(password, salt, ic, keySize, ivSize, md);
718        } else {
719            if (isVersion1) {
720                MessageDigest md = MessageDigest.getInstance(hash);
721                dk = deriveKeyV1(pwd, salt, ic, keySize, ivSize, md);
722            } else {
723                Mac mac = Mac.getInstance(hash);
724                dk = deriveKeyV2(pwd, salt, ic, keySize, ivSize, mac);
725            }
726        }
727
728
729        return decrypt(cipher, mode, dk, use2DES, pkcs8.iv, pkcs8.bigPayload);
730    }
731
732
733    public static DerivedKey deriveKeyV1(byte[] password, byte[] salt,
734                                         int iterations, int keySizeInBits,
735                                         int ivSizeInBits, MessageDigest md) {
736        int keySize = keySizeInBits / 8;
737        int ivSize = ivSizeInBits / 8;
738        md.reset();
739        md.update(password);
740        byte[] result = md.digest(salt);
741        for (int i = 1; i < iterations; i++) {
742            // Hash of the hash for each of the iterations.
743            result = md.digest(result);
744        }
745        byte[] key = new byte[keySize];
746        byte[] iv = new byte[ivSize];
747        System.arraycopy(result, 0, key, 0, key.length);
748        System.arraycopy(result, key.length, iv, 0, iv.length);
749        return new DerivedKey(key, iv);
750    }
751
752    public static DerivedKey deriveKeyPKCS12(char[] password, byte[] salt,
753                                             int iterations, int keySizeInBits,
754                                             int ivSizeInBits,
755                                             MessageDigest md) {
756        byte[] pwd;
757        if (password.length > 0) {
758            pwd = new byte[(password.length + 1) * 2];
759            for (int i = 0; i < password.length; i++) {
760                pwd[i * 2] = (byte) (password[i] >>> 8);
761                pwd[i * 2 + 1] = (byte) password[i];
762            }
763        } else {
764            pwd = new byte[0];
765        }
766        int keySize = keySizeInBits / 8;
767        int ivSize = ivSizeInBits / 8;
768        byte[] key = pkcs12(1, keySize, salt, pwd, iterations, md);
769        byte[] iv = pkcs12(2, ivSize, salt, pwd, iterations, md);
770        return new DerivedKey(key, iv);
771    }
772
773    /**
774     * This PKCS12 key derivation code comes from BouncyCastle.
775     *
776     * @param idByte         1 == key, 2 == iv
777     * @param n              keysize or ivsize
778     * @param salt           8 byte salt
779     * @param password       password
780     * @param iterationCount iteration-count
781     * @param md             The message digest to use
782     * @return byte[] the derived key
783     */
784    private static byte[] pkcs12(int idByte, int n, byte[] salt,
785                                 byte[] password, int iterationCount,
786                                 MessageDigest md) {
787        int u = md.getDigestLength();
788        // sha1, md2, md5 all use 512 bits.  But future hashes might not.
789        int v = 512 / 8;
790        md.reset();
791        byte[] D = new byte[v];
792        byte[] dKey = new byte[n];
793        for (int i = 0; i != D.length; i++) {
794            D[i] = (byte) idByte;
795        }
796        byte[] S;
797        if ((salt != null) && (salt.length != 0)) {
798            S = new byte[v * ((salt.length + v - 1) / v)];
799            for (int i = 0; i != S.length; i++) {
800                S[i] = salt[i % salt.length];
801            }
802        } else {
803            S = new byte[0];
804        }
805        byte[] P;
806        if ((password != null) && (password.length != 0)) {
807            P = new byte[v * ((password.length + v - 1) / v)];
808            for (int i = 0; i != P.length; i++) {
809                P[i] = password[i % password.length];
810            }
811        } else {
812            P = new byte[0];
813        }
814        byte[] I = new byte[S.length + P.length];
815        System.arraycopy(S, 0, I, 0, S.length);
816        System.arraycopy(P, 0, I, S.length, P.length);
817        byte[] B = new byte[v];
818        int c = (n + u - 1) / u;
819        for (int i = 1; i <= c; i++) {
820            md.update(D);
821            byte[] result = md.digest(I);
822            for (int j = 1; j != iterationCount; j++) {
823                result = md.digest(result);
824            }
825            for (int j = 0; j != B.length; j++) {
826                B[j] = result[j % result.length];
827            }
828            for (int j = 0; j < (I.length / v); j++) {
829                /*
830                     * add a + b + 1, returning the result in a. The a value is treated
831                     * as a BigInteger of length (b.length * 8) bits. The result is
832                     * modulo 2^b.length in case of overflow.
833                     */
834                int aOff = j * v;
835                int bLast = B.length - 1;
836                int x = (B[bLast] & 0xff) + (I[aOff + bLast] & 0xff) + 1;
837                I[aOff + bLast] = (byte) x;
838                x >>>= 8;
839                for (int k = B.length - 2; k >= 0; k--) {
840                    x += (B[k] & 0xff) + (I[aOff + k] & 0xff);
841                    I[aOff + k] = (byte) x;
842                    x >>>= 8;
843                }
844            }
845            if (i == c) {
846                System.arraycopy(result, 0, dKey, (i - 1) * u, dKey.length - ((i - 1) * u));
847            } else {
848                System.arraycopy(result, 0, dKey, (i - 1) * u, result.length);
849            }
850        }
851        return dKey;
852    }
853
854    public static DerivedKey deriveKeyV2(byte[] password, byte[] salt,
855                                         int iterations, int keySizeInBits,
856                                         int ivSizeInBits, Mac mac)
857        throws InvalidKeyException {
858        int keySize = keySizeInBits / 8;
859        int ivSize = ivSizeInBits / 8;
860
861        // Because we're using an Hmac, we need to initialize with a SecretKey.
862        // HmacSHA1 doesn't need SecretKeySpec's 2nd parameter, hence the "N/A".
863        SecretKeySpec sk = new SecretKeySpec(password, "N/A");
864        mac.init(sk);
865        int macLength = mac.getMacLength();
866        int derivedKeyLength = keySize + ivSize;
867        int blocks = (derivedKeyLength + macLength - 1) / macLength;
868        byte[] blockIndex = new byte[4];
869        byte[] finalResult = new byte[blocks * macLength];
870        for (int i = 1; i <= blocks; i++) {
871            int offset = (i - 1) * macLength;
872            blockIndex[0] = (byte) (i >>> 24);
873            blockIndex[1] = (byte) (i >>> 16);
874            blockIndex[2] = (byte) (i >>> 8);
875            blockIndex[3] = (byte) i;
876            mac.reset();
877            mac.update(salt);
878            byte[] result = mac.doFinal(blockIndex);
879            System.arraycopy(result, 0, finalResult, offset, result.length);
880            for (int j = 1; j < iterations; j++) {
881                mac.reset();
882                result = mac.doFinal(result);
883                for (int k = 0; k < result.length; k++) {
884                    finalResult[offset + k] ^= result[k];
885                }
886            }
887        }
888        byte[] key = new byte[keySize];
889        byte[] iv = new byte[ivSize];
890        System.arraycopy(finalResult, 0, key, 0, key.length);
891        System.arraycopy(finalResult, key.length, iv, 0, iv.length);
892        return new DerivedKey(key, iv);
893    }
894
895    public static byte[] formatAsPKCS8(byte[] privateKey, String oid,
896                                       ASN1Structure pkcs8) {
897        DERInteger derZero = new DERInteger(BigInteger.ZERO);
898        ASN1EncodableVector outterVec = new ASN1EncodableVector();
899        ASN1EncodableVector innerVec = new ASN1EncodableVector();
900        DEROctetString octetsToAppend;
901        try {
902            DERObjectIdentifier derOID = new DERObjectIdentifier(oid);
903            innerVec.add(derOID);
904            if (DSA_OID.equals(oid)) {
905                if (pkcs8 == null) {
906                    try {
907                        pkcs8 = ASN1Util.analyze(privateKey);
908                    }
909                    catch (Exception e) {
910                        throw new RuntimeException("asn1 parse failure " + e);
911                    }
912                }
913                if (pkcs8.derIntegers == null || pkcs8.derIntegers.size() < 6) {
914                    throw new RuntimeException("invalid DSA key - can't find P, Q, G, X");
915                }
916
917                DERInteger[] ints = new DERInteger[pkcs8.derIntegers.size()];
918                pkcs8.derIntegers.toArray(ints);
919                DERInteger p = ints[1];
920                DERInteger q = ints[2];
921                DERInteger g = ints[3];
922                DERInteger x = ints[5];
923
924                byte[] encodedX = encode(x);
925                octetsToAppend = new DEROctetString(encodedX);
926                ASN1EncodableVector pqgVec = new ASN1EncodableVector();
927                pqgVec.add(p);
928                pqgVec.add(q);
929                pqgVec.add(g);
930                DERSequence pqg = new DERSequence(pqgVec);
931                innerVec.add(pqg);
932            } else {
933                innerVec.add(DERNull.INSTANCE);
934                octetsToAppend = new DEROctetString(privateKey);
935            }
936
937            DERSequence inner = new DERSequence(innerVec);
938            outterVec.add(derZero);
939            outterVec.add(inner);
940            outterVec.add(octetsToAppend);
941            DERSequence outter = new DERSequence(outterVec);
942            return encode(outter);
943        }
944        catch (IOException ioe) {
945            throw JavaImpl.newRuntimeException(ioe);
946        }
947    }
948
949    private static boolean allZeroes(byte[] b) {
950        for (int i = 0; i < b.length; i++) {
951            if (b[i] != 0) {
952                return false;
953            }
954        }
955        return true;
956    }
957
958    public static byte[] encode(DEREncodable der) throws IOException {
959        ByteArrayOutputStream baos = new ByteArrayOutputStream(1024);
960        ASN1OutputStream out = new ASN1OutputStream(baos);
961        out.writeObject(der);
962        out.close();
963        return baos.toByteArray();
964    }
965
966    public static void main(String[] args) throws Exception {
967        String password = "changeit";
968        if (args.length == 0) {
969            System.out.println("Usage1:  [password] [file:private-key]      Prints decrypted PKCS8 key (base64).");
970            System.out.println("Usage2:  [password] [file1] [file2] etc...  Checks that all private keys are equal.");
971            System.out.println("Usage2 assumes that all files can be decrypted with the same password.");
972        } else if (args.length == 1 || args.length == 2) {
973            FileInputStream in = new FileInputStream(args[args.length - 1]);
974            if (args.length == 2) {
975                password = args[0];
976            }
977            byte[] bytes = Util.streamToBytes(in);
978            PKCS8Key key = new PKCS8Key(bytes, password.toCharArray());
979            PEMItem item = new PEMItem(key.getDecryptedBytes(), "PRIVATE KEY");
980            byte[] pem = PEMUtil.encode(Collections.singleton(item));
981            System.out.write(pem);
982        } else {
983            byte[] original = null;
984            File f = new File(args[0]);
985            int i = 0;
986            if (!f.exists()) {
987                // File0 doesn't exist, so it must be a password!
988                password = args[0];
989                i++;
990            }
991            for (; i < args.length; i++) {
992                FileInputStream in = new FileInputStream(args[i]);
993                byte[] bytes = Util.streamToBytes(in);
994                PKCS8Key key = null;
995                try {
996                    key = new PKCS8Key(bytes, password.toCharArray());
997                }
998                catch (Exception e) {
999                    System.out.println(" FAILED! " + args[i] + " " + e);
1000                }
1001                if (key != null) {
1002                    byte[] decrypted = key.getDecryptedBytes();
1003                    int keySize = key.getKeySize();
1004                    String keySizeStr = "" + keySize;
1005                    if (keySize < 10) {
1006                        keySizeStr = "  " + keySizeStr;
1007                    } else if (keySize < 100) {
1008                        keySizeStr = " " + keySizeStr;
1009                    }
1010                    StringBuffer buf = new StringBuffer(key.getTransformation());
1011                    int maxLen = "Blowfish/CBC/PKCS5Padding".length();
1012                    for (int j = buf.length(); j < maxLen; j++) {
1013                        buf.append(' ');
1014                    }
1015                    String transform = buf.toString();
1016                    String type = key.isDSA() ? "DSA" : "RSA";
1017
1018                    if (original == null) {
1019                        original = decrypted;
1020                        System.out.println("   SUCCESS    \t" + type + "\t" + transform + "\t" + keySizeStr + "\t" + args[i]);
1021                    } else {
1022                        boolean identical = Arrays.equals(original, decrypted);
1023                        if (!identical) {
1024                            System.out.println("***FAILURE*** \t" + type + "\t" + transform + "\t" + keySizeStr + "\t" + args[i]);
1025                        } else {
1026                            System.out.println("   SUCCESS    \t" + type + "\t" + transform + "\t" + keySizeStr + "\t" + args[i]);
1027                        }
1028                    }
1029                }
1030            }
1031        }
1032    }
1033
1034}