Skip to content

Commit

Permalink
Make temperature a nullable value so that it can be set to 0
Browse files Browse the repository at this point in the history
  • Loading branch information
gburt committed Jan 17, 2025
1 parent 2a0ff5a commit 8ea9b6f
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 8 deletions.
4 changes: 2 additions & 2 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ type AudioRequest struct {
Reader io.Reader

Prompt string
Temperature float32
Language string // Only for transcription.
Temperature float32 // defaults to 0, so fine to not be a pointer
Language string // Only for transcription.
Format AudioResponseFormat
TimestampGranularities []TranscriptionTimestampGranularity // Only for transcription.
}
Expand Down
2 changes: 1 addition & 1 deletion chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ type ChatCompletionRequest struct {
// MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion,
// including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
Expand Down
22 changes: 19 additions & 3 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(2),
Temperature: openai.NewFloat(2),
},
expectedError: openai.ErrO1BetaLimitationsOther,
},
Expand All @@ -170,7 +170,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(1),
Temperature: openai.NewFloat(1),
TopP: float32(0.1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
Expand All @@ -188,7 +188,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(1),
Temperature: openai.NewFloat(1),
TopP: float32(1),
N: 2,
},
Expand Down Expand Up @@ -259,6 +259,22 @@ func TestChatRequestOmitEmpty(t *testing.T) {
}
}

func TestChatRequestOmitEmptyWithZeroTemp(t *testing.T) {
data, err := json.Marshal(openai.ChatCompletionRequest{
// We set model b/c it's required, so omitempty doesn't make sense
Model: "gpt-4",
Temperature: openai.NewFloat(0),
})
checks.NoError(t, err)

// messages is also required so isn't omitted
// but the zero-value for temp is not excluded, b/c that's a valid value to set the temp to!
const expected = `{"model":"gpt-4","messages":null,"temperature":0}`
if string(data) != expected {
t.Errorf("expected JSON with all empty fields to be %v but was %v", expected, string(data))
}
}

func TestChatCompletionsWithStream(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
Expand Down
10 changes: 10 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,13 @@ type PromptTokensDetails struct {
AudioTokens int `json:"audio_tokens"`
CachedTokens int `json:"cached_tokens"`
}

// NewFloat returns a pointer to a float, useful for setting the temperature on some APIs

Check failure on line 26 in common.go

View workflow job for this annotation

GitHub Actions / Sanity check

Comment should end in a period (godot)
func NewFloat(v float32) *float32 {
return &v
}

// NewInt returns a pointer to an int, useful for setting the seed and other nullable parameters

Check failure on line 31 in common.go

View workflow job for this annotation

GitHub Actions / Sanity check

Comment should end in a period (godot)
func NewInt(v int) *int {
return &v
}
4 changes: 2 additions & 2 deletions completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func validateRequestForO1Models(request ChatCompletionRequest) error {
}

// Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0.
if request.Temperature > 0 && request.Temperature != 1 {
if request.Temperature != nil && *request.Temperature != 1 {
return ErrO1BetaLimitationsOther
}
if request.TopP > 0 && request.TopP != 1 {
Expand Down Expand Up @@ -263,7 +263,7 @@ type CompletionRequest struct {
Stop []string `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`
Suffix string `json:"suffix,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
User string `json:"user,omitempty"`
}
Expand Down

0 comments on commit 8ea9b6f

Please sign in to comment.