Skip to content

Commit b50641d

Browse files
authored
chore: Implement auth caching (#78)
* Initial implementation of auth caching * Add auth mod * Fix tests * Move Result to err mod * fix: refactor codes
1 parent e373905 commit b50641d

36 files changed

+302
-125
lines changed

.vscode/settings.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"cSpell.words": [
3+
"Mpesa"
4+
]
5+
}

Cargo.toml

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,37 @@ readme = "./README.md"
1010
license = "MIT"
1111

1212
[dependencies]
13-
chrono = {version = "0.4", optional = true, default-features = false, features = ["clock", "serde"] }
14-
openssl = {version = "0.10", optional = true}
15-
reqwest = {version = "0.11", features = ["json"]}
16-
secrecy = "0.8.0"
17-
serde = {version="1.0", features= ["derive"]}
13+
cached = { version = "0.46", features = ["wasm", "async", "proc_macro"] }
14+
chrono = { version = "0.4", optional = true, default-features = false, features = [
15+
"clock",
16+
"serde",
17+
] }
18+
openssl = { version = "0.10", optional = true }
19+
reqwest = { version = "0.11", features = ["json"] }
20+
serde = { version = "1.0", features = ["derive"] }
1821
serde_json = "1.0"
1922
serde_repr = "0.1"
2023
thiserror = "1.0.37"
2124
wiremock = "0.5"
25+
secrecy = "0.8.0"
2226

2327
[dev-dependencies]
2428
dotenv = "0.15"
25-
tokio = {version = "1", features = ["rt", "rt-multi-thread", "macros"]}
29+
tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros"] }
2630
wiremock = "0.5"
2731

2832
[features]
29-
default = ["account_balance", "b2b", "b2c", "bill_manager", "c2b_register", "c2b_simulate", "express_request", "transaction_reversal", "transaction_status"]
33+
default = [
34+
"account_balance",
35+
"b2b",
36+
"b2c",
37+
"bill_manager",
38+
"c2b_register",
39+
"c2b_simulate",
40+
"express_request",
41+
"transaction_reversal",
42+
"transaction_status",
43+
]
3044
account_balance = ["dep:openssl"]
3145
b2b = ["dep:openssl"]
3246
b2c = ["dep:openssl"]
@@ -35,4 +49,4 @@ c2b_register = []
3549
c2b_simulate = []
3650
express_request = ["dep:chrono"]
3751
transaction_reversal = ["dep:openssl"]
38-
transaction_status= ["dep:openssl"]
52+
transaction_status = ["dep:openssl"]

src/auth.rs

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
use cached::proc_macro::cached;
2+
use serde::{Deserialize, Serialize};
3+
4+
use crate::{ApiEnvironment, ApiError, Mpesa, MpesaError, MpesaResult};
5+
6+
const AUTHENTICATION_URL: &str = "/oauth/v1/generate?grant_type=client_credentials";
7+
8+
#[cached(
9+
size = 1,
10+
time = 3600,
11+
key = "String",
12+
result = true,
13+
convert = r#"{ format!("{}", client.client_key()) }"#
14+
)]
15+
pub(crate) async fn auth(client: &Mpesa<impl ApiEnvironment>) -> MpesaResult<String> {
16+
let url = format!("{}{}", client.environment.base_url(), AUTHENTICATION_URL);
17+
18+
let response = client
19+
.http_client
20+
.get(&url)
21+
.basic_auth(client.client_key(), Some(&client.client_secret()))
22+
.send()
23+
.await?;
24+
25+
if response.status().is_success() {
26+
let value = response.json::<AuthenticationResponse>().await?;
27+
let access_token = value.access_token;
28+
29+
return Ok(access_token);
30+
}
31+
32+
let error = response.json::<ApiError>().await?;
33+
Err(MpesaError::AuthenticationError(error))
34+
}
35+
36+
/// Response returned from the authentication function
37+
#[derive(Debug, Serialize, Deserialize)]
38+
pub struct AuthenticationResponse {
39+
/// Access token which is used as the Bearer-Auth-Token
40+
pub access_token: String,
41+
/// Expiry time in seconds
42+
pub expiry_in: u64,
43+
}
44+
45+
impl std::fmt::Display for AuthenticationResponse {
46+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47+
write!(
48+
f,
49+
"token :{} expires in: {}",
50+
self.access_token, self.expiry_in
51+
)
52+
}
53+
}
54+
55+
#[cfg(test)]
56+
mod tests {
57+
use wiremock::{Mock, MockServer};
58+
59+
use super::*;
60+
61+
#[derive(Debug, Clone)]
62+
pub struct TestEnvironment {
63+
pub server_url: String,
64+
}
65+
66+
impl TestEnvironment {
67+
pub async fn new(server: &MockServer) -> Self {
68+
TestEnvironment {
69+
server_url: server.uri(),
70+
}
71+
}
72+
}
73+
74+
impl ApiEnvironment for TestEnvironment {
75+
fn base_url(&self) -> &str {
76+
&self.server_url
77+
}
78+
79+
fn get_certificate(&self) -> &str {
80+
include_str!("../src/certificates/sandbox")
81+
}
82+
}
83+
84+
#[tokio::test]
85+
async fn test_cached_auth() {
86+
use cached::Cached;
87+
88+
use crate::Mpesa;
89+
90+
let server = MockServer::start().await;
91+
92+
let env = TestEnvironment::new(&server).await;
93+
94+
let client = Mpesa::new("test_api_key", "test_public_key", env);
95+
96+
Mock::given(wiremock::matchers::method("GET"))
97+
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(
98+
AuthenticationResponse {
99+
access_token: "test_token".to_string(),
100+
expiry_in: 3600,
101+
},
102+
))
103+
.expect(1)
104+
.mount(&server)
105+
.await;
106+
107+
auth_prime_cache(&client).await.unwrap();
108+
109+
let mut cache = AUTH.lock().await;
110+
111+
assert!(cache.cache_get(&client.client_key().to_string()).is_some());
112+
assert_eq!(cache.cache_hits().unwrap(), 1);
113+
assert_eq!(cache.cache_capacity().unwrap(), 1);
114+
}
115+
}

