improve die parser: allow chained expressions
This commit is contained in:
parent
19f0909907
commit
739316f903
6 changed files with 87 additions and 47 deletions
|
@ -18,7 +18,7 @@ class YamlConfigSettingsSource(PydanticBaseSettingsSource):
|
||||||
at the project's root.
|
at the project's root.
|
||||||
|
|
||||||
Here we happen to choose to use the `env_file_encoding` from Config
|
Here we happen to choose to use the `env_file_encoding` from Config
|
||||||
when reading `config.json`
|
when reading `config.yaml`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
|
@ -59,11 +59,19 @@ class YamlConfigSettingsSource(PydanticBaseSettingsSource):
|
||||||
|
|
||||||
|
|
||||||
class DiscordSettings(BaseModel):
|
class DiscordSettings(BaseModel):
|
||||||
|
"""
|
||||||
|
Holds all the settings needed to configure the bot for Discord usage.
|
||||||
|
"""
|
||||||
|
|
||||||
token: str = Field()
|
token: str = Field()
|
||||||
command_prefix: str = Field(default=".")
|
command_prefix: str = Field(default=".")
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
"""
|
||||||
|
Settings class for the bot
|
||||||
|
"""
|
||||||
|
|
||||||
discord: DiscordSettings
|
discord: DiscordSettings
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -16,11 +16,13 @@ class DiceCog(commands.Cog):
|
||||||
- 2d20 will roll a two d20 dies and multiply the result by two.
|
- 2d20 will roll a two d20 dies and multiply the result by two.
|
||||||
- 2d20+5 will roll a two d20 dies and multiply the result by two and ads 5.
|
- 2d20+5 will roll a two d20 dies and multiply the result by two and ads 5.
|
||||||
"""
|
"""
|
||||||
|
if dice_expression == "":
|
||||||
|
return
|
||||||
if dice_expression == "0/0": # easter eggs
|
if dice_expression == "0/0": # easter eggs
|
||||||
return await ctx.send("What do you expect me to do, destroy the universe?")
|
return await ctx.send("What do you expect me to do, destroy the universe?")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
roll_result = DiceRoller.roll(dice_expression)
|
roll_result = DiceRoller.roll_simple(dice_expression)
|
||||||
await ctx.send(f"You rolled: {roll_result}")
|
await ctx.send(f"You rolled: {roll_result}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
await ctx.send(f"Roll failed: {e}")
|
await ctx.send(f"Roll failed: {e}")
|
||||||
|
|
|
@ -1,8 +1,31 @@
|
||||||
|
import dataclasses
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from src.dice.parser import DieParser
|
from src.dice.parser import DieParser
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class DieRollResult:
|
||||||
|
"""
|
||||||
|
DieRoll is the result of a die roll.
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: int
|
||||||
|
modifier: int
|
||||||
|
rolls: typing.List[int]
|
||||||
|
type: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class DieExpressionResult:
|
||||||
|
"""
|
||||||
|
DiceResult is the result of a dice roll expression.
|
||||||
|
"""
|
||||||
|
|
||||||
|
total: int
|
||||||
|
dies: typing.List[DieRollResult]
|
||||||
|
|
||||||
|
|
||||||
class DiceRoller:
|
class DiceRoller:
|
||||||
"""
|
"""
|
||||||
DiceRoller is a simple class that allows you to roll dices.
|
DiceRoller is a simple class that allows you to roll dices.
|
||||||
|
@ -10,45 +33,28 @@ class DiceRoller:
|
||||||
A die can be rolled using the following expression:
|
A die can be rolled using the following expression:
|
||||||
- 1d20 will roll a 20-faceted die and output the result a random number between 1 and 20.
|
- 1d20 will roll a 20-faceted die and output the result a random number between 1 and 20.
|
||||||
- 1d100 will roll a 100 faceted die.
|
- 1d100 will roll a 100 faceted die.
|
||||||
- 2d20 will roll a two d20 dies and multiply the result by two.
|
- 2d20 will roll two d20 dies and multiply the result by two.
|
||||||
- 2d20+5 will roll a two d20 dies and multiply the result by two and ads 5.
|
- 2d20+5 will roll two d20 dies add them together then add 5 to the result.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_parser = DieParser.create()
|
_parser = DieParser.create()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def roll(expression: str, *, advantage: typing.Optional[bool] = None) -> int:
|
def roll_simple(expression: str) -> int:
|
||||||
"""
|
"""
|
||||||
Roll die and return the result.
|
Roll die and return the result.
|
||||||
:param expression: The die expression.
|
:param expression: The die expression.
|
||||||
:param advantage: Optionally, rolls a die with advantage or disadvantage.
|
|
||||||
:return: The die result.
|
:return: The die result.
|
||||||
"""
|
"""
|
||||||
if advantage is None:
|
result = DiceRoller._parser.parse(expression)
|
||||||
return DiceRoller._parser.parse(expression)
|
return result.get("total")
|
||||||
elif advantage is True:
|
|
||||||
return DiceRoller.roll_with_advantage(expression)
|
|
||||||
elif advantage is False:
|
|
||||||
return DiceRoller.roll_with_disadvantage(expression)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def roll_with_advantage(expression: str) -> int:
|
def roll(expression: str) -> DieExpressionResult:
|
||||||
"""
|
"""
|
||||||
Roll two dies and return the highest result.
|
Roll die and return the DiceResult.
|
||||||
:param expression: The die expression.
|
:param expression: The die expression.
|
||||||
:return: The die result.
|
:return: The die result.
|
||||||
"""
|
"""
|
||||||
one = DiceRoller._parser.parse(expression)
|
result = DiceRoller._parser.parse(expression)
|
||||||
two = DiceRoller._parser.parse(expression)
|
return DieExpressionResult(**result)
|
||||||
return max(one, two)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def roll_with_disadvantage(expression: str) -> int:
|
|
||||||
"""
|
|
||||||
Roll two dies and return the lowest result.
|
|
||||||
:param expression: The die expression.
|
|
||||||
:return: The die result.
|
|
||||||
"""
|
|
||||||
one = DiceRoller._parser.parse(expression)
|
|
||||||
two = DiceRoller._parser.parse(expression)
|
|
||||||
return min(one, two)
|
|
||||||
|
|
|
@ -8,12 +8,12 @@ DIE_GRAMMAR = """
|
||||||
@@grammar::Die
|
@@grammar::Die
|
||||||
@@whitespace :: None
|
@@whitespace :: None
|
||||||
|
|
||||||
start = die:die $;
|
start = die:die ~ {op:operator die:die} $;
|
||||||
|
|
||||||
die = [number_of_dies:number] die_type:die_type die_number:number [modifier:die_modifier];
|
die = [number_of_dies:number] die_type:die_type die_number:number [modifier:die_modifier];
|
||||||
die_modifier = op:operator modifier:number;
|
die_modifier = op:operator modifier:number;
|
||||||
|
|
||||||
operator = '+' | '-';
|
operator = '+' | '-' | 'adv' | 'dis';
|
||||||
|
|
||||||
die_type = 'd' | 'zd';
|
die_type = 'd' | 'zd';
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ class DieParser:
|
||||||
def create() -> "DieParser":
|
def create() -> "DieParser":
|
||||||
return DieParser()
|
return DieParser()
|
||||||
|
|
||||||
def parse(self, expression: str) -> int:
|
def parse(self, expression: str) -> dict:
|
||||||
"""
|
"""
|
||||||
Parses the die expression and returns the result.
|
Parses the die expression and returns the result.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
import copy
|
||||||
import random
|
import random
|
||||||
|
from collections import deque
|
||||||
from tatsu.ast import AST
|
from tatsu.ast import AST
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,7 +10,32 @@ class DieSemantics:
|
||||||
return int(ast)
|
return int(ast)
|
||||||
|
|
||||||
def start(self, ast):
|
def start(self, ast):
|
||||||
return ast.get("die").get("result")
|
die = ast.get("die")
|
||||||
|
if isinstance(die, dict):
|
||||||
|
return {"total": die.get("result"), "dies": [die]}
|
||||||
|
elif isinstance(die, list):
|
||||||
|
return_value = {"total": 0, "dies": copy.deepcopy(die)}
|
||||||
|
operators = deque(ast.get("op", []))
|
||||||
|
|
||||||
|
die_results = deque(map(lambda x: x.get("result"), die))
|
||||||
|
# Note: we may need to use a dequeue, the ops are quite inefficient.
|
||||||
|
while len(die_results) != 1:
|
||||||
|
left = die_results.popleft()
|
||||||
|
right = die_results.popleft()
|
||||||
|
operator = operators.popleft()
|
||||||
|
total = 0
|
||||||
|
if operator == "+":
|
||||||
|
total = left + right
|
||||||
|
if operator == "-":
|
||||||
|
total = left - right
|
||||||
|
if operator == "adv":
|
||||||
|
total = max(left, right)
|
||||||
|
if operator == "dis":
|
||||||
|
total = min(left, right)
|
||||||
|
die_results.appendleft(total)
|
||||||
|
|
||||||
|
return_value["total"] = die_results.pop()
|
||||||
|
return return_value
|
||||||
|
|
||||||
def die(self, ast):
|
def die(self, ast):
|
||||||
if not isinstance(ast, AST):
|
if not isinstance(ast, AST):
|
||||||
|
@ -34,8 +61,8 @@ class DieSemantics:
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"result": max(sum(rolls) + die_modifier, minimum_value_for_die),
|
"result": max(sum(rolls) + die_modifier, minimum_value_for_die),
|
||||||
"die_type": die_type,
|
"type": die_type,
|
||||||
"roll_history": rolls,
|
"rolls": rolls,
|
||||||
"modifier": die_modifier,
|
"modifier": die_modifier,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,10 +48,10 @@ def dice_roller():
|
||||||
("1d 4 +0", 1, 4),
|
("1d 4 +0", 1, 4),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_die_roller_die_roll(expression, range_min, range_max, dice_roller):
|
def test_die_roller_die_roll_simple(expression, range_min, range_max, dice_roller):
|
||||||
# let the dies roll...
|
# let the dies roll...
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
result = dice_roller.roll(expression)
|
result = dice_roller.roll_simple(expression)
|
||||||
assert range_min <= result <= range_max
|
assert range_min <= result <= range_max
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,10 +95,10 @@ def test_die_roller_die_roll(expression, range_min, range_max, dice_roller):
|
||||||
("1zd 4 +0", 0, 4),
|
("1zd 4 +0", 0, 4),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_die_roller_zero_die_roll(expression, range_min, range_max, dice_roller):
|
def test_die_roller_zero_die_roll_simple(expression, range_min, range_max, dice_roller):
|
||||||
# let the dies roll...
|
# let the dies roll...
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
result = dice_roller.roll(expression)
|
result = dice_roller.roll_simple(expression)
|
||||||
assert range_min <= result <= range_max
|
assert range_min <= result <= range_max
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,14 +119,11 @@ def test_die_roller_zero_die_roll(expression, range_min, range_max, dice_roller)
|
||||||
)
|
)
|
||||||
def test_die_roller_die_parsing_fail(expression, dice_roller):
|
def test_die_roller_die_parsing_fail(expression, dice_roller):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
dice_roller.roll(expression)
|
dice_roller.roll_simple(expression)
|
||||||
|
|
||||||
|
|
||||||
def test_die_roller_roll_with_advantage(dice_roller):
|
def test_die_roller_roll(dice_roller):
|
||||||
assert 1 <= dice_roller.roll_with_advantage("d20") <= 20
|
for i in range(100):
|
||||||
assert 1 <= dice_roller.roll("d20", advantage=True) <= 20
|
result = dice_roller.roll("d20 + d20 adv d20+5 dis d12+3")
|
||||||
|
assert 1 <= result.total <= 15
|
||||||
|
assert len(result.dies) == 4
|
||||||
def test_die_roller_roll_with_disadvantage(dice_roller):
|
|
||||||
assert 1 <= dice_roller.roll_with_advantage("d20") <= 20
|
|
||||||
assert 1 <= dice_roller.roll("d20", advantage=False) <= 20
|
|
||||||
|
|
Loading…
Reference in a new issue