LLMのコード能力を向上させる取り組み
はじめに
東京科学大学(旧 東京工業大学)の藤井です。
この記事では、Llama-3.1-Swallow-8B-v0.2の学習に利用されたコードデータセットであるswallow-code-v0.1の開発について主に解説を行います。
また、近年(2024)におけるLLMのコーディング能力向上に関する国内外の主要な取り組みについて概説します。
加えて、学習に利用したswallow-code-v0.1をhuggingface上でfiltering元であるbigcode/the-stack-v2-train-smol-idsと同様のライセンスで配布します。LLMのコーディング能力向上のためのデータやLLMの分析等にご利用ください。このデータセットは、後述するようにstack v2に4段階のfilteringを適用し、品質を向上させたものになります。
Llama-3.1-Swallow-8B-v0.2
Llama-3.1-Swallow-8B-v0.2は、Llama-3.1-8Bから約 250B Token継続事前学習を施した日本語と英語に強いLLMです。商用利用可能なライセンスで配布されています。
(Llama-3.1-Swallow-8B-v0.2 のイメージ画像)
前バージョンのLlama-3.1-Swallow-8B-v0.1からの変更点は、以下のとおりです。
- 学習トークン数の増加 (230B -> 250B Token)
- 日本語:英語 割合の変更 (=日本語割合の増加)
- 日本語コーパス(Swallow Corpus v2)のフィルタリング方法の改善
- コードコーパス(The Stack v2)の品質の改善
以上の変更点により、日本語のタスクとして特にNIILC、JHumanEvalの性能が改善しています。
NIILCは、日本語の質問応答タスクのうちの1つであり、JHumanEvalは日本語のコーディング能力を測るためのベンチマークの一つです。特にJHumanEvalが向上していることはLlama-3.1-Swalllow-8B-v0.1よりもv0.2の方が高いコーディング能力を有していることを示しています。
モデル名 | NIILC | JHumanEval |
---|---|---|
Llama-3.1-Swallow-8B-v0.1 | 0.6011 | 0.2811 |
Llama-3.1-Swallow-8B-v0.2 | 0.6272 | 0.3360 |
ベースモデルの網羅的なベンチマークスコアについては、以下のモデルカードをご覧ください。
Swallow Projectで評価した多数のモデルの評価結果については以下も参照ください。
関連研究
以下にLLMのコード能力を向上させる取り組みとして今回Swallow Project内で行った実験の関連研究(Related Works)を示します。関連研究のすべての知見を吸収できているわけではないため、我々が行ったフィルタリングに追加して工夫を適用する際の補助や、実験背景を理解する際に参照ください。
Llama-3
最初は、学習元のベースモデルでもあるLlama-3の論文です。コード能力を上昇させるために、コードデータにどの様な工夫をMeta社が行っているかを説明します。
Data Curation
3.1.1 Web Data Curationセクション(arxiv)にあるように、コードや数学に関連するweb pageを抽出するために専用のパイプラインをMetaは開発しています。コード、reasoning両方とも、分類器(classifier)には、Llama 2でannoatationされたweb dataで学習されたDistilRobertaを利用しています。加えて、通常の自然言語とcodeやmathのtoken distributionは大きく異なるため、pipelineにはドメイン固有のHTML抽出器や、通常とは異なるtext feature filtering、heuritstics fitleringが使用されていると報告されています。
このように、CommonCrawlから数学、コードテキストを収集するには、HTML抽出機から独自のものを利用する必要があるだけでなく、多数の工夫が必要であり、簡単に良質なコードコーパスを手に入れることは難しいことが分かります。また逆に、工数をかけてData CurationすればLlama-3並の性能を手に入れられるということも示されています。
Code Quality Control
Llama-3では、post trainingデータとして合成データ(synthetic data)を利用しており、中でもprogrammingに関する問題(programming problem disscriptions generation)と、その答え(soultion generation)をLLMを使用して作成するという作業を行っています。
この作業において、どのようにして答えが正しいのかを保証するのかという問題がありますが、Llama-3では、次のように対応しています。ただし、ここで注意したいのは、"完全な"正確性というものは保証しておらず、おそらく確からしいというという程度までデータセットをフィルタリングしているということです。
- static analysis: paraserとlinterを利用し、LLMにより生成されたコードに文法エラー(syntax error)がないか、未初期化の変数がないか、importされていない関数を利用していないかなどを確認しています。
- Unit test generation and execution: 各問題と答えについて、モデルにunit testを書くようにプロンプトを与え、コンテナ環境でテストを実行し、Runtime Errorが発生しないか確認しています。
このようなQuality Controlは、1番のstatic analysisだけであれば、CPU環境で完結するので、比較的安いコストでpre-trainingデータについても適用することが可能です。
そのため、今回のfilteringにおいては、static analysisを採用しました。
Qwen-2.5 Coder
データの種類
Qwen-2.5-Coderの学習データセットであるQwen-2.5-Coder-Dataは、ソースコードデータ(Source Code Data)、テキストとコードの組のデータ(Text-Code Grounding Data)、Synthetic Data(合成データ)、数学データ(Math Data)、テキストデータ(Text Data)の5種類からなります。以下では、コード能力の向上に特に関連するSoruce Code、Text-Code Grounding data、Math Dataについて紹介します。
Source Code
Qwenチームは、StarCoder2やDS-Coderと同様に、GitHubから収集可能なコードデータをrule-basedな方法によりフィルタリングしています。また、ただのコードデータに加えて、Qwenチームは、Pull Requestやgit commit、Jupyternote Books、Kaggle datasetsなどについても収集し、同様にrule basedなフィルタリングによりクリーニングを行っています。
Text-Code Grounding Data
CommonCrawlより、コードに関連するドキュメント、チュートリアル、ブログなどを収集し、text-code mixedなデータを構築しています。
上図が示すように、Qwen-2.5-Coder-1.5Bでの実験では、フィルタリングの段階を経るごとにデータサイズは減少しますが、それとは逆にモデルの性能は上昇しています。これは、この規模では、高品質なText-Code Grounding Dataの方がたとえ量が少なくともコーディング能力向上には有効であることを示しています。
Math Data
Qwen-2.5-Coderでは、モデルの数学的能力(mathmatical capabilities)を向上させるために、Qwen-2.5-Mathで利用したデータを学習データに統合しています。
その結果、数学データを含めることはコードタスクにおけるモデルの性能になんら悪影響を与えなかったことが判明しています。これは非常に重要な知見であり、Qwen並にデータを高品質化できた場合に限られる可能性はありますが、LLMの数学能力とコード能力はベンチマークタスク上は排反なものではなく、学習データ量さえ増やせば、両方の指標の上昇を目指すことが可能であるということです。
学習粒度
Qwen-2.5-CoderのPre-Training(継続事前学習)フェーズは上図のように2つに分けられます。
1つ目は、File-Level Pretraining、2つ目は、Repo-Level Pretrainingです。
File-Level Pretrainingは、最もよくある方法で、個別のコードファイルに着目して学習を行います。Qwen-2.5-Coderにおけるこのフェーズでは、context sizeは8,192としており、long contextでの学習は行っていません、学習データサイズは5.2T Tokenであり、next token predictionとfill-in-the-middleにより学習されています。
File-Level Pretrainingの後で、モデルのlong-context能力を向上させる目的で、repo-level pretraining(GitHub等のリポジトリ単位での学習)を行っています。
この段階では、context sizeはFile-Levelのときの8.192から32,768に4倍に増加させており、Long context trainingでよく見られるRoPEのbase frequencyを10,000から1,000,0000に増加させる手法が取られています。
また、以下のような形でFile-LevelでのFill-In-the-Middleではなく、リポジトリレベルでのFill-In-the-Middleを行っています。
Swallow-Code-v0.1 の製法
以上のような先行研究をふまえ、我々は以下のとおりThe stack v2からフィルタリングを行い、高品質なコードデータを得ました。学習方法にはnext token predictionを採用し、Fill-In-the-Middleは採用しませんでした。
フィルタリングは第1段階から第4段階まで存在し、第1 -> 第2 -> 第3とより厳しいfilteringになるように設計されています。なお、現在、別の効果的なフィルタリング方法について検討中です。
今回は、実験の簡素化のために対象とするプログラミング言語を絞り、bigcode/the-stack-v2-train-smol-idsのlanguageフィールドの値がPython
であるものに限定しています。
しかし、stack v2は多様なプログラミング言語を含んでおり、stack v2で学習したモデルの下流タスク性能とPythonに限定したfiltering済みデータセットで学習したモデルの下流タスク性能を比較するのは、filteringの効果を測る上で適切な方法ではありません。
そこで、本研究では、ベースラインとして stack v2からPythonのみを抽出したデータであるpython_filtered_v1
を作成し、このデータで学習した場合と比較することで良いコードデータのフィルタリング方法を探っています。
以下にフィルタリングの段階とどのようなフィルタリングを施したのかの対応を示します。
- 第1段階: bigcode/the-stack-v2-train-smol-ids からlanguageフィールドがPythonであるものを抽出したデータ。(Llama-3 tokenizerで約36B Token)
- 第2段階: Pythonのcompile関数を利用し、コードにSyntaxErrorが存在しないか検証し、問題があるデータは排除したデータ。(約31B Token)
- 第3段階: pylintを利用し、Pythonコードを点数付けし、一定の点数未満のデータは低品質であるとみなし、データを削除したデータ。(約20B Token)
- 第4段階: Pythonコード中のコメントとリテラルを取得し、英語と日本語のものだけに限定するフィルタリングを実施したデータ。(約16B Token)
Stage 2: Python compile関数によるフィルタリング
def check_syntax_error(code: str) -> list[str]:
"""
Args:
code (str): python code
Returns:
list[str]: py compile errors
"""
issues = []
try:
compile(code, "<string>", "exec")
except SyntaxError as e:
issues.append(f"Syntax error: {str(e)}")
return issues
Python compile関数によるSyntaxエラーの検知は上記のように実装しています。
コンパイル言語においては、コンパイルエラーを検知することや、Syntax Errorを静的解析するツールなどを利用して同様の処理を行うことが可能です。
このフィルタリングは、LLMがSyntax Errorを含むコードを学習し、実行不可能なコードを出力することで下流タスク性能が下がってしまうことを防止する目的で行われています。
Stage 3: Pylintによるフィルタリング
以下のように、一時的にPythonファイルを作成し、そのファイルをpylintによりlintingすることでフィルタリングを行っています。pylintは、`Your code has been rated at"のようにコードの品質を点数で出力する機能を有しています。これを利用して、コードの品質によるフィルタリングを行っています。
with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
temp_file.write(code.encode())
temp_file.flush()
result = subprocess.run(
["pylint", "--persistent=n","--disable=E0401,C0114,C0301,C0103,C0116,C0411,R0903,W0511,C0412", temp_file.name],
capture_output=True,
text=True,
)
いくつかの規則を無視するように--disable=E0401,C0114,C0301,C0103,C0116,C0411,R0903,W0511,C0412
としていますが、これには理由があります。
まず、stack v2に存在するコードにはpip install
により特定のpackageをinstallする必要があるライブラリに依存したコードや、企業や組織内のみで利用可能なライブラリに依存したPythonコードが多数含まれています。理想的には、import Error
が発生した時点で該当のライブラリをinstallして再度lintingを行うことが理想的ですが、大量のコードデータをfilteringする際には、そのような方法はfiltering時間の観点から望ましくありません。また、機械的にpip install
することによる悪影響も考えられるため、慎重にならざるを得ません。そこで、import-error/ E0401を無視することで、この問題に対処しています。同様に、missing-module-docstring C0114など、本来の規則からすると望ましくないが、実際のコードでは多用されているものに関しては、有用なコードを低品質と判定するおそれがあるため無視することにしました。
Stage 4: コード中のコメントを言語判定しフィルタリング
Pythonのtokenizeを利用することで、Pythonコードからtokenize.STRING
とtokenize.COMMENT
を抽出します。
この抽出された文字が英語または、日本語であるか判定を行い、他の言語であると判定されたコード文書を破棄するフィルタリングを行いました。
このフィルタリング意図は、Swallowモデルは、日本語と英語をターゲットにしたLLMであり、他の言語を用いて記述されたコメントが存在するコードデータをLLMが学習することで下流タスク性能に悪影響を及ぼすのではないかと仮説を立てたためです。
実験概要
上述のフィルタリングの効果を観測するために、Llama-3.1-8Bから継続事前学習を行いました。
学習データには、日本語、英語、コードが含まれており、実験ごとにコードデータのみを変更しました。
また、主要なハイパーパラメータは以下の通りに設定しました。
ハイパラ | 値 |
---|---|
global batch size | 512 |
sequence length | 8,192 |
LR | 2.5E-5 |
WD | 0.1 |
GC | 1.0 |
Adam( |
(0.9, 0.95, 1e-8) |
実験結果
Llama-3.1-8Bから50B Token日本語、英語、コードが混ざったデータを学習させ、10B Token学習するごとに下流タスクの性能を評価する形で実験を行いました。以下は、50B Token学習した際のコーディングに関係する下流タスクの評価結果です。
モデル名 | JHumanEval | HumanEval | MBPP-Ja(pass@1) | MBPP(pass@1) |
---|---|---|---|---|
実験1(stage 1) | 0.3061 | 0.3689 | 0.4012 | 0.4402 |
実験2(stage 2) | 0.3030 | 0.3787 | 0.4182 | 0.4370 |
実験3(stage 3) | 0.3195 | 0.3945 | 0.4110 | 0.4322 |
実験4(stage 4) | 0.3195 | 0.3841 | 0.4204 | 0.4228 |
実験結果が示すように実験3, 4においては、ベースラインである実験1と比較してJHumanEval、HumanEval、MBPP-Ja(pass@1)がすべて上昇しています。しかしながら、MBPP(pass@1)は低下しており、MBPP(pass@10)にすると低下幅は減少するものの依然として微小な低下傾向は存在します。
MBPPのみ異なる傾向を示すものの、JHumanEval、HumanEval、MBPP-Jaは同様にスコアの改善を示しており、filteringによるコードデータの高品質化がLLMのコーディング能力を上昇させているとみなせます。
JHumanEval、HumanEval、MBPP-Ja、MBPPの平均値と、各Stageにおけるデータセットサイズを図示すると以下のようになります。
また、日本語のコード性能のみを抽出しJHumanEval、MBPP-Jaの平均値とすると以下のようになります。
こちらでは、Stage 1 -> Stage 2 -> Stage 3 -> Stage 4とfilteringを多段階にすればするほどスコアが上昇する傾向が示されています。
今後
Swallow Project、そして横田研究室では、LLMのReasoning能力の向上のために、数学とコーディング能力を強化するための取り組みを引き続き行っていきます。
オフライン環境でLLMを利用したい場面や、情報の取り扱いに難がある場合など、ローカル環境で動作できるLLM(通称 ローカルLLM)の需要は依然として存在すると考えています。
単純な性能だけを求める場合は、OpenAIのGPT-4o, o1を利用することが得策であると私も考えますが、Swallowで開発したモデルが研究開発や産業利用など何かしらの役に立つことを願っています。
Discussion