src/client.rs

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
1+
use std::cell::RefCell;
2+
3+
use cached::Cached;
4+
use openssl::base64;
5+
use openssl::rsa::Padding;
6+
use openssl::x509::X509;
7+
use reqwest::Client as HttpClient;
8+
9+
use crate::auth::AUTH;
110
use crate::environment::ApiEnvironment;
211
use crate::services::{
312
AccountBalanceBuilder, B2bBuilder, B2cBuilder, BulkInvoiceBuilder, C2bRegisterBuilder,
413
C2bSimulateBuilder, CancelInvoiceBuilder, MpesaExpressRequestBuilder, OnboardBuilder,
514
OnboardModifyBuilder, ReconciliationBuilder, SingleInvoiceBuilder, TransactionReversalBuilder,
615
TransactionStatusBuilder,
716
};
8-
use crate::{ApiError, MpesaError};
9-
use openssl::base64;
10-
use openssl::rsa::Padding;
11-
use openssl::x509::X509;
12-
use reqwest::Client as HttpClient;
17+
use crate::{auth, MpesaResult};
1318
use secrecy::{ExposeSecret, Secret};
14-
use serde_json::Value;
15-
use std::cell::RefCell;
1619

1720
/// Source: [test credentials](https://developer.safaricom.co.ke/test_credentials)
1821
const DEFAULT_INITIATOR_PASSWORD: &str = "Safcom496!";
1922
/// Get current package version from metadata
2023
const CARGO_PACKAGE_VERSION: &str = env!("CARGO_PKG_VERSION");
2124

22-
/// `Result` enum type alias
23-
pub type MpesaResult<T> = Result<T, MpesaError>;
24-
2525
/// Mpesa client that will facilitate communication with the Safaricom API
2626
#[derive(Clone, Debug)]
2727
pub struct Mpesa<Env: ApiEnvironment> {
@@ -72,6 +72,16 @@ impl<'mpesa, Env: ApiEnvironment> Mpesa<Env> {
7272
p.expose_secret().into()
7373
}
7474

