1use crate::{EstError, EstResult};
14use base64::Engine;
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub enum MlKemLevel {
20 Level512,
22 Level768,
24 Level1024,
26}
27
28impl MlKemLevel {
29 pub fn oid(&self) -> &'static str {
31 match self {
32 Self::Level512 => "2.16.840.1.101.3.4.4.1",
33 Self::Level768 => "2.16.840.1.101.3.4.4.2",
34 Self::Level1024 => "2.16.840.1.101.3.4.4.3",
35 }
36 }
37
38 pub fn level(&self) -> u16 {
40 match self {
41 Self::Level512 => 512,
42 Self::Level768 => 768,
43 Self::Level1024 => 1024,
44 }
45 }
46
47 pub fn from_level(level: u16) -> EstResult<Self> {
49 match level {
50 512 => Ok(Self::Level512),
51 768 => Ok(Self::Level768),
52 1024 => Ok(Self::Level1024),
53 _ => Err(EstError::UnsupportedAlgorithm(format!("ML-KEM-{level}"))),
54 }
55 }
56
57 pub fn from_oid(oid: &str) -> EstResult<Self> {
59 match oid {
60 "2.16.840.1.101.3.4.4.1" => Ok(Self::Level512),
61 "2.16.840.1.101.3.4.4.2" => Ok(Self::Level768),
62 "2.16.840.1.101.3.4.4.3" => Ok(Self::Level1024),
63 _ => Err(EstError::UnsupportedAlgorithm(oid.to_string())),
64 }
65 }
66}
67
68#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
74pub struct MlKemKeyGenHint {
75 pub level: MlKemLevel,
77}
78
79impl MlKemKeyGenHint {
80 pub fn new(level: MlKemLevel) -> Self {
82 Self { level }
83 }
84
85 pub fn level_512() -> Self {
87 Self::new(MlKemLevel::Level512)
88 }
89
90 pub fn level_768() -> Self {
92 Self::new(MlKemLevel::Level768)
93 }
94
95 pub fn level_1024() -> Self {
97 Self::new(MlKemLevel::Level1024)
98 }
99
100 pub fn validate_supported(&self, supported_levels: &[MlKemLevel]) -> EstResult<()> {
107 if !supported_levels.contains(&self.level) {
108 return Err(EstError::MlKemLevelMismatch {
109 requested: self.level.level(),
110 supported: supported_levels.first().map(|l| l.level()).unwrap_or(0),
111 });
112 }
113 Ok(())
114 }
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
123pub enum Pkcs8Version {
124 V1,
126 V2,
128}
129
130#[derive(Debug, Clone, PartialEq, Eq)]
147pub struct Pkcs8PrivateKey {
148 der: Vec<u8>,
150 algorithm_oid: String,
152 version: Pkcs8Version,
154}
155
156impl Pkcs8PrivateKey {
157 pub fn new(der: Vec<u8>, algorithm_oid: String, version: Pkcs8Version) -> Self {
165 Self {
166 der,
167 algorithm_oid,
168 version,
169 }
170 }
171
172 pub fn v1(der: Vec<u8>, algorithm_oid: String) -> Self {
174 Self::new(der, algorithm_oid, Pkcs8Version::V1)
175 }
176
177 pub fn v2(der: Vec<u8>, algorithm_oid: String) -> Self {
179 Self::new(der, algorithm_oid, Pkcs8Version::V2)
180 }
181
182 pub fn der(&self) -> &[u8] {
184 &self.der
185 }
186
187 pub fn algorithm_oid(&self) -> &str {
189 &self.algorithm_oid
190 }
191
192 pub fn version(&self) -> Pkcs8Version {
194 self.version
195 }
196
197 pub fn validate_algorithm(&self, expected_oid: &str) -> EstResult<()> {
204 if self.algorithm_oid != expected_oid {
205 return Err(EstError::UnsupportedAlgorithm(format!(
206 "key algorithm OID mismatch: expected {expected_oid}, got {}",
207 self.algorithm_oid
208 )));
209 }
210 Ok(())
211 }
212
213 pub fn validate(&self) -> EstResult<()> {
215 if self.der.is_empty() {
216 return Err(EstError::InvalidPkcs8("empty PKCS#8 key".to_string()));
217 }
218 if self.der[0] != 0x30 {
219 return Err(EstError::InvalidPkcs8(
220 "invalid DER: expected SEQUENCE tag".to_string(),
221 ));
222 }
223 if self.algorithm_oid.is_empty() {
224 return Err(EstError::InvalidPkcs8("missing algorithm OID".to_string()));
225 }
226 Ok(())
227 }
228
229 pub fn to_enveloped_data(&self, recipient_cert_der: &[u8]) -> EstResult<Vec<u8>> {
245 if recipient_cert_der.is_empty() {
246 return Err(EstError::InvalidPkcs7(
247 "empty recipient certificate for EnvelopedData".to_string(),
248 ));
249 }
250 Ok(self.der.clone())
254 }
255}
256
257#[derive(Debug, Clone, PartialEq, Eq)]
269pub struct EncryptedPrivateKey {
270 der: Vec<u8>,
272 encryption_algorithm_oid: String,
274}
275
276impl EncryptedPrivateKey {
277 pub fn new(der: Vec<u8>, encryption_algorithm_oid: String) -> Self {
279 Self {
280 der,
281 encryption_algorithm_oid,
282 }
283 }
284
285 pub fn der(&self) -> &[u8] {
287 &self.der
288 }
289
290 pub fn encryption_algorithm_oid(&self) -> &str {
292 &self.encryption_algorithm_oid
293 }
294
295 pub fn validate(&self) -> EstResult<()> {
297 if self.der.is_empty() {
298 return Err(EstError::InvalidPkcs8(
299 "empty EncryptedPrivateKeyInfo".to_string(),
300 ));
301 }
302 if self.der[0] != 0x30 {
303 return Err(EstError::InvalidPkcs8(
304 "invalid DER: expected SEQUENCE tag".to_string(),
305 ));
306 }
307 Ok(())
308 }
309}
310
311#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
321pub struct ServerKeygenRequest {
322 #[serde(with = "serde_bytes")]
324 csr_der: Vec<u8>,
325
326 #[serde(skip_serializing_if = "Option::is_none")]
328 ml_kem_hint: Option<MlKemKeyGenHint>,
329}
330
331impl ServerKeygenRequest {
332 pub fn new(csr_der: Vec<u8>, ml_kem_hint: Option<MlKemKeyGenHint>) -> Self {
334 Self {
335 csr_der,
336 ml_kem_hint,
337 }
338 }
339
340 pub fn with_ml_kem(csr_der: Vec<u8>, level: MlKemLevel) -> Self {
342 Self::new(csr_der, Some(MlKemKeyGenHint::new(level)))
343 }
344
345 pub fn csr_der(&self) -> &[u8] {
347 &self.csr_der
348 }
349
350 pub fn ml_kem_hint(&self) -> Option<&MlKemKeyGenHint> {
352 self.ml_kem_hint.as_ref()
353 }
354
355 pub fn to_base64(&self) -> String {
357 base64::engine::general_purpose::STANDARD.encode(&self.csr_der)
358 }
359
360 pub fn from_base64(base64_data: &str) -> EstResult<Self> {
362 let csr_der = base64::engine::general_purpose::STANDARD
363 .decode(base64_data)
364 .map_err(|e| EstError::InvalidBase64(e.to_string()))?;
365
366 Ok(Self::new(csr_der, None))
367 }
368
369 pub fn validate(&self) -> EstResult<()> {
371 if self.csr_der.is_empty() {
372 return Err(EstError::InvalidPkcs10("Empty CSR".to_string()));
373 }
374
375 if self.csr_der[0] != 0x30 {
376 return Err(EstError::InvalidPkcs10(
377 "Invalid DER: expected SEQUENCE tag".to_string(),
378 ));
379 }
380
381 if self.csr_der.len() < 100 {
382 return Err(EstError::InvalidPkcs10(format!(
383 "CSR too small: {} bytes",
384 self.csr_der.len()
385 )));
386 }
387
388 Ok(())
389 }
390}
391
392#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
401pub struct ServerKeygenResponse {
402 #[serde(with = "serde_bytes")]
404 cert_pkcs7_der: Vec<u8>,
405
406 #[serde(with = "serde_bytes")]
408 key_pkcs8_der: Vec<u8>,
409
410 boundary: String,
412}
413
414impl ServerKeygenResponse {
415 pub fn new(cert_pkcs7_der: Vec<u8>, key_pkcs8_der: Vec<u8>, boundary: String) -> Self {
417 Self {
418 cert_pkcs7_der,
419 key_pkcs8_der,
420 boundary,
421 }
422 }
423
424 pub fn with_default_boundary(cert_pkcs7_der: Vec<u8>, key_pkcs8_der: Vec<u8>) -> Self {
426 Self::new(
427 cert_pkcs7_der,
428 key_pkcs8_der,
429 crate::content_type::DEFAULT_BOUNDARY.to_string(),
430 )
431 }
432
433 pub fn cert_pkcs7_der(&self) -> &[u8] {
435 &self.cert_pkcs7_der
436 }
437
438 pub fn key_pkcs8_der(&self) -> &[u8] {
440 &self.key_pkcs8_der
441 }
442
443 pub fn boundary(&self) -> &str {
445 &self.boundary
446 }
447
448 pub fn to_multipart_body(&self) -> String {
453 let cert_b64 = base64::engine::general_purpose::STANDARD.encode(&self.cert_pkcs7_der);
454 let key_b64 = base64::engine::general_purpose::STANDARD.encode(&self.key_pkcs8_der);
455
456 format!(
457 "--{boundary}\r\n\
458 Content-Type: application/pkcs7-mime\r\n\
459 Content-Transfer-Encoding: base64\r\n\
460 \r\n\
461 {cert_b64}\r\n\
462 --{boundary}\r\n\
463 Content-Type: application/pkcs8\r\n\
464 Content-Transfer-Encoding: base64\r\n\
465 \r\n\
466 {key_b64}\r\n\
467 --{boundary}--\r\n",
468 boundary = self.boundary,
469 cert_b64 = cert_b64,
470 key_b64 = key_b64
471 )
472 }
473
474 pub fn from_multipart_body(body: &str, boundary: &str) -> EstResult<Self> {
480 let boundary_line = format!("--{boundary}");
481 let end_boundary = format!("--{boundary}--");
482
483 let parts: Vec<&str> = body.split(&boundary_line).collect();
484
485 if parts.len() < 3 {
486 return Err(EstError::InvalidMultipart(
487 "Expected at least 2 MIME parts".to_string(),
488 ));
489 }
490
491 let cert_part = parts[1];
493 let cert_b64 = Self::extract_base64_content(cert_part, "application/pkcs7-mime")?;
494 let cert_pkcs7_der = base64::engine::general_purpose::STANDARD
495 .decode(cert_b64.trim())
496 .map_err(|e| EstError::InvalidBase64(e.to_string()))?;
497
498 let key_part = parts[2].trim_end_matches(&end_boundary).trim();
500 let key_b64 = Self::extract_base64_content(key_part, "application/pkcs8")?;
501 let key_pkcs8_der = base64::engine::general_purpose::STANDARD
502 .decode(key_b64.trim())
503 .map_err(|e| EstError::InvalidBase64(e.to_string()))?;
504
505 Ok(Self::new(
506 cert_pkcs7_der,
507 key_pkcs8_der,
508 boundary.to_string(),
509 ))
510 }
511
512 fn extract_base64_content<'a>(part: &'a str, expected_ct: &str) -> EstResult<&'a str> {
514 if !part.contains(expected_ct) {
515 return Err(EstError::InvalidMultipart(format!(
516 "Expected Content-Type: {expected_ct}"
517 )));
518 }
519
520 let body_start = part
522 .find("\r\n\r\n")
523 .or_else(|| part.find("\n\n"))
524 .ok_or_else(|| {
525 EstError::InvalidMultipart("Missing blank line after headers".to_string())
526 })?;
527
528 Ok(part[body_start..].trim())
529 }
530
531 pub fn validate(&self) -> EstResult<()> {
533 if self.cert_pkcs7_der.is_empty() {
535 return Err(EstError::InvalidPkcs7(
536 "Empty certificate PKCS#7".to_string(),
537 ));
538 }
539
540 if self.cert_pkcs7_der[0] != 0x30 {
541 return Err(EstError::InvalidPkcs7(
542 "Invalid DER: expected SEQUENCE tag".to_string(),
543 ));
544 }
545
546 if self.key_pkcs8_der.is_empty() {
548 return Err(EstError::InvalidPkcs8("Empty PKCS#8 key".to_string()));
549 }
550
551 if self.key_pkcs8_der[0] != 0x30 {
552 return Err(EstError::InvalidPkcs8(
553 "Invalid DER: expected SEQUENCE tag".to_string(),
554 ));
555 }
556
557 Ok(())
558 }
559}
560
561mod serde_bytes {
563 use serde::{Deserialize, Deserializer, Serializer};
564
565 pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
566 where
567 S: Serializer,
568 {
569 serializer.serialize_bytes(bytes)
570 }
571
572 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
573 where
574 D: Deserializer<'de>,
575 {
576 Vec::<u8>::deserialize(deserializer)
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583
584 #[test]
585 fn test_ml_kem_levels() {
586 assert_eq!(MlKemLevel::Level512.oid(), "2.16.840.1.101.3.4.4.1");
587 assert_eq!(MlKemLevel::Level768.oid(), "2.16.840.1.101.3.4.4.2");
588 assert_eq!(MlKemLevel::Level1024.oid(), "2.16.840.1.101.3.4.4.3");
589
590 assert_eq!(MlKemLevel::Level512.level(), 512);
591 assert_eq!(MlKemLevel::Level768.level(), 768);
592 assert_eq!(MlKemLevel::Level1024.level(), 1024);
593 }
594
595 #[test]
596 fn test_ml_kem_from_level() {
597 assert_eq!(MlKemLevel::from_level(512).unwrap(), MlKemLevel::Level512);
598 assert_eq!(MlKemLevel::from_level(768).unwrap(), MlKemLevel::Level768);
599 assert_eq!(MlKemLevel::from_level(1024).unwrap(), MlKemLevel::Level1024);
600 assert!(MlKemLevel::from_level(256).is_err());
601 }
602
603 #[test]
604 fn test_ml_kem_from_oid() {
605 assert_eq!(
606 MlKemLevel::from_oid("2.16.840.1.101.3.4.4.1").unwrap(),
607 MlKemLevel::Level512
608 );
609 assert_eq!(
610 MlKemLevel::from_oid("2.16.840.1.101.3.4.4.2").unwrap(),
611 MlKemLevel::Level768
612 );
613 assert_eq!(
614 MlKemLevel::from_oid("2.16.840.1.101.3.4.4.3").unwrap(),
615 MlKemLevel::Level1024
616 );
617 assert!(MlKemLevel::from_oid("1.2.3.4.5").is_err());
618 }
619
620 #[test]
621 fn test_ml_kem_hint() {
622 let hint = MlKemKeyGenHint::level_768();
623 assert_eq!(hint.level, MlKemLevel::Level768);
624
625 let supported = vec![MlKemLevel::Level512, MlKemLevel::Level768];
626 assert!(hint.validate_supported(&supported).is_ok());
627
628 let unsupported = vec![MlKemLevel::Level512];
629 assert!(matches!(
630 hint.validate_supported(&unsupported),
631 Err(EstError::MlKemLevelMismatch { .. })
632 ));
633 }
634
635 #[test]
636 fn test_serverkeygen_request() {
637 let mut der = vec![0x30, 0x82, 0x01, 0x00];
638 der.extend(vec![0x00; 252]);
639
640 let request = ServerKeygenRequest::with_ml_kem(der.clone(), MlKemLevel::Level1024);
641 assert_eq!(request.csr_der(), &der);
642 assert_eq!(request.ml_kem_hint().unwrap().level, MlKemLevel::Level1024);
643
644 assert!(request.validate().is_ok());
645 }
646
647 #[test]
648 fn test_multipart_roundtrip() {
649 let mut cert_der = vec![0x30, 0x82, 0x01, 0x00];
650 cert_der.extend(vec![0x00; 252]);
651
652 let mut key_der = vec![0x30, 0x82, 0x00, 0x80];
653 key_der.extend(vec![0x00; 124]);
654
655 let response =
656 ServerKeygenResponse::with_default_boundary(cert_der.clone(), key_der.clone());
657
658 let multipart_body = response.to_multipart_body();
659 assert!(multipart_body.contains("application/pkcs7-mime"));
660 assert!(multipart_body.contains("application/pkcs8"));
661
662 let parsed =
663 ServerKeygenResponse::from_multipart_body(&multipart_body, response.boundary())
664 .unwrap();
665
666 assert_eq!(parsed.cert_pkcs7_der(), &cert_der);
667 assert_eq!(parsed.key_pkcs8_der(), &key_der);
668 }
669
670 #[test]
671 fn test_validate_response() {
672 let mut cert_der = vec![0x30, 0x82, 0x01, 0x00];
673 cert_der.extend(vec![0x00; 252]);
674
675 let mut key_der = vec![0x30, 0x82, 0x00, 0x80];
676 key_der.extend(vec![0x00; 124]);
677
678 let response = ServerKeygenResponse::with_default_boundary(cert_der, key_der);
679 assert!(response.validate().is_ok());
680 }
681
682 #[test]
683 fn test_validate_empty_cert() {
684 let key_der = vec![0x30, 0x00];
685 let response = ServerKeygenResponse::with_default_boundary(vec![], key_der);
686 assert!(matches!(
687 response.validate(),
688 Err(EstError::InvalidPkcs7(_))
689 ));
690 }
691
692 #[test]
693 fn test_validate_empty_key() {
694 let cert_der = vec![0x30, 0x00];
695 let response = ServerKeygenResponse::with_default_boundary(cert_der, vec![]);
696 assert!(matches!(
697 response.validate(),
698 Err(EstError::InvalidPkcs8(_))
699 ));
700 }
701
702 #[test]
703 fn test_pkcs8_private_key_v1() {
704 let der = vec![0x30, 0x82, 0x00, 0x10];
705 let key = Pkcs8PrivateKey::v1(der.clone(), "2.16.840.1.101.3.4.4.2".to_string());
706 assert_eq!(key.der(), &der);
707 assert_eq!(key.algorithm_oid(), "2.16.840.1.101.3.4.4.2");
708 assert_eq!(key.version(), Pkcs8Version::V1);
709 }
710
711 #[test]
712 fn test_pkcs8_private_key_v2() {
713 let der = vec![0x30, 0x82, 0x00, 0x10];
714 let key = Pkcs8PrivateKey::v2(der.clone(), "2.16.840.1.101.3.4.4.3".to_string());
715 assert_eq!(key.version(), Pkcs8Version::V2);
716 }
717
718 #[test]
719 fn test_pkcs8_validate_algorithm_match() {
720 let der = vec![0x30, 0x82, 0x00, 0x10];
721 let key = Pkcs8PrivateKey::v1(der, "2.16.840.1.101.3.4.4.2".to_string());
722 assert!(key.validate_algorithm("2.16.840.1.101.3.4.4.2").is_ok());
723 }
724
725 #[test]
726 fn test_pkcs8_validate_algorithm_mismatch() {
727 let der = vec![0x30, 0x82, 0x00, 0x10];
728 let key = Pkcs8PrivateKey::v1(der, "2.16.840.1.101.3.4.4.2".to_string());
729 assert!(matches!(
730 key.validate_algorithm("2.16.840.1.101.3.4.4.1"),
731 Err(EstError::UnsupportedAlgorithm(_))
732 ));
733 }
734
735 #[test]
736 fn test_pkcs8_validate_structure() {
737 let der = vec![0x30, 0x82, 0x00, 0x10];
738 let key = Pkcs8PrivateKey::v1(der, "2.16.840.1.101.3.4.4.1".to_string());
739 assert!(key.validate().is_ok());
740
741 let empty_key = Pkcs8PrivateKey::v1(vec![], "2.16.840.1.101.3.4.4.1".to_string());
742 assert!(matches!(
743 empty_key.validate(),
744 Err(EstError::InvalidPkcs8(_))
745 ));
746
747 let bad_tag = Pkcs8PrivateKey::v1(vec![0x31, 0x00], "2.16.840.1.101.3.4.4.1".to_string());
748 assert!(matches!(bad_tag.validate(), Err(EstError::InvalidPkcs8(_))));
749
750 let no_oid = Pkcs8PrivateKey::v1(vec![0x30, 0x00], String::new());
751 assert!(matches!(no_oid.validate(), Err(EstError::InvalidPkcs8(_))));
752 }
753
754 #[test]
755 fn test_pkcs8_to_enveloped_data() {
756 let der = vec![0x30, 0x82, 0x00, 0x10];
757 let key = Pkcs8PrivateKey::v1(der.clone(), "2.16.840.1.101.3.4.4.2".to_string());
758 let result = key.to_enveloped_data(&[0x30, 0x00]);
759 assert!(result.is_ok());
760 assert_eq!(result.unwrap(), der);
761
762 assert!(matches!(
764 key.to_enveloped_data(&[]),
765 Err(EstError::InvalidPkcs7(_))
766 ));
767 }
768
769 #[test]
770 fn test_encrypted_private_key() {
771 let der = vec![0x30, 0x82, 0x00, 0x10];
772 let epk = EncryptedPrivateKey::new(der.clone(), "1.2.840.113549.1.5.13".to_string());
773 assert_eq!(epk.der(), &der);
774 assert_eq!(epk.encryption_algorithm_oid(), "1.2.840.113549.1.5.13");
775 assert!(epk.validate().is_ok());
776
777 let empty = EncryptedPrivateKey::new(vec![], "1.2.840.113549.1.5.13".to_string());
778 assert!(matches!(empty.validate(), Err(EstError::InvalidPkcs8(_))));
779 }
780}