Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/group_chat' into group_chat
Browse files Browse the repository at this point in the history
# Conflicts:
#	citadel_workspace_service/src/kernel/mod.rs
#	citadel_workspace_service/src/kernel/request_handler.rs
  • Loading branch information
Tjemmmic committed Oct 10, 2023
2 parents c191728 + 3c6213a commit a10c93c
Show file tree
Hide file tree
Showing 6 changed files with 401 additions and 345 deletions.
198 changes: 112 additions & 86 deletions citadel_workspace_service/src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub struct Connection {
peers: HashMap<u64, PeerConnection>,
associated_tcp_connection: Uuid,
c2s_file_transfer_handlers: HashMap<u64, Option<ObjectTransferHandler>>,
groups: HashMap<MessageGroupKey, GroupConnection>,
}

#[allow(dead_code)]
Expand All @@ -59,8 +60,6 @@ struct PeerConnection {
#[allow(dead_code)]
struct GroupConnection {
channel: GroupChannel,
implicated_cid: u64,
associated_tcp_connection: Uuid,
}

impl Connection {
Expand All @@ -75,6 +74,7 @@ impl Connection {
client_server_remote,
associated_tcp_connection,
c2s_file_transfer_handlers: HashMap::new(),
groups: HashMap::new(),
}
}

Expand Down Expand Up @@ -116,6 +116,14 @@ impl Connection {
}
}

pub fn add_group_channel(
&mut self,
group_key: MessageGroupKey,
group_channel: GroupConnection,
) {
self.groups.insert(group_key, group_channel);
}