75+
/// Get the client key
76+
pub(crate) fn client_key(&self) -> &str {
77+
&self.client_key
78+
}
79+
80+
/// Get the client secret
81+
pub(crate) fn client_secret(&self) -> &str {
82+
self.client_secret.expose_secret()
83+
}
84+
7585
/// Optional in development but required for production, you will need to call this method and set your production initiator password.
7686
/// If in development, default initiator password is already pre-set
7787
/// ```ignore
@@ -107,31 +117,27 @@ impl<'mpesa, Env: ApiEnvironment> Mpesa<Env> {
107117
/// # Errors
108118
/// Returns a `MpesaError` on failure
109119
pub(crate) async fn auth(&self) -> MpesaResult<String> {
110-
let url = format!(
111-
"{}/oauth/v1/generate?grant_type=client_credentials",
112-
self.environment.base_url()
113-
);
114-
let response = self
115-
.http_client
116-
.get(&url)
117-
.basic_auth(&self.client_key, Some(&self.client_secret.expose_secret()))
118-
.send()
119-
.await?;
120-
if response.status().is_success() {
121-
let value = response.json::<Value>().await?;
122-
let access_token = value
123-
.get("access_token")
124-
.ok_or_else(|| String::from("Failed to extract token from the response"))
125-
.unwrap();
126-
let access_token = access_token
127-
.as_str()
128-
.ok_or_else(|| String::from("Error converting access token to string"))
129-
.unwrap();
130-
131-
return Ok(access_token.to_string());
120+
if let Some(token) = AUTH.lock().await.cache_get(&self.client_key) {
121+
return Ok(token.to_owned());
132122
}
133-
let error = response.json::<ApiError>().await?;
134-
Err(MpesaError::AuthenticationError(error))
123+
124+
// Generate a new access token
125+
let new_token = match auth::auth_prime_cache(self).await {
126+
Ok(token) => token,
127+
Err(e) => return Err(e),
128+
};
129+
130+
// Double-check if the access token is cached by another thread
131+
if let Some(token) = AUTH.lock().await.cache_get(&self.client_key) {
132+
return Ok(token.to_owned());
133+
}
134+
135+
// Cache the new token
136+
AUTH.lock()
137+
.await
138+
.cache_set(self.client_key.clone(), new_token.to_owned());
139+
140+
Ok(new_token)
135141
}
136142

137143
/// **B2C Builder**
@@ -529,9 +535,8 @@ impl<'mpesa, Env: ApiEnvironment> Mpesa<Env> {
529535

530536
#[cfg(test)]
531537
mod tests {
532-
use crate::Sandbox;
533-
534538
use super::*;
539+
use crate::Sandbox;
535540

536541
#[test]
537542
fn test_setting_initator_password() {

src/constants.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
use std::fmt::{Display, Formatter, Result as FmtResult};
2+
13
use chrono::prelude::{DateTime, Utc};
24
use serde::{Deserialize, Serialize};
35
use serde_repr::{Deserialize_repr, Serialize_repr};
4-
use std::fmt::{Display, Formatter, Result as FmtResult};
56

67
/// Mpesa command ids
78
#[derive(Debug, Serialize, Deserialize)]

src/environment.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
//! and the `public key` an X509 certificate used for encrypting initiator passwords. You can read more about that from
1111
//! the Safaricom API [docs](https://developer.safaricom.co.ke/docs?javascript#security-credentials).
1212
13+
use std::convert::TryFrom;
14+
use std::str::FromStr;
15+
1316
use crate::MpesaError;
14-
use std::{convert::TryFrom, str::FromStr};
1517

1618
#[derive(Debug, Clone)]
1719
/// Enum to map to desired environment so as to access certificate

src/errors.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
use std::env::VarError;
2+
use std::fmt;
3+
14
use serde::{Deserialize, Serialize};
2-
use std::{env::VarError, fmt};
35

46
/// Mpesa error stack
57
#[derive(thiserror::Error, Debug)]
@@ -46,6 +48,9 @@ pub enum MpesaError {
4648
Message(&'static str),
4749
}
4850

51+
/// `Result` enum type alias
52+
pub type MpesaResult<T> = Result<T, MpesaError>;
53+
4954
#[derive(Debug, Serialize, Deserialize)]
5055
pub struct ApiError {
5156
pub request_id: String,

src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
#![doc = include_str!("../README.md")]
22

3+
mod auth;
34
mod client;
45
mod constants;
56
pub mod environment;
67
mod errors;
78
pub mod services;
89

9-
pub use client::{Mpesa, MpesaResult};
10+
pub use client::Mpesa;
1011
pub use constants::{
1112
CommandId, IdentifierTypes, Invoice, InvoiceItem, ResponseType, SendRemindersTypes,
1213
};
1314
pub use environment::ApiEnvironment;
1415
pub use environment::Environment::{self, Production, Sandbox};
15-
pub use errors::{ApiError, MpesaError};
16+
pub use errors::{ApiError, MpesaError, MpesaResult};

src/services/account_balance.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
use crate::client::MpesaResult;
1+
use serde::{Deserialize, Serialize};
2+
23
use crate::constants::{CommandId, IdentifierTypes};
34
use crate::environment::ApiEnvironment;
4-
use crate::{Mpesa, MpesaError};
5-
use serde::{Deserialize, Serialize};
5+
use crate::{Mpesa, MpesaError, MpesaResult};
66

77
#[derive(Debug, Serialize)]
88
/// Account Balance payload

src/services/b2b.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
use crate::client::{Mpesa, MpesaResult};
1+
use serde::{Deserialize, Serialize};
2+
3+
use crate::client::Mpesa;
24
use crate::constants::{CommandId, IdentifierTypes};
35
use crate::environment::ApiEnvironment;
4-
use crate::errors::MpesaError;
5-
use serde::{Deserialize, Serialize};
6+
use crate::errors::{MpesaError, MpesaResult};
67

78
#[derive(Debug, Serialize)]
89
struct B2bPayload<'mpesa> {

0 commit comments

Comments
 (0)