Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions pythaitts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
class TTS:
def __init__(self, pretrained="lunarlist_onnx", mode="last_checkpoint", version="1.0", device:str="cpu") -> None:
"""
:param str pretrained: TTS pretrained (lunarlist_onnx, khanomtan, lunarlist, vachana)
:param str pretrained: TTS pretrained (lunarlist_onnx, khanomtan, lunarlist, vachana, archa)
:param str mode: pretrained mode (lunarlist_onnx and vachana don't support)
:param str version: model version (default is 1.0 or 1.1)
:param str device: device for running model. (lunarlist_onnx and vachana support CPU only.)
:param str device: device for running model. (lunarlist_onnx and vachana support CPU only. archa supports cpu and cuda.)

**Options for mode**
* *last_checkpoint* (default) - last checkpoint of model
Expand All @@ -32,6 +32,10 @@ def __init__(self, pretrained="lunarlist_onnx", mode="last_checkpoint", version=
For vachana tts model, \
You can see more about vachana tts at `https://github.com/VYNCX/VachanaTTS2 <https://github.com/VYNCX/VachanaTTS2>`_

For archa tts model, you must install the required packages before use: \
pip install torch transformers snac soundfile noisereduce scipy numpy. \
You can see more about archa tts at `https://github.com/YangNobody12/Archa-TTS-0.5B-th <https://github.com/YangNobody12/Archa-TTS-0.5B-th>`_


"""
self.pretrained = pretrained
Expand All @@ -55,6 +59,9 @@ def load_pretrained(self,version):
elif self.pretrained == "vachana":
from pythaitts.pretrained.vachana_tts import VachanaTTS
self.model = VachanaTTS()
elif self.pretrained == "archa":
from pythaitts.pretrained.archa_tts import ArchaTTS
self.model = ArchaTTS(device=self.device)
else:
raise NotImplementedError(
"PyThaiTTS doesn't support %s pretrained." % self.pretrained
Expand All @@ -80,6 +87,8 @@ def tts(self, text: str, speaker_idx: str = "Linda", language_idx: str = "th-th"
return self.model(text=text,return_type=return_type,filename=filename)
elif self.pretrained == "vachana":
return self.model(text=text,speaker_idx=speaker_idx,return_type=return_type,filename=filename)
elif self.pretrained == "archa":
return self.model(text=text,return_type=return_type,filename=filename)
return self.model(
text=text,
speaker_idx=speaker_idx,
Expand Down
197 changes: 197 additions & 0 deletions pythaitts/pretrained/archa_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# -*- coding: utf-8 -*-
"""
Archa TTS model (YangNobody12/Archa-TTS-0.5B-th)

Archa TTS is a Thai text-to-speech model built on Qwen2.5-0.5B with LoRA fine-tuning
and SNAC 24kHz audio codec.

See more: https://github.com/YangNobody12/Archa-TTS-0.5B-th
HuggingFace model: https://huggingface.co/Pakorn2112/Archa-TTS-0.5B-th
"""
import tempfile
import os
import numpy as np

BASE_MODEL_PATH = "Pakorn2112/Archa-TTS-0.5B-th"
SNAC_MODEL_PATH = "hubertsiuzdak/snac_24khz"
SNAC_SR = 24000
TOKENISER_LENGTH = 151665
VOCAB_SIZE = 180500

# Special token IDs
END_OF_TEXT = TOKENISER_LENGTH + 2
START_OF_SPEECH = TOKENISER_LENGTH + 3
END_OF_SPEECH = TOKENISER_LENGTH + 4
START_OF_HUMAN = TOKENISER_LENGTH + 5
END_OF_HUMAN = TOKENISER_LENGTH + 6
START_OF_AI = TOKENISER_LENGTH + 7
AUDIO_TOKENS_START = TOKENISER_LENGTH + 10 # 151675

# Token generation defaults
ESTIMATED_TOKENS_PER_CHAR = 30
MIN_MAX_NEW_TOKENS = 16384


class ArchaTTS:
def __init__(self, device: str = None) -> None:
"""
Initialize ArchaTTS model.
The model will be automatically downloaded from HuggingFace on first use.

:param str device: Device to run the model on ('cpu' or 'cuda').
Defaults to 'cuda' if available, otherwise 'cpu'.
"""
try:
import torch
except ImportError:
raise ImportError(
"torch is not installed. Please install it with: pip install torch"
)
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError(
"transformers is not installed. Please install it with: pip install transformers"
)
try:
from snac import SNAC
except ImportError:
raise ImportError(
"snac is not installed. Please install it with: pip install snac"
)

if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device

torch_dtype = (
torch.bfloat16
if self.device == "cuda" and torch.cuda.is_bf16_supported()
else torch.float16
if self.device == "cuda"
else torch.float32
)

self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
self.model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_PATH, torch_dtype=torch_dtype, device_map=self.device
)
self.model.resize_token_embeddings(VOCAB_SIZE)
self.snac_model = SNAC.from_pretrained(SNAC_MODEL_PATH).eval().to(self.device)

def _decode_tokens(self, token_list):
"""Decode a list of audio tokens into a waveform using SNAC."""
valid_len = (len(token_list) // 7) * 7
token_list = token_list[:valid_len]
if valid_len < 7:
return np.array([], dtype=np.float32)

import torch

l1, l2, l3 = [], [], []
for i in range(valid_len // 7):
b = 7 * i
codes = [t - AUDIO_TOKENS_START for t in token_list[b : b + 7]]
l1.append(codes[0])
l2.append(codes[1] - 4096)
l3.append(codes[2] - 2 * 4096)
l3.append(codes[3] - 3 * 4096)
l2.append(codes[4] - 4 * 4096)
l3.append(codes[5] - 5 * 4096)
l3.append(codes[6] - 6 * 4096)

snac_codes = [
torch.tensor(l1, dtype=torch.long).unsqueeze(0).to(self.device),
torch.tensor(l2, dtype=torch.long).unsqueeze(0).to(self.device),
torch.tensor(l3, dtype=torch.long).unsqueeze(0).to(self.device),
]
with torch.no_grad():
audio = self.snac_model.decode(snac_codes)
return audio.squeeze().cpu().numpy()

def _generate_audio_tokens(self, text: str) -> list:
"""Generate audio tokens from text."""
import torch

text_ids = self.tokenizer.encode(text, add_special_tokens=True)
text_ids.append(END_OF_TEXT)
prompt_ids = (
[START_OF_HUMAN]
+ text_ids
+ [END_OF_HUMAN, START_OF_AI, START_OF_SPEECH]
)
input_ids = torch.tensor([prompt_ids]).to(self.device)

estimated_tokens = len(text) * ESTIMATED_TOKENS_PER_CHAR
max_tokens = max(MIN_MAX_NEW_TOKENS, estimated_tokens)

with torch.no_grad():
output_ids = self.model.generate(
input_ids=input_ids,
max_new_tokens=max_tokens,
use_cache=True,
do_sample=True,
temperature=0.8,
top_p=0.9,
repetition_penalty=1.1,
eos_token_id=END_OF_SPEECH,
pad_token_id=self.tokenizer.eos_token_id,
)

generated = output_ids[0][len(prompt_ids) :].tolist()
audio_tokens = [t for t in generated if t >= AUDIO_TOKENS_START]
return audio_tokens

def _denoise(self, audio: np.ndarray) -> np.ndarray:
"""Apply noise reduction to audio."""
try:
import noisereduce as nr

return nr.reduce_noise(y=audio, sr=SNAC_SR, prop_decrease=0.8).astype(
np.float32
)
except ImportError:
return audio.astype(np.float32)

def __call__(
self,
text: str,
return_type: str = "file",
filename: str = None,
**kwargs,
):
"""
Generate speech from text using Archa TTS.

:param str text: Input Thai text to synthesize
:param str return_type: Return type ("file" or "waveform"). Default is "file".
:param str filename: Output filename for the generated audio (WAV). Used when
return_type is "file". A temporary file is created if None.
:return: File path if return_type is "file", otherwise numpy waveform array.
"""
try:
import soundfile as sf
except ImportError:
raise ImportError(
"soundfile is not installed. Please install it with: pip install soundfile"
)

audio_tokens = self._generate_audio_tokens(text)
audio = self._decode_tokens(audio_tokens)

if len(audio) == 0:
raise RuntimeError("Archa TTS failed to generate audio for the given text.")

audio = self._denoise(audio)

if return_type == "waveform":
return audio

# File output
if filename is None:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
filename = fp.name

sf.write(filename, audio, SNAC_SR)
return filename
124 changes: 124 additions & 0 deletions tests/test_archa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# -*- coding: utf-8 -*-
"""
Unit tests for ArchaTTS integration
"""
import unittest
from unittest.mock import Mock, patch, MagicMock
import numpy as np
from pythaitts import TTS


class TestArchaIntegration(unittest.TestCase):
"""Test ArchaTTS integration"""

@patch('pythaitts.pretrained.archa_tts.ArchaTTS')
def test_archa_model_initialization(self, mock_archa):
"""Test that ArchaTTS model can be initialized"""
tts = TTS(pretrained="archa")
self.assertIsNotNone(tts.model)
self.assertEqual(tts.pretrained, "archa")

@patch('pythaitts.pretrained.archa_tts.ArchaTTS')
def test_archa_tts_call(self, mock_archa_class):
"""Test calling tts method with archa model"""
mock_instance = Mock()
mock_instance.return_value = "/tmp/output.wav"
mock_archa_class.return_value = mock_instance

tts = TTS(pretrained="archa")
result = tts.tts("สวัสดีครับ", filename="/tmp/test.wav")

mock_instance.assert_called_once()
call_args = mock_instance.call_args
self.assertEqual(call_args.kwargs['text'], "สวัสดีครับ")
self.assertEqual(call_args.kwargs['filename'], "/tmp/test.wav")
self.assertEqual(call_args.kwargs['return_type'], "file")

@patch('pythaitts.pretrained.archa_tts.ArchaTTS')
def test_archa_with_preprocessing(self, mock_archa_class):
"""Test that preprocessing works with archa model"""
mock_instance = Mock()
mock_instance.return_value = "/tmp/output.wav"
mock_archa_class.return_value = mock_instance

tts = TTS(pretrained="archa")
tts.tts("มี 5 คนๆ", preprocess=True)

mock_instance.assert_called_once()
call_args = mock_instance.call_args
processed_text = call_args.kwargs['text']

# Text should have numbers converted and ๆ expanded
self.assertNotIn("5", processed_text)
self.assertNotIn("ๆ", processed_text)
self.assertIn("ห้า", processed_text)
self.assertIn("คนคน", processed_text)

@patch('pythaitts.pretrained.archa_tts.ArchaTTS')
def test_archa_waveform_return(self, mock_archa_class):
"""Test waveform return type for archa model"""
mock_instance = Mock()
mock_waveform = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)
mock_instance.return_value = mock_waveform
mock_archa_class.return_value = mock_instance

tts = TTS(pretrained="archa")
result = tts.tts("สวัสดี", return_type="waveform")

mock_instance.assert_called_once()
call_args = mock_instance.call_args
self.assertEqual(call_args.kwargs['return_type'], "waveform")

@patch('pythaitts.pretrained.archa_tts.ArchaTTS')
def test_archa_no_filename_returns_temp_file(self, mock_archa_class):
"""Test that archa model returns a temp file path when filename is None"""
mock_instance = Mock()
mock_instance.return_value = "/tmp/tmpXXXXXX.wav"
mock_archa_class.return_value = mock_instance

tts = TTS(pretrained="archa")
result = tts.tts("สวัสดี")

mock_instance.assert_called_once()
call_args = mock_instance.call_args
self.assertIsNone(call_args.kwargs['filename'])


class TestArchaTTSUnit(unittest.TestCase):
"""Unit tests for ArchaTTS class methods"""

def _make_archa(self):
"""Create an ArchaTTS instance with mocked dependencies (no real torch/snac needed)."""
from pythaitts.pretrained.archa_tts import ArchaTTS

archa = ArchaTTS.__new__(ArchaTTS)
archa.device = "cpu"
archa.tokenizer = MagicMock()
archa.model = MagicMock()
archa.snac_model = MagicMock()
return archa

def test_decode_tokens_empty(self):
"""Test _decode_tokens returns empty array for short token list."""
from pythaitts.pretrained.archa_tts import ArchaTTS, AUDIO_TOKENS_START

archa = self._make_archa()
# fewer than 7 tokens → empty result
result = archa._decode_tokens([AUDIO_TOKENS_START] * 3)
self.assertEqual(len(result), 0)

def test_denoise_fallback_without_noisereduce(self):
"""Test that _denoise returns audio unchanged when noisereduce is not installed."""
from pythaitts.pretrained.archa_tts import ArchaTTS

archa = self._make_archa()
audio = np.array([0.1, 0.2, 0.3], dtype=np.float32)

with patch.dict('sys.modules', {'noisereduce': None}):
result = archa._denoise(audio)

np.testing.assert_array_almost_equal(result, audio)


if __name__ == '__main__':
unittest.main()