/* * Copyright (C) 2021 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ //! Provides a backing task to implement a Dispatcher use crate::boot_time::{BootTime, Duration}; use anyhow::{bail, Result}; use log::{debug, trace, warn}; use std::collections::HashMap; use tokio::sync::{mpsc, oneshot}; use super::{Command, QueryError, Response}; use crate::network::{Network, ServerInfo, SocketTagger, ValidationReporter}; use crate::{config, network}; pub struct Driver { command_rx: mpsc::Receiver, networks: HashMap, validation: ValidationReporter, tagger: SocketTagger, config_cache: config::Cache, } fn debug_err(r: Result<()>) { if let Err(e) = r { debug!("Dispatcher loop got {:?}", e); } } impl Driver { pub fn new( command_rx: mpsc::Receiver, validation: ValidationReporter, tagger: SocketTagger, ) -> Self { Self { command_rx, networks: HashMap::new(), validation, tagger, config_cache: config::Cache::new(), } } pub async fn drive(mut self) -> Result<()> { loop { self.drive_once().await? } } async fn drive_once(&mut self) -> Result<()> { if let Some(command) = self.command_rx.recv().await { trace!("dispatch command: {:?}", command); match command { Command::Probe { info, timeout } => debug_err(self.probe(info, timeout).await), Command::Query { net_id, base64_query, expired_time, resp } => { debug_err(self.query(net_id, base64_query, expired_time, resp).await) } Command::Clear { net_id } => { self.networks.remove(&net_id); self.config_cache.garbage_collect(); } Command::Exit => { bail!("Death due to Exit") } } Ok(()) } else { bail!("Death due to command_tx dying") } } async fn query( &mut self, net_id: u32, query: String, expiry: BootTime, response: oneshot::Sender, ) -> Result<()> { if let Some(network) = self.networks.get_mut(&net_id) { network.query(network::Query { query, response, expiry }).await?; } else { warn!("Tried to send a query to non-existent network net_id={}", net_id); response.send(Response::Error { error: QueryError::Unexpected }).unwrap_or_else(|_| { warn!("Unable to send reply for non-existent network net_id={}", net_id); }) } Ok(()) } async fn probe(&mut self, info: ServerInfo, timeout: Duration) -> Result<()> { use std::collections::hash_map::Entry; if !self.networks.get(&info.net_id).map_or(true, |net| net.get_info() == &info) { // If we have a network registered to the provided net_id, but the server info doesn't // match, our API has been used incorrectly. Attempt to recover by deleting the old // network and recreating it according to the probe request. warn!("Probing net_id={} with mismatched server info {:?}", info.net_id, info); self.networks.remove(&info.net_id); } // Can't use or_insert_with because creating a network may fail let net = match self.networks.entry(info.net_id) { Entry::Occupied(network) => network.into_mut(), Entry::Vacant(vacant) => { let key = config::Key { cert_path: info.cert_path.clone(), max_idle_timeout: info.idle_timeout_ms, enable_early_data: info.enable_early_data, }; let config = self.config_cache.get(&key)?; vacant.insert( Network::new(info, config, self.validation.clone(), self.tagger.clone()) .await?, ) } }; net.probe(timeout).await?; Ok(()) } }