Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

添加了流式解码器,做到更好的控制台体验和更高的解码效率 #817

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Align output with ChatGLM model
Add unit test for stream decoder
  • Loading branch information
lwh9346 committed Apr 26, 2023
commit 430224bf13ba1329e41505b8ba98b49e2e6088d5
4 changes: 2 additions & 2 deletions stream_cli_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from transformers import AutoTokenizer, AutoModel
import signal
import platform
from stream_utils import SPStreamDecoder
from stream_utils import ChatGLMStreamDecoder


tokenizer = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True)
stream_decoder = SPStreamDecoder(tokenizer.sp_tokenizer.text_tokenizer.sp)
stream_decoder = ChatGLMStreamDecoder(tokenizer.sp_tokenizer.text_tokenizer.sp)
model = AutoModel.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()
Expand Down
96 changes: 94 additions & 2 deletions stream_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sentencepiece as spm
from typing import Tuple
import re
import unittest

# python implantation of https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc

Expand All @@ -24,6 +26,7 @@ def DecodeSentencePiece(piece: str, id: int, is_bos_ws: bool, sp: spm.SentencePi
if is_bos_ws and (add_dummy_prefix or remove_extra_whitespaces):
t = piece.removeprefix(SPStreamDecoder.SpaceSymbol)
has_bos_ws = t != piece
piece = t
# if we are removing extra whitespace, we remove all leading whitespace
if remove_extra_whitespaces:
has_bos_ws = False
Expand Down Expand Up @@ -67,17 +70,21 @@ def __init__(self, sp: spm.SentencePieceProcessor, remove_extra_whitespaces=Fals
self._nothing_decoded = True
self._ids = []
self._decoded = ""
self._ending = False
self.remove_extra_whitespaces = remove_extra_whitespaces
self.add_dummy_prefix = add_dummy_prefix

def put(self, ids: list[int]) -> None:
self._ending = False
self._ids += ids
self._decode(eos=False)

def end(self) -> None:
self._decode(eos=True)
self._is_bos_ws = True
self._bos_ws_seen = False
self._nothing_decoded = True
self._ending = True
self._ids = []

def _decode(self, eos=False) -> None:
Expand All @@ -88,7 +95,7 @@ def _decode(self, eos=False) -> None:
if not self._sp.IsByte(self._ids[i]):
self._decoded += ProcessBytePieces(byte_pieces)
consumed += len(byte_pieces)
if consumed > 0:
if len(self._decoded) > 0:
self._nothing_decoded = False
byte_pieces = []
# if we have seen a bos_ws token or any non-empty token
Expand All @@ -98,7 +105,7 @@ def _decode(self, eos=False) -> None:
piece, self._ids[i], self._is_bos_ws, self._sp)
self._decoded += decoded
consumed += 1
if consumed > 0:
if len(self._decoded) > 0:
self._nothing_decoded = False
else:
byte_pieces.append(piece)
Expand All @@ -111,3 +118,88 @@ def get(self) -> str:
t = self._decoded
self._decoded = ""
return t


class ChatGLMStreamDecoder(SPStreamDecoder):

def get(self) -> str:
# if prefix of special tokens found, wait till it's impossible or end of decode
if "[" in self._decoded and len(self._decoded)-self._decoded.index("[") < 8 and not self._ending:
return ""
if "<" in self._decoded and len(self._decoded)-self._decoded.index("<") < 12 and not self._ending:
return ""
self._ending = False
t = self._decoded
self._decoded = ""
t = t.replace("<n>", "\n")
t = t.replace("[[训练时间]]", "2023年")
punkts = [
[",", ","],
["!", "!"],
[":", ":"],
[";", ";"],
["\?", "?"],
]
for item in punkts:
t = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], t)
t = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], t)
# for i in range(max_len, 1, -1):
# t = t.replace(f"<|blank_{i}|>", " " * i)
for blank_token in re.findall(r"<\|blank_\d+\|>", t):
t = t.replace(blank_token, " " *
int(re.search(r"\d+", blank_token)[0]))
return t


class ChatGLMStreamDecoderTest(unittest.TestCase):
def test_ChatGLM_StreamDecoder(self):
from transformers import AutoTokenizer, AutoModel
test_strings = [
"你好👋", # multi-byte encoding
"Hello this is ChatGLM!", # normal text
"你好👋 This is ChatGLM!", # multi-byte encoding with tail
"!?.,!?。,", # punctuations
"A\nB", # "<n>" -> "\n"
"[[训练时间]]", # training time token
"[[训练时间]123", # broken training time token
"1 1", # blank token. Note: It's hard to match the results of strip(), so add leading and tailing "1"
"<|blank_8|123", # broken blank token
]
tokenizer = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()
encoded_ids = [tokenizer(x)['input_ids'] for x in test_strings]
stream_decoder = ChatGLMStreamDecoder(
tokenizer.sp_tokenizer.text_tokenizer.sp)
# original output
expected_outputs = [model.process_response(
tokenizer.decode(x)) for x in encoded_ids]
# decode token by token
decoded_strings_stream_token_by_token = [None for _ in test_strings]
for i in range(len(test_strings)):
res = []
for t in encoded_ids[i]:
stream_decoder.put([t])
res.append(stream_decoder.get())
stream_decoder.end()
res.append(stream_decoder.get())
res = "".join(res)
decoded_strings_stream_token_by_token[i] = res
# decode all at once
decoded_strings_stream = [None for _ in test_strings]
for i in range(len(test_strings)):
stream_decoder.put(encoded_ids[i])
stream_decoder.end()
decoded_strings_stream[i] = stream_decoder.get()
for i in range(len(test_strings)):
print(
f"Stream decoder test{i}: expected: '{expected_outputs[i]}', token_by_token: '{decoded_strings_stream_token_by_token[i]}', all at once: '{decoded_strings_stream[i]}'")
self.assertEqual(
expected_outputs[i], decoded_strings_stream_token_by_token[i])
self.assertEqual(expected_outputs[i], decoded_strings_stream[i])


if __name__ == "__main__":
unittest.main()