|
| 1 | +from typing import List |
| 2 | +from datasets import Dataset |
| 3 | +from vllm import LLM, SamplingParams |
| 4 | + |
| 5 | +def generate_predictions( |
| 6 | + model_name: str, |
| 7 | + dataset: Dataset, |
| 8 | + temperature: float = 1.0, |
| 9 | + n: int = 1 |
| 10 | +) -> List[List[str]]: |
| 11 | + """ |
| 12 | + Generate predictions for a given dataset using a specified language model and |
| 13 | + sampling parameters. The function loads the dataset, constructs prompts from |
| 14 | + each example, and obtains generated predictions. The resulting predictions are |
| 15 | + then added as a new column to the dataset. |
| 16 | +
|
| 17 | + Args: |
| 18 | + model_name (str): Name of the model to use for generation. |
| 19 | + dataset (Dataset): The Dataset object. |
| 20 | + temperature (float, optional): Temperature setting for the model's |
| 21 | + sampling strategy. Default is 1.0. |
| 22 | + n (int, optional): Number of sampling runs per prompt. Default is 1. |
| 23 | +
|
| 24 | + Returns: |
| 25 | + predictions (List[List[str]]): Predictions on the dataset. |
| 26 | + """ |
| 27 | + sampling_params = SamplingParams(n=n, temperature=temperature, max_tokens=512) |
| 28 | + llm = LLM(model=model_name) |
| 29 | + |
| 30 | + prompts: List[str] = [] |
| 31 | + for example in dataset: |
| 32 | + prompt = ( |
| 33 | + f"{example['text']} Your code should satisfy these tests:\n\n" |
| 34 | + f"{'\n'.join(example['test_list'])}" |
| 35 | + ) |
| 36 | + prompts.append(prompt) |
| 37 | + |
| 38 | + outputs = llm.generate(prompts, sampling_params) |
| 39 | + |
| 40 | + results: List[List[str]] = [] |
| 41 | + for output in outputs: |
| 42 | + generated_texts = [one.text for one in output.outputs] |
| 43 | + results.append(generated_texts) |
| 44 | + return results |
| 45 | + #out_name = dataset_name.split("/")[-1] |
| 46 | + #out_name = f"wentingzhao/{out_name}_predictions_{n}" |
| 47 | + #ds.push_to_hub(out_name) |
| 48 | + |
0 commit comments