ctranslate2/
translator.rs

1use std::{
2    ffi::{CStr, CString, NulError, c_char, c_int, c_long, c_void},
3    fmt,
4    path::Path,
5    ptr::{self, NonNull},
6};
7
8use ctranslate2_sys::{
9    CTranslationOptions, CTranslationResult, CTranslator, translation_result_free,
10    translation_result_has_attention, translation_result_has_scores,
11    translation_result_num_hypotheses, translation_result_output_at,
12    translation_result_output_size, translation_result_score, translator_create,
13    translator_destroy,
14};
15
16use crate::{compute_type::ComputeType, device::Device};
17
18pub struct Translator {
19    inner: NonNull<CTranslator>,
20}
21
22pub struct TranslationResult {
23    inner: *mut CTranslationResult,
24}
25
26impl TranslationResult {
27    pub fn score(&self) -> f32 {
28        unsafe { translation_result_score(self.inner) }
29    }
30
31    pub fn has_attention(&self) -> bool {
32        unsafe { translation_result_has_attention(self.inner) }
33    }
34
35    pub fn has_scores(&self) -> bool {
36        unsafe { translation_result_has_scores(self.inner) }
37    }
38
39    pub fn num_hypotheses(&self) -> usize {
40        unsafe { translation_result_num_hypotheses(self.inner) }
41    }
42
43    pub fn output(&self) -> Vec<String> {
44        unsafe {
45            let len = translation_result_output_size(self.inner);
46            let mut out = Vec::with_capacity(len);
47            for idx in 0..len {
48                let ptr = translation_result_output_at(self.inner, idx);
49                out.push(CStr::from_ptr(ptr).to_string_lossy().to_string());
50            }
51            out
52        }
53    }
54}
55
56impl Drop for TranslationResult {
57    fn drop(&mut self) {
58        unsafe {
59            translation_result_free(self.inner);
60        }
61    }
62}
63
64impl Drop for Translator {
65    fn drop(&mut self) {
66        unsafe {
67            translator_destroy(self.inner.as_ptr());
68        }
69    }
70}
71
72#[derive(Debug)]
73pub enum TranslatorError {
74    NulInPath(NulError),
75    CreationFailed,
76}
77
78impl fmt::Display for TranslatorError {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        match self {
81            TranslatorError::NulInPath(err) => {
82                write!(f, "Invalid path (contains null byte): {}", err)
83            }
84            TranslatorError::CreationFailed => write!(f, "Failed to create the translator"),
85        }
86    }
87}
88
89// Implement std::error::Error for compatibility with `?` and other error handling
90impl std::error::Error for TranslatorError {
91    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
92        match self {
93            TranslatorError::NulInPath(err) => Some(err),
94            TranslatorError::CreationFailed => None,
95        }
96    }
97}
98
99pub struct TranslatorConfig {
100    pub device: Device,
101    pub compute_type: ComputeType,
102    pub device_indices: Vec<i32>,
103    pub tensor_parallel: bool,
104    pub num_threads_per_replica: usize,
105    pub max_queued_batches: i64,
106    pub cpu_core_offset: i32,
107}
108
109impl Default for TranslatorConfig {
110    fn default() -> Self {
111        Self {
112            device: Device::Cpu,
113            compute_type: ComputeType::Default,
114            device_indices: vec![0],
115            tensor_parallel: false,
116            num_threads_per_replica: 0,
117            max_queued_batches: 0,
118            cpu_core_offset: -1,
119        }
120    }
121}
122
123#[repr(i32)]
124#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
125pub enum BatchType {
126    Examples,
127    Tokens,
128}
129
130pub struct TranslationOptions {
131    beam_size: usize,
132    patience: f32,
133    length_penalty: f32,
134    coverage_penalty: f32,
135    repetition_penalty: f32,
136    no_repeat_ngram_size: usize,
137    disable_unk: bool,
138    suppress_sequences: Vec<Vec<String>>,
139    prefix_bias_beta: f32,
140    return_end_token: bool,
141    max_input_length: usize,
142    max_decoding_length: usize,
143    min_decoding_length: usize,
144    sampling_topk: usize,
145    sampling_topp: f32,
146    sampling_temperature: f32,
147    use_vmap: bool,
148    num_hypotheses: usize,
149    return_scores: bool,
150    return_attention: bool,
151    return_logits_vocab: bool,
152    return_alternatives: bool,
153    min_alternative_expansion_prob: f32,
154    replace_unknowns: bool,
155
156    max_batch_size: usize,
157    batch_type: BatchType,
158}
159
160impl Default for TranslationOptions {
161    fn default() -> Self {
162        Self {
163            // TODO:
164            // std::vector< std::vector< std::string > > 	suppress_sequences
165            // std::variant< std::string, std::vector< std::string >, std::vector< size_t > > 	end_token
166            // std::function< bool(GenerationStepResult)> 	callback = nullptr
167            beam_size: 2,
168            patience: 1.0,
169            length_penalty: 1.0,
170            coverage_penalty: 0.0,
171            repetition_penalty: 1.0,
172            no_repeat_ngram_size: 0,
173            disable_unk: false,
174            suppress_sequences: Default::default(),
175            prefix_bias_beta: 0.0,
176            return_end_token: false,
177            max_input_length: 1024,
178            max_decoding_length: 256,
179            min_decoding_length: 1,
180            sampling_topk: 1,
181            sampling_topp: 1.0,
182            sampling_temperature: 1.0,
183            use_vmap: false,
184            num_hypotheses: 1,
185            return_scores: false,
186            return_attention: false,
187            return_logits_vocab: false,
188            return_alternatives: false,
189            min_alternative_expansion_prob: 0.0,
190            replace_unknowns: false,
191            max_batch_size: 0,
192            batch_type: BatchType::Examples,
193        }
194    }
195}
196
197impl Translator {
198    pub fn new<P: AsRef<Path>>(
199        model_path: P,
200        config: &TranslatorConfig,
201    ) -> Result<Self, TranslatorError> {
202        let c_model = CString::new(model_path.as_ref().to_string_lossy().into_owned())
203            .map_err(TranslatorError::NulInPath)?;
204
205        let (device_indices_ptr, num_device_indices) = (
206            config.device_indices.as_ptr() as *const c_int,
207            config.device_indices.len(),
208        );
209
210        let raw = unsafe {
211            translator_create(
212                c_model.as_ptr(),
213                config.device as c_int,
214                config.compute_type as c_int,
215                device_indices_ptr,
216                num_device_indices,
217                config.tensor_parallel as c_int,
218                config.num_threads_per_replica,
219                config.max_queued_batches as c_long,
220                config.cpu_core_offset as c_int,
221            )
222        };
223
224        let non_null = NonNull::new(raw).ok_or(TranslatorError::CreationFailed)?;
225        Ok(Translator { inner: non_null })
226    }
227
228    pub fn translate_batch(
229        &self,
230        tokens: &[Vec<String>],
231        options: TranslationOptions,
232    ) -> Result<Vec<TranslationResult>, TranslatorError> {
233        let opt = CTranslationOptions {
234            prefix_bias_beta: options.prefix_bias_beta,
235            return_end_token: options.return_end_token,
236            beam_size: options.beam_size,
237            patience: options.patience,
238            length_penalty: options.length_penalty,
239            coverage_penalty: options.coverage_penalty,
240            repetition_penalty: options.repetition_penalty,
241            no_repeat_ngram_size: options.no_repeat_ngram_size,
242            disable_unk: if options.disable_unk { 1 } else { 0 },
243            max_input_length: options.max_input_length,
244            max_decoding_length: options.max_decoding_length,
245            min_decoding_length: options.min_decoding_length,
246            sampling_topk: options.sampling_topk,
247            sampling_topp: options.sampling_topp,
248            sampling_temperature: options.sampling_temperature,
249            use_vmap: if options.use_vmap { 1 } else { 0 },
250            num_hypotheses: options.num_hypotheses,
251            return_scores: if options.return_scores { 1 } else { 0 },
252            return_attention: if options.return_attention { 1 } else { 0 },
253            return_logits_vocab: if options.return_logits_vocab { 1 } else { 0 },
254            return_alternatives: if options.return_alternatives { 1 } else { 0 },
255            min_alternative_expansion_prob: options.min_alternative_expansion_prob,
256            replace_unknowns: if options.replace_unknowns { 1 } else { 0 },
257        };
258        unsafe {
259            let c_sentences: Result<Vec<Vec<CString>>, TranslatorError> = tokens
260                .iter()
261                .map(|sentence| {
262                    sentence
263                        .iter()
264                        .map(|s| {
265                            CString::new(s.as_str()).map_err(|e| TranslatorError::NulInPath(e))
266                        })
267                        .collect()
268                })
269                .collect();
270            let c_sentences = c_sentences?;
271            let c_ptrs: Vec<Vec<*const c_char>> = c_sentences
272                .iter()
273                .map(|sentence| {
274                    let mut s: Vec<*const c_char> = sentence.iter().map(|s| s.as_ptr()).collect();
275                    s.push(ptr::null());
276                    s
277                })
278                .collect();
279            let c_sentences_ptrs: Vec<*const *const c_char> =
280                c_ptrs.iter().map(|s| s.as_ptr()).collect();
281            let num_sentences = c_sentences_ptrs.len();
282
283            let mut out_num_translations: usize = 0;
284
285            let results_ptr = ctranslate2_sys::translator_translate_batch(
286                self.inner.as_ptr(),
287                c_sentences_ptrs.as_ptr() as *mut *mut *const c_char,
288                num_sentences,
289                &opt,
290                options.max_batch_size,
291                options.batch_type as i32,
292                &mut out_num_translations,
293            );
294            let results = take_c_results(results_ptr, out_num_translations)
295                .into_iter()
296                .map(|v| TranslationResult { inner: v })
297                .collect::<Vec<_>>();
298
299            Ok(results)
300        }
301    }
302}
303
304fn take_c_results<T>(c_results: *mut *mut T, n: usize) -> Vec<*mut T> {
305    unsafe {
306        let owned = std::slice::from_raw_parts(c_results.clone(), n).to_vec();
307        ctranslate2_sys::free_pointer_array(c_results as *mut *mut c_void);
308        owned
309    }
310}