Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support more attributes from the Encoding structure #5

Merged
merged 8 commits into from
Nov 15, 2023
100 changes: 79 additions & 21 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ use tokenizers::tokenizer::Tokenizer;
#[repr(C)]
pub struct Buffer {
ids: *mut u32,
type_ids: *mut u32,
special_tokens_mask: *mut u32,
attention_mask: *mut u32,
tokens: *mut *mut libc::c_char,
len: usize,
}
Expand Down Expand Up @@ -50,33 +53,70 @@ pub extern "C" fn from_file(config: *const libc::c_char) -> *mut libc::c_void {
}
}

#[repr(C)]
pub struct EncodeOptions {
add_special_tokens: bool,

return_type_ids: bool,
return_tokens: bool,
return_special_tokens_mask: bool,
return_attention_mask: bool,
}

#[no_mangle]
pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, add_special_tokens: bool) -> Buffer {
pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, options: &EncodeOptions) -> Buffer {
let tokenizer: &Tokenizer;
unsafe {
tokenizer = ptr.cast::<Tokenizer>().as_ref().expect("failed to cast tokenizer");
}
let message_cstr = unsafe { CStr::from_ptr(message) };
let message = message_cstr.to_str().unwrap();

let encoding = tokenizer.encode(message, add_special_tokens).expect("failed to encode input");
// we always return token ids unlike other attributes in order to reduce the performance cost of manipulating the other attributes
// if it is not explicitly requested by the client.
let encoding = tokenizer.encode(message, options.add_special_tokens).expect("failed to encode input");
let mut vec_ids = encoding.get_ids().to_vec();
let mut vec_tokens = encoding.get_tokens()
.to_vec().into_iter()
.map(|s| std::ffi::CString::new(s).unwrap().into_raw())
.collect::<Vec<_>>();

vec_ids.shrink_to_fit();
vec_tokens.shrink_to_fit();

let ids = vec_ids.as_mut_ptr();
let tokens = vec_tokens.as_mut_ptr();
let len = vec_ids.len();

std::mem::forget(vec_ids);
std::mem::forget(vec_tokens);

Buffer { ids, tokens, len }
let mut type_ids: *mut u32 = ptr::null_mut();
if options.return_type_ids {
let mut vec_type_ids = encoding.get_type_ids().to_vec();
vec_type_ids.shrink_to_fit();
type_ids = vec_type_ids.as_mut_ptr();
std::mem::forget(vec_type_ids);
}

let mut tokens: *mut *mut i8 = ptr::null_mut();
if options.return_tokens {
let mut vec_tokens = encoding.get_tokens()
.to_vec().into_iter()
.map(|s| std::ffi::CString::new(s).unwrap().into_raw())
.collect::<Vec<_>>();
vec_tokens.shrink_to_fit();
tokens = vec_tokens.as_mut_ptr();
std::mem::forget(vec_tokens);
}

let mut special_tokens_mask: *mut u32 = ptr::null_mut();
if options.return_special_tokens_mask {
let mut vec_special_tokens_mask = encoding.get_special_tokens_mask().to_vec();
vec_special_tokens_mask.shrink_to_fit();
special_tokens_mask = vec_special_tokens_mask.as_mut_ptr();
std::mem::forget(vec_special_tokens_mask);
}

let mut attention_mask: *mut u32 = ptr::null_mut();
if options.return_attention_mask {
let mut vec_attention_mask = encoding.get_attention_mask().to_vec();
vec_attention_mask.shrink_to_fit();
attention_mask = vec_attention_mask.as_mut_ptr();
std::mem::forget(vec_attention_mask);
}

Buffer { ids, type_ids, special_tokens_mask, attention_mask, tokens, len }
}

#[no_mangle]
Expand Down Expand Up @@ -111,15 +151,33 @@ pub extern "C" fn free_tokenizer(ptr: *mut ::libc::c_void) {

#[no_mangle]
pub extern "C" fn free_buffer(buf: Buffer) {
if buf.ids.is_null() {
return;
if !buf.ids.is_null() {
unsafe {
Vec::from_raw_parts(buf.ids, buf.len, buf.len);
}
}
unsafe {
Vec::from_raw_parts(buf.ids, buf.len, buf.len);
let strings = Vec::from_raw_parts(buf.tokens, buf.len, buf.len);
for s in strings {
drop(std::ffi::CString::from_raw(s));
}
if !buf.type_ids.is_null() {
unsafe {
Vec::from_raw_parts(buf.type_ids, buf.len, buf.len);
}
}
if !buf.special_tokens_mask.is_null() {
unsafe {
Vec::from_raw_parts(buf.special_tokens_mask, buf.len, buf.len);
}
}
if !buf.attention_mask.is_null() {
unsafe {
Vec::from_raw_parts(buf.attention_mask, buf.len, buf.len);
}
}
if !buf.tokens.is_null() {
unsafe {
let strings = Vec::from_raw_parts(buf.tokens, buf.len, buf.len);
for s in strings {
drop(std::ffi::CString::from_raw(s));
}
}
}
}

Expand Down
Loading