Commit d925c583 by ZhangXiaoyun

Initial commit

parents
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve OpenR
labels: [ "bug" ]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to fill out this bug report! 🤗 We kindly suggest asking your question in **English** to make them more accessible for everyone and easier to search for. We believe this further helps create a more inclusive and collaborative environment for everyone. 🤗
# Before you submit your bug report:
# - If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
# - Try our [docs bot](https://huggingface.co/spaces/huggingchat/hf-docs-chat) -- it might be able to help you with your issue
- type: textarea
id: system-info
attributes:
label: System Info
description: |
Please share your system info with us. You can manually provide the following details:
- Operating System (e.g., Linux, macOS, Windows)
- Python version (e.g., 3.8)
- Package version (relevant versions of your software)
- Hardware (e.g., CPU, GPU - if applicable)
- Other relevant libraries (any important dependencies)
placeholder: Operating System, Python version, package version, etc.
validations:
required: true
- type: textarea
id: who-can-help
attributes:
label: Who can help?
description: |
Your issue will be replied to more quickly if you can figure out the **right person to tag with @**, or **attach labels**
If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.
All issues are read by one of the core maintainers, so if you don't know who to tag, just leave this blank and
a core maintainer will ping the right person.
Please tag fewer than 3 people.
Codebase:
- Reasoning: @ziyuwan
- RL Training: @morning9393
- Reward Model: @Gebro13
- Data Pre-process: @mengfn, @gzqaq
Website and Docs: @YanSong97, @iamlilAJ
Models and Datasets: @mengfn, @YanSong97, @iamlilAJ
placeholder: "@Username ..."
- type: checkboxes
id: information-scripts-examples
attributes:
label: Information
description: 'The problem arises when using:'
options:
- label: "The official example scripts"
- label: "My own modified scripts"
- type: checkboxes
id: information-tasks
attributes:
label: Tasks
description: "The tasks I am working on are:"
options:
- label: "An officially supported task in the codebase (such as scrips/, ...)"
- label: "My own task or dataset (give details below)"
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
If you have code snippets, error messages, stack traces please provide them here as well.
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
placeholder: |
Steps to reproduce the behavior:
1.
2.
3.
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: Expected behavior
description: "A clear and concise description of what you would expect to happen."
name: "\U0001F680 Feature request"
description: Submit a proposal/request for a new OpenR feature
labels: [ "enhancement" ]
body:
- type: textarea
id: feature-request
validations:
required: true
attributes:
label: Feature request
description: |
A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist.
- type: textarea
id: motivation
validations:
required: true
attributes:
label: Motivation
description: |
Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
- type: textarea
id: contribution
validations:
required: true
attributes:
label: Your contribution
description: |
Is there any way that you could help, e.g. by submitting a PR?
# Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md)
logs_fastchat/
__pycache__/
results/
*.*~
*.pkl
**/*.egg-info
.python-version
.idea/
.vscode/
.DS_Store
_build/
# build files of doc website
# Ignore the default location of the built site, and caches and metadata generated by Jekyll
docs/_site/
docs/.sass-cache/
docs/.jekyll-cache/
docs/.jekyll-metadata
# Ignore folders generated by Bundler
docs/.bundle/
docs/vendor/
# Contribute to *OpenR*
> Every contribution, no matter how small, is valuable to the community.
Thank you for your interest in ***OpenR*** ! 🥰 We are deeply committed to the open-source community,
and we welcome contributions from everyone. Your efforts, whether big or small, help us grow and improve.
Contributions aren’t limited to code—answering questions, helping others, enhancing our
documentation, and sharing the project are equally impactful.
## Ways to contribute
There are several ways you can contribute to ***OpenR***:
- Fix existing issues in the codebase.
- Report bugs or suggest new features by submitting issues.
- Implement new models or features.
- Improve our examples or contribute to the documentation.
The best way to do that is to open a **Pull Request** and link it to the **issue** that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
## Create a Pull Request
> This guide was heavily inspired by [huggingface guide to contribution](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#create-a-pull-request)
You will need basic `git` proficiency to contribute to ***OpenR***. While `git` is not the easiest tool to use, it has the greatest
manual. Type `git --help` in a shell and enjoy! If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference.
Follow the steps below to start contributing:
1. Fork the [repository](https://github.com/openreasoner/openr) by
clicking on the **[Fork](https://github.com/openreasoner/openr/fork)** button on the repository's page. This creates a copy of the code
under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote:
```bash
git clone git@github.com:<your Github handle>/openr.git
cd openr
git remote add upstream https://github.com/openreasoner/openr.git
```
3. Create a new branch to hold your development changes:
```bash
git checkout -b a-descriptive-name-for-my-changes
```
🚨 **Do not** work on the `main` branch!
4. Set up a development environment by following the [README](https://github.com/openreasoner/openr/blob/main/README.md) file.
5. Develop the features in your branch.
Once you're happy with your changes, add the changed files with `git add` and
record your changes locally with `git commit`:
```bash
git add modified_file.py
git commit
```
Please remember to write [good commit
messages](https://chris.beams.io/posts/git-commit/) to clearly communicate the changes you made!
To keep your copy of the code up to date with the original
repository, rebase your branch on `upstream/branch` *before* you open a pull request or if requested by a maintainer:
```bash
git fetch upstream
git rebase upstream/main
```
Push your changes to your branch:
```bash
git push -u origin a-descriptive-name-for-my-changes
```
If you've already opened a pull request, you'll need to force push with the `--force` flag. Otherwise, if the pull request hasn't been opened yet, you can just push your changes normally.
6. Now you can go to your fork of the repository on GitHub and click on **Pull Request** to open a pull request.
When you're ready, you can send your changes to the project maintainers for review.
7. It's ok if maintainers request changes, it happens to our core contributors
too! So everyone can see the changes in the pull request, work in your local
branch and push the changes to your fork. They will automatically appear in
the pull request.
MIT License
Copyright (c) 2024 OpenR Team
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
This diff is collapsed. Click to expand it.
## ***OpenR*** Benchmark Results
\ No newline at end of file
## ***OpenR*** Benchmark Results
Here we present some preliminary results in the following tables.
<table>
<thead>
<tr>
<th>Generator</th>
<th>Reward Model</th>
<th>Search</th>
<th>Budget</th>
<th>MATH</th>
</tr>
</thead>
<tbody>
<tr>
<td rowspan="15" style="text-align: center; vertical-align: middle;">
<img src="https://avatars.githubusercontent.com/u/141221163?s=200&v=4" alt="Qwen2.5 Logo" style="width: 50px; height: 50px;" align="center"><br>
Qwen2.5-Math-1.5B-Instruct
</td>
<td rowspan="5">Math-Sphered-Mistral-7B-PRM</td>
<td>Greedy</td>
<td>$2^0$</td>
<td></td>
</tr>
<tr>
<td>Majority Vote</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Best-of-N</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Beam Search</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>MCTS</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td rowspan="5">Math-psa-7B</td>
<td>Greedy</td>
<td>$2^0$</td>
<td></td>
</tr>
<tr>
<td>Majority Vote</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Best-of-N</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Beam Search</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>MCTS</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td rowspan="5">Skywork-o1-PRM-7B</td>
<td>Greedy</td>
<td>$2^0$</td>
<td></td>
</tr>
<tr>
<td>Majority Vote</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Best-of-N</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Beam Search</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>MCTS</td>
<td>$2^6$</td>
<td></td>
</tr>
<td></td>
<tr>
<td rowspan="15" style="text-align: center; vertical-align: middle;">
<img src="https://avatars.githubusercontent.com/u/141221163?s=200&v=4" alt="Qwen2.5 Logo" style="width: 50px; height: 50px;" align="center"><br>
Qwen2.5-Math-7B-Instruct
</td>
<td rowspan="5">Math-Sphered-Mistral-7B-PRM</td>
<td>Greedy</td>
<td>$2^0$</td>
<td></td>
</tr>
<tr>
<td>Majority Vote</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Best-of-N</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Beam Search</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>MCTS</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td rowspan="5">Math-psa-7B</td>
<td>Greedy</td>
<td>$2^0$</td>
<td></td>
</tr>
<tr>
<td>Majority Vote</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Best-of-N</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Beam Search</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>MCTS</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td rowspan="5">Skywork-o1-PRM-7B</td>
<td>Greedy</td>
<td>$2^0$</td>
<td></td>
</tr>
<tr>
<td>Majority Vote</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Best-of-N</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Beam Search</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>MCTS</td>
<td>$2^6$</td>
<td></td>
</tr>
<td></td>
<tr>
<td rowspan="15" style="text-align: center; vertical-align: middle;">
<img src="https://cdn-avatars.huggingface.co/v1/production/uploads/652260e0728c0b6dc72dd957/lPyWi7-XFLhH3s1c2vnvD.png" alt="Skywork Logo" style="width: 50px; height: 50px;" align="center"><br>
Skywork-o1-Open-Llama-3.1-8B
</td>
<td rowspan="5">Math-Sphered-Mistral-7B-PRM</td>
<td>Greedy</td>
<td>$2^0$</td>
<td></td>
</tr>
<tr>
<td>Majority Vote</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Best-of-N</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Beam Search</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>MCTS</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td rowspan="5">Math-psa-7B</td>
<td>Greedy</td>
<td>$2^0$</td>
<td></td>
</tr>
<tr>
<td>Majority Vote</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Best-of-N</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Beam Search</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>MCTS</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td rowspan="5">Skywork-o1-PRM-7B</td>
<td>Greedy</td>
<td>$2^0$</td>
<td></td>
</tr>
<tr>
<td>Majority Vote</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Best-of-N</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>Beam Search</td>
<td>$2^6$</td>
<td></td>
</tr>
<tr>
<td>MCTS</td>
<td>$2^6$</td>
<td></td>
</tr>
</tbody>
</table>
def str2bool(x: str):
if x == "False":
return False
elif x == "True":
return True
else:
raise ValueError(
'you should either input "True" or "False" but not {}'.format(x)
)
def list_of_ints(arg):
return list(map(int, arg.split(",")))
# Open O1 dev
A simple implementation of data generation of [Improve Mathematical Reasoning in Language Models by Automated Process Supervision](https://arxiv.org/pdf/2406.06592)
## Generating data
Edit the `config.yaml` file to set the model and other parameters. Input json file should be a list of problems and answers with key `problem` and `final_answer`.
```yaml
input:
json_file_path: 'extracted_problems_and_answers.json'
output:
file_prefix: 'math'
log_file_path: 'processing_log.log'
processing:
initial_rollouts: 20
num_rollouts: 20
max_iterations: 100
model:
# supported model_types: "hf", "openai", "anthropic"
model_type: "hf"
model_name: "Qwen/Qwen2.5-Math-7B-Instruct"
model_args:
max_tokens: 200
temperature_range: [0.7, 1.0]
```
```bash
python gen_data.py
```
## Data
Download the [**MATH dataset**](https://people.eecs.berkeley.edu/~hendrycks/MATH.tar) here.
\ No newline at end of file
input:
json_file_path: 'extracted_problems_and_answers.json'
output:
file_prefix: 'math'
log_file_path: 'processing_log.log'
processing:
initial_rollouts: 20
num_rollouts: 20
max_iterations: 100
model:
# supported model_types: "hf", "openai", "anthropic"
model_type: "hf"
model_name: "Qwen/Qwen2.5-Math-7B-Instruct"
model_args:
max_tokens: 200
temperature_range: [0.7, 1.0]
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
import json
import logging
from datetime import datetime
import yaml
from module import Node, perform_rollouts, process_annotations, calculate_mc_score
from model_utils import LM
def load_config(config_path):
"""
Load configuration from a YAML file.
Args:
config_path (str): Path to the YAML configuration file.
Returns:
dict: A dictionary containing the configuration.
"""
with open(config_path, 'r') as file:
return yaml.safe_load(file)
def load_json_file(file_path):
"""
Load data from a JSON file.
Args:
file_path (str): Path to the JSON file.
Returns:
list: A list of dictionaries containing the problem and final answer.
"""
with open(file_path, 'r') as file:
data = json.load(file)
return data
def setup_logging(log_file):
"""
Set up logging configuration to output to file and console.
Args:
log_file (str): Path to the log file.
"""
logging.basicConfig(
filename=log_file,
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
console_handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.addHandler(console_handler)
logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.WARNING)
def main():
# Load configuration
config = load_config('config.yaml')
# Get parameters from config
json_file_path = config['input']['json_file_path']
log_file_path = config['output']['log_file_path']
file_prefix = config['output']['file_prefix']
num_rollouts = config['processing']['num_rollouts']
initial_rollouts = config['processing']['initial_rollouts']
max_iterations = config['processing']['max_iterations']
lm_model = LM(model_type=config['model']['model_type'], model_name=config['model']['model_name'], num_rollouts=num_rollouts, **config['model']['model_args'])
# Set up logging
setup_logging(log_file_path)
# Start the process and log it
logging.info("Started processing the JSON file.")
# Load the JSON data
data = load_json_file(json_file_path)
# Process each problem and its final answer
for i, item in enumerate(data):
problem = item.get('problem', 'No problem found')
final_answer = item.get('final_answer', 'No answer found')
# Log each problem and answer
logging.info(f"Processed Problem {i + 1}: {problem}")
logging.info(f"Final Answer: {final_answer}")
# Initialize the root node and perform rollouts
nodes = []
root_node = Node(problem, "", final_answer)
rollouts, correctness_flags = perform_rollouts(root_node, lm_model, initial_rollouts)
mc_score = calculate_mc_score(root_node)
root_node.mc_score = mc_score
nodes.append(root_node)
# Check if further processing is needed
if 0 < sum(correctness_flags) < initial_rollouts:
print("Processing annotations ...\n")
filename = f"{file_prefix}_{i+1}_nodes_data.json"
process_annotations(problem, nodes, lm_model, filename, max_iterations)
# Log completion
logging.info("Finished processing the JSON file.")
if __name__ == "__main__":
main()
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
import os
from transformers import set_seed
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from typing import List
# Set your Hugging Face token here
os.environ["HUGGINGFACE_HUB_TOKEN"] = "hf_yourkey"
# For reproducibility
SEED = 1234
set_seed(SEED)
random.seed(42)
class LM:
def __init__(self, model_name: str = "Qwen/Qwen2.5-Math-7B-Instruct", model_type: str = "hf", num_rollouts: int = 5, **model_args):
self.model_type = model_type.lower()
self.model_name = model_name
self.max_tokens = 200
self.temperature_range = [0.7, 1.0]
self.num_rollouts = num_rollouts
self.__dict__.update(model_args)
print("Updated model args:", self.__dict__)
if self.model_type == "vllm":
raise NotImplementedError("VLLM is not implemented yet")
from vllm import LLM, SamplingParams
self.llm = LLM(model=model_name, enable_prefix_caching=True)
self.SamplingParams = SamplingParams
elif self.model_type == "hf":
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, device_map="cuda"
)
elif self.model_type == "openai":
import openai
self.client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
elif self.model_type == "anthropic":
import anthropic
self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
else:
raise ValueError("Invalid model_type. Choose 'vllm', 'hf', 'openai', or 'anthropic'.")
def generate(self, question, partial_answer, num_rollouts=None):
if num_rollouts is None:
num_rollouts = self.num_rollouts
prompt = question + partial_answer
if self.model_type == "vllm":
return self.generate_vllm(prompt, num_rollouts)
elif self.model_type == "hf":
return self.generate_hf(prompt, num_rollouts)
elif self.model_type == "anthropic" or self.model_type == "openai":
return self.generate_api(prompt, num_rollouts)
def generate_hf(self, prompt, num_rollouts):
inputs = self.tokenizer(prompt, return_tensors="pt").to('cuda')
results = []
for _ in range(num_rollouts):
temperature = random.uniform(self.temperature_range[0], self.temperature_range[1])
outputs = self.model.generate(
**inputs, do_sample=True, max_new_tokens=self.max_tokens, temperature=temperature,
num_return_sequences=1
)
generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
result = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
results.append(result)
return results
def generate_vllm(self, prompt, num_rollouts):
raise NotImplementedError("VLLM is not implemented yet")
temperature = random.choice(self.temperature_range)
sampling_params = self.SamplingParams(
temperature=temperature,
top_p=1,
max_tokens=self.max_tokens,
n=num_rollouts,
seed=SEED
)
outputs = self.llm.generate(prompt, sampling_params)
result = [completion.text for output in outputs for completion in output.outputs]
return result
def generate_api(self, prompt: str, num_rollouts) -> List[str]:
def send_request(prompt):
temperature = random.choice(self.temperature_range)
if self.model_type == "openai":
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
max_tokens=self.max_tokens,
temperature=temperature
)
output = response.choices[0].message.content
elif self.model_type == "anthropic":
response = self.client.messages.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
max_tokens=self.max_tokens,
temperature=temperature
)
output = response.content[0].text
return output
responses = []
with ThreadPoolExecutor(max_workers=num_rollouts) as executor:
futures = [executor.submit(send_request, prompt) for _ in range(num_rollouts)]
for future in tqdm(as_completed(futures), total=len(futures)):
responses.append(future.result())
return responses
import torch
import random
import json
import re
import os
import math
from model_utils import LM
class Node:
def __init__(self, question, partial_answer, correct_answer):
self.question = question
self.partial_answer = partial_answer
self.correct_answer = correct_answer
self.mc_score = None
self.visits = 0
self.rollouts = []
self.visited_rollouts = []
def add_rollout(self, result):
self.rollouts.append(result)
self.visited_rollouts.append(False)
def increment_visits(self):
self.visits += 1
# Evaluation
def check_correctness(expected_answer, generated_response):
sentences = re.split(
r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', generated_response.strip()
)
last_sentence = sentences[-1] if sentences else ''
return expected_answer.strip() in last_sentence.strip()
def perform_rollouts(node, model: LM, num_rollouts=None):
correctness_flags = []
results = model.generate(node.question, node.partial_answer, num_rollouts)
for result in results:
node.add_rollout(result)
is_correct = check_correctness(node.correct_answer, result)
correctness_flags.append(int(is_correct))
return node.rollouts, correctness_flags
def calculate_mc_score(node):
correct_count = sum(
check_correctness(node.correct_answer, r) for r in node.rollouts
)
return correct_count / len(node.rollouts) if node.rollouts else 0
def select_best_node(nodes):
best_node = None
best_rollout_idx = -1
highest_qu_value = -1
for node in nodes:
mc_score = (
node.mc_score if node.mc_score is not None else calculate_mc_score(node)
)
if mc_score in [0, 1]:
continue
for idx, rollout in enumerate(node.rollouts):
if node.visited_rollouts[idx]:
continue
q_val = compute_q_value(rollout, mc_score)
u_val = compute_u_value(node, nodes)
qu_value = q_val + u_val
if qu_value > highest_qu_value:
highest_qu_value = qu_value
best_node = node
best_rollout_idx = idx
if best_rollout_idx != -1 and best_node is not None:
best_node.visited_rollouts[best_rollout_idx] = True
return best_node, best_node.rollouts[best_rollout_idx], highest_qu_value
else:
return None, None, None
def split_text_middle(text):
text = text.strip()
mid_idx = len(text) // 2
if text[mid_idx] != ' ':
left_space = text.rfind(' ', 0, mid_idx)
right_space = text.find(' ', mid_idx)
if left_space == -1:
split_idx = right_space
elif right_space == -1:
split_idx = left_space
else:
split_idx = (
left_space if (mid_idx - left_space) <= (right_space - mid_idx) else right_space
)
else:
split_idx = mid_idx
part1 = text[:split_idx].strip()
part2 = text[split_idx:].strip()
return part1, part2
def locate_error(node, rollout, model):
current_span = rollout
previous_text = ""
nodes_to_expand = []
leaf_nodes = []
while True:
if len(current_span.split()) < 2:
break
left_part, right_part = split_text_middle(current_span)
print("----")
print(" Left:", left_part)
print(" Right:", right_part)
new_node = Node(
node.question, previous_text + left_part, node.correct_answer
)
perform_rollouts(new_node, model)
mc_score = calculate_mc_score(new_node)
new_node.mc_score = mc_score
if mc_score == 1:
break
elif mc_score > 0:
current_span = right_part
previous_text += left_part
nodes_to_expand.append(new_node)
else:
current_span = left_part
leaf_nodes.append(new_node)
print("----")
return nodes_to_expand, leaf_nodes
def compute_q_value(rollout_text, mc_score, alpha=0.5, beta=0.9, max_length=500):
part1 = alpha ** (1 - mc_score)
part2 = beta ** (len(rollout_text) / max_length)
return part1 * part2
def compute_u_value(node, all_nodes, exploration_param=0.125):
total_visits = sum(n.visits for n in all_nodes)
numerator = math.sqrt(total_visits)
denominator = 1 + node.visits
return exploration_param * (numerator / denominator)
def process_annotations(question, nodes, model: LM, filename='nodes_data.json', max_iterations=100):
print("++++++")
iteration = 0
leaf_nodes = []
while True:
node, rollout, max_qu = select_best_node(nodes)
if node is not None and node.partial_answer != '':
new_entry = {
"question": question,
"partial_answer": node.partial_answer,
"mc_score": node.mc_score,
}
append_to_json(filename, new_entry)
iteration += 1
if iteration > max_iterations:
break
if node is None:
break
print()
print("[Selected Node]")
print(node)
print(" Rollout:", rollout, " || QU Value:", max_qu)
node.increment_visits()
expanded_nodes, leaves = locate_error(node, rollout, model)
if not expanded_nodes:
continue
nodes.extend(
n for n in expanded_nodes if n is not None and n.partial_answer != ''
)
leaf_nodes.extend(leaves)
for leaf_node in leaf_nodes:
new_entry = {
"question": question,
"partial_answer": leaf_node.partial_answer,
"mc_score": leaf_node.mc_score,
}
append_to_json(filename, new_entry)
print("++++++")
# Utils
def append_to_json(filename, data_entry):
if os.path.exists(filename):
with open(filename, 'r') as file:
data = json.load(file)
else:
data = []
data.append(data_entry)
with open(filename, 'w') as file:
json.dump(data, file, indent=4)
print(f"Data appended to {filename}")
\ No newline at end of file
import os
import threading
from transformers import pipeline
from typing import List
# Import vllm if using vLLM backend
try:
from vllm import LLM, SamplingParams
except ImportError:
print("vLLM not installed. Install it if you wish to use it as a model backend.")
# Set the environment variable for the endpoint
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
class LLMService:
"""
A class to manage a large language model (LLM) service using Hugging Face's transformers library.
"""
def __init__(self, model_name: str = "/root/.cache/modelscope/hub/Qwen/Qwen2___5-Math-7B-Instruct",
device: str = "cuda", max_new_tokens: int = 2048,
temperature: float = 0.7, top_k: int = 30, top_p: float = 0.9, model_type: str="hf"):
"""
Initialize the LLMService with model parameters and sampling settings.
Parameters:
- model_name (str): Path or Hugging Face hub model name.
- device (str): Device for computation, e.g., 'cuda' or 'cpu'.
- max_new_tokens (int): Maximum number of new tokens to generate.
- temperature (float): Sampling temperature for response generation.
- top_k (int): Top-K sampling parameter for response diversity.
- top_p (float): Top-P sampling parameter for response diversity.
"""
self.model_name = model_name
self.device = device
self.max_new_tokens = max_new_tokens
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p
self.model_type = model_type.lower()
self.pipe = None
self.llm = None
self.load_lock = threading.Lock()
def start_service(self):
"""
Start the LLM service by loading the model into the chosen pipeline if it's not already loaded.
Ensures thread-safe loading using a lock.
"""
with self.load_lock:
if self.model_type == "hf":
if self.pipe is None:
print(f"Loading Hugging Face model '{self.model_name}' on device '{self.device}'...")
self.pipe = pipeline(
"text-generation",
model=self.model_name,
torch_dtype="auto",
device_map=self.device
)
print("Hugging Face model loaded successfully.")
elif self.model_type == "vllm":
if self.llm is None:
print(f"Loading vLLM model '{self.model_name}' on device '{self.device}'...")
self.llm = LLM(self.model_name, tensor_parallel_size=1)
print("vLLM model loaded successfully.")
else:
raise ValueError("Unsupported model_type. Choose 'hf' for Hugging Face or 'vllm' for vLLM.")
def generate_response(self, prompt: str, num_copies: int = 2) -> List[str]:
"""
Generate responses from the model based on the provided prompt, duplicated to form a batch.
Parameters:
- prompt (str): The input prompt to generate responses for.
- num_copies (int): The number of copies of the prompt to create for batch processing (default is 16).
Returns:
- List[str]: A list of generated responses, each corresponding to a duplicate of the input prompt.
"""
if self.model_type == "hf":
return self._generate_response_hf(prompt, num_copies)
elif self.model_type == "vllm":
return self._generate_response_vllm(prompt, num_copies)
else:
raise ValueError("Unsupported model_type. Choose 'hf' for Hugging Face or 'vllm'.")
def _generate_response_hf(self, prompt: str, num_copies: int = 2) -> List[str]:
"""
Generate responses from the model based on the provided prompt, duplicated to form a batch.
Parameters:
- prompt (str): The input prompt to generate responses for.
- num_copies (int): The number of copies of the prompt to create for batch processing (default is 16).
Returns:
- List[str]: A list of generated responses, each corresponding to a duplicate of the input prompt.
"""
if self.pipe is None:
raise ValueError("LLM service not started. Please call start_service() first.")
# Create a batch of the same prompt
prompts = [prompt] * num_copies
# Generate responses from the model
responses = self.pipe(
prompts,
max_new_tokens=self.max_new_tokens,
batch_size=num_copies,
do_sample=True,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
return_full_text=False
)
response_message_batch = [result[0]["generated_text"] for result in responses]
# Extract and return the generated text for each response
return response_message_batch
def _generate_response_vllm(self, prompt: str, num_copies: int) -> List[str]:
"""
Generate responses using vLLM.
"""
if self.llm is None:
raise ValueError("LLM service not started for vLLM model. Please call start_service() first.")
sampling_params = SamplingParams(
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
max_tokens=self.max_new_tokens
)
prompts = [prompt] * num_copies
responses = self.llm.generate(prompts, sampling_params=sampling_params)
return [response.outputs[0].text for response in responses]
if __name__ == "__main__":
# Initialize the service for vLLM
llm_service = LLMService(model_type="vllm")
llm_service.start_service()
prompt = "What is game theory?"
responses = llm_service.generate_response(prompt, num_copies=3)
print(responses)
import json
import os
import random
from typing import List
import math
def sample_questions(input_filepath: str, output_filepath: str, num_samples: int):
"""
Samples a specified number of questions from the input JSON file and saves them to an output JSON file.
Parameters:
- input_filepath (str): Path to the original JSON file containing all questions.
- output_filepath (str): Path to save the sampled questions JSON file.
- num_samples (int): Number of questions to sample.
"""
with open(input_filepath, 'r') as f:
questions = json.load(f)
# Sample questions
sampled_questions = random.sample(questions, min(num_samples, len(questions)))
# Save sampled questions to the output file
with open(output_filepath, 'w') as f:
json.dump(sampled_questions, f, indent=4)
print(f"Saved {len(sampled_questions)} sampled questions to {output_filepath}")
def split_questions(input_filepath: str, output_dir: str, questions_per_file: int):
"""
Splits the input JSON file into multiple smaller JSON files, each containing a specified number of questions.
Parameters:
- input_filepath (str): Path to the original JSON file containing all questions.
- output_dir (str): Directory to save the split JSON files.
- questions_per_file (int): Number of questions per split file.
"""
with open(input_filepath, 'r') as f:
questions = json.load(f)
# Create output directory if it does not exist
os.makedirs(output_dir, exist_ok=True)
# Split questions into chunks and save each chunk as a separate file
for i in range(0, len(questions), questions_per_file):
chunk = questions[i:i + questions_per_file]
output_filepath = os.path.join(output_dir, f"questions_part_{i // questions_per_file + 1}.json")
with open(output_filepath, 'w') as f:
json.dump(chunk, f, indent=4)
print(f"Saved {len(chunk)} questions to {output_filepath}")
def split_questions_uniformly(input_filepath: str, output_directory: str, num_files: int):
"""
Split a JSON file containing questions into a specified number of files with approximately equal questions.
Parameters:
- input_filepath (str): Path to the JSON file containing the list of questions.
- output_directory (str): Directory to save the split JSON files.
- num_files (int): Number of files to split the questions into.
Each output file will contain approximately len(questions) / num_files questions.
"""
# Load questions from the input file
with open(input_filepath, 'r') as f:
questions = json.load(f)
# Calculate the number of questions per file
total_questions = len(questions)
questions_per_file = math.ceil(total_questions / num_files)
# Ensure the output directory exists
os.makedirs(output_directory, exist_ok=True)
# Split the questions and write to output files
for i in range(num_files):
start_idx = i * questions_per_file
end_idx = min(start_idx + questions_per_file, total_questions)
questions_subset = questions[start_idx:end_idx]
output_filepath = os.path.join(output_directory, f"questions_part_{i + 1}.json")
with open(output_filepath, 'w') as f_out:
json.dump(questions_subset, f_out, indent=4)
print(f"Saved {len(questions_subset)} questions to {output_filepath}")
# Example usage
if __name__ == "__main__":
split_questions_uniformly("extracted_problems_and_answers.json", "output_directory", 8)
#
# # Sampling a subset of questions from the original file
# sample_questions("extracted_problems_and_answers.json", "sampled_questions.json", 10)
#
# # Splitting the original file into multiple files with each containing 5 questions
# split_questions("extracted_problems_and_answers.json", "output_directory", 5)
# OmegaPRM Multi-GPU Runner
This script runs `OmegaPRM` on multiple GPUs, with each GPU handling a different part of the dataset for parallel processing.
## Steps to Use
1. **Split the Input Data**
Use `process_json.py` to split your input JSON file into multiple parts for each GPU:
```bash
python process_json.py --input_file questions.json --output_dir output_directory --num_splits 8
```
2. **Run the Script**
Use `run_omegaprm_multi_gpu.sh` to start processing with OmegaPRM on each GPU:
``` bash
run_omegaprm_multi_gpu.sh
```
Results are saved in `output_results`.
**Note**: Before running, make sure to set the correct values for parameters in the script. Important parameters include:
-`MODEL_NAME`: Path to the model (e.g., Hugging Face or vLLM model).
-`MODEL_TYPE`: Set to "hf" for Hugging Face or "vllm" for vLLM support.
-`Other parameters` like `MAX_NEW_TOKENS`, `TEMPERATURE`, `TOP_K`, and other hyperparameters according to your needs.
## Run on a Single GPU
```bash
CUDA_VISIBLE_DEVICES=0 python run_omegaprm.py \
--question_file ../extracted_problems_and_answers.json \
--output_dir output_results \
--model_name "Your model name or path " \
--model_type $MODEL_TYPE \
--device "cuda" \
--max_new_tokens 2048 \
--temperature 0.7 \
--top_k 30 \
--top_p 0.9 \
--c_puct 0.125 \
--alpha 0.5 \
--beta 0.9 \
--length_scale 500 \
--num_rollouts 16 \
--max_search_count 20 \
--rollout_budget 200 \
--save_data_tree True\
--log_file_prefix "log/omega_prm_single_gpu"
```
# Generated Data Format
For data generated by `OmegaPRM_v2`, two formats are available:
1. Flat Format (`save_data_tree=False`):
Each entry is structured as:
{ "solution_prefix": [Q, x_1:x_i], "mc_value": 0.5 }
where `i` is a variable representing the number of reasoning steps. This format provides a linear view of the reasoning process without hierarchical structure.
2. Tree Format (`save_data_tree=True`):
In this format, data is organized as a tree structure, aligned with the figure presented in the paper. Each reasoning step (or node) includes:
- text: The cumulative reasoning from the root node up to this specific step.
- mc_value: The Monte Carlo score computed for the reasoning progression up to this step.
- children: A list of child nodes branching from the current node.
\ No newline at end of file
import argparse
import json
import logging
import os
from typing import Dict
from omegaprm import LanguageModel, OmegaPRM
DS_NAME = "math-aps-v2"
logger: logging.Logger
# Set up logging based on provided log file prefix and create log directory if it doesn't exist
def setup_logging(log_file_prefix: str):
log_dir = os.path.dirname(log_file_prefix)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
log_filename = f"{log_file_prefix}.log"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.FileHandler(log_filename), logging.StreamHandler()],
)
return logging.getLogger(__name__)
# Load questions from JSON
def load_questions(filepath: str):
with open(filepath, "r") as f:
return json.load(f)
# Filter a single question based on 32 rollouts
def should_process_question(question: Dict[str, str], llm: LanguageModel) -> bool:
prompt = question["problem"]
correct_answer = question["final_answer"]
has_correct = False
has_incorrect = False
initial_batch_answers = llm.generate_rollout(prompt, 32)
for answer in initial_batch_answers:
if llm.evaluate_correctness(answer, correct_answer):
has_correct = True
else:
has_incorrect = True
if has_correct and has_incorrect:
logger.info(f"Question passed filter: {question['problem']}")
return True
return False
# Run OmegaPRM on a question if it passes the filter
def process_question(omega_prm: OmegaPRM, question: Dict[str, str]):
logger.info(f"Processing question with OmegaPRM: {question['problem']}")
reasoning_steps = omega_prm.run(question["problem"], question["final_answer"])
collected_data = {
"question": question["problem"],
"final_answer": question["final_answer"],
"reasoning_steps": reasoning_steps,
}
return collected_data
# Save collected data for each question
def save_question_data(collected_data: Dict, index: int, output_path: str):
collected_data["question_id"] = index
with open(output_path, "a") as fd:
line = json.dumps(collected_data)
fd.write(f"{line}\n")
logger.info(f"Question {index} is saved to {output_path}")
def main(args):
global logger
logger = setup_logging(args.log_file_prefix)
os.makedirs(args.output_dir, exist_ok=True)
output_file = os.path.join(args.output_dir, f"{DS_NAME}.jsonl")
# ensure output_file is empty since we are appending to it later
with open(output_file, "w") as fd:
fd.write("")
logger.info("Starting OmegaPRM processing")
logger.info(f"Using model: {args.model_name} on device: {args.device}")
logger.info(f"Output file: {output_file}")
logger.info(f"Question file: {args.question_file}")
questions = load_questions(args.question_file)
llm = LanguageModel(
model_name=args.model_name,
device=args.device,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
model_type=args.model_type,
)
omega_prm = OmegaPRM(
LM=llm,
c_puct=args.c_puct,
alpha=args.alpha,
beta=args.beta,
L=args.length_scale,
k=args.num_rollouts,
N=args.max_search_count,
rollout_budget=args.rollout_budget,
save_data_tree=args.save_data_tree,
)
processed_count = 0 # Counter for processed questions
for idx, question in enumerate(questions):
if should_process_question(question, llm):
collected_data = process_question(omega_prm, question)
save_question_data(collected_data, idx, output_file)
processed_count += 1
else:
logger.info(f"Skipping question: {question['problem']}")
# Log summary
logger.info(
f"Total questions processed by OmegaPRM: {processed_count}/{len(questions)}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run OmegaPRM on filtered questions")
parser.add_argument(
"--question_file",
type=str,
required=True,
help="Path to the questions JSON file",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help=f"Directory to save the output file {DS_NAME}.jsonl",
)
parser.add_argument(
"--log_file_prefix",
type=str,
default="omega_prm",
help="Prefix for the log files",
)
parser.add_argument(
"--model_name",
type=str,
default="/root/.cache/modelscope/hub/Qwen/Qwen2___5-Math-7B-Instruct",
help="Model name or path for the language model",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to run the model on (e.g., 'cuda', 'cpu')",
)
parser.add_argument(
"--max_new_tokens", type=int, default=2048, help="Max tokens for LLM generation"
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Sampling temperature for LLM generation",
)
parser.add_argument(
"--top_k", type=int, default=30, help="Top-K sampling for LLM generation"
)
parser.add_argument(
"--top_p", type=float, default=0.9, help="Top-P sampling for LLM generation"
)
parser.add_argument(
"--model_type",
type=str,
default="hf",
help="Model backend to use ('hf' for Hugging Face or 'vllm')",
)
# OmegaPRM parameters with provided defaults
parser.add_argument(
"--c_puct", type=float, default=0.125, help="Exploration constant for OmegaPRM"
)
parser.add_argument(
"--alpha", type=float, default=0.5, help="Weight for MC(s) in OmegaPRM"
)
parser.add_argument(
"--beta", type=float, default=0.9, help="Length penalty for OmegaPRM"
)
parser.add_argument(
"--length_scale", type=int, default=500, help="length scale in OmegaPRM"
)
parser.add_argument(
"--num_rollouts",
type=int,
default=16,
help="Number of rollouts for Monte Carlo estimation in OmegaPRM",
)
parser.add_argument(
"--max_search_count", type=int, default=20, help="Max search count in OmegaPRM"
)
parser.add_argument(
"--rollout_budget", type=int, default=200, help="Rollout budget for OmegaPRM"
)
parser.add_argument(
"--save_data_tree",
type=bool,
default=True,
help="Save data in tree structure for OmegaPRM",
)
args = parser.parse_args()
main(args)
#!/bin/bash
# Set the model and other parameters
MODEL_NAME="/root/.cache/modelscope/hub/Qwen/Qwen2___5-Math-7B-Instruct"
MODEL_TYPE="vllm" # Set to "vllm" for vLLM or "hf" for Hugging Face
DEVICE="cuda"
MAX_NEW_TOKENS=2048
TEMPERATURE=0.7
TOP_K=30
TOP_P=0.9
C_PUCT=0.125
ALPHA=0.5
BETA=0.9
LENGTH_SCALE=500
NUM_ROLLOUTS=16
MAX_SEARCH_COUNT=20
ROLLOUT_BUDGET=200
SAVE_DATA_TREE=True
OUTPUT_DIR="output_results_data"
# Split files directory
SPLIT_DIR="output_directory"
# Create output directory if it doesn't exist
mkdir -p $OUTPUT_DIR
mkdir -p log
# Start the OmegaPRM process on each GPU with separate split files
for i in {1..8}
do
SPLIT_FILE="$SPLIT_DIR/questions_part_${i}.json"
GPU_ID=$((i-1))
OUTPUT_FILE="$OUTPUT_DIR/results_part_${i}.json"
LOG_FILE_PREFIX="log/omega_prm_gpu_$GPU_ID"
# Run the OmegaPRM process in the background on the specified GPU
CUDA_VISIBLE_DEVICES=$GPU_ID python3 run_omegaprm.py \
--question_file $SPLIT_FILE \
--output_dir $OUTPUT_FILE \
--model_name $MODEL_NAME \
--model_type $MODEL_TYPE \
--device $DEVICE \
--max_new_tokens $MAX_NEW_TOKENS \
--temperature $TEMPERATURE \
--top_k $TOP_K \
--top_p $TOP_P \
--c_puct $C_PUCT \
--alpha $ALPHA \
--beta $BETA \
--length_scale $LENGTH_SCALE \
--num_rollouts $NUM_ROLLOUTS \
--max_search_count $MAX_SEARCH_COUNT \
--rollout_budget $ROLLOUT_BUDGET \
--save_data_tree $SAVE_DATA_TREE \
--log_file_prefix $LOG_FILE_PREFIX &
done
# Wait for all processes to finish
wait
echo "All OmegaPRM processes complete."
import torch.distributed as dist
import torch
from datetime import timedelta
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if dist.is_initialized():
if dist.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def print_with_rank(message):
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
print("[{}/{}]: {}".format(rank, world_size, message), flush=True)
else:
print(message, flush=True)
def init_distributed(timeout=timedelta(seconds=3 * 60 * 60)):
if not dist.is_initialized():
dist.init_process_group(backend="nccl", timeout=timeout)
world_size = dist.get_world_size()
local_rank = dist.get_rank()
return local_rank, world_size
def gather_scalar(x, rank, world_size):
x_tensor = torch.tensor(x).to(f"cuda:{rank}")
if rank == 0:
x_list = [torch.zeros_like(x_tensor) for _ in range(world_size)]
dist.gather(x_tensor, x_list, 0)
return [x.item() for x in x_list]
else:
dist.gather(x_tensor)
from .env import Env, extract_answer, extract_groundtruth, judge_correct
from .data import get_train_test_dataset
from .prompt import COT_EXAMPLES, COT_TASK_DESC, PROBLEM_FORMAT_STR, SEP
from pathlib import Path
import jsonlines
from torch.utils.data import Dataset
def get_train_test_dataset(*args, **kwargs):
env_dir = Path(__file__).parent
test_ds = JsonlMathDataset(env_dir / "dataset/test500.jsonl")
train_ds = JsonlMathDataset(env_dir / "dataset/train.jsonl")
return train_ds, test_ds
class JsonlMathDataset(Dataset):
def __init__(self, data_path):
super().__init__()
self.data = []
with jsonlines.open(data_path, "r") as reader:
for obj in reader:
self.data.append(obj)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
return {"question": x["problem"], "answer": x["solution"]}
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
import copy
import re
from typing import List, Optional
import numpy as np
from envs.base_env import CoTEnv, NoLegalActionException, INVALID_ANS
from .prompt import COT_EXAMPLES, COT_TASK_DESC, PROBLEM_FORMAT_STR, SEP
# from .verify_utils import extract_answer as extract_fn, grade_answer
from .parse_utils_qwen import extract_answer as extract_fn, parse_ground_truth
from .grader import math_equal
ANS_RE = None
STOP_STR = None
def extract_answer(answer_str: str) -> str:
return extract_fn(answer_str, data_name='math')
def extract_groundtruth(groundtruth_str: str) -> str:
return parse_ground_truth(groundtruth_str, data_name='math')
def judge_correct(
problem_str: str, extracted_groundtruth: Optional[str], answer: str
) -> bool:
# return grade_answer(given_answer=answer, ground_truth=extracted_groundtruth)
result = math_equal(answer, extracted_groundtruth)
return result
class Env(CoTEnv):
sep = SEP
def __init__(
self,
config,
math_problems,
llm_gen_fn,
task_desc_str: str = COT_TASK_DESC,
cot_example_str: str = COT_EXAMPLES,
problem_format_str: str = PROBLEM_FORMAT_STR,
reset=True,
):
super().__init__(
config,
math_problems,
llm_gen_fn,
task_desc_str,
cot_example_str,
problem_format_str,
reset,
)
@property
def stop_str(self):
return STOP_STR
def post_process_act(self, action: str):
if not action.endswith(self.sep):
action = action.strip() + self.sep
return action
def _is_correct(self, completion):
extracted_answer = extract_answer(completion)
# print("Compare: {} -- {}".format(extrated_answer,
# self.math_problem['answer']))
# return extrated_answer == self.math_problem['answer']
return judge_correct(
self.math_problem["question"], self.math_problem["answer"], extracted_answer
)
def get_reward(self):
"""To implement based on learned reward model"""
return 0
# .coveragerc to control coverage.py
[run]
branch = True
include =
latex2sympy.py
omit =
sandbox/*
gen/*
asciimath_printer.py
setup.py
__init__.py
[report]
# Regexes for lines to exclude from consideration
exclude_lines =
# Have to re-enable the standard pragma
pragma: no cover
# Don't complain about missing debug-only code:
def __repr__
if self\.debug
# Don't complain if tests don't hit defensive assertion code:
raise AssertionError
raise NotImplementedError
# Don't complain if non-runnable code isn't run:
if 0:
if __name__ == .__main__.:
ignore_errors = True
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
.antlr
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don’t work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Azure Functions artifacts
bin
obj
appsettings.json
local.settings.json
.python_packages
stemgen-solution-engine.zip
\ No newline at end of file
The MIT License (MIT)
Copyright 2016, latex2sympy
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
![Logo](https://picgo-1258602555.cos.ap-nanjing.myqcloud.com/icon.png)
# [latex2sympy2](https://github.com/OrangeX4/latex2sympy)
## About
`latex2sympy2` parses **LaTeX math expressions** and converts it into the equivalent **SymPy form**. The latex2sympy2 is adapted from [augustt198/latex2sympy](https://github.com/augustt198/latex2sympy) and [purdue-tlt / latex2sympy](https://github.com/purdue-tlt/latex2sympy).
This project is a part of a VS Code extension called [Latex Sympy Calculator](https://marketplace.visualstudio.com/items?itemName=OrangeX4.latex-sympy-calculator). It is designed for providing people writing in latex or markdown a ability to calculate something when writing math expression.
[ANTLR](http://www.antlr.org/) is used to generate the parser.
## Features
* **Arithmetic:** Add (+), Sub (-), Dot Mul (·), Cross Mul (×), Frac (/), Power (^), Abs (|x|), Sqrt (√), etc...
* **Alphabet:** a - z, A - Z, α - ω, Subscript (x_1), Accent Bar(ā), etc...
* **Common Functions:** gcd, lcm, floor, ceil, max, min, log, ln, exp, sin, cos, tan, csc, sec, cot, arcsin, sinh, arsinh, etc...
* **Funcion Symbol:** f(x), f(x-1,), g(x,y), etc...
* **Calculous:** Limit ($lim_{n\to\infty}$), Derivation ($\frac{d}{dx}(x^2+x)$), Integration ($\int xdx$), etc...
* **Linear Algebra:** Matrix, Determinant, Transpose, Inverse, Elementary Transformation, etc...
* **Other:** Binomial...
**NOTICE:** It will do some irreversible calculations when converting determinants, transposed matrixes and elementary transformations...
## Installation
```
pip install latex2sympy2
```
**Requirements:** `sympy` and `antlr4-python3-runtime` packages.
## Usage
### Basic
In Python:
```python
from latex2sympy2 import latex2sympy, latex2latex
tex = r"\frac{d}{dx}(x^{2}+x)"
# Or you can use '\mathrm{d}' to replace 'd'
latex2sympy(tex)
# => "Derivative(x**2 + x, x)"
latex2latex(tex)
# => "2 x + 1"
```
### Examples
|LaTeX|Converted SymPy|Calculated Latex|
|-----|-----|---------------|
|`x^{3}` $x^{3}$| `x**3`|`x^{3}` $x^{3}$|
|`\frac{d}{dx} tx` $\frac{d}{dx}tx$|`Derivative(x*t, x)`|`t` $t$|
|`\sum_{i = 1}^{n} i` $\sum_{i = 1}^{n} i$|`Sum(i, (i, 1, n))`|`\frac{n \left(n + 1\right)}{2}` $\frac{n \left(n + 1\right)}{2}$|
|`\int_{a}^{b} \frac{dt}{t}`|`Integral(1/t, (t, a, b))`|`-\log{(a)} + \log{(b)}` $-\log{(a)} + \log{(b)}$|
|`(2x^3 - x + z)|_{x=3}` $(2x^3 - x + z)\|_{x=3}$|`z + 51`| `z + 51` $z + 51$ |
If you want to read the math formula, you can click [GitNotes](https://notes.orangex4.cool/?git=github&github=OrangeX4/latex2sympy).
### Solve Equation
``` latex
# Before
x + y = 1
# After
[ y = 1 - x, \ x = 1 - y]
```
### Eval At
``` latex
# Before
(x+2)|_{x=y+1}
# After
y + 3
```
### Matrix
#### Identity matrix
```
tex = r"\bm{I}_3"
latex2sympy(tex)
# => "Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]])"
```
#### Determinant
``` python
from latex2sympy2 import latex2sympy
tex = r"\begin{vmatrix} x & 0 & 0 \\ 0 & x & 0 \\ 0 & 0 & x \end{vmatrix}"
latex2sympy(tex)
# => "x^{3}"
```
#### Transpose
``` python
from latex2sympy2 import latex2sympy
tex = r"\begin{pmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{pmatrix}^T"
# Or you can use "\begin{pmatrix}1&2&3\\4&5&6\\7&8&9\end{pmatrix}'"
latex2sympy(tex)
# => "Matrix([[1, 4, 7], [2, 5, 8], [3, 6, 9]])"
```
#### Elementary Transformation
``` python
from latex2sympy2 import latex2sympy
matrix = r'''
\begin{pmatrix}
1 & 2 & 3 \\
4 & 5 & 6 \\
7 & 8 & 9 \\
\end{pmatrix}
'''
# Scale the row with grammar "\xrightarrow{kr_n}"
tex = matrix + r'\xrightarrow{3r_1}'
latex2sympy(tex)
# => "Matrix([[3, 6, 9], [4, 5, 6], [7, 8, 9]])"
# Swap the cols with grammar "\xrightarrow{c_1<=>c_2}"
# Of course, you can use "\leftrightarrow" to replace "<=>"
tex = matrix + r'\xrightarrow{c_1<=>c_2}'
latex2sympy(tex)
# => "Matrix([[2, 1, 3], [5, 4, 6], [8, 7, 9]])"
# Scale the second row and add it to the first row
# with grammar "\xrightarrow{r_1+kr_2}"
tex = matrix + r'\xrightarrow{r_1+kr_2}'
latex2sympy(tex)
# => "Matrix([[4*k + 1, 5*k + 2, 6*k + 3], [4, 5, 6], [7, 8, 9]])"
# You can compose the transform with comma ","
# and grammar "\xrightarrow[4r_3]{2r_1, 3r_2}"
# Remember the priority of "{}" is higher than "[]"
tex = matrix + r'\xrightarrow[4r_3]{2r_1, 3r_2}'
latex2sympy(tex)
# => "Matrix([[2, 4, 6], [12, 15, 18], [28, 32, 36]])"
```
### Variances
``` python
from latex2sympy2 import latex2sympy, variances, var, set_variances
# Assign x a value of 1
latex2sympy(r"x = 1")
# Assign x a matrix symbol with dimension of n x m
latex2sympy(r"x \in \mathbb{R}^{n \times m}")
# Calculate x + y
latex2sympy(r"x + y")
# => "y + 1"
# Get all variances
print(variances)
# => "{x: 1}"
# Get variance of "x"
print(var["x"])
# => "1"
# Reset all variances
set_variances({})
latex2sympy(r"x + y")
# => "x + y"
```
### Complex Number Support
``` python
from latex2sympy2 import set_real
set_real(False)
```
## Contributing
If you want to add a new grammar, you can fork the code from [OrangeX4/latex2sympy](https://github.com/OrangeX4/latex2sympy).
* To modify parser grammar, view the existing structure in `PS.g4`.
* To modify the action associated with each grammar, look into `latex2sympy.py`.
Contributors are welcome! Feel free to open a pull request or an issue.
import latex2sympy
\ No newline at end of file
from sympy.printing.str import StrPrinter
from sympy.core import S
class AsciiMathPrinter(StrPrinter):
def _print_Limit(self, expr):
e, z = expr.args
return "lim_(%s -> %s) %s" % (self._print(z), self._print(z), self._print(e))
def _print_Integral(self, expr):
e, lims = expr.args
if len(lims) > 1:
return "int_(%s)^(%s) %s d%s" % (self._print(lims[1]), self._print(lims[2]), self._print(e), self._print(lims[0]))
else:
return "int %s d%s" % (self._print(e), self._print(lims))
def _print_Sum(self, expr):
e, lims = expr.args
return "sum_(%s = %s)^(%s) %s" % (self._print(lims[0]), self._print(lims[1]), self._print(lims[2]), self._print(e))
def _print_Product(self, expr):
e, lims = expr.args
return "prod_(%s = %s)^(%s) %s" % (self._print(lims[0]), self._print(lims[1]), self._print(lims[2]), self._print(e))
def _print_factorial(self, expr):
return "%s!" % self._print(expr.args[0])
def _print_Derivative(self, expr):
e = expr.args[0]
wrt = expr.args[1]
return "d/d%s %s" % (self._print(wrt), self._print(e))
def _print_Abs(self, expr):
return "|%s|" % self._print(expr.args[0])
def _print_Equality(self, expr):
return "%s = %s" % (self._print(expr.args[0]), self._print(expr.args[1]))
def _print_Pow(self, expr):
b = self._print(expr.base)
if expr.exp is S.Half:
return "sqrt(%s)" % b
if -expr.exp is S.Half:
return "1/sqrt(%s)" % b
if expr.exp is -S.One:
return "1/%s" % b
return "%s^(%s)" % (b, self._print(expr.exp))
latex2sympy2: https://github.com/OrangeX4/latex2sympy
About
`latex2sympy2` parses **LaTeX math expressions** and converts it into the equivalent **SymPy form**. The latex2sympy2 is adapted from [augustt198/latex2sympy](https://github.com/augustt198/latex2sympy) and [purdue-tlt / latex2sympy](https://github.com/purdue-tlt/latex2sympy).
[ANTLR](http://www.antlr.org/) is used to generate the parser.
Features
* **Arithmetic:** Add (+), Sub (-), Dot Mul (·), Cross Mul (×), Frac (/), Power (^), Abs (|x|), Sqrt (√), etc...
* **Alphabet:** a - z, A - Z, α - ω, Subscript (x_1), Accent Bar(ā), etc...
* **Common Functions:** gcd, lcm, floor, ceil, max, min, log, ln, exp, sin, cos, tan, csc, sec, cot, arcsin, sinh, arsinh, etc...
* **Calculous:** Limit ($lim_{n\to\infty}$), Derivation ($\frac{d}{dx}(x^2+x)$), Integration ($\int xdx$), etc...
* **Linear Algebra:** Matrix, Determinant, Transpose, Inverse, Elementary Transformation, etc...
* **Other:** Binomial...
**NOTICE:** It will do some irreversible calculations when converting determinants, transposed matrixes and elementary transformations...
Installation
```
pip install latex2sympy2
```
**Requirements:** `sympy` and `antlr4-python3-runtime` packages.
Usage
Basic
In Python:
```python
from latex2sympy2 import latex2sympy, latex2latex
tex = r"\frac{d}{dx}(x^{2}+x)"
# Or you can use '\mathrm{d}' to replace 'd'
latex2sympy(tex)
# => "Derivative(x**2 + x, x)"
latex2latex(tex)
# => "2 x + 1"
```
Examples
|LaTeX|Converted SymPy|Calculated Latex|
|-----|-----|---------------|
|`x^{3}` $x^{3}$| `x**3`|`x^{3}` $x^{3}$|
|`\frac{d}{dx} tx` $\frac{d}{dx}tx$|`Derivative(x*t, x)`|`t` $t$|
|`\sum_{i = 1}^{n} i` $\sum_{i = 1}^{n} i$|`Sum(i, (i, 1, n))`|`\frac{n \left(n + 1\right)}{2}` $\frac{n \left(n + 1\right)}{2}$|
|`\int_{a}^{b} \frac{dt}{t}`|`Integral(1/t, (t, a, b))`|`-\log{(a)} + \log{(b)}` $-\log{(a)} + \log{(b)}$|
|`(2x^3 - x + z)|_{x=3}` $(2x^3 - x + z)\|_{x=3}$|`z + 51`| `z + 51` $z + 51$ |
If you want to read the math formula, you can click [GitNotes](https://notes.orangex4.cool/?git=github&github=OrangeX4/latex2sympy).
Matrix
Determinant
``` python
from latex2sympy2 import latex2sympy
tex = r"\begin{vmatrix} x & 0 & 0 \\ 0 & x & 0 \\ 0 & 0 & x \end{vmatrix}"
latex2sympy(tex)
# => "x^{3}"
```
Transpose
``` python
from latex2sympy2 import latex2sympy
tex = r"\begin{pmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{pmatrix}^T"
# Or you can use "\begin{pmatrix}1&2&3\\4&5&6\\7&8&9\end{pmatrix}'"
latex2sympy(tex)
# => "Matrix([[1, 4, 7], [2, 5, 8], [3, 6, 9]])"
```
Elementary Transformation
``` python
from latex2sympy2 import latex2sympy
matrix = r'''
\begin{pmatrix}
1 & 2 & 3 \\
4 & 5 & 6 \\
7 & 8 & 9 \\
\end{pmatrix}
'''
# Scale the row with grammar "\xrightarrow{kr_n}"
tex = matrix + r'\xrightarrow{3r_1}'
latex2sympy(tex)
# => "Matrix([[3, 6, 9], [4, 5, 6], [7, 8, 9]])"
# Swap the cols with grammar "\xrightarrow{c_1<=>c_2}"
# Of course, you can use "\leftrightarrow" to replace "<=>"
tex = matrix + r'\xrightarrow{c_1<=>c_2}'
latex2sympy(tex)
# => "Matrix([[2, 1, 3], [5, 4, 6], [8, 7, 9]])"
# Scale the second row and add it to the first row
# with grammar "\xrightarrow{r_1+kr_2}"
tex = matrix + r'\xrightarrow{r_1+kr_2}'
latex2sympy(tex)
# => "Matrix([[4*k + 1, 5*k + 2, 6*k + 3], [4, 5, 6], [7, 8, 9]])"
# You can compose the transform with comma ","
# and grammar "\xrightarrow[4r_3]{2r_1, 3r_2}"
# Remember the priority of "{}" is higher than "[]"
tex = matrix + r'\xrightarrow[4r_3]{2r_1, 3r_2}'
latex2sympy(tex)
# => "Matrix([[2, 4, 6], [12, 15, 18], [28, 32, 36]])"
```
Variances
``` python
from latex2sympy2 import latex2sympy, variances, var, set_variances
# Assign x a value of 1
latex2sympy(r"x = 1")
# Calculate x + y
latex2sympy(r"x + y")
# => "y + 1"
# Get all variances
print(variances)
# => "{x: 1}"
# Get variance of "x"
print(var["x"])
# => "1"
# Reset all variances
set_variances({})
latex2sympy(r"x + y")
# => "x + y"
```
Contributing
If you want to add a new grammar, you can fork the code from [OrangeX4/latex2sympy](https://github.com/OrangeX4/latex2sympy).
* To modify parser grammar, view the existing structure in `PS.g4`.
* To modify the action associated with each grammar, look into `latex2sympy.py`.
Contributors are welcome! Feel free to open a pull request or an issue.
-r requirements.txt
# Development
pip-tools
pytest
pytest-cov
pycodestyle
autopep8
-e .
#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile dev-requirements.in
#
# via -r dev-requirements.in
antlr4-python3-runtime==4.11.1
# via
# -r requirements.txt
# latex2sympy2
atomicwrites==1.3.0
# via pytest
attrs==19.3.0
# via pytest
autopep8==1.4.4
# via -r dev-requirements.in
click==7.0
# via pip-tools
coverage==4.5.4
# via pytest-cov
more-itertools==7.2.0
# via pytest
mpmath==1.3.0
# via
# -r requirements.txt
# sympy
packaging==19.2
# via pytest
pip-tools==4.2.0
# via -r dev-requirements.in
pluggy==0.13.0
# via pytest
py==1.8.0
# via pytest
pycodestyle==2.5.0
# via
# -r dev-requirements.in
# autopep8
pyparsing==2.4.4
# via packaging
pytest==5.2.2
# via
# -r dev-requirements.in
# pytest-cov
pytest-cov==2.8.1
# via -r dev-requirements.in
six==1.13.0
# via
# packaging
# pip-tools
sympy==1.12
# via
# -r requirements.txt
# latex2sympy2
wcwidth==0.1.7
# via pytest
# THIS MUST BE MAINTAINED AS-IS
-e .
\ No newline at end of file
T__0=1
T__1=2
T__2=3
T__3=4
T__4=5
T__5=6
T__6=7
T__7=8
T__8=9
T__9=10
T__10=11
T__11=12
T__12=13
T__13=14
T__14=15
T__15=16
T__16=17
T__17=18
T__18=19
T__19=20
T__20=21
T__21=22
T__22=23
T__23=24
T__24=25
T__25=26
T__26=27
T__27=28
T__28=29
T__29=30
T__30=31
T__31=32
T__32=33
T__33=34
T__34=35
T__35=36
T__36=37
T__37=38
T__38=39
T__39=40
T__40=41
T__41=42
T__42=43
T__43=44
WS=45
DOLLAR_SIGN=46
ADD=47
SUB=48
MUL=49
DIV=50
L_PAREN=51
R_PAREN=52
L_GROUP=53
R_GROUP=54
L_BRACE=55
R_BRACE=56
L_BRACE_VISUAL=57
R_BRACE_VISUAL=58
L_BRACE_CMD=59
R_BRACE_CMD=60
L_BRACKET=61
R_BRACKET=62
L_BRACK=63
R_BRACK=64
BAR=65
L_VERT=66
R_VERT=67
VERT=68
NORM=69
L_FLOOR=70
R_FLOOR=71
LL_CORNER=72
LR_CORNER=73
L_CEIL=74
R_CEIL=75
UL_CORNER=76
UR_CORNER=77
L_LEFT=78
R_RIGHT=79
ML_LEFT=80
MR_RIGHT=81
FUNC_LIM=82
LIM_APPROACH_SYM=83
FUNC_INT=84
FUNC_SUM=85
FUNC_PROD=86
FUNC_LOG=87
FUNC_LN=88
FUNC_EXP=89
FUNC_SIN=90
FUNC_COS=91
FUNC_TAN=92
FUNC_CSC=93
FUNC_SEC=94
FUNC_COT=95
FUNC_ARCSIN=96
FUNC_ARCCOS=97
FUNC_ARCTAN=98
FUNC_ARCCSC=99
FUNC_ARCSEC=100
FUNC_ARCCOT=101
FUNC_SINH=102
FUNC_COSH=103
FUNC_TANH=104
FUNC_ARSINH=105
FUNC_ARCOSH=106
FUNC_ARTANH=107
FUNC_ARCSINH=108
FUNC_ARCCOSH=109
FUNC_ARCTANH=110
FUNC_ARSINH_NAME=111
FUNC_ARCSINH_NAME=112
FUNC_ARCOSH_NAME=113
FUNC_ARCCOSH_NAME=114
FUNC_ARTANH_NAME=115
FUNC_ARCTANH_NAME=116
FUNC_GCD_NAME=117
FUNC_LCM_NAME=118
FUNC_FLOOR_NAME=119
FUNC_CEIL_NAME=120
FUNC_SQRT=121
FUNC_GCD=122
FUNC_LCM=123
FUNC_FLOOR=124
FUNC_CEIL=125
FUNC_MAX=126
FUNC_MIN=127
FUNC_DET=128
FUNC_EYE_NAME=129
FUNC_ZEROS_NAME=130
FUNC_ONES_NAME=131
FUNC_COLS_NAME=132
FUNC_ROWS_NAME=133
FUNC_DIAG_NAME=134
FUNC_NORM_NAME=135
FUNC_RANK_NAME=136
FUNC_TRACE_NAME=137
FUNC_RREF_NAME=138
FUNC_HSTACK_NAME=139
FUNC_VSTACK_NAME=140
FUNC_ORTHOGONALIZE_NAME=141
FUNC_NULLSPACE_NAME=142
FUNC_DIAGONALIZE_NAME=143
FUNC_EIGENVALS_NAME=144
FUNC_EIGENVECTORS_NAME=145
FUNC_SVD_NAME=146
CMD_TIMES=147
CMD_CDOT=148
CMD_DIV=149
CMD_FRAC=150
CMD_BINOM=151
CMD_CHOOSE=152
CMD_MOD=153
CMD_MATHIT=154
CMD_OPERATORNAME=155
MATRIX_TYPE_MATRIX=156
MATRIX_TYPE_PMATRIX=157
MATRIX_TYPE_BMATRIX=158
MATRIX_TYPE_DET=159
MATRIX_TYPES=160
CMD_MATRIX_START=161
CMD_MATRIX_END=162
CMD_DET_START=163
CMD_DET_END=164
MATRIX_DEL_COL=165
MATRIX_DEL_ROW=166
UNDERSCORE=167
CARET=168
COLON=169
SEMICOLON=170
COMMA=171
PERIOD=172
DIFFERENTIAL=173
EXP_E=174
E_NOTATION_E=175
LETTER_NO_E=176
MATRIX_XRIGHTARROW=177
TRANSFORM_EXCHANGE=178
NUMBER=179
E_NOTATION=180
IN=181
ASSIGNMENT=182
EQUAL=183
LT=184
LTE=185
GT=186
GTE=187
UNEQUAL=188
BANG=189
PERCENT_NUMBER=190
GREEK_CMD=191
OTHER_SYMBOL_CMD=192
SYMBOL=193
VARIABLE=194
'\\acute'=1
'\\bar'=2
'\\overline'=3
'\\breve'=4
'\\check'=5
'\\widecheck'=6
'\\dot'=7
'\\ddot'=8
'\\grave'=9
'\\hat'=10
'\\tilde'=11
'\\widetilde'=12
'\\vec'=13
'\\overrightarrow'=14
'\\bm'=15
'\\boldsymbol'=16
'\\text'=17
'\\textit'=18
'\\mathbb'=19
'\\mathbin'=20
'\\mathbf'=21
'\\mathcal'=22
'\\mathclap'=23
'\\mathclose'=24
'\\mathellipsis'=25
'\\mathfrak'=26
'\\mathinner'=27
'\\mathnormal'=28
'\\mathop'=29
'\\mathopen'=30
'\\mathord'=31
'\\mathpunct'=32
'\\mathrel'=33
'\\mathring'=34
'\\mathrlap'=35
'\\mathrm'=36
'\\mathscr'=37
'\\mathsf'=38
'\\mathsterling'=39
'\\mathtt'=40
'^T'=41
'^{T}'=42
'^{\\top}'=43
'\''=44
'\\$'=46
'+'=47
'-'=48
'*'=49
'('=51
')'=52
'\\lgroup'=53
'\\rgroup'=54
'{'=55
'}'=56
'\\{'=57
'\\}'=58
'\\lbrace'=59
'\\rbrace'=60
'['=61
']'=62
'\\lbrack'=63
'\\rbrack'=64
'|'=65
'\\lvert'=66
'\\rvert'=67
'\\vert'=68
'\\|'=69
'\\lfloor'=70
'\\rfloor'=71
'\\llcorner'=72
'\\lrcorner'=73
'\\lceil'=74
'\\rceil'=75
'\\ulcorner'=76
'\\urcorner'=77
'\\left'=78
'\\right'=79
'\\mleft'=80
'\\mright'=81
'\\lim'=82
'\\int'=84
'\\sum'=85
'\\prod'=86
'\\log'=87
'\\ln'=88
'\\exp'=89
'\\sin'=90
'\\cos'=91
'\\tan'=92
'\\csc'=93
'\\sec'=94
'\\cot'=95
'\\arcsin'=96
'\\arccos'=97
'\\arctan'=98
'\\arccsc'=99
'\\arcsec'=100
'\\arccot'=101
'\\sinh'=102
'\\cosh'=103
'\\tanh'=104
'\\arsinh'=105
'\\arcosh'=106
'\\artanh'=107
'\\arcsinh'=108
'\\arccosh'=109
'\\arctanh'=110
'arsinh'=111
'arcsinh'=112
'arcosh'=113
'arccosh'=114
'artanh'=115
'arctanh'=116
'gcd'=117
'lcm'=118
'floor'=119
'ceil'=120
'\\sqrt'=121
'\\gcd'=122
'\\lcm'=123
'\\floor'=124
'\\ceil'=125
'\\max'=126
'\\min'=127
'\\det'=128
'eye'=129
'zeros'=130
'ones'=131
'cols'=132
'rows'=133
'diag'=134
'norm'=135
'rank'=136
'rref'=138
'hstack'=139
'vstack'=140
'nullspace'=142
'\\times'=147
'\\cdot'=148
'\\div'=149
'\\frac'=150
'\\choose'=152
'\\mod'=153
'\\mathit'=154
'\\operatorname'=155
'matrix'=156
'pmatrix'=157
'bmatrix'=158
'vmatrix'=159
'&'=165
'\\\\'=166
'_'=167
'^'=168
':'=169
';'=170
','=171
'.'=172
'E'=175
'\\in'=181
'='=182
'<'=184
'>'=186
'!'=189
This source diff could not be displayed because it is too large. You can view the blob instead.
T__0=1
T__1=2
T__2=3
T__3=4
T__4=5
T__5=6
T__6=7
T__7=8
T__8=9
T__9=10
T__10=11
T__11=12
T__12=13
T__13=14
T__14=15
T__15=16
T__16=17
T__17=18
T__18=19
T__19=20
T__20=21
T__21=22
T__22=23
T__23=24
T__24=25
T__25=26
T__26=27
T__27=28
T__28=29
T__29=30
T__30=31
T__31=32
T__32=33
T__33=34
T__34=35
T__35=36
T__36=37
T__37=38
T__38=39
T__39=40
T__40=41
T__41=42
T__42=43
T__43=44
WS=45
DOLLAR_SIGN=46
ADD=47
SUB=48
MUL=49
DIV=50
L_PAREN=51
R_PAREN=52
L_GROUP=53
R_GROUP=54
L_BRACE=55
R_BRACE=56
L_BRACE_VISUAL=57
R_BRACE_VISUAL=58
L_BRACE_CMD=59
R_BRACE_CMD=60
L_BRACKET=61
R_BRACKET=62
L_BRACK=63
R_BRACK=64
BAR=65
L_VERT=66
R_VERT=67
VERT=68
NORM=69
L_FLOOR=70
R_FLOOR=71
LL_CORNER=72
LR_CORNER=73
L_CEIL=74
R_CEIL=75
UL_CORNER=76
UR_CORNER=77
L_LEFT=78
R_RIGHT=79
ML_LEFT=80
MR_RIGHT=81
FUNC_LIM=82
LIM_APPROACH_SYM=83
FUNC_INT=84
FUNC_SUM=85
FUNC_PROD=86
FUNC_LOG=87
FUNC_LN=88
FUNC_EXP=89
FUNC_SIN=90
FUNC_COS=91
FUNC_TAN=92
FUNC_CSC=93
FUNC_SEC=94
FUNC_COT=95
FUNC_ARCSIN=96
FUNC_ARCCOS=97
FUNC_ARCTAN=98
FUNC_ARCCSC=99
FUNC_ARCSEC=100
FUNC_ARCCOT=101
FUNC_SINH=102
FUNC_COSH=103
FUNC_TANH=104
FUNC_ARSINH=105
FUNC_ARCOSH=106
FUNC_ARTANH=107
FUNC_ARCSINH=108
FUNC_ARCCOSH=109
FUNC_ARCTANH=110
FUNC_ARSINH_NAME=111
FUNC_ARCSINH_NAME=112
FUNC_ARCOSH_NAME=113
FUNC_ARCCOSH_NAME=114
FUNC_ARTANH_NAME=115
FUNC_ARCTANH_NAME=116
FUNC_GCD_NAME=117
FUNC_LCM_NAME=118
FUNC_FLOOR_NAME=119
FUNC_CEIL_NAME=120
FUNC_SQRT=121
FUNC_GCD=122
FUNC_LCM=123
FUNC_FLOOR=124
FUNC_CEIL=125
FUNC_MAX=126
FUNC_MIN=127
FUNC_DET=128
FUNC_EYE_NAME=129
FUNC_ZEROS_NAME=130
FUNC_ONES_NAME=131
FUNC_COLS_NAME=132
FUNC_ROWS_NAME=133
FUNC_DIAG_NAME=134
FUNC_NORM_NAME=135
FUNC_RANK_NAME=136
FUNC_TRACE_NAME=137
FUNC_RREF_NAME=138
FUNC_HSTACK_NAME=139
FUNC_VSTACK_NAME=140
FUNC_ORTHOGONALIZE_NAME=141
FUNC_NULLSPACE_NAME=142
FUNC_DIAGONALIZE_NAME=143
FUNC_EIGENVALS_NAME=144
FUNC_EIGENVECTORS_NAME=145
FUNC_SVD_NAME=146
CMD_TIMES=147
CMD_CDOT=148
CMD_DIV=149
CMD_FRAC=150
CMD_BINOM=151
CMD_CHOOSE=152
CMD_MOD=153
CMD_MATHIT=154
CMD_OPERATORNAME=155
MATRIX_TYPE_MATRIX=156
MATRIX_TYPE_PMATRIX=157
MATRIX_TYPE_BMATRIX=158
MATRIX_TYPE_DET=159
MATRIX_TYPES=160
CMD_MATRIX_START=161
CMD_MATRIX_END=162
CMD_DET_START=163
CMD_DET_END=164
MATRIX_DEL_COL=165
MATRIX_DEL_ROW=166
UNDERSCORE=167
CARET=168
COLON=169
SEMICOLON=170
COMMA=171
PERIOD=172
DIFFERENTIAL=173
EXP_E=174
E_NOTATION_E=175
LETTER_NO_E=176
MATRIX_XRIGHTARROW=177
TRANSFORM_EXCHANGE=178
NUMBER=179
E_NOTATION=180
IN=181
ASSIGNMENT=182
EQUAL=183
LT=184
LTE=185
GT=186
GTE=187
UNEQUAL=188
BANG=189
PERCENT_NUMBER=190
GREEK_CMD=191
OTHER_SYMBOL_CMD=192
SYMBOL=193
VARIABLE=194
'\\acute'=1
'\\bar'=2
'\\overline'=3
'\\breve'=4
'\\check'=5
'\\widecheck'=6
'\\dot'=7
'\\ddot'=8
'\\grave'=9
'\\hat'=10
'\\tilde'=11
'\\widetilde'=12
'\\vec'=13
'\\overrightarrow'=14
'\\bm'=15
'\\boldsymbol'=16
'\\text'=17
'\\textit'=18
'\\mathbb'=19
'\\mathbin'=20
'\\mathbf'=21
'\\mathcal'=22
'\\mathclap'=23
'\\mathclose'=24
'\\mathellipsis'=25
'\\mathfrak'=26
'\\mathinner'=27
'\\mathnormal'=28
'\\mathop'=29
'\\mathopen'=30
'\\mathord'=31
'\\mathpunct'=32
'\\mathrel'=33
'\\mathring'=34
'\\mathrlap'=35
'\\mathrm'=36
'\\mathscr'=37
'\\mathsf'=38
'\\mathsterling'=39
'\\mathtt'=40
'^T'=41
'^{T}'=42
'^{\\top}'=43
'\''=44
'\\$'=46
'+'=47
'-'=48
'*'=49
'('=51
')'=52
'\\lgroup'=53
'\\rgroup'=54
'{'=55
'}'=56
'\\{'=57
'\\}'=58
'\\lbrace'=59
'\\rbrace'=60
'['=61
']'=62
'\\lbrack'=63
'\\rbrack'=64
'|'=65
'\\lvert'=66
'\\rvert'=67
'\\vert'=68
'\\|'=69
'\\lfloor'=70
'\\rfloor'=71
'\\llcorner'=72
'\\lrcorner'=73
'\\lceil'=74
'\\rceil'=75
'\\ulcorner'=76
'\\urcorner'=77
'\\left'=78
'\\right'=79
'\\mleft'=80
'\\mright'=81
'\\lim'=82
'\\int'=84
'\\sum'=85
'\\prod'=86
'\\log'=87
'\\ln'=88
'\\exp'=89
'\\sin'=90
'\\cos'=91
'\\tan'=92
'\\csc'=93
'\\sec'=94
'\\cot'=95
'\\arcsin'=96
'\\arccos'=97
'\\arctan'=98
'\\arccsc'=99
'\\arcsec'=100
'\\arccot'=101
'\\sinh'=102
'\\cosh'=103
'\\tanh'=104
'\\arsinh'=105
'\\arcosh'=106
'\\artanh'=107
'\\arcsinh'=108
'\\arccosh'=109
'\\arctanh'=110
'arsinh'=111
'arcsinh'=112
'arcosh'=113
'arccosh'=114
'artanh'=115
'arctanh'=116
'gcd'=117
'lcm'=118
'floor'=119
'ceil'=120
'\\sqrt'=121
'\\gcd'=122
'\\lcm'=123
'\\floor'=124
'\\ceil'=125
'\\max'=126
'\\min'=127
'\\det'=128
'eye'=129
'zeros'=130
'ones'=131
'cols'=132
'rows'=133
'diag'=134
'norm'=135
'rank'=136
'rref'=138
'hstack'=139
'vstack'=140
'nullspace'=142
'\\times'=147
'\\cdot'=148
'\\div'=149
'\\frac'=150
'\\choose'=152
'\\mod'=153
'\\mathit'=154
'\\operatorname'=155
'matrix'=156
'pmatrix'=157
'bmatrix'=158
'vmatrix'=159
'&'=165
'\\\\'=166
'_'=167
'^'=168
':'=169
';'=170
','=171
'.'=172
'E'=175
'\\in'=181
'='=182
'<'=184
'>'=186
'!'=189
This source diff could not be displayed because it is too large. You can view the blob instead.
#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile requirements.in
#
antlr4-python3-runtime==4.11.1
# via -r requirements.in
mpmath==1.3.0
# via sympy
sympy==1.12
# via -r requirements.in
from latex2sympy import process_sympy
import sys
sys.path.append("..")
# latex = "2\\begin{pmatrix}1&1&1\\\\0&1&1\\\\0&0&1\\end{pmatrix}\\begin{pmatrix}1&1&1\\\\0&1&1\\\\0&0&1\\end{pmatrix}"
latex = "\\frac{a^{2} \\left(3 \\pi - 4 \\sin{\\left(\\pi \\right)} + \\frac{\\sin{\\left(2 \\pi \\right)}}{2}\\right)}{2}"
math = process_sympy(latex)
print(type(math))
print("latex: %s to math: %s" % (latex, math))
from latex2sympy import process_sympy
import sys
sys.path.append("..")
latex = "\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix}"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
latex = "\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix},\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix}"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
latex = "[\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix},\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix}]"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
latex = "\\left\\{\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix},\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix}\\right\\}"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
from latex2sympy import process_sympy
from sympy import *
import sys
sys.path.append("..")
theta = Symbol('theta', real=True)
latex = "\\begin{matrix}1&2\\\\3&4\\end{matrix}"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
latex = "\\begin{matrix}1&2\\\\3&4\\\\5&6\\end{matrix}"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
latex = "\\begin{matrix}1&2&3\\\\4&5&6\\\\7&8&9\\end{matrix}"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
latex = "\\begin{matrix}x^1&x^2&x^3\\\\y^1&y^2&y^3\\\\z^1&z^2&z^3\\end{matrix}"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
latex = "\\begin{matrix}x\\\\y\\end{matrix}"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
latex = "2\\cdot\\begin{matrix}x\\\\y\\end{matrix}"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
latex = "2\\cdot\\begin{matrix}x\\\\y\\end{matrix} + \\begin{matrix}2\\\\3\\end{matrix}"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
latex = "-2\\begin{matrix}1&2\\\\3&4\\end{matrix}"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
latex = "2\\cdot\\theta\\begin{matrix}x\\\\y\\end{matrix} + \\begin{matrix}2\\\\3\\end{matrix}"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
latex = "\\theta\\begin{matrix}1\\\\3\\end{matrix} - \\begin{matrix}-1\\\\2\\end{matrix}"
math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, math))
from latex2sympy import process_sympy
from sympy import *
import sys
import hashlib
import time
sys.path.append("..")
M = Matrix([[1, 2], [3, 4]])
v = Matrix([1, 2])
# sub settings
sub_settings_symbols = {}
sub_settings_symbols[Symbol('M' + hashlib.md5('M'.encode()).hexdigest(), real=True)] = M
sub_settings_symbols[Symbol('v' + hashlib.md5('v'.encode()).hexdigest(), real=True)] = v
# one parameters
latex = "\\begin{matrix}1&2\\\\3&4\\end{matrix}\\cdot[!v!]"
equation_sympy_check = MatMul(M, Symbol('v' + hashlib.md5('v'.encode()).hexdigest(), real=True))
equation_sympy_subs_check = MatMul(M, v)
# placeholders
equation_sympy = process_sympy(latex)
print('latex = %s' % latex)
print('equation_sympy = %s' % equation_sympy)
print('equation_sympy_check = %s' % equation_sympy_check)
print('equation_sympy = %s' % (srepr(equation_sympy)))
equation_sympy_subs = equation_sympy.subs(sub_settings_symbols, evaluate=False)
print('equation_sympy_subs = %s' % equation_sympy_subs)
print('equation_sympy_subs_check = %s' % equation_sympy_subs_check)
# two parameters
# sub settings
print('')
print('============== Two Parameters -> M*v = Matrix*Vector =============')
sub_settings_symbols = {}
sub_settings_symbols[Symbol('M' + hashlib.md5('M'.encode()).hexdigest(), commutative=False)] = M
sub_settings_symbols[Symbol('v' + hashlib.md5('v'.encode()).hexdigest(), commutative=False)] = v
latex = "[!M!]\\cdot[!v!]"
math_check = Mul(Symbol('M' + hashlib.md5('M'.encode()).hexdigest(), commutative=False), Symbol('v' + hashlib.md5('v'.encode()).hexdigest(), commutative=False))
# placeholders
equation_sympy = process_sympy(latex)
print(latex)
print(math_check)
print(equation_sympy)
print(srepr(equation_sympy))
# performance
t0 = time.time()
# process_sympy and substitute at the same time
# Only needed for linalg input
placeholder_values = {'M': M, 'v': v}
equation_sympy_subs = process_sympy(latex, variable_values=placeholder_values)
t1 = time.time()
print('equation with substituted placeholders = %s' % (str(equation_sympy_subs)))
print('time to process to sympy with placeholders = %s s' % (t1 - t0))
print('')
print('============== Two Parameters -> M*v = Matrix*Vector =============')
from sympy import *
from latex2sympy import process_sympy
# latex = '\\variable{a}^{\\variable{b}}'
# variables = {'a': process_sympy('658.95998'), 'b': process_sympy('185083.8060')}
# c_ans_expr = process_sympy(latex, variables)
# print(c_ans_expr)
# print(srepr(c_ans_expr))
# c_ans = c_ans_expr.doit(deep=False).evalf(chop=True)
# print(c_ans)
# print(srepr(c_ans))
# numeric_responses = ['1', '1.0', '-1', '-1.0', '.5', '-.5', '3x10^3', '3E3', '3,000x10^{-3}', '0.5E-1', '\\frac{1}{3}', '(5\\times 3)^3', '\\sin(1)']
# for latex in numeric_responses:
# parsed = process_sympy(latex)
# print('latex: ', latex)
# print('sympy: ', parsed)
# print('is_number: ', parsed.is_number)
# print('is_Number: ', parsed.is_Number)
# print('srepr: ', srepr(parsed))
# print('-----------------------------------------------------')
from sympy import *
from latex2sympy import process_sympy
#
# Equality Testing
#
answer_sets = [
{
'correct_answer': '(x-y)(x+2y)',
'student_answers': [
'x^2+xy-2y^2',
'(x-y)(x+2y)',
'(x+2y)(x-y)',
'(2\\times y+x)(-y+x)',
'(y\\cdot 2+x)(-y+x)'
]
},
{
'correct_answer': '2\\pi \\variable{r}^2',
'student_answers': [
'2\\pi \\variable{r}^2',
'\\pi 2\\variable{r}^2',
'2\\times \\pi \\times \\variable{r}^2',
'2\\pi \\variable{r} \\times \\variable{r}'
]
},
{
'correct_answer': '2x - 3y',
'student_answers': [
'-3y + 2x'
]
},
{
'correct_answer': 'x\\times x',
'student_answers': [
'x\\times x',
'x\\cdot x',
'x^2',
'(\\sqrt{x})^{4}'
]
},
{
'correct_answer': '23e^{-1\\times \\sqrt{t^2}}',
'student_answers': [
'23e^{-t}'
]
},
{
'correct_answer': 'a=x^2+1',
'student_answers': [
'x^2+1=a'
]
}
]
for answer_set in answer_sets:
correct_answer = answer_set['correct_answer']
correct_answer_parsed = process_sympy(answer_set['correct_answer'])
for student_answer in answer_set['student_answers']:
student_answer_parsed = process_sympy(student_answer)
print('correct_answer (c): ', correct_answer, correct_answer_parsed)
print('student_answer (a): ', student_answer, student_answer_parsed)
print('')
print('Expression Tree (srepr(c) == srepr(a)) =>', srepr(correct_answer_parsed) == srepr(student_answer_parsed))
print('srepr(c) =>', srepr(correct_answer_parsed))
print('srepr(a) =>', srepr(student_answer_parsed))
print('')
# print('Structural (c == a) =>', correct_answer_parsed == student_answer_parsed)
print('Symbolic (simplify(c - s) == 0) =>', simplify(correct_answer_parsed - student_answer_parsed) == 0)
print('simplified =>', simplify(correct_answer_parsed - student_answer_parsed))
print('')
print('Numeric Substitution (c.equals(s)) =>', correct_answer_parsed.equals(student_answer_parsed))
print('-----------------------------------------------------')
from sympy import *
import sys
sys.path.append("..")
# # x^2\cdot \left(3\cdot \tan \left([!a!]\cdot x+[!c!]\right)+[!a!]\cdot x\left(\sec \left([!a!]\cdot x+[!c!]\right)\right)^2\right)
# latex1 = "x^2\\cdot \\left(3\\cdot \\tan \\left(2\\cdot x+5\\right)+2\\cdot x\\left(\\sec \\left(2\\cdot x+5\\right)\\right)^2\\right)"
# math1 = process_sympy(latex1)
# print("latex: %s to math: %s" %(latex1,math1))
#
# latex2 = "x^2\\cdot \\left(3\\cdot \\tan \\left(2\\cdot x+5\\right)+2\\cdot x\\left(\\sec \\left(2\\cdot x+5\\right)^2\\right)\\right)"
# math2 = process_sympy(latex2)
# print("latex: %s to math: %s" %(latex2,math2))
#
# latex3 = "x^2\\cdot \\left(3\\cdot \\tan \\left(2\\cdot x+5\\right)+2\\cdot x\\left(1+\\tan \\left(2\\cdot x+5\\right)^2\\right)\\right)"
# math3 = process_sympy(latex3)
# print("latex: %s to math: %s" %(latex3,math3))
#
# print(simplify(math1 - math2))
# print(simplify(math1 - math3))
#
# latex1 = "\\sec^2(2\\cdot x+5)"
# math1 = process_sympy(latex1)
# print("latex: %s to math: %s" %(latex1,math1))
#
# latex2 = "1+\\tan^2(2\\cdot x+5)"
# math2 = process_sympy(latex2)
# print("latex: %s to math: %s" %(latex2,math2))
# print(simplify(math1 - math2))
x = Symbol('x', real=True)
y = Symbol('y', real=True)
# BUG: 1 + tan^2(x+1) should be == sec^2(x+1) but isnt
lhs = (1 + (tan(x + 1))**2)
rhs = (sec(x + 1))**2
eq = lhs - rhs
print(simplify(lhs))
print(simplify(rhs))
print(simplify(eq))
print(simplify(lhs) == simplify(rhs))
# 1 + tan^2(x) == sec^2(x) but isnt
lhs = (1 + (tan(x))**2)
rhs = (sec(x))**2
eq = lhs - rhs
print(simplify(lhs))
print(simplify(rhs))
print(simplify(eq))
print(simplify(lhs) == simplify(rhs))
import numpy as np
from sympy import *
import sys
sys.path.append("..")
# row column matrix = vector
v = [1, 2, 3]
# single column matrix = vector
m = Matrix([1, 2, 3])
print(m[:, 0])
# a three row and 2 column matrix
m = Matrix([[1, 2], [3, 4], [5, 6]])
print(m[:, 0])
# determinant of lin indp system != 0
m = Matrix([[1, 1], [1, 2]])
print(m.det())
# determinant of lin dep system = 0
m = Matrix([[1, 1], [2, 2]])
print(m.det())
# determinant of lin dep system = 0
x = Symbol('x')
y = Symbol('y')
m = Matrix([[x, y], [x, y]])
print(m.det())
# Reduced Row-Echelon Form
_, ind = m.rref()
print(len(ind))
# determinant of lin dep system != 0
m = Matrix([[x, y], [y, x]])
print(m.det())
# Reduced Row-Echelon Form
_, ind = m.rref()
print(len(ind))
# determinant of lin dep system != 0
# Reduced Row-Echelon Form
m = Matrix([[x, x, y], [y, y, y]])
_, ind = m.rref()
# Reduced Row-Echelon Form
print(len(ind))
#==================#
#===== Numpy ======#
#==================#
# http://kitchingroup.cheme.cmu.edu/blog/2013/03/01/Determining-linear-independence-of-a-set-of-vectors/
# Lin Indp of set of numerical vectors
TOLERANCE = 1e-14
v1 = [6, 0, 3, 1, 4, 2]
v2 = [0, -1, 2, 7, 0, 5]
v3 = [12, 3, 0, -19, 8, -11]
A = np.row_stack([v1, v2, v3])
U, s, V = np.linalg.svd(A)
print(s)
print(np.sum(s > TOLERANCE))
v1 = [1, 1]
v2 = [4, 4]
A = np.row_stack([v1, v2])
U, s, V = np.linalg.svd(A)
print(s)
print(np.sum(s > TOLERANCE))
latex = "\\begin{matrix}1&2\\\\3&4\\end{matrix}"
# math = process_sympy(latex)
print("latex: %s to math: %s" % (latex, 1))
#!/bin/sh
# Get relative path of the root directory of the project
rdir=`git rev-parse --git-dir`
rel_path="$(dirname "$rdir")"
# Change to that path and run the file
cd $rel_path
java -jar antlr-4.11.1-complete.jar PS.g4 -o gen
#!/bin/sh
pytest --doctest-modules --junitxml=junit/test-results.xml --cov-report=xml --cov-config=.coveragerc --cov=latex2sympy tests
\ No newline at end of file
#!/bin/sh
# Get relative path of the root directory of the project
rdir=`git rev-parse --git-dir`
rel_path="$(dirname "$rdir")"
# Change to that path and run the file
cd $rel_path
# Activate virtual environment
echo "activating venv..."
if test -f .env/bin/activate
then source .env/bin/activate && echo "venv activate (bin)"
elif test -f .env/Scripts/activate
then source .env/Scripts/activate && echo "venv activated (Scripts)"
else exit 1
fi
# Run unit test coverage
echo "starting coverage..."
if pytest --doctest-modules --cov-report=html --cov-config=.coveragerc --cov=latex2sympy tests
then echo "coverage finished"
else exit 1
fi
#!/bin/sh
# Get relative path of the root directory of the project
rdir=`git rev-parse --git-dir`
rel_path="$(dirname "$rdir")"
# Change to that path and run the file
cd $rel_path
echo "pre-commit hook started..."
# Activate virtual environment
echo "activating venv..."
if test -f .env/bin/activate
then source .env/bin/activate && echo "venv activated."
elif test -f .env/Scripts/activate
then source .env/Scripts/activate && echo "venv activated."
else exit 1
fi
# Run auto formatting on all staged python files, then add those changes
echo "auto-formatting code..."
if autopep8 --in-place `git diff --name-status --cached | grep '.py' | awk 'match($1, "A|M"){print $2}'` && git add `git diff --name-status --cached | grep '.py' | awk 'match($1, "A|M"){print $2}'`
then echo "code was auto-formatted."
else echo "no code was auto-formatted."
fi
exit 0
#!/bin/sh
# Get relative path of the root directory of the project
rdir=`git rev-parse --git-dir`
rel_path="$(dirname "$rdir")"
# Change to that path and run the file
cd $rel_path
echo "pre-push hook started..."
# Activate virtual environment
echo "activating venv..."
if test -f .env/bin/activate
then source .env/bin/activate && echo "venv activated."
elif test -f .env/Scripts/activate
then source .env/Scripts/activate && echo "venv activated."
else exit 1
fi
# Run unit tests
echo "starting tests..."
# if pytest tests
# then echo "tests finished."
# else exit 1
# fi
exit 0
rm ./dist/*
python3 setup.py bdist_wheel
twine upload dist/*
#!/bin/sh
cp scripts/pre-push .git/hooks/
cp scripts/pre-commit .git/hooks/
\ No newline at end of file
#!/bin/sh
# Get relative path of the root directory of the project
rdir=`git rev-parse --git-dir`
rel_path="$(dirname "$rdir")"
# Change to that path and run the file
cd $rel_path
echo "creating venv..."
if test -d .env
then echo "venv exists"
else python3 -m venv .env && echo "venv created"
fi
echo ''
# Activate virtual environment
echo "activating venv..."
if test -f .env/bin/activate
then source .env/bin/activate && echo "venv activate (bin)"
elif test -f .env/Scripts/activate
then source .env/Scripts/activate && echo "venv activated (Scripts)"
else exit 1
fi
echo ''
echo "installing requirements..."
if pip install -r dev-requirements.txt
then echo "requirements installed"
else exit 1
fi
echo ''
echo "compiling parser..."
sh scripts/compile.sh
echo "parser compiled"
echo ''
echo "setup git hooks..."
sh scripts/setup-hooks.sh
echo "git hooks setup"
exit 0
#!/bin/sh
# Get relative path of the root directory of the project
rdir=`git rev-parse --git-dir`
rel_path="$(dirname "$rdir")"
# Change to that path and run the file
cd $rel_path
# Activate virtual environment
echo "activating venv..."
if test -f .env/bin/activate
then source .env/bin/activate && echo "venv activate (bin)"
elif test -f .env/Scripts/activate
then source .env/Scripts/activate && echo "venv activated (Scripts)"
else exit 1
fi
echo ''
echo "compiling parser..."
sh scripts/compile.sh
echo "parser compiled"
echo ''
# Run unit tests
echo "starting tests..."
if pytest tests
then echo "tests finished"
else exit 1
fi
exit 0
[pycodestyle]
max-line-length = 120
ignore = E501
from setuptools import setup, find_packages
from codecs import open
from os import path
here = path.abspath(path.dirname(__file__))
setup(
name="latex2sympy2",
version="1.9.0",
description='Convert latex to sympy with ANTLR and support Matrix, Linear Algebra and CAS functions.',
long_description_content_type='text/markdown',
long_description=open(path.join(here, "README.md"), encoding='utf-8').read(),
# The project's main homepage.
url='https://github.com/ZubinGou/latex2sympy',
# Author details
author='ZubinGou',
author_email='zebgou@gmail.com',
# Choose your license
license='MIT',
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
'Topic :: Education',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Software Development :: Compilers',
'Topic :: Text Processing :: Markup :: LaTeX',
'Topic :: Text Processing :: Markup :: Markdown',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
packages=find_packages(exclude=('tests')),
py_modules=['asciimath_printer', 'latex2sympy2'],
install_requires=[
'sympy>=1.4',
'antlr4-python3-runtime==4.11.1'
],
)
from .context import assert_equal, get_simple_examples
import pytest
from sympy import Abs
examples = get_simple_examples(Abs)
delimiter_pairs = {
'|': '|',
'\\vert': '\\vert',
'\\lvert': '\\rvert'
}
@pytest.mark.parametrize('input, output, symbolically', examples)
def test_abs(input, output, symbolically):
for left, right in delimiter_pairs.items():
assert_equal("{left}{input}{right}".format(left=left, right=right, input=input), output, symbolically=symbolically)
assert_equal("\\left{left}{input}\\right{right}".format(left=left, right=right, input=input), output, symbolically=symbolically)
assert_equal("\\mleft{left}{input}\\mright{right}".format(left=left, right=right, input=input), output, symbolically=symbolically)
from .context import assert_equal, process_sympy
import pytest
def pytest_generate_tests(metafunc):
metafunc.parametrize('s', metafunc.cls.BAD_STRINGS)
class TestAllBad(object):
# These bad latex strings should raise an exception when parsed
BAD_STRINGS = [
"(",
")",
# "a / b /",
"\\frac{d}{dx}",
"(\\frac{d}{dx})"
"\\sqrt{}",
"\\sqrt",
"{",
"}",
# "1.1.1",
"\\mathit{TEST}"
"\\frac{2}{}",
"\\frac{}{2}",
"\\int",
# "1 +",
# "a +",
"!",
"!0",
"_",
"^",
# "a // b",
# "a \\cdot \\cdot b",
# "a \\div \\div b",
"a\\mod \\begin{matrix}b\\end{matrix}"
"|",
"||x|",
"\\lfloor x",
"\\lfloor a \\rceil",
"\\operatorname{floor}(12.3, 123.4)",
"()",
"((((((((((((((((()))))))))))))))))",
"-",
"\\frac{d}{dx} + \\frac{d}{dt}",
# "f()",
# "f(,",
# "f(x,,y)",
# "f(x,y,",
"\\sin^x",
"\\cos^2",
# "\\cos 1 \\cos",
# "\\gcd(3)",
# "\\lcm(2)",
"@", "#", "$", "%", "&", "*",
"\\",
"~",
"\\frac{(2 + x}{1 - x)}",
"\\lim_{\\pi \\to 3} a",
# because mix of COMMA and SEMICOLON
"\\left\\{\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix},\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix};\\begin{pmatrix}1\\\\1\\\\1\\end{pmatrix}\\right\\}",
# percentages without numbers before-hand
"a\\%",
"\\%100",
# dollar signs without numbers after
"\\$"
]
def test_bad_string(self, s):
with pytest.raises(Exception):
process_sympy(s)
from .context import assert_equal
import pytest
from sympy import Symbol, Integer, Pow
# label, text, symbol_text
symbols = [
('letter', 'x', 'x'),
('greek letter', '\\lambda', 'lambda'),
('greek letter w/ space', '\\alpha ', 'alpha'),
('accented letter', '\\overline{x}', 'xbar')
]
subscripts = [
('2'),
('{23}'),
('i'),
('{ij}'),
('{i,j}'),
('{good}'),
('{x^2}')
]
examples = []
for symbol in symbols:
for subscript in subscripts:
examples.append(tuple(list(symbol) + [subscript]))
@pytest.mark.parametrize('label, text, symbol_text, subscript', examples)
def test_with_supexpr(label, text, symbol_text, subscript):
assert_equal(text + '^2', Pow(Symbol(symbol_text, real=True), Integer(2)))
@pytest.mark.parametrize('label, text, symbol_text, subscript', examples)
def test_with_subexpr(label, text, symbol_text, subscript):
assert_equal(text + '_' + subscript, Symbol(symbol_text + '_' + subscript, real=True))
@pytest.mark.parametrize('label, text, symbol_text, subscript', examples)
def test_with_subexpr_before_supexpr(label, text, symbol_text, subscript):
assert_equal(text + '_' + subscript + '^2', Pow(Symbol(symbol_text + '_' + subscript, real=True), Integer(2)))
@pytest.mark.parametrize('label, text, symbol_text, subscript', examples)
def test_with_subexpr_before_supexpr_with_braces(label, text, symbol_text, subscript):
wrapped_subscript = subscript if '{' in subscript else '{' + subscript + '}'
assert_equal(text + '_' + wrapped_subscript + '^{2}', Pow(Symbol(symbol_text + '_' + subscript, real=True), Integer(2)))
@pytest.mark.parametrize('label, text, symbol_text, subscript', examples)
def test_with_supexpr_before_subexpr(label, text, symbol_text, subscript):
assert_equal(text + '^2_' + subscript, Pow(Symbol(symbol_text + '_' + subscript, real=True), Integer(2)))
@pytest.mark.parametrize('label, text, symbol_text, subscript', examples)
def test_with_supexpr_before_subexpr_with_braces(label, text, symbol_text, subscript):
wrapped_subscript = subscript if '{' in subscript else '{' + subscript + '}'
assert_equal(text + '^{2}_' + wrapped_subscript, Pow(Symbol(symbol_text + '_' + subscript, real=True), Integer(2)))
from .context import assert_equal, _Add, _Mul, _Pow
import pytest
from sympy import binomial, Symbol
x = Symbol('x', real=True)
y = Symbol('y', real=True)
theta = Symbol('theta', real=True)
gamma = Symbol('gamma', real=True)
def test_binomial_numeric():
assert_equal("\\binom{16}{2}", binomial(16, 2))
def test_binomial_symbols():
assert_equal("\\binom{x}{y}", binomial(x, y))
def test_binomial_greek_symbols():
assert_equal("\\binom{\\theta}{\\gamma}", binomial(theta, gamma))
def test_binomial_expr():
assert_equal("\\binom{16+2}{\\frac{4}{2}}", binomial(_Add(16, 2), _Mul(4, _Pow(2, -1)), evaluate=False))
def test_choose_numeric():
assert_equal("\\choose{16}{2}", binomial(16, 2))
def test_choose_symbols():
assert_equal("\\choose{x}{y}", binomial(x, y))
def test_choose_greek_symbols():
assert_equal("\\choose{\\theta}{\\gamma}", binomial(theta, gamma))
from .context import assert_equal, get_simple_examples
import pytest
from sympy import ceiling
examples = get_simple_examples(ceiling)
@pytest.mark.parametrize('input, output, symbolically', examples)
def test_ceil_func(input, output, symbolically):
assert_equal("\\ceil({input})".format(input=input), output, symbolically=symbolically)
@pytest.mark.parametrize('input, output, symbolically', examples)
def test_ceil_operatorname(input, output, symbolically):
assert_equal("\\operatorname{{ceil}}({input})".format(input=input), output, symbolically=symbolically)
@pytest.mark.parametrize('input, output, symbolically', examples)
def test_ceil_cmd(input, output, symbolically):
assert_equal("\\lceil {input}\\rceil".format(input=input), output, symbolically=symbolically)
assert_equal("\\left\\lceil {input}\\right\\rceil".format(input=input), output, symbolically=symbolically)
assert_equal("\\mleft\\lceil {input}\\mright\\rceil".format(input=input), output, symbolically=symbolically)
@pytest.mark.parametrize('input, output, symbolically', examples)
def test_ceil_corners(input, output, symbolically):
assert_equal("\\ulcorner {input}\\urcorner".format(input=input), output, symbolically=symbolically)
assert_equal("\\left\\ulcorner {input}\\right\\urcorner".format(input=input), output, symbolically=symbolically)
assert_equal("\\mleft\\ulcorner {input}\\mright\\urcorner".format(input=input), output, symbolically=symbolically)
from .context import assert_equal
import pytest
from sympy import Sum, I, Symbol, Integer
a = Symbol('a', real=True)
b = Symbol('b', real=True)
i = Symbol('i', real=True)
n = Symbol('n', real=True)
x = Symbol('x', real=True)
def test_complex():
assert_equal("a+Ib", a + I * b)
def test_complex_e():
assert_equal("e^{I\\pi}", Integer(-1))
def test_complex_sum():
assert_equal("\\sum_{i=0}^{n} i \\cdot x", Sum(i * x, (i, 0, n)))
from sympy import simplify, srepr, Add, Mul, Pow, Rational, pi, sqrt, Symbol
from latex2sympy.latex2sympy2 import latex2sympy as process_sympy
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
x = Symbol('x', real=True)
y = Symbol('y', real=True)
# shorthand definitions
def _Add(a, b):
return Add(a, b, evaluate=False)
def _Mul(a, b):
return Mul(a, b, evaluate=False)
def _Pow(a, b):
return Pow(a, b, evaluate=False)
def get_simple_examples(func):
'''
Returns an array of tuples, containing the string `input`, sympy `output` using the provided sympy `func`, and `symbolically` boolean
for calling `compare`.
'''
return [
("1.1", func(1.1), False),
("6.9", func(6.9), False),
("3.5", func(3.5), False),
("8", func(8), False),
("0", func(0), False),
("290348E32", func(Rational('290348E32')), False),
("1237.293894239480234", func(Rational('1237.293894239480234')), False),
("8623.4592104E-2", func(Rational('8623.4592104E-2')), False),
("\\pi ", func(pi), False),
("\\sqrt{100}", func(sqrt(100)), False),
("12,123.4", func(Rational('12123.4')), False),
("-9.4", func(-9.4), False),
("-35.9825", func(-35.9825), False),
("-\\sqrt{5}", func(-sqrt(5)), False),
("-324E-3", func(Rational('-324E-3')), False),
("-0.23", func(-0.23), False),
("\\frac{1}{2}", func(Rational('1/2')), False),
("\\frac{6}{2}", func(Rational('6/2')), False),
("\\frac{9}{5}", func(Rational('9/5')), False),
("\\frac{-42}{6}", func(Rational('-42/6')), False),
("-\\frac{325}{3}", func(Rational('-325/3')), False),
("\\frac{\\pi }{2}", func(pi / 2), False),
("(1+6)/3", func(Rational(1 + 6, 3)), False),
("1+6/3", func(1 + Rational('6/3')), False),
("7*4/5", func(7 * 4 / 5), False),
("15-2.3", func(15 - Rational('2.3')), False),
("x", func(x), True),
("x + y", func(x + y), True),
("\\frac{9x}{4}", func(9 * x / 4), True),
("y\\pi", func(y * pi), True),
("2y-y-y", func(2 * y - y - y), True)
]
def compare(actual, expected, symbolically=False):
if symbolically:
assert simplify(actual - expected) == 0
else:
actual_exp_tree = srepr(actual)
expected_exp_tree = srepr(expected)
try:
assert actual_exp_tree == expected_exp_tree
except Exception:
if isinstance(actual, int) or isinstance(actual, float) or actual.is_number and isinstance(expected, int) or isinstance(expected, float) or expected.is_number:
assert actual == expected or actual - expected == 0 or simplify(actual - expected) == 0
else:
print('expected_exp_tree = ', expected_exp_tree)
print('actual exp tree = ', actual_exp_tree)
raise
def assert_equal(latex, expr, variable_values={}, symbolically=False):
parsed = process_sympy(latex, variable_values)
compare(parsed, expr, symbolically)
from .context import assert_equal
import pytest
from sympy import exp, sin, Symbol, E
x = Symbol('x', real=True)
y = Symbol('y', real=True)
def test_exp_letter():
assert_equal("e", E)
assert_equal("e", exp(1))
def test_exp_func():
assert_equal("\\exp(3)", exp(3))
def test_exp_func_no_delim():
assert_equal("\\exp3", exp(3))
def test_exp_command_symbol():
assert_equal("\\exponentialE", E)
assert_equal("\\exponentialE", exp(1))
def test_exp_command_symbol_expression():
assert_equal("\\exponentialE^{3}", exp(3))
def test_exp_command_symbol_multiplied():
'''
\\exponentialE is NOT a function, so using the following notation equates to multiplication
'''
assert_equal("\\exponentialE (3)", E * 3)
assert_equal("\\exponentialE \\left( 3\\right)", E * 3)
assert_equal("\\exponentialE \\times 3", E * 3)
def test_exp_numeric():
assert_equal("e^3", exp(3))
def test_exp_symbol():
assert_equal("e^x", exp(x))
def test_exp_symbol_expr():
assert_equal("e^{x+y}", exp(x + y))
def test_exp_symbol_expr_group():
assert_equal("e^{(x+y)}", exp(x + y))
def test_exp_expr():
assert_equal("\\sin(x)*e^x", sin(x) * exp(x))
from .context import assert_equal, get_simple_examples
import pytest
from sympy import floor
examples = get_simple_examples(floor)
@pytest.mark.parametrize('input, output, symbolically', examples)
def test_floor_func(input, output, symbolically):
assert_equal("\\floor({input})".format(input=input), output, symbolically=symbolically)
@pytest.mark.parametrize('input, output, symbolically', examples)
def test_floor_operatorname(input, output, symbolically):
assert_equal("\\operatorname{{floor}}({input})".format(input=input), output, symbolically=symbolically)
@pytest.mark.parametrize('input, output, symbolically', examples)
def test_floor_cmd(input, output, symbolically):
assert_equal("\\lfloor {input}\\rfloor".format(input=input), output, symbolically=symbolically)
assert_equal("\\left\\lfloor {input}\\right\\rfloor".format(input=input), output, symbolically=symbolically)
assert_equal("\\mleft\\lfloor {input}\\mright\\rfloor".format(input=input), output, symbolically=symbolically)
@pytest.mark.parametrize('input, output, symbolically', examples)
def test_floor_corners(input, output, symbolically):
assert_equal("\\llcorner {input}\\lrcorner".format(input=input), output, symbolically=symbolically)
assert_equal("\\left\\llcorner {input}\\right\\lrcorner".format(input=input), output, symbolically=symbolically)
assert_equal("\\mleft\\llcorner {input}\\mright\\lrcorner".format(input=input), output, symbolically=symbolically)
from .context import assert_equal
import pytest
from sympy import Symbol
epsilon_upper = Symbol('char"000190', real=True)
epsilon_lower = Symbol('epsilon', real=True)
varepsilon = Symbol('varepsilon', real=True)
def test_greek_epsilon():
assert_equal("\\epsilon", epsilon_lower)
def test_greek_epsilon_upper():
assert_equal('\\char"000190', epsilon_upper)
def test_greek_varepsilon():
assert_equal('\\varepsilon', varepsilon)
from .context import assert_equal, _Pow, _Add, _Mul
import pytest
from sympy import Integral, sin, Symbol, Mul, Integer, Pow
from latex2sympy.latex2sympy2 import latex2sympy as process_sympy
a = Symbol('a', real=True)
b = Symbol('b', real=True)
x = Symbol('x', real=True)
theta = Symbol('theta', real=True)
func_arg_examples = [
('\\int ', 'x dx', Integral(x, x)),
('\\sin', '\\theta ', sin(theta))
]
example_groups = [
('1+2', '3-4', _Mul(_Add(1, 2), _Add(3, _Mul(-1, 4))))
]
modifiable_delimiter_pairs = {
'(': ')',
'\\lgroup': '\\rgroup',
'\\{': '\\}',
'\\lbrace': '\\rbrace',
'[': ']',
'\\lbrack': '\\rbrack',
}
@pytest.mark.parametrize('func, args, output', func_arg_examples)
def test_func_arg_groupings(func, args, output):
# none
assert_equal("{func} {args}".format(func=func, args=args), output)
# normal brace (not modifiable)
assert_equal("{func}{{{args}}}".format(func=func, args=args), output)
# rest of delimiters, with modifications
for left, right in modifiable_delimiter_pairs.items():
assert_equal("{func}{left}{args}{right}".format(left=left, right=right, func=func, args=args), output)
assert_equal("{func}\\left{left}{args}\\right{right}".format(left=left, right=right, func=func, args=args), output)
assert_equal("{func}\\mleft{left}{args}\\mright{right}".format(left=left, right=right, func=func, args=args), output)
@pytest.mark.parametrize('group1, group2, output', example_groups)
def test_delimiter_groupings(group1, group2, output):
# normal brace (not modifiable)
assert_equal("{{{group1}}}{{{group2}}}".format(group1=group1, group2=group2), output)
# rest of delimiters, with modifications
for left, right in modifiable_delimiter_pairs.items():
assert_equal("{left}{group1}{right}{left}{group2}{right}".format(left=left, right=right, group1=group1, group2=group2), output)
assert_equal("\\left{left}{group1}\\right{right}\\left{left}{group2}\\right{right}".format(left=left, right=right, group1=group1, group2=group2), output)
assert_equal("\\mleft{left}{group1}\\mright{right}\\mleft{left}{group2}\\mright{right}".format(left=left, right=right, group1=group1, group2=group2), output)
from .context import assert_equal
import pytest
from sympy import Symbol, Rational, UnevaluatedExpr, lcm, ilcm, sqrt, pi
x = Symbol('x', real=True)
y = Symbol('y', real=True)
z = Symbol('z', real=True)
def test_lcm_usual():
assert_equal("\\lcm(6, 4)", lcm(6, 4))
assert_equal("\\lcm(4, 6)", lcm(4, 6))
assert_equal("\\lcm(2, 2)", lcm(2, 2))
assert_equal("\\lcm(0, 21)", UnevaluatedExpr(lcm(0, 21)))
assert_equal("\\lcm(21, 0)", UnevaluatedExpr(lcm(21, 0)))
assert_equal("\\lcm(0, 0)", UnevaluatedExpr(lcm(0, 0)))
assert_equal("\\lcm(9, 21)", lcm(9, 21))
assert_equal("\\lcm(6128, 24)", lcm(6128, 24))
assert_equal("\\lcm(24, 6128)", lcm(24, 6128))
assert_equal("\\lcm(1E20, 1000000)", lcm(Rational('1E20'), 1000000))
assert_equal("\\lcm(128*10^32, 1)", lcm(Rational('128E32'), 1))
assert_equal("\\operatorname{lcm}(6, 4)", lcm(6, 4))
assert_equal("\\operatorname{lcm}(4, 6)", lcm(4, 6))
assert_equal("\\operatorname{lcm}(2, 2)", lcm(2, 2))
assert_equal("\\operatorname{lcm}(0, 21)", UnevaluatedExpr(lcm(0, 21)))
assert_equal("\\operatorname{lcm}(21, 0)", UnevaluatedExpr(lcm(21, 0)))
assert_equal("\\operatorname{lcm}(0, 0)", UnevaluatedExpr(lcm(0, 0)))
assert_equal("\\operatorname{lcm}(9, 21)", lcm(9, 21))
assert_equal("\\operatorname{lcm}(6128, 24)", lcm(6128, 24))
assert_equal("\\operatorname{lcm}(24, 6128)", lcm(24, 6128))
assert_equal("\\operatorname{lcm}(1E20, 1000000)", lcm(Rational('1E20'), 1000000))
assert_equal("\\operatorname{lcm}(128*10^32, 1)", lcm(Rational('128E32'), 1))
def test_lcm_negative():
assert_equal("\\lcm(-12, 4)", lcm(-12, 4))
assert_equal("\\lcm(219, -9)", lcm(219, -9))
assert_equal("\\lcm(-8, -12)", lcm(-8, -12))
assert_equal("\\lcm(-5, -5)", lcm(-5, -5))
assert_equal("\\lcm(-1, 182033)", lcm(-1, 182033))
assert_equal("\\lcm(25, -30)", lcm(25, -30))
assert_equal("\\lcm(243, -2.9543127E21)", lcm(243, Rational('-2.9543127E21')))
assert_equal("\\operatorname{lcm}(-12, 4)", lcm(-12, 4))
assert_equal("\\operatorname{lcm}(219, -9)", lcm(219, -9))
assert_equal("\\operatorname{lcm}(-8, -12)", lcm(-8, -12))
assert_equal("\\operatorname{lcm}(-5, -5)", lcm(-5, -5))
assert_equal("\\operatorname{lcm}(-1, 182033)", lcm(-1, 182033))
assert_equal("\\operatorname{lcm}(25, -30)", lcm(25, -30))
assert_equal("\\operatorname{lcm}(243, -2.9543127E21)", lcm(243, Rational('-2.9543127E21')))
def test_lcm_float():
assert_equal("\\lcm(2.4, 3.6)", lcm(Rational('2.4'), Rational('3.6')))
assert_equal("\\lcm(3.6, 2.4)", lcm(Rational('3.6'), Rational('2.4')))
assert_equal("\\lcm(\\pi, 3)", lcm(pi, 3))
assert_equal("\\lcm(618, 1.5)", lcm(618, Rational('1.5')))
assert_equal("\\lcm(-1.5, 618)", lcm(Rational('-1.5'), 618))
assert_equal("\\lcm(0.42, 2)", lcm(Rational('0.42'), 2))
assert_equal("\\lcm(1.43E-13, 21)", lcm(Rational('1.43E-13'), 21))
assert_equal("\\lcm(21, -143E-13)", lcm(21, Rational('-143E-13')))
assert_equal("\\lcm(9.80655, 9.80655)", lcm(Rational('9.80655'), Rational('9.80655')))
assert_equal("\\lcm(0.0000923423, -8341.234802909)", lcm(Rational('0.0000923423'), Rational('-8341.234802909')))
assert_equal("\\lcm(\\sqrt{5}, \\sqrt{2})", lcm(sqrt(5), sqrt(2)))
assert_equal("\\operatorname{lcm}(2.4, 3.6)", lcm(Rational('2.4'), Rational('3.6')))
assert_equal("\\operatorname{lcm}(3.6, 2.4)", lcm(Rational('3.6'), Rational('2.4')))
assert_equal("\\operatorname{lcm}(\\pi, 3)", lcm(pi, 3))
assert_equal("\\operatorname{lcm}(618, 1.5)", lcm(618, Rational('1.5')))
assert_equal("\\operatorname{lcm}(-1.5, 618)", lcm(Rational('-1.5'), 618))
assert_equal("\\operatorname{lcm}(0.42, 2)", lcm(Rational('0.42'), 2))
assert_equal("\\operatorname{lcm}(1.43E-13, 21)", lcm(Rational('1.43E-13'), 21))
assert_equal("\\operatorname{lcm}(21, -143E-13)", lcm(21, Rational('-143E-13')))
assert_equal("\\operatorname{lcm}(9.80655, 9.80655)", lcm(Rational('9.80655'), Rational('9.80655')))
assert_equal("\\operatorname{lcm}(0.0000923423, -8341.234802909)", lcm(Rational('0.0000923423'), Rational('-8341.234802909')))
assert_equal("\\operatorname{lcm}(\\sqrt{5}, \\sqrt{2})", lcm(sqrt(5), sqrt(2)))
def test_lcm_fraction():
assert_equal("\\lcm(1/2, 3)", lcm(Rational('1/2'), 3))
assert_equal("\\lcm(3, 1/2)", lcm(3, Rational('1/2')))
assert_equal("\\lcm(6/2, 3)", lcm(Rational('6/2'), 3))
assert_equal("\\lcm(1/10, 1/10)", lcm(Rational('1/10'), Rational('1/10')))
assert_equal("\\lcm(42, 42/6)", lcm(42, Rational('42/6')))
assert_equal("\\lcm(10000000/10, 10000)", lcm(Rational('10000000/10'), 10000))
assert_equal("\\operatorname{lcm}(1/2, 3)", lcm(Rational('1/2'), 3))
assert_equal("\\operatorname{lcm}(3, 1/2)", lcm(3, Rational('1/2')))
assert_equal("\\operatorname{lcm}(6/2, 3)", lcm(Rational('6/2'), 3))
assert_equal("\\operatorname{lcm}(1/10, 1/10)", lcm(Rational('1/10'), Rational('1/10')))
assert_equal("\\operatorname{lcm}(42, 42/6)", lcm(42, Rational('42/6')))
assert_equal("\\operatorname{lcm}(10000000/10, 10000)", lcm(Rational('10000000/10'), 10000))
def test_lcm_expr():
assert_equal("\\lcm(1+1, 8)", lcm(1 + 1, 8))
assert_equal("920*\\lcm(9, 12*4/2)", 920 * lcm(9, 12 * Rational('4/2')))
assert_equal("\\lcm(32-128, 10)*22", lcm(32 - 128, 10) * 22)
assert_equal("\\sqrt{\\lcm(1.25E24, 1E12)}", sqrt(lcm(Rational('1.25E24'), Rational('1E12'))))
assert_equal("\\lcm(92.0, 000+2)", lcm(Rational('92.0'), 000 + 2))
assert_equal("\\operatorname{lcm}(1+1, 8)", lcm(1 + 1, 8))
assert_equal("920*\\operatorname{lcm}(9, 12*4/2)", 920 * lcm(9, 12 * Rational('4/2')))
assert_equal("\\operatorname{lcm}(32-128, 10)*22", lcm(32 - 128, 10) * 22)
assert_equal("\\sqrt{\\operatorname{lcm}(1.25E24, 1E12)}", sqrt(lcm(Rational('1.25E24'), Rational('1E12'))))
assert_equal("\\operatorname{lcm}(92.0, 000+2)", lcm(Rational('92.0'), 000 + 2))
def test_lcm_symbol():
assert_equal("\\lcm(x, y)", lcm(x, y), symbolically=True)
assert_equal("\\lcm(y, -x)", lcm(y, -x), symbolically=True)
assert_equal("\\lcm(2y, x)", lcm(2 * y, x), symbolically=True)
assert_equal("\\lcm(125, 50x)", lcm(125, 50 * x), symbolically=True)
assert_equal("\\lcm(x + 76, \\sqrt{x} * 4)", lcm(x + 76, sqrt(x) * 4), symbolically=True)
assert_equal("\\lcm(y, y)", lcm(y, y), symbolically=True)
assert_equal("y + \\lcm(0.4x, 8/3) / 2", y + lcm(Rational('0.4') * x, Rational('8/3')) / 2, symbolically=True)
assert_equal("6.673E-11 * (\\lcm(8.85418782E-12, 9x) + 4) / 8y", Rational('6.673E-11') * (lcm(Rational('8.85418782E-12'), 9 * x) + 4) / (8 * y), symbolically=True)
assert_equal("\\operatorname{lcm}(x, y)", lcm(x, y), symbolically=True)
assert_equal("\\operatorname{lcm}(y, -x)", lcm(y, -x), symbolically=True)
assert_equal("\\operatorname{lcm}(2y, x)", lcm(2 * y, x), symbolically=True)
assert_equal("\\operatorname{lcm}(125, 50x)", lcm(125, 50 * x), symbolically=True)
assert_equal("\\operatorname{lcm}(x + 76, \\sqrt{x} * 4)", lcm(x + 76, sqrt(x) * 4), symbolically=True)
assert_equal("\\operatorname{lcm}(y, y)", lcm(y, y), symbolically=True)
assert_equal("y + \\operatorname{lcm}(0.4x, 8/3) / 2", y + lcm(Rational('0.4') * x, Rational('8/3')) / 2, symbolically=True)
assert_equal("6.673E-11 * (\\operatorname{lcm}(8.85418782E-12, 9x) + 4) / 8y", Rational('6.673E-11') * (lcm(Rational('8.85418782E-12'), 9 * x) + 4) / (8 * y), symbolically=True)
def test_multiple_parameters():
assert_equal("\\lcm(830,450)", lcm(830, 450))
assert_equal("\\lcm(6,321,429)", ilcm(6, 321, 429))
assert_equal("\\lcm(14,2324)", lcm(14, 2324))
assert_equal("\\lcm(3, 6, 2)", ilcm(3, 6, 2))
assert_equal("\\lcm(8, 9, 21)", ilcm(8, 9, 21))
assert_equal("\\lcm(144, 2988, 37116)", ilcm(144, 2988, 37116))
assert_equal("\\lcm(144,2988,37116,18,72)", ilcm(144, 2988, 37116, 18, 72))
assert_equal("\\lcm(144, 2988, 37116, 18, 72, 12, 6)", ilcm(144, 2988, 37116, 18, 72, 12, 6))
assert_equal("\\lcm(32)", lcm(32, 32))
assert_equal("\\lcm(-8, 4, -2)", lcm(-8, lcm(4, -2)))
assert_equal("\\lcm(x, y, z)", lcm(x, lcm(y, z)), symbolically=True)
assert_equal("\\lcm(6*4, 48, 3)", ilcm(6 * 4, 48, 3))
assert_equal("\\lcm(2.4, 3.6, 0.6)", lcm(Rational('2.4'), lcm(Rational('3.6'), Rational('0.6'))))
assert_equal("\\lcm(\\sqrt{3}, \\sqrt{2},\\sqrt{100})", lcm(sqrt(3), lcm(sqrt(2), sqrt(100))))
assert_equal("\\lcm(1E12, 1E6, 1E3, 10)", ilcm(Rational('1E12'), Rational('1E6'), Rational('1E3'), 10))
assert_equal("\\operatorname{lcm}(830,450)", lcm(830, 450))
assert_equal("\\operatorname{lcm}(6,321,429)", ilcm(6, 321, 429))
assert_equal("\\operatorname{lcm}(14,2324)", lcm(14, 2324))
assert_equal("\\operatorname{lcm}(3, 6, 2)", ilcm(3, 6, 2))
assert_equal("\\operatorname{lcm}(8, 9, 21)", ilcm(8, 9, 21))
assert_equal("\\operatorname{lcm}(144, 2988, 37116)", ilcm(144, 2988, 37116))
assert_equal("\\operatorname{lcm}(144,2988,37116,18,72)", ilcm(144, 2988, 37116, 18, 72))
assert_equal("\\operatorname{lcm}(144, 2988, 37116, 18, 72, 12, 6)", ilcm(144, 2988, 37116, 18, 72, 12, 6))
assert_equal("\\operatorname{lcm}(32)", lcm(32, 32))
assert_equal("\\operatorname{lcm}(-8, 4, -2)", lcm(-8, lcm(4, -2)))
assert_equal("\\operatorname{lcm}(x, y, z)", lcm(x, lcm(y, z)), symbolically=True)
assert_equal("\\operatorname{lcm}(6*4,48, 3)", ilcm(6 * 4, 48, 3))
assert_equal("\\operatorname{lcm}(2.4, 3.6,0.6)", lcm(Rational('2.4'), lcm(Rational('3.6'), Rational('0.6'))))
assert_equal("\\operatorname{lcm}(\\sqrt{3}, \\sqrt{2},\\sqrt{100})", lcm(sqrt(3), lcm(sqrt(2), sqrt(100))))
assert_equal("\\operatorname{lcm}(1E12,1E6, 1E3, 10)", ilcm(Rational('1E12'), Rational('1E6'), Rational('1E3'), 10))
from .context import assert_equal
import pytest
from sympy import sin, Symbol
x = Symbol('x', real=True)
def test_left_right_cdot():
assert_equal("\\sin\\left(x\\right)\\cdot x", sin(x) * x)
from .context import assert_equal
import pytest
from sympy import MatMul, Matrix
def test_linalg_placeholder():
assert_equal("\\begin{pmatrix}1&2\\\\3&4\\end{pmatrix}\\cdot\\variable{v}", MatMul(Matrix([[1, 2], [3, 4]]), Matrix([1, 2])), {'v': Matrix([1, 2])})
def test_linalg_placeholder_multiple():
assert_equal("\\variable{M}\\cdot\\variable{v}", MatMul(Matrix([[1, 2], [3, 4]]), Matrix([1, 2])), {'M': Matrix([[1, 2], [3, 4]]), 'v': Matrix([1, 2])})
def test_linalg_placeholder_multiple_mul():
assert_equal("\\begin{pmatrix}3&-1\\end{pmatrix}\\cdot\\variable{M}\\cdot\\variable{v}", MatMul(Matrix([[3, -1]]), Matrix([[1, 2], [3, 4]]), Matrix([1, 2])), {'M': Matrix([[1, 2], [3, 4]]), 'v': Matrix([1, 2])})
from .context import assert_equal
import pytest
from sympy import Symbol, Rational, Float, Max, sqrt, exp, pi, nsimplify
x = Symbol('x', real=True)
y = Symbol('y', real=True)
z = Symbol('z', real=True)
def test_max_usual():
assert_equal("\\max(1, 5)", Max(1, 5))
assert_equal("\\max(12, 4)", Max(12, 4))
assert_equal("\\max(109, 120)", Max(109, 120))
assert_equal("\\max(3, 3)", Max(3, 3))
assert_equal("\\max(0, 0)", Max(0, 0))
assert_equal("\\max(1)", Max(1))
assert_equal("\\max(1092198374, 290348E32)", Max(1092198374, Rational('290348E32')))
assert_equal("\\max(5, 2, 17, 4)", Max(5, 2, 17, 4))
def test_max_negative():
assert_equal("\\max(-9, 4)", Max(-9, 4))
assert_equal("\\max(4, -9)", Max(4, -9))
assert_equal("\\max(-7)", Max(-7))
assert_equal("\\max(-2, -2)", Max(-2, -2))
assert_equal("\\max(-324E-3, -58)", Max(Rational('-324E-3'), -58))
assert_equal("\\max(-1, 0, 1, -37, 42)", Max(-1, 0, 1, -37, 42))
def test_max_float():
assert_equal("\\max(\\pi, 3)", Max(pi, 3))
assert_equal("\\max(1234.56789, 1234.5678901)", Max(Rational('1234.56789'), Rational('1234.5678901')))
assert_equal("\\max(12.4, 9.5)", Max(12.4, 9.5))
assert_equal("\\max(6, 6.2)", Max(6, 6.2))
assert_equal("\\max(-98.7)", Max(-98.7))
assert_equal("\\max(7.1, 9)", Max(7.1, 9))
assert_equal("\\max(-21E-12, 0.00005)", Max(nsimplify(Rational('-21E-12')), Rational('0.00005')), symbolically=True)
assert_equal("\\max(\\sqrt{3}, 0, 1)", Max(sqrt(3), 0, 1))
def test_max_fraction():
assert_equal("\\max(1/2, 1/4)", Max(Rational('1/2'), Rational('1/4')))
assert_equal("\\max(6/2, 3)", Max(Rational('6/2'), 3))
assert_equal("\\max(2/4, 1/2)", Max(Rational('2/4'), Rational('1/2')))
assert_equal("\\max(-12/5, 6.4)", Max(Rational('-12/5'), Rational('6.4')))
assert_equal("\\max(1/10)", Max(Rational('1/10')))
assert_equal("\\max(1.5, \\pi/2)", Max(Rational('1.5'), pi / 2, evaluate=False))
assert_equal("\\max(-4/3, -2/1, 0/9, -3)", Max(Rational('-4/3'), Rational('-2/1'), Rational('0/9'), -3))
def test_max_expr():
assert_equal("\\max((1+6)/3, 7)", Max(Rational(1 + 6, 3), 7))
assert_equal("\\max(58*9)", Max(58 * 9))
assert_equal("\\max(1+6/3, -5)", Max(1 + Rational('6/3'), -5))
assert_equal("\\max(7*4/5, 092) * 2", Max(7 * 4 / 5, 92) * 2)
assert_equal("38+\\max(13, 15-2.3)", 38 + Max(13, 15 - Rational('2.3')))
assert_equal("\\sqrt{\\max(99.9999999999999, 100)}", sqrt(Max(Rational('99.9999999999999'), 100)))
assert_equal("\\max(274/(5+2), \\exp(12.4), 1.4E2)", Max(Rational(274, 5 + 2), exp(Rational('12.4')), Rational('1.4E2')))
def test_max_symbol():
assert_equal("\\max(x)", Max(x), symbolically=True)
assert_equal("\\max(x, y)", Max(x, y), symbolically=True)
assert_equal("\\max(y, x)", Max(y, x), symbolically=True)
assert_equal("\\max(x+y, y+x)", Max(x + y, y + x), symbolically=True)
assert_equal("\\max(9x/4, z)", Max(9 * x / 4, z), symbolically=True)
assert_equal("\\max(y\\pi, 9)", Max(y * pi, 9), symbolically=True)
assert_equal("\\max(2y-y, y + 1)", Max(2 * y - y, y + 1), symbolically=True)
assert_equal("\\max(z, y, x)", Max(z, y, x), symbolically=True)
def test_max_multiarg():
assert_equal("\\max(1,2)", Max(1, 2))
assert_equal("\\max(9,876,543)", Max(9, 876, 543))
assert_equal("\\max(x, y,z)", Max(x, y, z), symbolically=True)
assert_equal("\\max(5.8,7.4, 2.2,-10)", Max(Rational('5.8'), Rational('7.4'), Rational('2.2'), -10))
assert_equal("\\max(\\pi,12E2,84,\\sqrt{5},12/5)", Max(pi, Rational('12E2'), 84, sqrt(5), Rational('12/5')))
assert_equal("\\max(823,51)", Max(823, 51))
assert_equal("\\max(72*4,23, 9)", Max(72 * 4, 23, 9))
from .context import assert_equal
import pytest
from sympy import Symbol, Rational, Float, Min, sqrt, exp, pi, nsimplify
x = Symbol('x', real=True)
y = Symbol('y', real=True)
z = Symbol('z', real=True)
def test_min_usual():
assert_equal("\\min(1, 5)", Min(1, 5))
assert_equal("\\min(12, 4)", Min(12, 4))
assert_equal("\\min(109, 120)", Min(109, 120))
assert_equal("\\min(3, 3)", Min(3, 3))
assert_equal("\\min(0, 0)", Min(0, 0))
assert_equal("\\min(1)", Min(1))
assert_equal("\\min(1092198374, 290348E32)", Min(1092198374, Rational('290348E32')))
assert_equal("\\min(5, 2, 17, 4)", Min(5, 2, 17, 4))
def test_min_negative():
assert_equal("\\min(-9, 4)", Min(-9, 4))
assert_equal("\\min(4, -9)", Min(4, -9))
assert_equal("\\min(-7)", Min(-7))
assert_equal("\\min(-2, -2)", Min(-2, -2))
assert_equal("\\min(-324E-3, -58)", Min(Rational('-324E-3'), -58))
assert_equal("\\min(-1, 0, 1, -37, 42)", Min(-1, 0, 1, -37, 42))
def test_min_float():
assert_equal("\\min(\\pi, 3)", Min(pi, 3))
assert_equal("\\min(1234.56789, 1234.5678901)", Min(Rational('1234.56789'), Rational('1234.5678901')))
assert_equal("\\min(12.4, 9.5)", Min(12.4, 9.5))
assert_equal("\\min(6, 6.2)", Min(6, 6.2))
assert_equal("\\min(-98.7)", Min(-98.7))
assert_equal("\\min(7.1, 9)", Min(7.1, 9))
assert_equal("\\min(-21E-12, 0.00005)", Min(nsimplify(Rational('-21E-12')), Rational('0.00005')), symbolically=True)
assert_equal("\\min(\\sqrt{3}, 0, 1)", Min(sqrt(3), 0, 1))
def test_min_fraction():
assert_equal("\\min(1/2, 1/4)", Min(Rational('1/2'), Rational('1/4')))
assert_equal("\\min(6/2, 3)", Min(Rational('6/2'), 3))
assert_equal("\\min(2/4, 1/2)", Min(Rational('2/4'), Rational('1/2')))
assert_equal("\\min(-12/5, 6.4)", Min(Rational('-12/5'), Rational('6.4')))
assert_equal("\\min(1/10)", Min(Rational('1/10')))
assert_equal("\\min(1.5, \\pi/2)", Min(Rational('1.5'), pi / 2, evaluate=False))
assert_equal("\\min(-4/3, -2/1, 0/9, -3)", Min(Rational('-4/3'), Rational('-2/1'), Rational('0/9'), -3))
def test_min_expr():
assert_equal("\\min((1+6)/3, 7)", Min(Rational(1 + 6, 3), 7))
assert_equal("\\min(58*9)", Min(58 * 9))
assert_equal("\\min(1+6/3, -5)", Min(1 + Rational('6/3'), -5))
assert_equal("\\min(7*4/5, 092) * 2", Min(7 * 4 / 5, 92) * 2)
assert_equal("38+\\min(13, 15-2.3)", 38 + Min(13, 15 - Rational('2.3')))
assert_equal("\\sqrt{\\min(99.9999999999999, 100)}", sqrt(Min(Rational('99.9999999999999'), 100)))
assert_equal("\\min(274/(5+2), \\exp(12.4), 1.4E2)", Min(Rational(274, 5 + 2), exp(Rational('12.4')), Rational('1.4E2')))
def test_min_symbol():
assert_equal("\\min(x)", Min(x), symbolically=True)
assert_equal("\\min(x, y)", Min(x, y), symbolically=True)
assert_equal("\\min(y, x)", Min(y, x), symbolically=True)
assert_equal("\\min(x+y, y+x)", Min(x + y, y + x), symbolically=True)
assert_equal("\\min(9x/4, z)", Min(9 * x / 4, z), symbolically=True)
assert_equal("\\min(y\\pi, 9)", Min(y * pi, 9), symbolically=True)
assert_equal("\\min(2y-y, y + 1)", Min(2 * y - y, y + 1), symbolically=True)
assert_equal("\\min(z, y, x)", Min(z, y, x), symbolically=True)
def test_min_multiarg():
assert_equal("\\min(1,2)", Min(1, 2))
assert_equal("\\min(9,876,543)", Min(9, 876, 543))
assert_equal("\\min(x, y,z)", Min(x, y, z), symbolically=True)
assert_equal("\\min(5.8,7.4, 2.2,-10)", Min(Rational('5.8'), Rational('7.4'), Rational('2.2'), -10))
assert_equal("\\min(\\pi,12E2,84,\\sqrt{5},12/5)", Min(pi, Rational('12E2'), 84, sqrt(5), Rational('12/5')))
assert_equal("\\min(823,51)", Min(823, 51))
assert_equal("\\min(72*4,23, 9)", Min(72 * 4, 23, 9))
from .context import assert_equal
import pytest
from sympy import Symbol, Rational, Mod, sqrt, nsimplify, pi, GoldenRatio
from sympy.physics.units import hbar
x = Symbol('x', real=True)
y = Symbol('y', real=True)
def test_mod_usual():
assert_equal("128\\mod 3", Mod(128, 3))
assert_equal("7\\mod 128", Mod(7, 128))
assert_equal("5\\mod 10", Mod(5, 10))
assert_equal("5\\mod 5", Mod(5, 5))
assert_equal("3\\mod 2", Mod(3, 2))
assert_equal("0 \\mod 6", Mod(0, 6))
assert_equal("6109\\mod 28", Mod(6109, 28))
assert_equal("4000000000\\mod 28791", Mod(4000000000, 28791))
assert_equal("128*10^300\\mod 876123", Mod(Rational('128E300'), 876123))
assert_equal("876,123\\mod 128E300)", Mod(876123, Rational('128E300')))
def test_mod_negative():
assert_equal("-1\\mod 2", Mod(-1, 2))
assert_equal("-3\\mod 3", Mod(-3, 3))
assert_equal("-12\\mod -12", Mod(-12, -12))
assert_equal("-128\\mod 4", Mod(-128, 4))
assert_equal("9\\mod -213", Mod(9, -213))
assert_equal("123123\\mod -541", Mod(123123, -541))
assert_equal("-123123\\mod 541", Mod(-123123, 541))
assert_equal("-97E34\\mod 7", Mod(Rational('-97E34'), 7))
def test_mod_fraction():
assert_equal("1/2\\mod 3", Mod(Rational(1, 2), 3))
assert_equal("6/2\\mod 3", Mod(Rational(6, 2), 3))
assert_equal("-14/2\\mod 5", Mod(Rational(-14, 2), 5))
assert_equal("123\\mod (42/6)", Mod(123, Rational(42, 6)))
assert_equal("431\\mod (2/123)", Mod(431, Rational(2, 123)))
assert_equal("5/5\\mod (5/5)", Mod(Rational(5, 5), Rational(5, 5)))
assert_equal("849/-21\\mod (092/2)", Mod(Rational(849, -21), Rational(92, 2)))
assert_equal("13*10^9\\mod (21/-2)", Mod(13E9, Rational(21, -2)))
def test_mod_float():
assert_equal("0.41\\mod 2", Mod(Rational('0.41'), 2))
assert_equal("143E-13\\mod 21", Mod(Rational('143E-13'), 21))
assert_equal("-9.80665\\mod 9.80665", Mod(-9.80665, 9.80665))
assert_equal("0.0000923423\\mod -8341.234802909", nsimplify(Mod(0.0000923423, -8341.234802909)))
assert_equal("\\sqrt{5}\\mod \\sqrt{2}", Mod(sqrt(5), sqrt(2)))
assert_equal("987\\mod \\pi", Mod(987, pi))
assert_equal("\\pi\\mod ((1+\\sqrt{5})/2)", Mod(pi, nsimplify(GoldenRatio)), symbolically=True)
assert_equal("1234\\mod 1E-29", Mod(1234, Rational('1E-29'), evaluate=False))
def test_mod_expr():
assert_equal("1+1\\mod 2", 1 + Mod(1, 2))
assert_equal("876123\\mod 128\\times 10^300", Mod(876123, 128) * 1E300)
assert_equal("141\\mod 9/3", Rational(Mod(141, 9) / 3))
assert_equal("872 / (12\\mod 9 * 4) * 2", Rational(2 * 872, (Mod(12, 9) * 4)))
assert_equal("1E-32 * (1E29\\mod 74)", Rational('1E-32') * Mod(Rational('1E29'), 74))
assert_equal("299,792,458\\mod 9.81", Mod(299792458, Rational('9.81')))
def test_mod_symbol():
assert_equal("x\\mod y", Mod(x, y))
assert_equal("2x\\mod y", Mod(2 * x, y))
assert_equal("y + 3\\mod 2 / 4", y + Rational(Mod(3, 2), 4), symbolically=True)
assert_equal("0.5x * 2 + \\sqrt{x}\\mod 8y", 0.5 * x * 2 + Mod(sqrt(x), 8 * y), symbolically=True)
assert_equal("6.673E-11 * ((8.85418782E-12\\mod 9x) + 4) / 2y", Rational('6.673E-11') * (Mod(Rational('8.85418782E-12'), 9 * x) + 4) / (2 * y), symbolically=True)
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment