Skip to content

Commit

Permalink
Merge pull request #25 from noctisynth/refactor/async-swap
Browse files Browse the repository at this point in the history
refactor: use `async-swap` instead of rwlock
  • Loading branch information
fu050409 committed Jun 23, 2024
2 parents 08403d0 + 595d633 commit 8113103
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 26 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion oblivion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ scrypt = "0.11"
hkdf = "0.12"

# Utils
arc-swap = "1.7.1"
oblivion-codegen = { workspace = true }
proc-macro2 = { workspace = true }
futures = { workspace = true }
Expand Down Expand Up @@ -51,7 +52,7 @@ name = "main"
bench = []
perf = []
unsafe = ["elliptic-curve", "p256"]
python = ["pyo3"]
pyo3 = ["dep:pyo3"]
serde = ["dep:serde"]

[[bench]]
Expand Down
6 changes: 3 additions & 3 deletions oblivion/src/exceptions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! # Oblivion exception
//! All exceptions to the Oblivion function return `OblivionException`.
#[cfg(feature = "python")]
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
use ring::error::Unspecified;
use scrypt::errors::InvalidOutputLen;
Expand Down Expand Up @@ -39,13 +39,13 @@ pub enum Exception {
ConnectionClosed,
}

#[cfg(feature = "python")]
#[cfg(feature = "pyo3")]
#[pyclass]
pub struct PyOblivionException {
pub message: String,
}

#[cfg(feature = "python")]
#[cfg(feature = "pyo3")]
#[pymethods]
impl PyOblivionException {
#[new]
Expand Down
24 changes: 12 additions & 12 deletions oblivion/src/models/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,38 @@ use serde::{Deserialize, Serialize};
use tokio::{net::TcpStream, sync::Mutex, task::JoinHandle};

use crate::exceptions::Exception;
#[cfg(feature = "python")]
#[cfg(feature = "pyo3")]
use crate::exceptions::PyOblivionException;

use crate::utils::gear::Socket;
use crate::utils::parser::OblivionPath;

#[cfg(feature = "python")]
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
#[cfg(not(feature = "python"))]
#[cfg(not(feature = "pyo3"))]
use serde_json::{from_slice, Value};
#[cfg(feature = "python")]
#[cfg(feature = "pyo3")]
use serde_json::{json, Value};

use super::session::Session;

#[cfg_attr(feature = "python", pyclass)]
#[cfg_attr(feature = "pyo3", pyclass)]
#[derive(Debug, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Response {
#[cfg_attr(feature = "python", pyo3(get))]
#[cfg_attr(feature = "pyo3", pyo3(get))]
pub header: Option<String>,
#[cfg_attr(feature = "python", pyo3(get))]
#[cfg_attr(feature = "pyo3", pyo3(get))]
pub content: Vec<u8>,
#[cfg_attr(feature = "python", pyo3(get))]
#[cfg_attr(feature = "pyo3", pyo3(get))]
pub entrance: Option<String>,
#[cfg_attr(feature = "python", pyo3(get))]
#[cfg_attr(feature = "pyo3", pyo3(get))]
pub status_code: u32,
#[cfg_attr(feature = "python", pyo3(get))]
#[cfg_attr(feature = "pyo3", pyo3(get))]
pub flag: u32,
}

#[cfg(not(feature = "python"))]
#[cfg(not(feature = "pyo3"))]
impl Response {
pub fn new(
header: Option<String>,
Expand Down Expand Up @@ -92,7 +92,7 @@ impl PartialEq for Response {
}
}

#[cfg(feature = "python")]
#[cfg(feature = "pyo3")]
#[pymethods]
impl Response {
#[new]
Expand Down
17 changes: 7 additions & 10 deletions oblivion/src/models/session.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::sync::Arc;

use anyhow::{anyhow, Result};
use arc_swap::ArcSwap;
use chrono::{DateTime, Local};
use serde_json::Value;
use tokio::sync::RwLock;

#[cfg(feature = "unsafe")]
use p256::{ecdh::EphemeralSecret, PublicKey};
Expand Down Expand Up @@ -37,7 +37,7 @@ pub struct Session {
pub request_time: DateTime<Local>,
pub request: OblivionRequest,
pub socket: Arc<Socket>,
closed: RwLock<bool>,
closed: ArcSwap<bool>,
}

impl Session {
Expand All @@ -54,7 +54,7 @@ impl Session {
request_time: Local::now(),
request: Default::default(),
socket: Arc::new(socket),
closed: RwLock::new(false),
closed: ArcSwap::new(Arc::new(false)),
})
}

Expand All @@ -71,7 +71,7 @@ impl Session {
request_time: Local::now(),
request: Default::default(),
socket: Arc::new(socket),
closed: RwLock::new(false),
closed: ArcSwap::new(Arc::new(false)),
})
}

Expand Down Expand Up @@ -190,10 +190,7 @@ impl Session {
let socket = &self.socket;

let flag = OSC::from_stream(socket).await?.status_code;
let content = OED::new(&self.aes_key)
.from_stream(socket)
.await?
.take();
let content = OED::new(&self.aes_key).from_stream(socket).await?.take();
let status_code = OSC::from_stream(socket).await?.status_code;
let response = Response::new(None, content, None, status_code, flag);

Expand All @@ -205,15 +202,15 @@ impl Session {

pub async fn close(&self) -> Result<()> {
if !self.closed().await {
*self.closed.write().await = true;
self.closed.store(Arc::new(true));
self.socket.close().await
} else {
Ok(())
}
}

pub async fn closed(&self) -> bool {
*self.closed.read().await
**self.closed.load()
}

pub fn header(&self) -> &str {
Expand Down

0 comments on commit 8113103

Please sign in to comment.