ctranslate2/
translator2.rs

1use std::path::Path;
2
3use crate::{
4    Tokenizer, TranslationOptions, Translator, TranslatorConfig, translator::TranslatorError,
5};
6
7pub struct Translator2<T: Tokenizer> {
8    t: Translator,
9    tokenizer: T,
10}
11
12#[inline]
13pub(crate) fn encode_all<T: Tokenizer, U: AsRef<str>>(
14    tokenizer: &T,
15    sources: &[U],
16) -> anyhow::Result<Vec<Vec<String>>> {
17    sources
18        .iter()
19        .map(|s| tokenizer.encode(s.as_ref()))
20        .collect()
21}
22
23impl<T: Tokenizer> Translator2<T> {
24    pub fn new<P: AsRef<Path>>(
25        model_path: P,
26        config: &TranslatorConfig,
27        tokenizer: T,
28    ) -> Result<Self, TranslatorError> {
29        Ok(Translator2 {
30            t: Translator::new(model_path, config)?,
31            tokenizer,
32        })
33    }
34
35    pub fn translate_batch(
36        &self,
37        sources: &[String],
38        options: TranslationOptions,
39    ) -> anyhow::Result<Vec<(String, f32)>> {
40        let out = self
41            .t
42            .translate_batch(&encode_all(&self.tokenizer, sources)?, options)?;
43        let mut res = Vec::new();
44        for r in out.into_iter() {
45            let score = r.score();
46            res.push((
47                self.tokenizer
48                    .decode(r.output())
49                    .map_err(|err| anyhow::anyhow!("failed to decode: {err}"))?,
50                score,
51            ));
52        }
53        Ok(res)
54    }
55}