use std::{fmt::Display, marker::PhantomData, mem, ptr};
use windows::{
core::{PCWSTR, PWSTR},
Win32::{Globalization::lstrlenW, System::Com},
};
pub struct CoTaskMemPWSTR<'a>(PWSTR, PhantomData<&'a PWSTR>);
pub struct CoTaskMemRef<'a>(PCWSTR, PhantomData<&'a PCWSTR>);
impl<'a> CoTaskMemRef<'a> {
pub fn as_pcwstr(&self) -> &PCWSTR {
&self.0
}
}
impl<'a> From<&'a CoTaskMemPWSTR<'a>> for CoTaskMemRef<'a> {
fn from(value: &'a CoTaskMemPWSTR<'a>) -> Self {
Self(PCWSTR::from_raw(value.0.as_ptr()), PhantomData)
}
}
pub struct CoTaskMemMut<'a>(&'a PWSTR);
impl<'a> CoTaskMemMut<'a> {
pub fn as_pwstr(&mut self) -> &'a PWSTR {
self.0
}
}
impl<'a> From<&'a mut CoTaskMemPWSTR<'a>> for CoTaskMemMut<'a> {
fn from(value: &'a mut CoTaskMemPWSTR<'a>) -> Self {
Self(&value.0)
}
}
impl<'a> CoTaskMemPWSTR<'a> {
pub fn as_mut(&'a mut self) -> CoTaskMemMut<'a> {
From::from(self)
}
pub fn as_ref(&'a self) -> CoTaskMemRef<'a> {
From::from(self)
}
pub fn take(&mut self) -> PWSTR {
let result = self.0;
self.0 = PWSTR::null();
result
}
}
impl<'a> Drop for CoTaskMemPWSTR<'a> {
fn drop(&mut self) {
if !self.0.is_null() {
unsafe {
Com::CoTaskMemFree(Some(self.0.as_ptr() as *mut _ as *const _));
}
}
}
}
impl<'a> Default for CoTaskMemPWSTR<'a> {
fn default() -> Self {
Self(PWSTR::null(), PhantomData)
}
}
impl<'a> From<PWSTR> for CoTaskMemPWSTR<'a> {
fn from(value: PWSTR) -> Self {
Self(value, PhantomData)
}
}
impl<'a> From<&str> for CoTaskMemPWSTR<'a> {
fn from(value: &str) -> Self {
match value {
"" => Default::default(),
value => {
let encoded: Vec<_> = value.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let mut buffer =
Com::CoTaskMemAlloc(encoded.len() * mem::size_of::<u16>()) as *mut u16;
let result = PWSTR::from_raw(buffer);
for char in encoded {
*buffer = char;
buffer = buffer.add(1);
}
Self(result, PhantomData)
}
}
}
}
}
impl<'a> Display for CoTaskMemPWSTR<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let value = string_from_pcwstr(self.as_ref().as_pcwstr());
f.write_str(value.as_str())
}
}
pub fn string_from_pcwstr(source: &PCWSTR) -> String {
if source.0.is_null() {
String::new()
} else {
let len = unsafe { lstrlenW(*source) };
if len > 0 {
unsafe {
let buffer = ptr::slice_from_raw_parts(source.0, len as usize);
String::from_utf16_lossy(&*buffer)
}
} else {
String::new()
}
}
}
pub fn take_pwstr(source: PWSTR) -> String {
CoTaskMemPWSTR::from(source).to_string()
}
pub fn pwstr_from_str(source: &str) -> PWSTR {
CoTaskMemPWSTR::from(source).take()
}