Skip to content

Commit 2c2350c

Browse files
committed
initial commit for star+commit0
1 parent df0dc34 commit 2c2350c

File tree

5 files changed

+1759
-0
lines changed

5 files changed

+1759
-0
lines changed

examples/star/inference.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import List
2+
from datasets import Dataset
3+
from vllm import LLM, SamplingParams
4+
5+
def generate_predictions(
6+
model_name: str,
7+
dataset: Dataset,
8+
temperature: float = 1.0,
9+
n: int = 1
10+
) -> List[List[str]]:
11+
"""
12+
Generate predictions for a given dataset using a specified language model and
13+
sampling parameters. The function loads the dataset, constructs prompts from
14+
each example, and obtains generated predictions. The resulting predictions are
15+
then added as a new column to the dataset.
16+
17+
Args:
18+
model_name (str): Name of the model to use for generation.
19+
dataset (Dataset): The Dataset object.
20+
temperature (float, optional): Temperature setting for the model's
21+
sampling strategy. Default is 1.0.
22+
n (int, optional): Number of sampling runs per prompt. Default is 1.
23+
24+
Returns:
25+
predictions (List[List[str]]): Predictions on the dataset.
26+
"""
27+
sampling_params = SamplingParams(n=n, temperature=temperature, max_tokens=512)
28+
llm = LLM(model=model_name)
29+
30+
prompts: List[str] = []
31+
for example in dataset:
32+
prompt = (
33+
f"{example['text']} Your code should satisfy these tests:\n\n"
34+
f"{'\n'.join(example['test_list'])}"
35+
)
36+
prompts.append(prompt)
37+
38+
outputs = llm.generate(prompts, sampling_params)
39+
40+
results: List[List[str]] = []
41+
for output in outputs:
42+
generated_texts = [one.text for one in output.outputs]
43+
results.append(generated_texts)
44+
return results
45+
#out_name = dataset_name.split("/")[-1]
46+
#out_name = f"wentingzhao/{out_name}_predictions_{n}"
47+
#ds.push_to_hub(out_name)
48+

examples/star/star.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Main STaR Loop"""
2+
import argparse
3+
from datasets import Dataset, load_dataset
4+
from inference import generate_predictions
5+
6+
def main():
7+
parser = argparse.ArgumentParser()
8+
parser.add_argument("--model_name", type=str, required=True, help="model to use")
9+
parser.add_argument("--dataset_name", type=str, required=True, help="dataset to use")
10+
parser.add_argument("--temperature", type=float, default=1)
11+
parser.add_argument("-n", type=int, default=1)
12+
args = parser.parse_args()
13+
14+
ds = load_dataset(args.dataset_name)
15+
assert "train" in ds
16+
samples = generate_predictions(args.model_name, ds["train"], args.temperature, args.n)
17+
18+
if __name__ == '__main__':
19+
main()

0 commit comments

Comments
 (0)