// fn clear_group_connection(&mut self, group_key: MessageGroupKey) -> Option<GroupConnection> {
// self.groups.remove(&group_key)
// }
Expand Down Expand Up @@ -180,30 +188,39 @@ impl NetKernel for CitadelWorkspaceService {
let remote_for_closure = remote.clone();
let listener = tokio::net::TcpListener::bind(self.bind_address).await?;

let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<InternalServiceRequest>();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();

let tcp_connection_map = &self.tcp_connection_map;
let server_connection_map = &self.server_connection_map;

let listener_task = async move {
while let Ok((conn, _addr)) = listener.accept().await {
let (tx1, rx1) = tokio::sync::mpsc::unbounded_channel::<InternalServiceResponse>();
let id = Uuid::new_v4();
tcp_connection_map.lock().await.insert(id, tx1);
handle_connection(conn, tx.clone(), rx1, id);
handle_connection(
conn,
tx.clone(),
rx1,
id,
tcp_connection_map.clone(),
server_connection_map.clone(),
);
}
Ok(())
};

let server_connection_map = &self.server_connection_map;

let inbound_command_task = async move {
while let Some(command) = rx.recv().await {
while let Some((command, conn_id)) = rx.recv().await {
// TODO: handle error once payload_handler is fallible
handle_request(
command,
conn_id,
server_connection_map,
&mut remote,
tcp_connection_map,
&self.group_map
)
.await;
}
Expand Down Expand Up @@ -348,16 +365,12 @@ impl NetKernel for CitadelWorkspaceService {
}
}
NodeResult::GroupChannelCreated(group_channel_created) => {
let group_channel = group_channel_created.channel;
let group_key = group_channel.key();
let implicated_cid = group_channel.cid();
let channel = group_channel_created.channel;
let cid = channel.cid();
let key = channel.key();
let mut server_connection_map = self.server_connection_map.lock().await;
if let Some(connection) = server_connection_map.get_mut(&implicated_cid) {
self.group_map.lock().await.insert(group_key, GroupConnection{
channel: group_channel,
implicated_cid,
associated_tcp_connection: connection.associated_tcp_connection,
} );
connection.add_group_channel(key, GroupConnection { channel });
send_response_to_tcp_client(
&self.tcp_connection_map,
InternalServiceResponse::GroupChannelCreateSuccess(GroupChannelCreateSuccess {
Expand All @@ -370,37 +383,39 @@ impl NetKernel for CitadelWorkspaceService {
.await;
}
}
NodeResult::PeerEvent(event) => {
match event.event {
PeerSignal::Disconnect(
PeerConnectionType::LocalGroupPeer {
implicated_cid,
peer_cid,
},
_,
) => {
if let Some(conn) = self.clear_peer_connection(implicated_cid, peer_cid).await {
let response = InternalServiceResponse::Disconnected(Disconnected {
cid: implicated_cid,
peer_cid: Some(peer_cid),
request_id: None,
});
send_response_to_tcp_client(
&self.tcp_connection_map,
response,
conn.associated_tcp_connection,
)
.await;
}
}
PeerSignal::BroadcastConnected(group_broadcast) => {
let mut group_map = self.group_map.lock().await;
let mut server_connection_map = self.server_connection_map.lock().await;
handle_group_broadcast(group_broadcast, &mut group_map, &mut server_connection_map, self.tcp_connection_map.clone()).await;
NodeResult::PeerEvent(event) => match event.event {
PeerSignal::Disconnect(
PeerConnectionType::LocalGroupPeer {
implicated_cid,
peer_cid,
},
_,
) => {
if let Some(conn) = self.clear_peer_connection(implicated_cid, peer_cid).await {
let response = InternalServiceResponse::Disconnected(Disconnected {
cid: implicated_cid,
peer_cid: Some(peer_cid),
request_id: None,
});
send_response_to_tcp_client(
&self.tcp_connection_map,
response,
conn.associated_tcp_connection,
)
.await;
}
_ => {}
}
}
PeerSignal::BroadcastConnected(group_broadcast) => {
let mut group_map = self.group_map.lock().await;
handle_group_broadcast(
group_broadcast,
&mut group_map,
self.tcp_connection_map.clone(),
)
.await;
}
_ => {}
},

NodeResult::GroupEvent(group_event) => {
let mut group_map = self.group_map.lock().await;
Expand Down Expand Up @@ -453,10 +468,11 @@ async fn sink_send_payload(

fn send_to_kernel(
payload_to_send: &[u8],
sender: &UnboundedSender<InternalServiceRequest>,
sender: &UnboundedSender<(InternalServiceRequest, Uuid)>,
conn_id: Uuid,
) -> Result<(), NetworkError> {
if let Some(payload) = deserialize(payload_to_send) {
sender.send(payload)?;
sender.send((payload, conn_id))?;
Ok(())
} else {
error!(target: "citadel", "w task: failed to deserialize payload");
Expand All @@ -466,20 +482,19 @@ fn send_to_kernel(

fn handle_connection(
conn: TcpStream,
to_kernel: UnboundedSender<InternalServiceRequest>,
to_kernel: UnboundedSender<(InternalServiceRequest, Uuid)>,
mut from_kernel: tokio::sync::mpsc::UnboundedReceiver<InternalServiceResponse>,
conn_id: Uuid,
tcp_connection_map: Arc<Mutex<HashMap<Uuid, UnboundedSender<InternalServiceResponse>>>>,
server_connection_map: Arc<Mutex<HashMap<u64, Connection>>>,
) {
tokio::task::spawn(async move {
let framed = wrap_tcp_conn(conn);
let (mut sink, mut stream) = framed.split();

let write_task = async move {
let response =
InternalServiceResponse::ServiceConnectionAccepted(ServiceConnectionAccepted {
id: conn_id,
request_id: None,
});
InternalServiceResponse::ServiceConnectionAccepted(ServiceConnectionAccepted);

sink_send_payload(&response, &mut sink).await;

Expand All @@ -492,7 +507,7 @@ fn handle_connection(
while let Some(message) = stream.next().await {
match message {
Ok(message) => {
if let Err(err) = send_to_kernel(&message, &to_kernel) {
if let Err(err) = send_to_kernel(&message, &to_kernel, conn_id) {
error!(target: "citadel", "Failed to send to kernel: {:?}", err);
break;
}
Expand All @@ -509,6 +524,11 @@ fn handle_connection(
res0 = write_task => res0,
res1 = read_task => res1,
}

tcp_connection_map.lock().await.remove(&conn_id);
let mut server_connection_map = server_connection_map.lock().await;
// Remove all connections whose associated_tcp_connection is conn_id
server_connection_map.retain(|_, v| v.associated_tcp_connection != conn_id);
});
}

Expand All @@ -521,63 +541,70 @@ async fn handle_group_broadcast(
let Some((response, cid)) = match group_broadcast {
GroupBroadcast::Invitation(group_key) => {
if let Some(group_connection) = group_map.get_mut(&group_key) {
let cid = group_connection.implicated_cid;
Some((InternalServiceResponse::GroupInvitation(GroupInvitation {
let cid = group_connection.channel.cid();
Some((
InternalServiceResponse::GroupInvitation(GroupInvitation {
cid,
group_key,
request_id: None,
}),
cid,
group_key,
request_id: None,
}), cid))
}
else {
))
} else {
None
}

},
}

GroupBroadcast::RequestJoin(group_key) => {
if let Some(group_connection) = group_map.get_mut(&group_key) {
let cid = group_connection.implicated_cid;
Some((InternalServiceResponse::GroupRequestJoinAccepted(GroupRequestJoinAccepted {
Some((
InternalServiceResponse::GroupJoinRequest(GroupJoinRequest {
cid,
peer_cid,
group_key,
request_id: None,
}),
cid,
peer_cid,
group_key,
request_id: None,
}), cid))
}
else {
))
} else {
None
}
},
}

GroupBroadcast::AcceptMembership(group_key) => {
if let Some(group_connection) = group_map.get_mut(&group_key) {
let cid = group_connection.implicated_cid;
Some((InternalServiceResponse::GroupRequestJoinAccepted(GroupRequestJoinAccepted {
Some((
InternalServiceResponse::GroupRequestJoinAccepted(GroupRequestJoinAccepted {
cid,
group_key,
request_id: None,
}),
cid,
group_key,
request_id: None,
}), cid))
}
else {
))
} else {
None
}
},
}

GroupBroadcast::Message(peer_cid, group_key, message) => {
if let Some(group_connection) = group_map.get_mut(&group_key) {
let cid = group_connection.implicated_cid;
Some((InternalServiceResponse::GroupMessageReceived(GroupMessageReceived {
Some((
InternalServiceResponse::GroupMessageReceived(GroupMessageReceived {
cid,
peer_cid,
message: message.into(),
group_key,
request_id: None,
}),
cid,
peer_cid,
message: Vec::from(message.into_buffer()),
group_key,
request_id: None,
}), cid))
}
else {
))
} else {
None
}
},
}

GroupBroadcast::LeaveRoomResponse(group_key, status, message) => {None},

Expand All @@ -603,7 +630,7 @@ async fn handle_group_broadcast(

GroupBroadcast::SignalResponse(result) => {None},

_ => {None},
_ => None,
};
match Some((response, cid)) {
Some((internal_service_response, response_cid)) => {
Expand Down Expand Up @@ -638,7 +665,6 @@ fn spawn_tick_updater(
match tcp_connection_map.lock().await.get(&uuid) {
Some(entry) => {
let message = InternalServiceResponse::FileTransferTick(FileTransferTick {
uuid,
cid: implicated_cid,
peer_cid,
status: status_message,
Expand Down
Loading

0 comments on commit a10c93c

Please sign in to comment.