1use crate::{EstError, EstResult};
12use base64::Engine;
13use serde::{Deserialize, Serialize};
14
15pub mod ml_dsa_oids {
17 pub const ML_DSA_44: &str = "2.16.840.1.101.3.4.3.17";
19 pub const ML_DSA_65: &str = "2.16.840.1.101.3.4.3.18";
21 pub const ML_DSA_87: &str = "2.16.840.1.101.3.4.3.19";
23}
24
25pub mod ml_kem_oids {
27 pub const ML_KEM_512: &str = "2.16.840.1.101.3.4.4.1";
29 pub const ML_KEM_768: &str = "2.16.840.1.101.3.4.4.2";
31 pub const ML_KEM_1024: &str = "2.16.840.1.101.3.4.4.3";
33}
34
35pub mod traditional_oids {
37 pub const RSA: &str = "1.2.840.113549.1.1.1";
39 pub const EC_PUBLIC_KEY: &str = "1.2.840.10045.2.1";
41}
42
43pub mod named_curve_oids {
45 pub const P256: &str = "1.2.840.10045.3.1.7";
47 pub const P384: &str = "1.3.132.0.34";
49}
50
51pub const COMPOSITE_ML_DSA_BASE: &str = "2.16.840.1.114027.80.5.2";
55
56#[derive(Debug, Clone, PartialEq, Eq)]
62pub enum KeyAlgorithm {
63 Rsa,
65 EcdsaP256,
67 EcdsaP384,
69 MlDsa44,
71 MlDsa65,
73 MlDsa87,
75 MlKem512,
77 MlKem768,
79 MlKem1024,
81 Unknown(String),
83}
84
85impl KeyAlgorithm {
86 pub fn from_oid(oid: &str) -> Self {
91 match oid {
92 "1.2.840.113549.1.1.1" => Self::Rsa,
93 "1.2.840.10045.2.1" => Self::EcdsaP256, "2.16.840.1.101.3.4.3.17" => Self::MlDsa44,
95 "2.16.840.1.101.3.4.3.18" => Self::MlDsa65,
96 "2.16.840.1.101.3.4.3.19" => Self::MlDsa87,
97 "2.16.840.1.101.3.4.4.1" => Self::MlKem512,
98 "2.16.840.1.101.3.4.4.2" => Self::MlKem768,
99 "2.16.840.1.101.3.4.4.3" => Self::MlKem1024,
100 other => Self::Unknown(other.to_string()),
101 }
102 }
103
104 pub fn from_ec_oid(curve_oid: &str) -> Self {
109 match curve_oid {
110 "1.2.840.10045.3.1.7" => Self::EcdsaP256,
111 "1.3.132.0.34" => Self::EcdsaP384,
112 other => Self::Unknown(format!("ec-unknown-curve:{other}")),
113 }
114 }
115
116 pub fn oid(&self) -> &str {
118 match self {
119 Self::Rsa => "1.2.840.113549.1.1.1",
120 Self::EcdsaP256 | Self::EcdsaP384 => "1.2.840.10045.2.1",
121 Self::MlDsa44 => "2.16.840.1.101.3.4.3.17",
122 Self::MlDsa65 => "2.16.840.1.101.3.4.3.18",
123 Self::MlDsa87 => "2.16.840.1.101.3.4.3.19",
124 Self::MlKem512 => "2.16.840.1.101.3.4.4.1",
125 Self::MlKem768 => "2.16.840.1.101.3.4.4.2",
126 Self::MlKem1024 => "2.16.840.1.101.3.4.4.3",
127 Self::Unknown(oid) => oid.as_str(),
128 }
129 }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
146pub struct CertificationRequest {
147 pub version: u8,
149 pub subject: String,
151 pub key_algorithm: KeyAlgorithm,
153 pub subject_public_key_info: Vec<u8>,
155 pub signature_algorithm: String,
157 pub signature: Vec<u8>,
159 pub subject_alt_names: Vec<String>,
162 pub key_usage: Vec<String>,
164 pub challenge_password: Option<String>,
169 pub tbs_der: Vec<u8>,
171}
172
173impl CertificationRequest {
174 pub fn verify_self_signature(&self) -> EstResult<()> {
186 if self.tbs_der.is_empty() {
187 return Err(EstError::InvalidPkcs10(
188 "empty CertificationRequestInfo for signature verification".to_string(),
189 ));
190 }
191 if self.signature.is_empty() {
192 return Err(EstError::InvalidPkcs10(
193 "empty signature in CSR".to_string(),
194 ));
195 }
196 Ok(())
199 }
200
201 pub fn validate_challenge_password(&self, expected: &str) -> EstResult<()> {
207 match &self.challenge_password {
208 Some(pw) if pw == expected => Ok(()),
209 Some(pw) => Err(EstError::InvalidPop(format!(
210 "challengePassword mismatch: expected {expected:?}, got {pw:?}",
211 ))),
212 None => Err(EstError::MissingField(
213 "challengePassword attribute".to_string(),
214 )),
215 }
216 }
217}
218
219#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
232pub struct EnrollRequest {
233 #[serde(with = "serde_bytes")]
235 csr_der: Vec<u8>,
236}
237
238impl EnrollRequest {
239 pub fn new(csr_der: Vec<u8>) -> Self {
241 Self { csr_der }
242 }
243
244 pub fn csr_der(&self) -> &[u8] {
246 &self.csr_der
247 }
248
249 pub fn into_csr_der(self) -> Vec<u8> {
251 self.csr_der
252 }
253
254 pub fn to_base64(&self) -> String {
256 base64::engine::general_purpose::STANDARD.encode(&self.csr_der)
257 }
258
259 pub fn from_base64(base64_data: &str) -> EstResult<Self> {
261 let csr_der = base64::engine::general_purpose::STANDARD
262 .decode(base64_data)
263 .map_err(|e| EstError::InvalidBase64(e.to_string()))?;
264
265 Ok(Self::new(csr_der))
266 }
267
268 pub fn validate(&self) -> EstResult<()> {
273 if self.csr_der.is_empty() {
274 return Err(EstError::InvalidPkcs10("Empty CSR".to_string()));
275 }
276
277 if self.csr_der[0] != 0x30 {
279 return Err(EstError::InvalidPkcs10(
280 "Invalid DER: expected SEQUENCE tag".to_string(),
281 ));
282 }
283
284 if self.csr_der.len() < 100 {
286 return Err(EstError::InvalidPkcs10(format!(
287 "CSR too small: {} bytes",
288 self.csr_der.len()
289 )));
290 }
291
292 Ok(())
293 }
294
295 pub fn detect_signature_algorithm(&self) -> Option<String> {
304 None
307 }
308
309 pub fn contains_ml_dsa(&self) -> bool {
313 let ml_dsa_prefix = b"\x06\x0b\x60\x86\x48\x01\x65\x03\x04\x03"; self.csr_der
315 .windows(ml_dsa_prefix.len())
316 .any(|w| w == ml_dsa_prefix)
317 }
318
319 pub fn contains_ml_kem(&self) -> bool {
323 let ml_kem_prefix = b"\x06\x0b\x60\x86\x48\x01\x65\x03\x04\x04"; self.csr_der
325 .windows(ml_kem_prefix.len())
326 .any(|w| w == ml_kem_prefix)
327 }
328
329 pub fn to_certification_request(&self) -> CertificationRequest {
338 CertificationRequest {
339 version: 0, subject: String::new(),
341 key_algorithm: KeyAlgorithm::Unknown("unparsed".to_string()),
342 subject_public_key_info: Vec::new(),
343 signature_algorithm: String::new(),
344 signature: Vec::new(),
345 subject_alt_names: Vec::new(),
346 key_usage: Vec::new(),
347 challenge_password: None,
348 tbs_der: self.csr_der.clone(),
349 }
350 }
351}
352
353#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
362pub struct EnrollResponse {
363 #[serde(with = "serde_bytes")]
365 pkcs7_der: Vec<u8>,
366}
367
368impl EnrollResponse {
369 pub fn new(pkcs7_der: Vec<u8>) -> Self {
371 Self { pkcs7_der }
372 }
373
374 pub fn pkcs7_der(&self) -> &[u8] {
376 &self.pkcs7_der
377 }
378
379 pub fn into_pkcs7_der(self) -> Vec<u8> {
381 self.pkcs7_der
382 }
383
384 pub fn to_base64(&self) -> String {
386 base64::engine::general_purpose::STANDARD.encode(&self.pkcs7_der)
387 }
388
389 pub fn from_base64(base64_data: &str) -> EstResult<Self> {
391 let pkcs7_der = base64::engine::general_purpose::STANDARD
392 .decode(base64_data)
393 .map_err(|e| EstError::InvalidBase64(e.to_string()))?;
394
395 Ok(Self::new(pkcs7_der))
396 }
397
398 pub fn validate(&self) -> EstResult<()> {
400 if self.pkcs7_der.is_empty() {
401 return Err(EstError::InvalidPkcs7("Empty PKCS#7 structure".to_string()));
402 }
403
404 if self.pkcs7_der[0] != 0x30 {
405 return Err(EstError::InvalidPkcs7(
406 "Invalid DER: expected SEQUENCE tag".to_string(),
407 ));
408 }
409
410 if self.pkcs7_der.len() < 100 {
411 return Err(EstError::InvalidPkcs7(format!(
412 "PKCS#7 too small: {} bytes",
413 self.pkcs7_der.len()
414 )));
415 }
416
417 Ok(())
418 }
419}
420
421mod serde_bytes {
423 use serde::{Deserialize, Deserializer, Serializer};
424
425 pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
426 where
427 S: Serializer,
428 {
429 serializer.serialize_bytes(bytes)
430 }
431
432 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
433 where
434 D: Deserializer<'de>,
435 {
436 Vec::<u8>::deserialize(deserializer)
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_enroll_request_roundtrip() {
446 let der = vec![0x30, 0x82, 0x01, 0x00]; let mut full_der = der.clone();
448 full_der.extend(vec![0x00; 252]); let request = EnrollRequest::new(full_der.clone());
451 assert_eq!(request.csr_der(), &full_der);
452
453 let base64 = request.to_base64();
454 let decoded = EnrollRequest::from_base64(&base64).unwrap();
455 assert_eq!(decoded.csr_der(), &full_der);
456 }
457
458 #[test]
459 fn test_enroll_response_roundtrip() {
460 let der = vec![0x30, 0x82, 0x01, 0x00];
461 let mut full_der = der.clone();
462 full_der.extend(vec![0x00; 252]);
463
464 let response = EnrollResponse::new(full_der.clone());
465 assert_eq!(response.pkcs7_der(), &full_der);
466
467 let base64 = response.to_base64();
468 let decoded = EnrollResponse::from_base64(&base64).unwrap();
469 assert_eq!(decoded.pkcs7_der(), &full_der);
470 }
471
472 #[test]
473 fn test_validate_csr() {
474 let mut der = vec![0x30, 0x82, 0x01, 0x00];
475 der.extend(vec![0x00; 252]);
476 let request = EnrollRequest::new(der);
477 assert!(request.validate().is_ok());
478 }
479
480 #[test]
481 fn test_validate_empty_csr() {
482 let request = EnrollRequest::new(vec![]);
483 assert!(matches!(
484 request.validate(),
485 Err(EstError::InvalidPkcs10(_))
486 ));
487 }
488
489 #[test]
490 fn test_ml_dsa_oids() {
491 assert_eq!(ml_dsa_oids::ML_DSA_44, "2.16.840.1.101.3.4.3.17");
492 assert_eq!(ml_dsa_oids::ML_DSA_65, "2.16.840.1.101.3.4.3.18");
493 assert_eq!(ml_dsa_oids::ML_DSA_87, "2.16.840.1.101.3.4.3.19");
494 }
495
496 #[test]
497 fn test_ml_kem_oids() {
498 assert_eq!(ml_kem_oids::ML_KEM_512, "2.16.840.1.101.3.4.4.1");
499 assert_eq!(ml_kem_oids::ML_KEM_768, "2.16.840.1.101.3.4.4.2");
500 assert_eq!(ml_kem_oids::ML_KEM_1024, "2.16.840.1.101.3.4.4.3");
501 }
502
503 #[test]
504 fn test_contains_ml_dsa() {
505 let mut der = vec![0x30, 0x82, 0x01, 0x00];
507 der.extend_from_slice(b"\x06\x0b\x60\x86\x48\x01\x65\x03\x04\x03\x11"); der.extend(vec![0x00; 240]);
509
510 let request = EnrollRequest::new(der);
511 assert!(request.contains_ml_dsa());
512 assert!(!request.contains_ml_kem());
513 }
514
515 #[test]
516 fn test_contains_ml_kem() {
517 let mut der = vec![0x30, 0x82, 0x01, 0x00];
519 der.extend_from_slice(b"\x06\x0b\x60\x86\x48\x01\x65\x03\x04\x04\x01"); der.extend(vec![0x00; 240]);
521
522 let request = EnrollRequest::new(der);
523 assert!(!request.contains_ml_dsa());
524 assert!(request.contains_ml_kem());
525 }
526
527 #[test]
528 fn test_key_algorithm_from_oid() {
529 assert_eq!(
530 KeyAlgorithm::from_oid("1.2.840.113549.1.1.1"),
531 KeyAlgorithm::Rsa
532 );
533 assert_eq!(
534 KeyAlgorithm::from_oid("2.16.840.1.101.3.4.3.17"),
535 KeyAlgorithm::MlDsa44
536 );
537 assert_eq!(
538 KeyAlgorithm::from_oid("2.16.840.1.101.3.4.3.18"),
539 KeyAlgorithm::MlDsa65
540 );
541 assert_eq!(
542 KeyAlgorithm::from_oid("2.16.840.1.101.3.4.3.19"),
543 KeyAlgorithm::MlDsa87
544 );
545 assert_eq!(
546 KeyAlgorithm::from_oid("2.16.840.1.101.3.4.4.1"),
547 KeyAlgorithm::MlKem512
548 );
549 assert_eq!(
550 KeyAlgorithm::from_oid("2.16.840.1.101.3.4.4.2"),
551 KeyAlgorithm::MlKem768
552 );
553 assert_eq!(
554 KeyAlgorithm::from_oid("2.16.840.1.101.3.4.4.3"),
555 KeyAlgorithm::MlKem1024
556 );
557 assert!(matches!(
558 KeyAlgorithm::from_oid("1.2.3"),
559 KeyAlgorithm::Unknown(_)
560 ));
561 }
562
563 #[test]
564 fn test_key_algorithm_ec_curves() {
565 assert_eq!(
566 KeyAlgorithm::from_ec_oid("1.2.840.10045.3.1.7"),
567 KeyAlgorithm::EcdsaP256
568 );
569 assert_eq!(
570 KeyAlgorithm::from_ec_oid("1.3.132.0.34"),
571 KeyAlgorithm::EcdsaP384
572 );
573 assert!(matches!(
574 KeyAlgorithm::from_ec_oid("1.2.3.4"),
575 KeyAlgorithm::Unknown(_)
576 ));
577 }
578
579 #[test]
580 fn test_certification_request_verify_empty_tbs() {
581 let cr = CertificationRequest {
582 version: 0,
583 subject: String::new(),
584 key_algorithm: KeyAlgorithm::Rsa,
585 subject_public_key_info: Vec::new(),
586 signature_algorithm: String::new(),
587 signature: vec![0x00],
588 subject_alt_names: Vec::new(),
589 key_usage: Vec::new(),
590 challenge_password: None,
591 tbs_der: Vec::new(),
592 };
593 assert!(matches!(
594 cr.verify_self_signature(),
595 Err(EstError::InvalidPkcs10(_))
596 ));
597 }
598
599 #[test]
600 fn test_certification_request_verify_empty_signature() {
601 let cr = CertificationRequest {
602 version: 0,
603 subject: String::new(),
604 key_algorithm: KeyAlgorithm::Rsa,
605 subject_public_key_info: Vec::new(),
606 signature_algorithm: String::new(),
607 signature: Vec::new(),
608 subject_alt_names: Vec::new(),
609 key_usage: Vec::new(),
610 challenge_password: None,
611 tbs_der: vec![0x30, 0x00],
612 };
613 assert!(matches!(
614 cr.verify_self_signature(),
615 Err(EstError::InvalidPkcs10(_))
616 ));
617 }
618
619 #[test]
620 fn test_challenge_password_validation() {
621 let cr = CertificationRequest {
622 version: 0,
623 subject: String::new(),
624 key_algorithm: KeyAlgorithm::Rsa,
625 subject_public_key_info: Vec::new(),
626 signature_algorithm: String::new(),
627 signature: vec![0x00],
628 subject_alt_names: Vec::new(),
629 key_usage: Vec::new(),
630 challenge_password: Some("secret123".to_string()),
631 tbs_der: vec![0x30, 0x00],
632 };
633 assert!(cr.validate_challenge_password("secret123").is_ok());
634 assert!(matches!(
635 cr.validate_challenge_password("wrong"),
636 Err(EstError::InvalidPop(_))
637 ));
638 }
639
640 #[test]
641 fn test_challenge_password_missing() {
642 let cr = CertificationRequest {
643 version: 0,
644 subject: String::new(),
645 key_algorithm: KeyAlgorithm::Rsa,
646 subject_public_key_info: Vec::new(),
647 signature_algorithm: String::new(),
648 signature: vec![0x00],
649 subject_alt_names: Vec::new(),
650 key_usage: Vec::new(),
651 challenge_password: None,
652 tbs_der: vec![0x30, 0x00],
653 };
654 assert!(matches!(
655 cr.validate_challenge_password("anything"),
656 Err(EstError::MissingField(_))
657 ));
658 }
659
660 #[test]
661 fn test_to_certification_request() {
662 let mut der = vec![0x30, 0x82, 0x01, 0x00];
663 der.extend(vec![0x00; 252]);
664 let request = EnrollRequest::new(der.clone());
665 let cr = request.to_certification_request();
666 assert_eq!(cr.version, 0);
667 assert_eq!(cr.tbs_der, der);
668 }
669
670 #[test]
671 fn test_traditional_oids() {
672 assert_eq!(traditional_oids::RSA, "1.2.840.113549.1.1.1");
673 assert_eq!(traditional_oids::EC_PUBLIC_KEY, "1.2.840.10045.2.1");
674 }
675
676 #[test]
677 fn test_named_curve_oids() {
678 assert_eq!(named_curve_oids::P256, "1.2.840.10045.3.1.7");
679 assert_eq!(named_curve_oids::P384, "1.3.132.0.34");
680 }
681}