Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cert-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use dstack_types::{AppKeys, KeyProvider};
use ra_rpc::client::{RaClient, RaClientConfig};
use ra_tls::{
attestation::{QuoteContentType, VersionedAttestation},
cert::{generate_ra_cert, CaCert, CertConfig, CertSigningRequestV2, Csr},
cert::{generate_ra_cert, CaCert, CertConfigV2, CertSigningRequestV2, Csr},
rcgen::KeyPair,
};

Expand Down Expand Up @@ -96,7 +96,7 @@ impl CertRequestClient {
pub async fn request_cert(
&self,
key: &KeyPair,
config: CertConfig,
config: CertConfigV2,
attestation_override: Option<VersionedAttestation>,
) -> Result<Vec<String>> {
let pubkey = key.public_key_der();
Expand Down
11 changes: 5 additions & 6 deletions dstack-util/src/system_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use luks2::{
LuksSegmentSize,
};
use ra_rpc::client::{CertInfo, RaClient, RaClientConfig};
use ra_tls::cert::generate_ra_cert;
use ra_tls::cert::{generate_ra_cert, CertConfigV2};
use rand::Rng as _;
use scopeguard::defer;
use serde::{Deserialize, Serialize};
Expand All @@ -48,10 +48,7 @@ use cmd_lib::run_fun as cmd;
use dstack_gateway_rpc::{
gateway_client::GatewayClient, RegisterCvmRequest, RegisterCvmResponse, WireGuardPeer,
};
use ra_tls::{
cert::CertConfig,
rcgen::{KeyPair, PKCS_ECDSA_P256_SHA256},
};
use ra_tls::rcgen::{KeyPair, PKCS_ECDSA_P256_SHA256};
use serde_human_bytes as hex_bytes;
use serde_json::Value;

Expand Down Expand Up @@ -388,13 +385,15 @@ impl<'a> GatewayContext<'a> {
let sk = cmd!(wg genkey)?;
let pk = cmd!(echo $sk | wg pubkey).or(Err(anyhow!("Failed to generate public key")))?;

let config = CertConfig {
let config = CertConfigV2 {
org_name: None,
subject: "dstack-guest-agent".to_string(),
subject_alt_names: vec![],
usage_server_auth: false,
usage_client_auth: true,
ext_quote: true,
not_before: None,
not_after: None,
};
let cert_client = CertRequestClient::create(
self.keys,
Expand Down
2 changes: 2 additions & 0 deletions gateway/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ async fn maybe_gen_certs(config: &Config, tls_config: &TlsConfig) -> Result<()>
usage_ra_tls: true,
usage_server_auth: true,
usage_client_auth: false,
not_before: None,
not_after: None,
})
.await?;

Expand Down
2 changes: 2 additions & 0 deletions gateway/src/main_service/sync_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ pub(crate) async fn sync_task(proxy: Proxy) -> Result<()> {
usage_ra_tls: false,
usage_server_auth: false,
usage_client_auth: true,
not_after: None,
not_before: None,
})
.await
.context("Failed to get sync-client keys")?;
Expand Down
4 changes: 4 additions & 0 deletions guest-agent/rpc/proto/agent_rpc.proto
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ message GetTlsKeyArgs {
bool usage_server_auth = 4;
// Key usage client auth
bool usage_client_auth = 5;
// Certificate validity start time as seconds since UNIX epoch
optional uint64 not_before = 6;
// Certificate validity end time as seconds since UNIX epoch
optional uint64 not_after = 7;
}

// The request to derive a key
Expand Down
14 changes: 10 additions & 4 deletions guest-agent/src/rpc_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use or_panic::ResultOrPanic;
use ra_rpc::{Attestation, CallContext, RpcCall};
use ra_tls::{
attestation::{QuoteContentType, VersionedAttestation, DEFAULT_HASH_ALGORITHM},
cert::CertConfig,
cert::CertConfigV2,
kdf::{derive_ecdsa_key, derive_ecdsa_key_pair_from_bytes},
};
use rcgen::KeyPair;
Expand Down Expand Up @@ -78,13 +78,15 @@ impl AppStateInner {
.cert_client
.request_cert(
&key,
CertConfig {
CertConfigV2 {
org_name: None,
subject: "demo-cert".to_string(),
subject_alt_names: vec![],
usage_server_auth: false,
usage_client_auth: true,
ext_quote: true,
not_after: None,
not_before: None,
},
attestation_override,
)
Expand Down Expand Up @@ -233,13 +235,15 @@ impl DstackGuestRpc for InternalRpcHandler {
.context("Failed to generate secure seed")?;
let derived_key =
derive_ecdsa_key_pair_from_bytes(&seed, &[]).context("Failed to derive key")?;
let config = CertConfig {
let config = CertConfigV2 {
org_name: None,
subject: request.subject,
subject_alt_names: request.alt_names,
usage_server_auth: request.usage_server_auth,
usage_client_auth: request.usage_client_auth,
ext_quote: request.usage_ra_tls,
not_after: request.not_after,
not_before: request.not_before,
};
let attestation_override = self
.state
Expand Down Expand Up @@ -493,13 +497,15 @@ impl TappdRpc for InternalRpcHandlerV0 {
};
let derived_key = derive_ecdsa_key_pair_from_bytes(seed, &[request.path.as_bytes()])
.context("Failed to derive key")?;
let config = CertConfig {
let config = CertConfigV2 {
org_name: None,
subject: request.subject,
subject_alt_names: request.alt_names,
usage_server_auth: request.usage_server_auth,
usage_client_auth: request.usage_client_auth,
ext_quote: request.usage_ra_tls,
not_before: None,
not_after: None,
};
let attestation_override = self
.state
Expand Down
60 changes: 53 additions & 7 deletions ra-tls/src/cert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//! Certificate creation functions.

use std::time::SystemTime;
use std::time::{SystemTime, UNIX_EPOCH};
use std::{path::Path, time::Duration};

use anyhow::{anyhow, bail, Context, Result};
Expand Down Expand Up @@ -99,6 +99,8 @@ impl CaCert {
.maybe_attestation(attestation)
.maybe_app_id(app_id)
.special_usage(usage)
.maybe_not_before(cfg.not_before.map(unix_time_to_system_time))
.maybe_not_after(cfg.not_after.map(unix_time_to_system_time))
.build();
self.sign(req).context("Failed to sign certificate")
}
Expand All @@ -121,6 +123,42 @@ pub struct CertConfig {
pub ext_quote: bool,
}

/// The configuration of the certificate with optional validity overrides.
#[derive(Encode, Decode, Clone, PartialEq)]
pub struct CertConfigV2 {
/// The organization name of the certificate.
pub org_name: Option<String>,
/// The subject of the certificate.
pub subject: String,
/// The subject alternative names of the certificate.
pub subject_alt_names: Vec<String>,
/// The purpose of the certificate.
pub usage_server_auth: bool,
/// The purpose of the certificate.
pub usage_client_auth: bool,
/// Whether the certificate is quoted.
pub ext_quote: bool,
/// The certificate validity start time as seconds since UNIX epoch.
pub not_before: Option<u64>,
/// The certificate validity end time as seconds since UNIX epoch.
pub not_after: Option<u64>,
}

impl From<CertConfig> for CertConfigV2 {
fn from(config: CertConfig) -> Self {
Self {
org_name: config.org_name,
subject: config.subject,
subject_alt_names: config.subject_alt_names,
usage_server_auth: config.usage_server_auth,
usage_client_auth: config.usage_client_auth,
ext_quote: config.ext_quote,
not_before: None,
not_after: None,
}
}
}

/// A certificate signing request.
#[derive(Encode, Decode, Clone)]
pub struct CertSigningRequestV1 {
Expand Down Expand Up @@ -240,7 +278,7 @@ pub struct CertSigningRequestV2 {
/// The public key of the certificate.
pub pubkey: Vec<u8>,
/// The certificate configuration.
pub config: CertConfig,
pub config: CertConfigV2,
/// The attestation.
pub attestation: VersionedAttestation,
}
Expand All @@ -251,7 +289,7 @@ impl TryFrom<CertSigningRequestV1> for CertSigningRequestV2 {
Ok(Self {
confirm: v0.confirm,
pubkey: v0.pubkey,
config: v0.config,
config: v0.config.into(),
attestation: Attestation::from_tdx_quote(v0.quote, &v0.event_log)?.into_versioned(),
})
}
Expand Down Expand Up @@ -381,6 +419,10 @@ fn add_ext(params: &mut CertificateParams, oid: &[u64], content: impl AsRef<[u8]
.push(CustomExtension::from_oid_content(oid, content));
}

fn unix_time_to_system_time(secs: u64) -> SystemTime {
UNIX_EPOCH + Duration::from_secs(secs)
}

impl CertRequest<'_, KeyPair> {
/// Create a self-signed certificate.
pub fn self_signed(self) -> Result<Certificate> {
Expand Down Expand Up @@ -624,13 +666,15 @@ mod tests {
let csr = CertSigningRequestV2 {
confirm: "please sign cert:".to_string(),
pubkey: vec![1, 2, 3],
config: CertConfig {
config: CertConfigV2 {
org_name: None,
subject: "test.example.com".to_string(),
subject_alt_names: vec![],
usage_server_auth: true,
usage_client_auth: false,
ext_quote: false,
not_before: None,
not_after: None,
},
attestation: Attestation {
quote: AttestationQuote::DstackTdx(TdxQuote {
Expand All @@ -646,7 +690,7 @@ mod tests {
};

let actual = hex::encode(csr.encode());
let expected = "44706c65617365207369676e20636572743a0c0102030040746573742e6578616d706c652e636f6d0001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000";
let expected = "44706c65617365207369676e20636572743a0c0102030040746573742e6578616d706c652e636f6d00010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000";
assert_eq!(actual, expected);
}

Expand All @@ -655,13 +699,15 @@ mod tests {
let csr = CertSigningRequestV2 {
confirm: "please sign cert:".to_string(),
pubkey: vec![1, 2, 3],
config: CertConfig {
config: CertConfigV2 {
org_name: None,
subject: "test.example.com".to_string(),
subject_alt_names: vec![],
usage_server_auth: true,
usage_client_auth: false,
ext_quote: true,
not_before: None,
not_after: None,
},
attestation: Attestation {
quote: AttestationQuote::DstackTdx(TdxQuote {
Expand All @@ -677,7 +723,7 @@ mod tests {
};

let actual = hex::encode(csr.encode());
let expected = "44706c65617365207369676e20636572743a0c0102030040746573742e6578616d706c652e636f6d000100010000040900000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000";
let expected = "44706c65617365207369676e20636572743a0c0102030040746573742e6578616d706c652e636f6d0001000100000000040900000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000";
assert_eq!(actual, expected);
}
}