1use std::{
2 ffi::{CStr, CString, NulError, c_char, c_int, c_long, c_void},
3 fmt,
4 path::Path,
5 ptr::{self, NonNull},
6};
7
8use ctranslate2_sys::{
9 CTranslationOptions, CTranslationResult, CTranslator, translation_result_free,
10 translation_result_has_attention, translation_result_has_scores,
11 translation_result_num_hypotheses, translation_result_output_at,
12 translation_result_output_size, translation_result_score, translator_create,
13 translator_destroy,
14};
15
16use crate::{compute_type::ComputeType, device::Device};
17
18pub struct Translator {
19 inner: NonNull<CTranslator>,
20}
21
22pub struct TranslationResult {
23 inner: *mut CTranslationResult,
24}
25
26impl TranslationResult {
27 pub fn score(&self) -> f32 {
28 unsafe { translation_result_score(self.inner) }
29 }
30
31 pub fn has_attention(&self) -> bool {
32 unsafe { translation_result_has_attention(self.inner) }
33 }
34
35 pub fn has_scores(&self) -> bool {
36 unsafe { translation_result_has_scores(self.inner) }
37 }
38
39 pub fn num_hypotheses(&self) -> usize {
40 unsafe { translation_result_num_hypotheses(self.inner) }
41 }
42
43 pub fn output(&self) -> Vec<String> {
44 unsafe {
45 let len = translation_result_output_size(self.inner);
46 let mut out = Vec::with_capacity(len);
47 for idx in 0..len {
48 let ptr = translation_result_output_at(self.inner, idx);
49 out.push(CStr::from_ptr(ptr).to_string_lossy().to_string());
50 }
51 out
52 }
53 }
54}
55
56impl Drop for TranslationResult {
57 fn drop(&mut self) {
58 unsafe {
59 translation_result_free(self.inner);
60 }
61 }
62}
63
64impl Drop for Translator {
65 fn drop(&mut self) {
66 unsafe {
67 translator_destroy(self.inner.as_ptr());
68 }
69 }
70}
71
72#[derive(Debug)]
73pub enum TranslatorError {
74 NulInPath(NulError),
75 CreationFailed,
76}
77
78impl fmt::Display for TranslatorError {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 match self {
81 TranslatorError::NulInPath(err) => {
82 write!(f, "Invalid path (contains null byte): {}", err)
83 }
84 TranslatorError::CreationFailed => write!(f, "Failed to create the translator"),
85 }
86 }
87}
88
89impl std::error::Error for TranslatorError {
91 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
92 match self {
93 TranslatorError::NulInPath(err) => Some(err),
94 TranslatorError::CreationFailed => None,
95 }
96 }
97}
98
99pub struct TranslatorConfig {
100 pub device: Device,
101 pub compute_type: ComputeType,
102 pub device_indices: Vec<i32>,
103 pub tensor_parallel: bool,
104 pub num_threads_per_replica: usize,
105 pub max_queued_batches: i64,
106 pub cpu_core_offset: i32,
107}
108
109impl Default for TranslatorConfig {
110 fn default() -> Self {
111 Self {
112 device: Device::Cpu,
113 compute_type: ComputeType::Default,
114 device_indices: vec![0],
115 tensor_parallel: false,
116 num_threads_per_replica: 0,
117 max_queued_batches: 0,
118 cpu_core_offset: -1,
119 }
120 }
121}
122
123#[repr(i32)]
124#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
125pub enum BatchType {
126 Examples,
127 Tokens,
128}
129
130pub struct TranslationOptions {
131 beam_size: usize,
132 patience: f32,
133 length_penalty: f32,
134 coverage_penalty: f32,
135 repetition_penalty: f32,
136 no_repeat_ngram_size: usize,
137 disable_unk: bool,
138 suppress_sequences: Vec<Vec<String>>,
139 prefix_bias_beta: f32,
140 return_end_token: bool,
141 max_input_length: usize,
142 max_decoding_length: usize,
143 min_decoding_length: usize,
144 sampling_topk: usize,
145 sampling_topp: f32,
146 sampling_temperature: f32,
147 use_vmap: bool,
148 num_hypotheses: usize,
149 return_scores: bool,
150 return_attention: bool,
151 return_logits_vocab: bool,
152 return_alternatives: bool,
153 min_alternative_expansion_prob: f32,
154 replace_unknowns: bool,
155
156 max_batch_size: usize,
157 batch_type: BatchType,
158}
159
160impl Default for TranslationOptions {
161 fn default() -> Self {
162 Self {
163 beam_size: 2,
168 patience: 1.0,
169 length_penalty: 1.0,
170 coverage_penalty: 0.0,
171 repetition_penalty: 1.0,
172 no_repeat_ngram_size: 0,
173 disable_unk: false,
174 suppress_sequences: Default::default(),
175 prefix_bias_beta: 0.0,
176 return_end_token: false,
177 max_input_length: 1024,
178 max_decoding_length: 256,
179 min_decoding_length: 1,
180 sampling_topk: 1,
181 sampling_topp: 1.0,
182 sampling_temperature: 1.0,
183 use_vmap: false,
184 num_hypotheses: 1,
185 return_scores: false,
186 return_attention: false,
187 return_logits_vocab: false,
188 return_alternatives: false,
189 min_alternative_expansion_prob: 0.0,
190 replace_unknowns: false,
191 max_batch_size: 0,
192 batch_type: BatchType::Examples,
193 }
194 }
195}
196
197impl Translator {
198 pub fn new<P: AsRef<Path>>(
199 model_path: P,
200 config: &TranslatorConfig,
201 ) -> Result<Self, TranslatorError> {
202 let c_model = CString::new(model_path.as_ref().to_string_lossy().into_owned())
203 .map_err(TranslatorError::NulInPath)?;
204
205 let (device_indices_ptr, num_device_indices) = (
206 config.device_indices.as_ptr() as *const c_int,
207 config.device_indices.len(),
208 );
209
210 let raw = unsafe {
211 translator_create(
212 c_model.as_ptr(),
213 config.device as c_int,
214 config.compute_type as c_int,
215 device_indices_ptr,
216 num_device_indices,
217 config.tensor_parallel as c_int,
218 config.num_threads_per_replica,
219 config.max_queued_batches as c_long,
220 config.cpu_core_offset as c_int,
221 )
222 };
223
224 let non_null = NonNull::new(raw).ok_or(TranslatorError::CreationFailed)?;
225 Ok(Translator { inner: non_null })
226 }
227
228 pub fn translate_batch(
229 &self,
230 tokens: &[Vec<String>],
231 options: TranslationOptions,
232 ) -> Result<Vec<TranslationResult>, TranslatorError> {
233 let opt = CTranslationOptions {
234 prefix_bias_beta: options.prefix_bias_beta,
235 return_end_token: options.return_end_token,
236 beam_size: options.beam_size,
237 patience: options.patience,
238 length_penalty: options.length_penalty,
239 coverage_penalty: options.coverage_penalty,
240 repetition_penalty: options.repetition_penalty,
241 no_repeat_ngram_size: options.no_repeat_ngram_size,
242 disable_unk: if options.disable_unk { 1 } else { 0 },
243 max_input_length: options.max_input_length,
244 max_decoding_length: options.max_decoding_length,
245 min_decoding_length: options.min_decoding_length,
246 sampling_topk: options.sampling_topk,
247 sampling_topp: options.sampling_topp,
248 sampling_temperature: options.sampling_temperature,
249 use_vmap: if options.use_vmap { 1 } else { 0 },
250 num_hypotheses: options.num_hypotheses,
251 return_scores: if options.return_scores { 1 } else { 0 },
252 return_attention: if options.return_attention { 1 } else { 0 },
253 return_logits_vocab: if options.return_logits_vocab { 1 } else { 0 },
254 return_alternatives: if options.return_alternatives { 1 } else { 0 },
255 min_alternative_expansion_prob: options.min_alternative_expansion_prob,
256 replace_unknowns: if options.replace_unknowns { 1 } else { 0 },
257 };
258 unsafe {
259 let c_sentences: Result<Vec<Vec<CString>>, TranslatorError> = tokens
260 .iter()
261 .map(|sentence| {
262 sentence
263 .iter()
264 .map(|s| {
265 CString::new(s.as_str()).map_err(|e| TranslatorError::NulInPath(e))
266 })
267 .collect()
268 })
269 .collect();
270 let c_sentences = c_sentences?;
271 let c_ptrs: Vec<Vec<*const c_char>> = c_sentences
272 .iter()
273 .map(|sentence| {
274 let mut s: Vec<*const c_char> = sentence.iter().map(|s| s.as_ptr()).collect();
275 s.push(ptr::null());
276 s
277 })
278 .collect();
279 let c_sentences_ptrs: Vec<*const *const c_char> =
280 c_ptrs.iter().map(|s| s.as_ptr()).collect();
281 let num_sentences = c_sentences_ptrs.len();
282
283 let mut out_num_translations: usize = 0;
284
285 let results_ptr = ctranslate2_sys::translator_translate_batch(
286 self.inner.as_ptr(),
287 c_sentences_ptrs.as_ptr() as *mut *mut *const c_char,
288 num_sentences,
289 &opt,
290 options.max_batch_size,
291 options.batch_type as i32,
292 &mut out_num_translations,
293 );
294 let results = take_c_results(results_ptr, out_num_translations)
295 .into_iter()
296 .map(|v| TranslationResult { inner: v })
297 .collect::<Vec<_>>();
298
299 Ok(results)
300 }
301 }
302}
303
304fn take_c_results<T>(c_results: *mut *mut T, n: usize) -> Vec<*mut T> {
305 unsafe {
306 let owned = std::slice::from_raw_parts(c_results.clone(), n).to_vec();
307 ctranslate2_sys::free_pointer_array(c_results as *mut *mut c_void);
308 owned
309 }
310}