ctranslate2/
translator2.rs1use std::path::Path;
2
3use crate::{
4 Tokenizer, TranslationOptions, Translator, TranslatorConfig, translator::TranslatorError,
5};
6
7pub struct Translator2<T: Tokenizer> {
8 t: Translator,
9 tokenizer: T,
10}
11
12#[inline]
13pub(crate) fn encode_all<T: Tokenizer, U: AsRef<str>>(
14 tokenizer: &T,
15 sources: &[U],
16) -> anyhow::Result<Vec<Vec<String>>> {
17 sources
18 .iter()
19 .map(|s| tokenizer.encode(s.as_ref()))
20 .collect()
21}
22
23impl<T: Tokenizer> Translator2<T> {
24 pub fn new<P: AsRef<Path>>(
25 model_path: P,
26 config: &TranslatorConfig,
27 tokenizer: T,
28 ) -> Result<Self, TranslatorError> {
29 Ok(Translator2 {
30 t: Translator::new(model_path, config)?,
31 tokenizer,
32 })
33 }
34
35 pub fn translate_batch(
36 &self,
37 sources: &[String],
38 options: TranslationOptions,
39 ) -> anyhow::Result<Vec<(String, f32)>> {
40 let out = self
41 .t
42 .translate_batch(&encode_all(&self.tokenizer, sources)?, options)?;
43 let mut res = Vec::new();
44 for r in out.into_iter() {
45 let score = r.score();
46 res.push((
47 self.tokenizer
48 .decode(r.output())
49 .map_err(|err| anyhow::anyhow!("failed to decode: {err}"))?,
50 score,
51 ));
52 }
53 Ok(res)
54 }
55}