1
1
#!/usr/bin/env python3
2
2
"""
3
- Extra gRPC server for Kokoro models.
3
+ This is an extra gRPC server of LocalAI for Kokoro TTS
4
4
"""
5
5
from concurrent import futures
6
-
6
+ import time
7
7
import argparse
8
8
import signal
9
9
import sys
10
10
import os
11
- import time
12
11
import backend_pb2
13
12
import backend_pb2_grpc
13
+
14
+ import torch
15
+ from kokoro import KPipeline
14
16
import soundfile as sf
17
+
15
18
import grpc
16
19
17
- from models import build_model
18
- from kokoro import generate
19
- import torch
20
20
21
- SAMPLE_RATE = 22050
22
21
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
23
22
24
23
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
25
24
MAX_WORKERS = int (os .environ .get ('PYTHON_GRPC_MAX_WORKERS' , '1' ))
25
+ KOKORO_LANG_CODE = os .environ .get ('KOKORO_LANG_CODE' , 'a' )
26
26
27
27
# Implement the BackendServicer class with the service methods
28
28
class BackendServicer (backend_pb2_grpc .BackendServicer ):
29
29
"""
30
- A gRPC servicer for the backend service.
31
-
32
- This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
30
+ BackendServicer is the class that implements the gRPC service
33
31
"""
34
32
def Health (self , request , context ):
35
- """
36
- A gRPC method that returns the health status of the backend service.
37
-
38
- Args:
39
- request: A HealthRequest object that contains the request parameters.
40
- context: A grpc.ServicerContext object that provides information about the RPC.
41
-
42
- Returns:
43
- A Reply object that contains the health status of the backend service.
44
- """
45
33
return backend_pb2 .Reply (message = bytes ("OK" , 'utf-8' ))
46
-
34
+
47
35
def LoadModel (self , request , context ):
48
- """
49
- A gRPC method that loads a model into memory.
36
+ # Get device
37
+ if torch .cuda .is_available ():
38
+ print ("CUDA is available" , file = sys .stderr )
39
+ device = "cuda"
40
+ else :
41
+ print ("CUDA is not available" , file = sys .stderr )
42
+ device = "cpu"
50
43
51
- Args:
52
- request: A LoadModelRequest object that contains the request parameters.
53
- context: A grpc.ServicerContext object that provides information about the RPC.
44
+ if not torch .cuda .is_available () and request .CUDA :
45
+ return backend_pb2 .Result (success = False , message = "CUDA is not available" )
54
46
55
- Returns:
56
- A Result object that contains the result of the LoadModel operation.
57
- """
58
- model_name = request .Model
59
47
try :
60
- device = "cuda:0" if torch .cuda .is_available () else "cpu"
61
- self .MODEL = build_model (request .ModelFile , device )
48
+ print ("Preparing Kokoro TTS pipeline, please wait" , file = sys .stderr )
49
+ # empty dict
50
+ self .options = {}
62
51
options = request .Options
63
- # Find the voice from the options, options are a list of strings in this form optname:optvalue:
64
- VOICE_NAME = None
52
+ # The options are a list of strings in this form optname:optvalue
53
+ # We are storing all the options in a dict so we can use it later when
54
+ # generating the images
65
55
for opt in options :
66
- if opt .startswith ("voice:" ):
67
- VOICE_NAME = opt .split (":" )[1 ]
68
- break
69
- if VOICE_NAME is None :
70
- return backend_pb2 .Result (success = False , message = f"No voice specified in options" )
71
- MODELPATH = request .ModelPath
72
- # If voice name contains a plus, split it and load the two models and combine them
73
- if "+" in VOICE_NAME :
74
- voice1 , voice2 = VOICE_NAME .split ("+" )
75
- voice1 = torch .load (f'{ MODELPATH } /{ voice1 } .pt' , weights_only = True ).to (device )
76
- voice2 = torch .load (f'{ MODELPATH } /{ voice2 } .pt' , weights_only = True ).to (device )
77
- self .VOICEPACK = torch .mean (torch .stack ([voice1 , voice2 ]), dim = 0 )
78
- else :
79
- self .VOICEPACK = torch .load (f'{ MODELPATH } /{ VOICE_NAME } .pt' , weights_only = True ).to (device )
80
-
81
- self .VOICE_NAME = VOICE_NAME
82
-
83
- print (f'Loaded voice: { VOICE_NAME } ' )
56
+ if ":" not in opt :
57
+ continue
58
+ key , value = opt .split (":" )
59
+ self .options [key ] = value
60
+
61
+ # Initialize Kokoro pipeline with language code
62
+ lang_code = self .options .get ("lang_code" , KOKORO_LANG_CODE )
63
+ self .pipeline = KPipeline (lang_code = lang_code )
64
+ print (f"Kokoro TTS pipeline loaded with language code: { lang_code } " , file = sys .stderr )
84
65
except Exception as err :
85
66
return backend_pb2 .Result (success = False , message = f"Unexpected { err = } , { type (err )= } " )
86
-
87
- return backend_pb2 .Result (message = "Model loaded successfully" , success = True )
67
+
68
+ return backend_pb2 .Result (message = "Kokoro TTS pipeline loaded successfully" , success = True )
88
69
89
70
def TTS (self , request , context ):
90
- model_name = request .model
91
- if model_name == "" :
92
- return backend_pb2 .Result (success = False , message = "request.model is required" )
93
71
try :
94
- audio , out_ps = generate (self .MODEL , request .text , self .VOICEPACK , lang = self .VOICE_NAME )
95
- print (out_ps )
96
- sf .write (request .dst , audio , SAMPLE_RATE )
72
+ # Get voice from request, default to 'af_heart' if not specified
73
+ voice = request .voice if request .voice else 'af_heart'
74
+
75
+ # Generate audio using Kokoro pipeline
76
+ generator = self .pipeline (request .text , voice = voice )
77
+
78
+ # Get the first (and typically only) audio segment
79
+ for i , (gs , ps , audio ) in enumerate (generator ):
80
+ # Save audio to the destination file
81
+ sf .write (request .dst , audio , 24000 )
82
+ print (f"Generated audio segment { i } : gs={ gs } , ps={ ps } " , file = sys .stderr )
83
+ # For now, we only process the first segment
84
+ # If you need to handle multiple segments, you might want to modify this
85
+ break
86
+
97
87
except Exception as err :
98
88
return backend_pb2 .Result (success = False , message = f"Unexpected { err = } , { type (err )= } " )
89
+
99
90
return backend_pb2 .Result (success = True )
100
91
101
92
def serve (address ):
@@ -108,11 +99,11 @@ def serve(address):
108
99
backend_pb2_grpc .add_BackendServicer_to_server (BackendServicer (), server )
109
100
server .add_insecure_port (address )
110
101
server .start ()
111
- print ("[Kokoro] Server started. Listening on: " + address , file = sys .stderr )
102
+ print ("Server started. Listening on: " + address , file = sys .stderr )
112
103
113
104
# Define the signal handler function
114
105
def signal_handler (sig , frame ):
115
- print ("[Kokoro] Received termination signal. Shutting down..." )
106
+ print ("Received termination signal. Shutting down..." )
116
107
server .stop (0 )
117
108
sys .exit (0 )
118
109
@@ -132,5 +123,5 @@ def signal_handler(sig, frame):
132
123
"--addr" , default = "localhost:50051" , help = "The address to bind the server to."
133
124
)
134
125
args = parser .parse_args ()
135
- print ( f"[Kokoro] startup: { args } " , file = sys . stderr )
126
+
136
127
serve (args .addr )
0 commit comments